#!/usr/bin/env python

################################################################################
#									       #
# 	Module for reading a trajectory ( r(t), v(t)) from VASP output files   #
#									       #
################################################################################
#									       #
# 	F. Nattino							       #
################################################################################
#									       #
#	Last Update: 4-01-2013                                                 #
################################################################################

import sys, os, gzip
from math import *

class Trajectory:

#********************************************************************************************************************
	
	def __init__(self, VASPVersion=5, Verbosity=1):
		
		if ( Verbosity ):
			print "\n Reading trajectory from VASP output files (VASP version: ", VASPVersion," ). \n" 
	
		self.VASPVersion = VASPVersion  # VASP version
		self.Verbosity = Verbosity # Verbosity = 1 => print output, = 0 => no output from this module
		
		self.maxnt =  0 # Number of time-steps in XDATCAR
		self.dt    =  0. # Timestep in femtoseconds

                self.Atoms = [ ] # List of length = number of atoms with atom types

		self.SelectiveDynamics = False 	# Logical for selective dynamics
		self.POSCARCart = True		# Logical for POSCAR in cartesian coordinates

		self.rCar  = []  # Positions in Cartesian coordinates
		self.rDir  = []  # Positions in Direct coordinates
		self.vCar  = []  # Velocities in Cartesian coordinates
		self.vDir  = []  # Velocities in Direct coordinates

		self.lines = ''  # Lines read from files

		self.basis = []  # Basis read from POSCAR
		self.dasis = []  # Inverse of Basis

		self.FirstXDATCAR = True # See ReadXDATCAR 
#********************************************************************************************************************

        # Reading file (.tgz or regular)                
        def ReadFile( self, FileName ):
                if (os.path.splitext( FileName )[1] == ".tgz"):
                        self.lines = gzip.open( FileName ).readlines()
                else:
                        self.lines = open( FileName ).readlines()

                if ( self.Verbosity ):
			print " File: ", FileName, " open and read."

