#!/usr/bin/env python
import numpy as np
import os
from RuNNerCalculator import RuNNerCalculator
from scipy.constants import physical_constants as phc

from ase import Atoms
from ase import geometry
from ase.build import fcc111, molecule, add_adsorbate
from ase.constraints import FixAtoms, FixedPlane
from ase.dimer import DimerControl, MinModeAtoms, MinModeTranslate
from ase.io import read, write
from ase.optimize import BFGS
from ase.vibrations import Vibrations
from ase.visualize import view

atmas = phc['atomic mass constant'][0]
eV2J  = phc['electron volt-joule relationship'][0]
Bc    = phc['Boltzmann constant'][0]
hbar  = phc['Planck constant'][0] / (2*np.pi)
H2J   = phc['Hartree energy'][0]
A2B   = 1e-10 / phc['Bohr radius'][0]
H2eV  = phc['Hartree energy in eV'][0]
eV2kJmol = phc['electron volt-joule relationship'][0]*phc['Avogadro constant'][0]/1000.

#### Surface ####################################
os.popen('./slab-poscar.py 5 1 3 3 2 2').read()	# This is done in a very, very dirty way but it works (mostly)
#os.popen('./slab-poscar.py 5 3 3 1 1 1')
atoms = read('POSCAR')
atoms.set_pbc( [True, True, True])
Cell = atoms.get_cell()
for i in range( len(atoms) ):
	if atoms.positions[i][2] == 0.:
		atoms.positions[i][2] = Cell[2][2]

#################################################

######################### Set up calculator ########################
atoms.set_calculator( RuNNerCalculator() )

######################### Add molecule #############################
H2 = molecule('H2')
H2.rotate(90, 'x', center='COM')
add_adsorbate(atoms, H2, 7.5, (0.,0.))

######################### Defining constraints #####################
mask = [atom.symbol == 'Pt' for atom in atoms]				# Metal surface is not allowed to relax
fixlayers = FixAtoms(mask=mask)						# Previous line is applied
fixplane1 = FixedPlane(-2, (0, 0, 1))
fixplane2 = FixedPlane(-1, (0, 0, 1))
atoms.set_constraint( [fixlayers, fixplane1, fixplane2] )

# Obtain asymptotic energy
dyn = BFGS(atoms, trajectory='optm.traj')				# Setting up the optimizer routine
dyn.run(fmax=0.001, steps=200)
E_asymp = atoms.get_potential_energy()

dimermask = [atom.symbol == 'H' for atom in atoms]
d_control = DimerControl(initial_eigenmode_method = 'displacement', displacement_method = 'vector', logfile = None, mask = dimermask)
d_atoms = MinModeAtoms(atoms, d_control)

from ase.constraints import FixConstraint, Filter
class FixComIndex(FixConstraint):
    """Constraint class for fixing the center of mass.

    References

    https://pubs.acs.org/doi/abs/10.1021/jp9722824

    """

    def __init__(self, indices=None, mask=None):
        """Constrain chosen atoms.

        Parameters
        ----------
        indices : list of int
           Indices for those atoms that should be constrained.
        mask : list of bool
           One boolean per atom indicating if the atom should be
           constrained or not.

        Examples
        --------
        Fix all Copper atoms:

        >>> mask = [s == 'Cu' for s in atoms.get_chemical_symbols()]
        >>> c = FixAtoms(mask=mask)
        >>> atoms.set_constraint(c)

        Fix all atoms with z-coordinate less than 1.0 Angstrom:

        >>> c = FixAtoms(mask=atoms.positions[:, 2] < 1.0)
        >>> atoms.set_constraint(c)
        """

        if indices is None and mask is None:
            raise ValueError('Use "indices" or "mask".')
        if indices is not None and mask is not None:
            raise ValueError('Use only one of "indices" and "mask".')

        if mask is not None:
            indices = np.arange(len(mask))[np.asarray(mask, bool)]
        else:
            # Check for duplicates:
            srt = np.sort(indices)
            if (np.diff(srt) == 0).any():
                raise ValueError(
                    'FixComIndex: The indices array contained duplicates. '
                    'Perhaps you wanted to specify a mask instead, but '
                    'forgot the mask= keyword.')
        self.index = np.asarray(indices, int)

        if self.index.ndim != 1:
            raise ValueError('Wrong argument to FixComIndex class!')

        self.removed_dof = 2

    def adjust_positions(self, atoms, new):
        atoms_f = Filter(atoms, indices=self.index)
        new_f = new[self.index]
        masses = atoms_f.get_masses()
        old_cm = np.dot( masses, atoms_f.get_positions() ) / masses.sum()
        new_cm = np.dot( masses, new_f ) / masses.sum()
        d = old_cm - new_cm
        new[self.index,:2] += d[:2]

    def adjust_forces(self, atoms, forces):
        m = Filter(atoms, indices=self.index).get_masses()
        mm = np.tile(m, (3, 1)).T
        forces_f = forces[self.index]
        lb = np.sum(mm * forces_f, axis=0) / sum(m**2)
        forces[self.index,:2] -= mm[:,:2] * lb[:2]

    def todict(self):
        return {'name': 'FixCom',
                'kwargs': {}}

#Cell_minimum = Cell[:2,:2]
#Cell_minimum[1] /= 3.
#POS = np.dot( [X, Y], Cell_minimum )

#for POS in [[26.399692,	0.]]:
for POS in np.array(
#[[9.3349486316472152, 1.4160467837259223],
#[7.0117112910680595, 0.],
[[4.6664608213598395, 1.4160467837259223]]):
	atoms[-2].position = [POS[0], POS[1]+0.5, 0.]
	atoms[-1].position = [POS[0], POS[1]-0.5, 0.]

	# Set up the initial displacement vector for the dimer calculation
	del atoms.constraints
	fixcom = FixComIndex( [-2, -1] )
	atoms.set_constraint( [fixlayers, fixcom] )
	vib = Vibrations(atoms, indices=[-2,-1])
	vib.clean()
	vib.run()
	vib.summary()
	displacement_vector = vib.get_mode(0)
	d_atoms.displace(displacement_vector = displacement_vector)

	dim_rlx = MinModeTranslate(d_atoms, trajectory = 'dimer_method.traj', logfile = None)
	dim_rlx.run(fmax = 0.001)
	E_ts = atoms.get_potential_energy()

	# vib = Vibrations(atoms, indices=[-2,-1])
	# vib.clean()
	# vib.run()
	# vib.summary()
	# vib.clean()

	v_H2 = atoms[-2].position - atoms[-1].position
	r_H2 = np.linalg.norm(v_H2)
	theta = geometry.get_angles([v_H2], [[0.0, 0.0, 1.0]])
	phi = geometry.get_angles([v_H2], [[1.0, 0.0, 0.0]])
	POS = ( atoms[-2].position + atoms[-1].position ) / 2.

	print( 'X (Ang), Y (Ang), E_TS (kJ/mol), Z (Ang), r (Ang), Theta (deg.), Phi (deg.)' )
	print( '{:8.6f} {:8.6f} {:8.6f} {:8.6f} {:8.6f} {:9.6f} {:9.6f}\n'.format( POS[0], POS[1], (E_ts - E_asymp)*eV2kJmol, POS[2], r_H2, theta[0], phi[0] ) )













