import numpy as np
import openmdao.api as om

class SellarMDA(om.Group):
    """Group containing the Sellar MDA.
    """
    def initialize(self):
        # strategy = direct: 
        #     - give all partials
        #     - solve coupled cycle in one shot using DirectSolver (PetSC::lu)
        # strategy = bgs_auto:
        #     - give all partials
        #     - solve coupled cycle using LNBGS
        #     - solve each discipline using DirectSolver (let OpenMDAO "automatically invert" the matrix dResidual/dOutput)
        # strategy = bgs_custom:
        #     - give all partials
        #     - solve coupled cycle using LNBGS
        #     - solve each discipline using solve_linear (user provide custom method to "invert" the matrix dResidual/dOutput)
        # strategy = bgs_matfree:
        #     - precompute available partials (optional)
        #     - solve coupled cycle using LNBGS
        #     - and solve each discipline using solve_linear and apply_linear
        # NOTE if DirectSolver is used apply_linear methods must be commented!
        self.options.declare('strategy', default='bgs_matfree')
        # verbosity = 0: display nothing
        # verbosity = 1: display method call
        # verbosity = 2: display method outputs
        self.options.declare('verbosity', default=0)

    def setup(self):
        strategy = self.options['strategy']
        cycle = self.add_subsystem('cycle', om.Group(), promotes=['*'])
        cycle.add_subsystem('d1', SellarDis1(strategy=strategy, verbosity=self.options['verbosity']),
                            promotes_inputs=['x', 'z', 'y2'],
                            promotes_outputs=['y1'])
        cycle.add_subsystem('d2', SellarDis2(strategy=strategy, verbosity=self.options['verbosity']),
                            promotes_inputs=['z', 'y1'],
                            promotes_outputs=['y2'])

        cycle.set_input_defaults('x', 1.0)
        cycle.set_input_defaults('z', np.array([5.0, 2.0]))

        # Nonlinear solver
        cycle.nonlinear_solver = om.NonlinearBlockGS()
        # Linear solver (cycle)
        if strategy == 'direct':
            cycle.linear_solver = om.DirectSolver()
        else:
            cycle.linear_solver = om.LinearBlockGS()
        # Linear solver (disciplines)
        if strategy == 'bgs_auto':
            cycle.d1.linear_solver = om.DirectSolver()
            cycle.d2.linear_solver = om.DirectSolver()

        # Objective
        self.add_subsystem('obj', ObjFun(verbosity=self.options['verbosity']),
                           promotes_inputs=['x', 'z', 'y1', 'y2'], promotes_outputs=['obj'])
        # Constraint
        self.add_subsystem('cons1', Cons1(verbosity=self.options['verbosity']), promotes=['y1', 'con1'])
        self.add_subsystem('cons2', Cons2(verbosity=self.options['verbosity']), promotes=['y2', 'con2'])