#********************************************************************************************************************

	# Read POSCAR (positions and velocities time 0 )
	def ReadPOSCAR( self, Filename='POSCAR'):

		# Read the file (default name POSCAR)
		self.ReadFile( Filename )
			
		# First line is a comment ( self.lines[0] )

		# Scaling factor        
		ScalingFactor = float( self.lines[1].split()[0] )

		# Basis
		self.basis.append([ float( self.lines[2].split()[0]), float( self.lines[2].split()[1]), float( self.lines[2].split()[2])  ])
		self.basis.append([ float( self.lines[3].split()[0]), float( self.lines[3].split()[1]), float( self.lines[3].split()[2])  ])
		self.basis.append([ float( self.lines[4].split()[0]), float( self.lines[4].split()[1]), float( self.lines[4].split()[2])  ])

		if ( self.Verbosity ):
			print "\n The basis is:"
			print " % 10.7f % 10.7f % 10.7f " %tuple( self.basis[0] ) 
                	print " % 10.7f % 10.7f % 10.7f " %tuple( self.basis[1] )
                	print " % 10.7f % 10.7f % 10.7f " %tuple( self.basis[2] )

		# Invert Basis
		self.dasis = self.Invert3x3Matrix( self.basis )

                if ( self.Verbosity ):
			print "\n The Inverse of the basis is:"
			print " % 10.7f % 10.7f % 10.7f " %tuple( self.dasis[0] )
                	print " % 10.7f % 10.7f % 10.7f " %tuple( self.dasis[1] )
                	print " % 10.7f % 10.7f % 10.7f " %tuple( self.dasis[2] )

		NAtomTypes = len( self.lines[5].split() )

		# Depending on VASP version now we have atom types or not
		if (self.VASPVersion == 5 ):
			NAtomsPerType = self.lines[6].split()
			for i in range( NAtomTypes ):
				for j in range( int ( NAtomsPerType[ i ]) ):
					self.Atoms.append( self.lines[5].split()[ i ] )
			Counter = 7
		elif (self.VASPVersion == 4 ):
			NAtomsPerType = self.lines[5].split()
			for i in range( NAtomTypes ):
                                for j in range( int( NAtomsPerType[ i ]) ):
                                        self.Atoms.append( 'X' )
			Counter = 6
		
		if ( self.Verbosity ):
			print "\n Reading ", len( self.Atoms ), " atoms."

		# Is it a selective dynamics?
		if  self.lines[Counter].split()[0][0] in [ 'S', 's' ] :
			self.SelectiveDynamics = True
			Counter = Counter + 1
			if ( self.Verbosity ):
				print " Selective dynamics"
		else: 
			self.SelectiveDynamics = False

		# Is Cartesian or direct POSCAR? 
		if   self.lines[Counter].split()[0][0] in [ 'D', 'd' ]:
			self.POSCARCart = False
			if ( self.Verbosity ):
				print " POSCAR in direct coordinates."
		elif self.lines[Counter].split()[0][0] in [ 'C', 'c', 'K', 'k' ]:
			self.POSCARCart = True
			if ( self.Verbosity ):
				print " POSCAR in cartesian coordinates."
		else:
			if ( self.Verbosity ):
				print " problems in reading POSCAR's format. stop"
			quit()

		# Read atoms position
                Counter = Counter + 1
		if (self.POSCARCart):

			# If Cartesian coordinates read them in and transform them in direct coordinates
			Positions = []
			for i in range( len( self.Atoms ) ):
				Positions.append([ float( self.lines[Counter].split()[0]), float( self.lines[Counter].split()[1]), float( self.lines[Counter].split()[2])  ])
				Counter = Counter + 1
			self.rCar.append( Positions )
			
			PositionsDir = []
			for i in range( len( self.Atoms ) ):
                        	rXDir = self.rCar[0][i][0] * self.dasis[0][0] + self.rCar[0][i][1] * self.dasis[1][0] + self.rCar[0][i][2] * self.dasis[2][0]
                        	rYDir = self.rCar[0][i][0] * self.dasis[0][1] + self.rCar[0][i][1] * self.dasis[1][1] + self.rCar[0][i][2] * self.dasis[2][1]
                        	rZDir = self.rCar[0][i][0] * self.dasis[0][2] + self.rCar[0][i][1] * self.dasis[1][2] + self.rCar[0][i][2] * self.dasis[2][2]
                        	PositionsDir.append( [ rXDir, rYDir, rZDir ]  )
                	self.rDir.append( PositionsDir )
						
		else:
			PositionsDir = []
			for i in range( len( self.Atoms ) ):
                        	PositionsDir.append([ float( self.lines[Counter].split()[0]), float( self.lines[Counter].split()[1]), float( self.lines[Counter].split()[2])  ])
                                Counter = Counter + 1
			self.rDir.append( PositionsDir )

		# Read atoms velocities
		Counter = Counter + 1
		Velocities = [ ]
		for i in range( len( self.Atoms ) ):
                	Velocities.append([ float( self.lines[Counter].split()[0]), float( self.lines[Counter].split()[1]), float( self.lines[Counter].split()[2])  ])
                        Counter = Counter + 1
                self.vCar.append( Velocities )

		# Transform velocities in direct coordinates 
		VelocitiesDir = [ ]
		for i in range( len( self.Atoms ) ):
			vXDir = self.vCar[0][i][0] * self.dasis[0][0] + self.vCar[0][i][1] * self.dasis[1][0] + self.vCar[0][i][2] * self.dasis[2][0]
                        vYDir = self.vCar[0][i][0] * self.dasis[0][1] + self.vCar[0][i][1] * self.dasis[1][1] + self.vCar[0][i][2] * self.dasis[2][1]
                        vZDir = self.vCar[0][i][0] * self.dasis[0][2] + self.vCar[0][i][1] * self.dasis[1][2] + self.vCar[0][i][2] * self.dasis[2][2]
                        VelocitiesDir.append( [ vXDir, vYDir, vZDir ]  )
		self.vDir.append( VelocitiesDir )

#********************************************************************************************************************

	# Read XDATCAR
	def ReadXDATCAR( self, Filename='XDATCAR' ):	

                # Read the file (default name XDATCAR)
                self.ReadFile( Filename )
		
		# Number of lines in the header depends on the VASP version used
		if   (self.VASPVersion == 5 ):
			HeaderLines = 7
		elif (slef.VASPVersion == 4 ):
			HeaderLines = 5

		# Determine number of timesteps
		StepsInFile = ( len( self.lines ) - HeaderLines ) / ( len( self.Atoms ) + 1 )
		self.maxnt = self.maxnt + StepsInFile
	        if ( self.Verbosity ):
			print " Timesteps = ", StepsInFile , " found in ", Filename 

		# Read positions 
		Counter = HeaderLines + 1

		# If POSCAR is in Cartesian coordinates, we have already the first configuration in direct coordinate ( so we skip it in the XDATCAR )
		if (self.POSCARCart and self.FirstXDATCAR):
			Counter = Counter + len( self.Atoms ) + 1
			StepsInFile = StepsInFile - 1
			# Only for the first XDATCAR! If we have a series, the following will have Direct coordinates..
			self.FirstXDATCAR = False

		for i in range( StepsInFile ):
			PositionsDir = [] 
			for j in range( len( self.Atoms ) ):
                                        PositionsDir.append([ float( self.lines[Counter].split()[0]), float( self.lines[Counter].split()[1]), float( self.lines[Counter].split()[2])  ])
                                        Counter = Counter + 1
                        self.rDir.append( PositionsDir )
			Counter = Counter + 1

