import os
import numpy as np
from ase.units import kJ,mol
from ase.calculators.calculator import FileIOCalculator
from simtk.openmm import app
import simtk.openmm as mm
from simtk import unit



class mbpol(FileIOCalculator):
    implemented_properties = ['energy', 'forces']

    def __init__(self, X=1,Y=1,Z=1,restart=None, ignore_bad_restart_file=False,
                 label=os.curdir, atoms=None, cubes=None, radmul=None,
                 tier=None, mbpol_command=None,
                 outfilename=None, **kwargs):

        FileIOCalculator.__init__(self, restart, ignore_bad_restart_file,
                                  label, atoms,
                                  command=self.command,
                                  **kwargs)
        self.X=X
        self.Y=Y
        self.Z=Z

    def set_atoms(self, atoms):
        self.atoms = atoms

    def read(self, label):
        geometry = 'structure.xyz'

    def read_results(self):
        self.read_energy_force()

    def write_input(self, atoms, properties=None, system_changes=None,
                    ghosts=None, scaled=False):        
        ice= atoms.copy()
        ice=ice*(self.X,self.Y,self.Z)
	fileobj = open('ice.pdb', 'wb')
	format = ('HETATM%5d%3s   HOH%6d%12.4f%8.4f%8.4f  %3.2f  %3.2f \n')
	symbols = ice.get_chemical_symbols()
	nice = len(symbols)
	p = ice.get_positions()
	bb=0
	cc=0
	for a in range(nice):
	  if symbols[a]=="O":
	    bb=bb+1
	    fileobj.write(format % (a+1, "O",bb, p[a][0], p[a][1], p[a][2], 1, 0))
	  elif a%3==1:
	    fileobj.write(format % (a+1, "H1",bb,
				    p[a][0], p[a][1], p[a][2], 1, 0))
	  else:
	    fileobj.write(format % (a+1, "H2",bb,
				    p[a][0], p[a][1], p[a][2], 1, 0))
	fileobj.write('ENDMDL\n')
        fileobj.close() 


    def read_energy_force(self):
       import mbpol

       atoms=self.atoms
       ice= self.atoms.copy()
       ice=ice*(self.X,self.Y,self.Z)
       lattice=ice.get_cell()*0.1
       pdb = app.PDBFile("ice.pdb")
       forcefield = app.ForceField(mbpol.__file__.replace('mbpol.py', 'mbpol.xml'))
       aa= app.Modeller(pdb.topology,pdb.positions )
       aa.addExtraParticles(forcefield)
       pdb.topology=aa.getTopology()
       pdb.positions=aa.getPositions()
       nonbonded = mm.app.PME
       cutoff=min([lattice[0][0], lattice[1][1],lattice[2][2]])/2  
       boxSize = (lattice[0][0], lattice[1][1],lattice[2][2]) * unit.nanometer
       pdb.topology.setUnitCellDimensions(boxSize)
       system = forcefield.createSystem(pdb.topology, nonbondedMethod=nonbonded, nonbondedCutoff=cutoff*unit.nanometer, ewaldErrorTolerance=1e-08)
       integrator = mm.VerletIntegrator(1*unit.femtoseconds)
       platform = mm.Platform.getPlatformByName('Reference')
       simulation = app.Simulation(pdb.topology, system, integrator, platform)
       simulation.context.setPositions(pdb.positions)
       simulation.context.computeVirtualSites()
       simulation.step(0)
       kilocalorie_per_mole_per_angstrom = unit.kilocalorie_per_mole/unit.angstrom

       state = simulation.context.getState(getForces=True, getEnergy=True)
       fforce=state.getForces(asNumpy=True)
       E0 = state.getPotentialEnergy()*kJ/mol
       E0 = state.getPotentialEnergy()*kJ/mol/(self.X*self.Y*self.Z)
       force=np.array(fforce)
       forces=np.delete(force, list(range(3, force.shape[0],4)), axis=0)*kJ/mol*0.1
       
       forces1= np.zeros((len(atoms),3))
       for i in range(len(atoms)):
          forces1[i][0]=forces[i][0]
          forces1[i][1]=forces[i][1]
          forces1[i][2]=forces[i][2]
      

       self.results['energy'] = E0
       self.results['forces'] = forces1

    def get_stress(self,atoms):
       import mbpol
       atoms=self.atoms
       ice= self.atoms.copy()
       ice.constraints = []
       ice=ice*(self.X,self.Y,self.Z)
       voigt=True
       d=1e-06
       stress = np.zeros((3, 3), dtype=float)
       cell = ice.cell.copy()
       V = ice.get_volume()
       lattice=ice.get_cell()*0.1
       for i in range(3):
          xx = np.eye(3)
          xx[i, i] += d
          ice.set_cell(np.dot(cell, xx), scale_atoms=True)
          if True:
		fileobj = open('ice1.pdb', 'wb')
		format = ('HETATM%5d%3s   HOH%6d%12.4f%8.4f%8.4f  %3.2f  %3.2f \n')
		symbols = ice.get_chemical_symbols()
		nice = len(symbols)
		p = ice.get_positions()
		bb=0
		cc=0
		for a in range(nice):
		  if symbols[a]=="O":
		    bb=bb+1
		    fileobj.write(format % (a+1, "O",bb, p[a][0], p[a][1], p[a][2], 1, 0))
		  elif a%3==1:
		    fileobj.write(format % (a+1, "H1",bb,
					    p[a][0], p[a][1], p[a][2], 1, 0))
		  else:
		    fileobj.write(format % (a+1, "H2",bb,
					    p[a][0], p[a][1], p[a][2], 1, 0))
		fileobj.write('ENDMDL\n')
                fileobj.close() 


 
          pdb = app.PDBFile("ice1.pdb")
          forcefield = app.ForceField(mbpol.__file__.replace('mbpol.py', 'mbpol.xml'))
          aa= app.Modeller(pdb.topology,pdb.positions )
          aa.addExtraParticles(forcefield)
          pdb.topology=aa.getTopology()
          pdb.positions=aa.getPositions()
          nonbonded = mm.app.PME
          lattice=ice.get_cell()*0.1
          cutoff=min([lattice[0][0], lattice[1][1],lattice[2][2]])/2 
          boxSize = (lattice[0][0], lattice[1][1],lattice[2][2]) * unit.nanometer
          pdb.topology.setUnitCellDimensions(boxSize)
          system = forcefield.createSystem(pdb.topology, nonbondedMethod=nonbonded, nonbondedCutoff=cutoff*unit.nanometer, ewaldErrorTolerance=1e-08)
          integrator = mm.VerletIntegrator(1*unit.femtoseconds)
          platform = mm.Platform.getPlatformByName('Reference')
          simulation = app.Simulation(pdb.topology, system, integrator, platform)
          simulation.context.setPositions(pdb.positions)
          simulation.context.computeVirtualSites()
          simulation.step(0)
          kilocalorie_per_mole_per_angstrom = unit.kilocalorie_per_mole/unit.angstrom
          state = simulation.context.getState(getForces=True, getEnergy=True)
          E0 = state.getPotentialEnergy()*kJ/mol
          eplus = E0
          xx[i, i] -= 2 * d
          ice.set_cell(np.dot(cell, xx), scale_atoms=True)
          if True:
		fileobj = open('ice2.pdb', 'wb')
		format = ('HETATM%5d%3s   HOH%6d%12.4f%8.4f%8.4f  %3.2f  %3.2f \n')
		symbols = ice.get_chemical_symbols()
		nice = len(symbols)
		p = ice.get_positions()
		bb=0
		cc=0
		for a in range(nice):
		  if symbols[a]=="O":
		    bb=bb+1
		    fileobj.write(format % (a+1, "O",bb, p[a][0], p[a][1], p[a][2], 1, 0))
		  elif a%3==1:
		    fileobj.write(format % (a+1, "H1",bb,
					    p[a][0], p[a][1], p[a][2], 1, 0))
		  else:
		    fileobj.write(format % (a+1, "H2",bb,
					    p[a][0], p[a][1], p[a][2], 1, 0))
		fileobj.write('ENDMDL\n')
                fileobj.close() 


          pdb = app.PDBFile("ice2.pdb")
          forcefield = app.ForceField(mbpol.__file__.replace('mbpol.py', 'mbpol.xml'))
          aa= app.Modeller(pdb.topology,pdb.positions )
          aa.addExtraParticles(forcefield)
          pdb.topology=aa.getTopology()
          pdb.positions=aa.getPositions()
          nonbonded = mm.app.PME
          lattice=ice.get_cell()*0.1
          cutoff=min([lattice[0][0], lattice[1][1],lattice[2][2]])/2 
          boxSize = (lattice[0][0], lattice[1][1],lattice[2][2]) * unit.nanometer
          pdb.topology.setUnitCellDimensions(boxSize)
          pdb.positions=aa.getPositions()
          nonbonded = mm.app.PME
          lattice=ice.get_cell()*0.1
          cutoff=min([lattice[0][0], lattice[1][1],lattice[2][2]])/2 
          boxSize = (lattice[0][0], lattice[1][1],lattice[2][2]) * unit.nanometer
          pdb.topology.setUnitCellDimensions(boxSize)
          system = forcefield.createSystem(pdb.topology, nonbondedMethod=nonbonded, nonbondedCutoff=cutoff*unit.nanometer, ewaldErrorTolerance=1e-08)
          integrator = mm.VerletIntegrator(1*unit.femtoseconds)
          platform = mm.Platform.getPlatformByName('Reference')
          simulation = app.Simulation(pdb.topology, system, integrator, platform)
          simulation.context.setPositions(pdb.positions)
          simulation.context.computeVirtualSites()
          simulation.step(0)
          kilocalorie_per_mole_per_angstrom = unit.kilocalorie_per_mole/unit.angstrom

          state = simulation.context.getState(getForces=True, getEnergy=True)
          E0 = state.getPotentialEnergy()*kJ/mol
          eminus = E0

          stress[i, i] = (eplus - eminus) / (2 * d * V)
          xx[i, i] += d

          j = i - 2
	  ice.set_cell(cell, scale_atoms=True)

       if voigt:
		    return stress.flat[[0, 4, 8, 5, 2, 1]]
       else:
		    return stress