class SellarDis1(om.ImplicitComponent):
    """Component containing Discipline 1
    """
    def initialize(self):
        self.options.declare('strategy')
        self.options.declare('verbosity')

    def setup(self):
        self.strategy = self.options['strategy']
        self.verb = self.options['verbosity']
        self.add_input('z', val=np.zeros(2)) # Global Design Variable
        self.add_input('x', val=0.) # Local Design Variable
        self.add_input('y2', val=1.0) # Coupling parameter
        self.add_output('y1', val=1.0) # Coupling output

    def setup_partials(self):
        if not self.strategy == 'bgs_matfree':
            self.declare_partials('y1', '*')

    def solve_nonlinear(self, inputs, outputs):
        if self.verb > 0: print("SellarDis1 - solve_nonlinear")
        outputs['y1'] = inputs['z'][0]**2 + inputs['z'][1] + inputs['x'] - 0.2*inputs['y2']

    def apply_nonlinear(self, inputs, outputs, residuals):
        if self.verb > 0: print("SellarDis1 - apply_nonlinear")
        residuals['y1'] = outputs['y1'] - (inputs['z'][0]**2 + inputs['z'][1] + inputs['x'] - 0.2*inputs['y2'])
        raise RuntimeError("Not allowed!")

    def linearize(self, inputs, outputs, partials):
        if self.verb > 0: print("SellarDis1 - linearize")
        if not self.strategy == 'bgs_matfree':
            partials['y1', 'z'] = np.array([-2*inputs['z'][0], -1.0])
            partials['y1', 'x'] = -1.0
            partials['y1', 'y1'] = 1.0
            partials['y1', 'y2'] = 0.2

    def apply_linear(self, inputs, outputs, d_inputs, d_outputs, d_residuals, mode):
        if self.verb > 0: print("SellarDis1 - apply_linear")
        if mode == 'fwd':
            if 'y1' in d_residuals:
                if 'y1' in d_outputs:
                    d_residuals['y1'] += 1.0 * d_outputs['y1']
                if 'z' in d_inputs:
                    d_residuals['y1'] += np.array([-2*inputs['z'][0], -1.0]).dot(d_inputs['z'])
                if 'x' in d_inputs:
                    d_residuals['y1'] += -1.0 * d_inputs['x']
                if 'y2' in d_inputs:
                    d_residuals['y1'] += 0.2 * d_inputs['y2']
        elif mode == 'rev':
            if 'y1' in d_residuals:
                if 'y1' in d_outputs:
                    d_outputs['y1'] += 1.0 * d_residuals['y1']
                    if self.verb > 1: print('d_res[y1] =', d_residuals['y1'], '-> d_out[y1] =', d_outputs['y1'])
                if 'z' in d_inputs:
                    d_inputs['z'] += np.array([-2*inputs['z'][0], -1.0]) * d_residuals['y1']
                    if self.verb > 1: print('d_res[y1] =', d_residuals['y1'], '-> d_in[z] =', d_inputs['z'])
                if 'x' in d_inputs:
                    d_inputs['x'] += -1.0 * d_residuals['y1']
                    if self.verb > 1: print('d_res[y1] =', d_residuals['y1'], '-> d_in[x] =', d_inputs['x'])
                if 'y2' in d_inputs:
                    d_inputs['y2'] += 0.2 * d_residuals['y1']
                    if self.verb > 1: print('d_res[y1] =', d_residuals['y1'], '-> d_in[y2] =', d_inputs['y2'])

    def solve_linear(self, d_outputs, d_residuals, mode):
        if self.verb > 0: print("SellarDis1 - solve_linear")
        if mode == 'fwd':
            d_outputs['y1'] = 1.0 * d_residuals['y1']
        elif mode == 'rev':
            d_residuals['y1'] = 1.0 * d_outputs['y1']
            if self.verb > 1: print('d_out[y1] =', d_outputs['y1'], '-> d_res[y1] =', d_residuals['y1'])

class SellarDis2(om.ImplicitComponent):
    """Component containing Discipline 2
    """
    def __sign_check(self, y1):
        # Note: this may cause some issues. However, y1 is constrained to be
        # above 3.16, so lets just let it converge, and the optimizer will
        # throw it out
        return y1 if y1.real >= 0.0 else -y1

    def initialize(self):
        self.options.declare('strategy')
        self.options.declare('verbosity')

    def setup(self):
        self.strategy = self.options['strategy']
        self.verb = self.options['verbosity']
        self.add_input('z', val=np.zeros(2)) # Global Design Variable
        self.add_input('y1', val=1.0) # Coupling parameter
        self.add_output('y2', val=1.0) # Coupling output

    def setup_partials(self):
        if not self.strategy == 'bgs_matfree':
            self.declare_partials('y2', '*')

    def solve_nonlinear(self, inputs, outputs):
        if self.verb > 0: print("SellarDis2 - solve_nonlinear")
        outputs['y2'] = self.__sign_check(inputs['y1'])**.5 + inputs['z'][0] + inputs['z'][1]

    def apply_nonlinear(self, inputs, outputs, residuals):
        if self.verb > 0: print("SellarDis2 - apply_nonlinear")
        residuals['y2'] = outputs['y2'] - (self.__sign_check(inputs['y1'])**.5 + inputs['z'][0] + inputs['z'][1])
        raise RuntimeError("Not allowed!")

    def linearize(self, inputs, outputs, partials):
        if self.verb > 0: print("SellarDis2 - linearize")
        if not self.strategy == 'bgs_matfree':
            partials['y2', 'z'] = np.array([-1.0, -1.0])
            partials['y2', 'y1'] = -0.5 * inputs['y1']**(-0.5)
            partials['y2', 'y2'] = 1.0

    def apply_linear(self, inputs, outputs, d_inputs, d_outputs, d_residuals, mode):
        if self.verb > 0: print("SellarDis2 - apply_linear")
        if mode == 'fwd':
            if 'y2' in d_residuals:
                if 'y2' in d_outputs:
                    d_residuals['y2'] += 1.0 * d_outputs['y2']
                if 'z' in d_inputs:
                    d_residuals['y2'] += np.array([-1.0, -1.0]).dot(d_inputs['z'])
                if 'y1' in d_inputs:
                    d_residuals['y2'] += -0.5 * inputs['y1']**(-0.5) * d_inputs['y1']
        elif mode == 'rev':
            if 'y2' in d_residuals:
                if 'y2' in d_outputs:
                    d_outputs['y2'] += 1.0 * d_residuals['y2']
                    if self.verb > 1: print('d_res[y2] =', d_residuals['y2'], '-> d_out[y2] =', d_outputs['y2'])
                if 'z' in d_inputs:
                    d_inputs['z'] += np.array([-1.0, -1.0]) * d_residuals['y2']
                    if self.verb > 1: print('d_res[y2] =', d_residuals['y2'], '-> d_in[z] =', d_inputs['z'])
                if 'y1' in d_inputs:
                    d_inputs['y1'] += -0.5 * inputs['y1']**(-0.5) * d_residuals['y2']
                    if self.verb > 1: print('d_res[y2] =', d_residuals['y2'], '-> d_in[y1] =', d_inputs['y1'])

    def solve_linear(self, d_outputs, d_residuals, mode):
        if self.verb > 0: print("SellarDis2 - solve_linear")
        if mode == 'fwd':
            d_outputs['y2'] = 1.0 * d_residuals['y2']
        elif mode == 'rev':
            d_residuals['y2'] = 1.0 * d_outputs['y2']
            if self.verb > 1: print( 'd_out[y2] =', d_outputs['y2'], '-> d_res[y2] =', d_residuals['y2'])