#********************************************************************************************************************

	# Extract real positions and velocities
	def ExtractRealPositionsAndVelocities( self , dt ):

		# Timestep 
		self.dt = dt

		# Obtain real positions
		for i in range( 1, self.maxnt ):
			for j in range( len( self.Atoms ) ):
				for k in range( 3 ):
					StepDir     = self.rDir[i][j][k] - self.rDir[i-1][j][k]
					StepDirAug  = StepDir + 100.5
					RealStepDir = StepDirAug % 1.0 - 0.5
					self.rDir[i][j][k] = self.rDir[i-1][j][k] + RealStepDir

		# Obtain velocities
		for i in range( 1, self.maxnt - 1 ): 
			# Exclude first and last timestep
			VelocitiesDir = [ ] 
                        for j in range( len( self.Atoms ) ):
				vXDir = ((( self.rDir[i+1][j][0] - self.rDir[i-1][j][0] ) + 100.5 ) % 1.0 - 0.5 ) * 0.5 / dt
                		vYDir = ((( self.rDir[i+1][j][1] - self.rDir[i-1][j][1] ) + 100.5 ) % 1.0 - 0.5 ) * 0.5 / dt
				vZDir = ((( self.rDir[i+1][j][2] - self.rDir[i-1][j][2] ) + 100.5 ) % 1.0 - 0.5 ) * 0.5 / dt
		                VelocitiesDir.append( [ vXDir, vYDir, vZDir ] )
			
			self.vDir.append( VelocitiesDir )

		# If POSCAR was in Cartesian coordinates, we already have the first step in cartesian coordinates
		if (self.POSCARCart):
			Init = 1
		else:
			Init = 0

		# Convert direct coordinates and velocities to cartesian ones
		for i in range( Init, self.maxnt ):
			Positions = [ ]
                        for j in range( len( self.Atoms ) ):
				rX = self.rDir[i][j][0] * self.basis[0][0] + self.rDir[i][j][1] * self.basis[1][0] + self.rDir[i][j][2] * self.basis[2][0]
				rY = self.rDir[i][j][0] * self.basis[0][1] + self.rDir[i][j][1] * self.basis[1][1] + self.rDir[i][j][2] * self.basis[2][1]
				rZ = self.rDir[i][j][0] * self.basis[0][2] + self.rDir[i][j][1] * self.basis[1][2] + self.rDir[i][j][2] * self.basis[2][2]
				Positions.append( [ rX, rY, rZ ])
			self.rCar.append( Positions )
		
		for i in range( 1, self.maxnt - 1 ):
			Velocities = [ ]
                        for j in range( len( self.Atoms ) ):
                                vX = self.vDir[i][j][0] * self.basis[0][0] + self.vDir[i][j][1] * self.basis[1][0] + self.vDir[i][j][2] * self.basis[2][0]
                                vY = self.vDir[i][j][0] * self.basis[0][1] + self.vDir[i][j][1] * self.basis[1][1] + self.vDir[i][j][2] * self.basis[2][1]
                                vZ = self.vDir[i][j][0] * self.basis[0][2] + self.vDir[i][j][1] * self.basis[1][2] + self.vDir[i][j][2] * self.basis[2][2]
                                Velocities.append( [ vX, vY, vZ ] )
                        self.vCar.append( Velocities )

		