class ObjFun(om.ExplicitComponent):
    def initialize(self):
        self.options.declare('verbosity')

    def setup(self):
        self.verb = self.options['verbosity']
        self.add_input('x', val = 0.0)
        self.add_input('z', val = np.zeros(2))
        self.add_input('y1')
        self.add_input('y2')
        self.add_output('obj')

    def setup_partials(self):
        self.declare_partials('obj', '*')

    def compute(self, inputs, outputs):
        if self.verb > 0: print("ObjFun - compute")
        outputs['obj'] = inputs['x']**2 + inputs['z'][1] + inputs['y1'] + np.exp(-inputs['y2'])

    def compute_partials(self, inputs, partials):
        if self.verb > 0: print("ObjFun - compute_partials")
        partials['obj', 'x'] = 2 * inputs['x']
        partials['obj', 'z'] = np.array([0.0, 1.0])
        partials['obj', 'y1'] = 1.0
        partials['obj', 'y2'] = -np.exp(-inputs['y2'])
        if self.verb > 1:
            print('d_obj/d_x =', partials['obj', 'x'])
            print('d_obj/d_z =', partials['obj', 'z'])
            print('d_obj/d_y1 =', partials['obj', 'y1'])
            print('d_obj/d_y2 =', partials['obj', 'y2'])

class Cons1(om.ExplicitComponent):
    def initialize(self):
        self.options.declare('verbosity')

    def setup(self):
        self.verb = self.options['verbosity']
        self.add_input('y1')
        self.add_output('con1')

    def setup_partials(self):
        self.declare_partials('con1', 'y1')

    def compute(self, inputs, outputs):
        if self.verb > 0: print("Cons1 - compute")
        outputs['con1'] = 3.16 - inputs['y1']

    def compute_partials(self, inputs, partials):
        if self.verb > 0: print("Cons1 - compute_partials")
        partials['con1', 'y1'] = -1.0

class Cons2(om.ExplicitComponent):
    def initialize(self):
        self.options.declare('verbosity')

    def setup(self):
        self.verb = self.options['verbosity']
        self.add_input('y2')
        self.add_output('con2')

    def setup_partials(self):
        self.declare_partials('con2', 'y2')

    def compute(self, inputs, outputs):
        if self.verb > 0: print("Cons2 - compute")
        outputs['con2'] = inputs['y2'] - 24.0

    def compute_partials(self, inputs, partials):
        if self.verb > 0: print("Cons2 - compute_partials")
        partials['con2', 'y2'] = 1.0