#********************************************************************************************************************

	# Read last step velocities from CONTCAR		
	def ReadCONTCAR( self, Filename='CONTCAR' ):	 

		# Open CONTCAR and read it
		self.ReadFile( Filename )		

		# Determine on which line velocities are written (depending on VASP version and selective dyn)
		if ( self.VASPVersion == 5 ):
			if ( self.SelectiveDynamics ):
				Counter = 10 + len( self.Atoms ) 
			else:
				Counter = 9 + len( self.Atoms )
		elif ( self.VASPVersion == 4 ):
			if ( self.SelectiveDynamics ):
                                Counter = 9 + len( self.Atoms )
                        else:
                                Counter = 8 + len( self.Atoms )

		# Read velocities
                Velocities = [ ]
                for i in range( len( self.Atoms ) ):
                        Velocities.append([ float( self.lines[Counter].split()[0]), float( self.lines[Counter].split()[1]), float( self.lines[Counter].split()[2])  ])
                        Counter = Counter + 1
                self.vCar.append( Velocities )

		# Transform velocities in direct coordinates 
                VelocitiesDir = [ ]
                for i in range( len( self.Atoms ) ):
                        vXDir = self.vCar[self.maxnt-1][i][0] * self.dasis[0][0] + self.vCar[self.maxnt-1][i][1] * self.dasis[1][0] + self.vCar[self.maxnt-1][i][2] * self.dasis[2][0]
                        vYDir = self.vCar[self.maxnt-1][i][0] * self.dasis[0][1] + self.vCar[self.maxnt-1][i][1] * self.dasis[1][1] + self.vCar[self.maxnt-1][i][2] * self.dasis[2][1]
                        vZDir = self.vCar[self.maxnt-1][i][0] * self.dasis[0][2] + self.vCar[self.maxnt-1][i][1] * self.dasis[1][2] + self.vCar[self.maxnt-1][i][2] * self.dasis[2][2]
                        VelocitiesDir.append( [ vXDir, vYDir, vZDir ]  )
                self.vDir.append( VelocitiesDir )	

#********************************************************************************************************************

	# Calculate inverse of 3x3 matrix 
	def Invert3x3Matrix( self, Matrix ):

		# First compute the determinant
		Det =  Matrix[ 0][ 0]*( Matrix[ 1][ 1]*Matrix[ 2][ 2] - Matrix[ 1][ 2]*Matrix[ 2][ 1] )  \
         	     - Matrix[ 0][ 1]*( Matrix[ 1][ 0]*Matrix[ 2][ 2] - Matrix[ 1][ 2]*Matrix[ 2][ 0] )  \
        	     + Matrix[ 0][ 2]*( Matrix[ 1][ 0]*Matrix[ 2][ 1] - Matrix[ 1][ 1]*Matrix[ 2][ 0] )

		# The inverse matrix is also 3x3..
	        Inverse = [ [ 0., 0., 0. ], [ 0., 0., 0. ], [ 0., 0., 0. ] ]	
		
		# Compute element by element
	   	for i1 in range(1,4):
   			for i2 in range (1,4): 
      				Inverse[ i2-1][ i1-1] = 1.0 / Det * (  Matrix[    i1   %3 ][    i2   %3 ]*  \
                                                                       Matrix[ (i1 + 1)%3 ][ (i2 + 1)%3 ]-  \
                                                                       Matrix[    i1   %3 ][ (i2 + 1)%3 ]*  \
                                                                       Matrix[ (i1 + 1)%3 ][    i2   %3 ] )
		
		return Inverse

#********************************************************************************************************************

        # Calculate closest periodic image of AtomB to AtomA ( in DIRECT or CARTESIAN Coordinates ) 
        def FindClosestImage( self, Step, NumAtomA, NumAtomB, Coordinates='C' ):

		# Check the value of Coordinates
		ValuesAccepted = ( 'Dir','Direct','direct','dir','D','d','C','c','Cart','Cartesian','cartesian','cart')
		if Coordinates not in ValuesAccepted :
			print "Wrong values assigned to Coordinates in FindClosestImage. stop"
			quit()

		# Initialize smallest distance with diagonal of unit cell
		dSmallest = sqrt( ( self.basis[0][0] + self.basis[1][0] + self.basis[2][0] ) ** 2. + \
				  ( self.basis[0][1] + self.basis[1][1] + self.basis[2][1] ) ** 2. + \
				  ( self.basis[0][2] + self.basis[1][2] + self.basis[2][2] ) ** 2.   )

		# Find position of the two atoms
		AtomA = self.rDir[Step][NumAtomA]
		AtomB = self.rDir[Step][NumAtomB]

		ClosestB = [ 0., 0., 0. ]

		for ix in range(-1, 2):
			for iy in range(-1, 2):
				for iz in range(-1, 2):
				        dxABDir = AtomA[0] - (AtomB[0] + float(ix))
				        dyABDir = AtomA[1] - (AtomB[1] + float(iy))
				        dzABDir = AtomA[2] - (AtomB[2] + float(iz))

					dxABCar = dxABDir * self.basis[0][0] + dyABDir * self.basis[1][0] + dzABDir * self.basis[2][0]
                                        dyABCar = dxABDir * self.basis[0][1] + dyABDir * self.basis[1][1] + dzABDir * self.basis[2][1]
                                        dzABCar = dxABDir * self.basis[0][2] + dyABDir * self.basis[1][2] + dzABDir * self.basis[2][2]

					Dist = sqrt( dxABCar ** 2. + dyABCar ** 2. + dzABCar ** 2. )
           				if( Dist < dSmallest ):
						dSmallest = Dist
	                                        if (Coordinates[0] == 'd' or Coordinates[0] == 'D'):
							ClosestB[0] = AtomB[0] + float(ix)
							ClosestB[1] = AtomB[1] + float(iy)
							ClosestB[2] = AtomB[2] + float(iz)
						else:
							ClosestB[0] = (AtomB[0] + float(ix)) * self.basis[0][0] + (AtomB[1] + float(iy)) * self.basis[1][0] + (AtomB[2] + float(iz)) * self.basis[2][0]
							ClosestB[1] = (AtomB[0] + float(ix)) * self.basis[0][1] + (AtomB[1] + float(iy)) * self.basis[1][1] + (AtomB[2] + float(iz)) * self.basis[2][1]
							ClosestB[2] = (AtomB[0] + float(ix)) * self.basis[0][2] + (AtomB[1] + float(iy)) * self.basis[1][2] + (AtomB[2] + float(iz)) * self.basis[2][2]
		return ClosestB	

#********************************************************************************************************************

        # Calculate closest periodic image of AtomB to AtomA ( AtomA being external to the trajectory)  
        def FindClosestImageToExternal( self, Step, AtomA, NumAtomB, Coordinates='C' ):

                # Check the value of Coordinates
                ValuesAccepted = ( 'Dir','Direct','direct','dir','D','d','C','c','Cart','Cartesian','cartesian','cart')
                if Coordinates not in ValuesAccepted :
                        print "Wrong values assigned to Coordinates in FindClosestImage. stop"
                        quit()

                # Initialize smallest distance with diagonal of unit cell
                dSmallest = sqrt( ( self.basis[0][0] + self.basis[1][0] + self.basis[2][0] ) ** 2. + \
                                  ( self.basis[0][1] + self.basis[1][1] + self.basis[2][1] ) ** 2. + \
                                  ( self.basis[0][2] + self.basis[1][2] + self.basis[2][2] ) ** 2.   )

                # Find position of the B atom
                AtomB = self.rDir[Step][NumAtomB]

                ClosestB = [ 0., 0., 0. ]

                for ix in range(-1, 2):
                        for iy in range(-1, 2):
                                for iz in range(-1, 2):
                                        dxABDir = AtomA[0] - (AtomB[0] + float(ix))
                                        dyABDir = AtomA[1] - (AtomB[1] + float(iy))
                                        dzABDir = AtomA[2] - (AtomB[2] + float(iz))

                                        dxABCar = dxABDir * self.basis[0][0] + dyABDir * self.basis[1][0] + dzABDir * self.basis[2][0]
                                        dyABCar = dxABDir * self.basis[0][1] + dyABDir * self.basis[1][1] + dzABDir * self.basis[2][1]
                                        dzABCar = dxABDir * self.basis[0][2] + dyABDir * self.basis[1][2] + dzABDir * self.basis[2][2]

                                        Dist = sqrt( dxABCar ** 2. + dyABCar ** 2. + dzABCar ** 2. )
                                        if( Dist < dSmallest ):
                                                dSmallest = Dist
                                                if (Coordinates[0] == 'd' or Coordinates[0] == 'D'):
                                                        ClosestB[0] = AtomB[0] + float(ix)
                                                        ClosestB[1] = AtomB[1] + float(iy)
                                                        ClosestB[2] = AtomB[2] + float(iz)
                                                else:
                                                        ClosestB[0] = (AtomB[0] + float(ix)) * self.basis[0][0] + (AtomB[1] + float(iy)) * self.basis[1][0] + (AtomB[2] + float(iz)) * self.basis[2][0]
                                                        ClosestB[1] = (AtomB[0] + float(ix)) * self.basis[0][1] + (AtomB[1] + float(iy)) * self.basis[1][1] + (AtomB[2] + float(iz)) * self.basis[2][1]
                                                        ClosestB[2] = (AtomB[0] + float(ix)) * self.basis[0][2] + (AtomB[1] + float(iy)) * self.basis[1][2] + (AtomB[2] + float(iz)) * self.basis[2][2]
                return ClosestB

