#!/usr/bin/env python3

import os
import random
import numpy as np

NSurf = [0, 9]
tSurf = [0, 999]
NMol = [0, 9]
tMol = [0, 999]
NBeads = 24

v0 = 2899.
alpha = 216.
Masses = [ 12.000, 1.008, 1.008, 1.008, 1.008 ]
Mass = sum( Masses )

initialZ = 6.5
Cell = np.array(
[[8.5609344674269, 0., 0.],
[-4.2804672337134, 7.4139867289255, 0.],
[0., 0., 22.3408205496026]]
)

PosSurfDat = []
VelSurfDat = []
for i in range( NSurf[0], NSurf[1]+1 ):
	tmp1 = []
	tmp2 = []
	for j in range(NBeads):
		tmp1.append( open('DynamicsINPUTS/Surface/{:02d}/simulation.pos_bead_{:02d}.xyz'.format( i, j ), 'r').readlines() )
		tmp2.append( open('DynamicsINPUTS/Surface/{:02d}/simulation.vel_bead_{:02d}.xyz'.format( i, j ), 'r').readlines() )
	PosSurfDat.append( tmp1 )
	VelSurfDat.append( tmp2 )
#PosSurfCentroidDat = []
#VelSurfCentroidDat = []
#for i in range(NSurf):
#	PosSurfCentroidDat.append( open('DynamicsINPUTS/Surface/{:02d}/simulation.pos_centroid.xyz'.format( i ), 'r').readlines() )
#	VelSurfCentroidDat.append( open('DynamicsINPUTS/Surface/{:02d}/simulation.vel_centroid.xyz'.format( i ), 'r').readlines() )

PosMolDat = []
VelMolDat = []
for i in range( NMol[0], NMol[1]+1 ):
	tmp1 = []
	tmp2 = []
	for j in range(NBeads):
		tmp1.append( open('DynamicsINPUTS/Molecule/{:02d}/simulation.pos_bead_{:02d}.xyz'.format( i, j ), 'r').readlines() )
		tmp2.append( open('DynamicsINPUTS/Molecule/{:02d}/simulation.vel_bead_{:02d}.xyz'.format( i, j ), 'r').readlines() )
	PosMolDat.append( tmp1 )
	VelMolDat.append( tmp2 )
PosMolCentroidDat = []
VelMolCentroidDat = []
for i in range( NMol[0], NMol[1]+1 ):
	PosMolCentroidDat.append( open('DynamicsINPUTS/Molecule/{:02d}/simulation.pos_centroid.xyz'.format( i ), 'r').readlines() )
	VelMolCentroidDat.append( open('DynamicsINPUTS/Molecule/{:02d}/simulation.vel_centroid.xyz'.format( i ), 'r').readlines() )

# Conversion factors from NIST Database
au2Ang=0.52917721092
amu2au=1822.88839
fs2au=41.341373337
au2eV=27.21138505
J2eV=6.24181 * 10**18
Hz2eV=4.13558 * 10**-15
amu2kg=1.66053892 * 10**-27
wn2J=1.98630 * 10**-23
eV2wn=8065.54445
ev2kjmol=96.4853365
H=6.6260695729 * 10**-34       # Planck's Constant ( m^2 * Kg / s )
KB=1.3806503 * 10**-23         # Boltzmann Constant

#****** V E L O C I T Y   D I S T R I B U T I O N *************************************************
#
#      F. Nattino, 05/2011
#
#     Modified to python by N. Gerrits, 01/2019
#
#     Select incidence energy according to flux weighted velocity distribution as 
#      in Michelsen and al. J Chem. Phys. 94, 7502 (1991) 
#    
#     G(E)dE = ( 1 / N ) * E * exp[ - 4Es * ( SQRT(E) - SQRT(Es) ) ** 2 / DelEs**2 ] dE
#
#     DelEs / Es = 2 DelVs / Vs = 2 / S
#
#     N = (DelEs / 2 S)**2 * { ( 1 + S**2 ) * exp( -S**2) + 
#         SQRT(PI) * S * ( 1.5 + S**2 )[ 1 + ERF(S) ] }
#
#     Units: Mass = kg, Vs and DelVs in m/sec, Etransmax in eV
#
#     USE: Constants (J2eV and ev2kjmol) and Random

def SetupVelocityDistribution(v0, alpha, mass):
	from math import erf

	class vel:
		"""Translational information"""
		pass
	# Translational parameters
	vel.dist	= True	# Use a Velocity distribution for normal translational energy
	vel.stream	= v0	# Stream Velocity (m/s)
	vel.width	= alpha	# Velocity width (m/s)
	vel.Emax	= 4.5	# Maximum translational energy considered (eV)
	vel.shift	= [True, 0.0410]	# Shift velocity distribution. If true, by how much (eV)
	vel.output	= True	# If True, the velocity distribution will be written to a file (VelocityDistribution.dat) and a summary to the screen,  and a normalization check is performed
	vel.mass	= mass	# Molecular mass (amu)

	vel.Estream = 0.5 * vel.mass*amu2kg * vel.stream**2. * J2eV
	vel.DelEstream = 2. * vel.Estream * vel.width / vel.stream
	vel.SqEstream = np.sqrt( vel.Estream )

	srma = vel.stream / vel.width
	# Normalization factor
	vel.FacNor = ( vel.DelEstream /( 2. * srma) )**2. * (( 1. + srma**2. )* np.exp( -1. * srma**2. ) + np.sqrt( np.pi ) * srma * ( 1.5 + srma**2. ) * ( 1. + erf( srma ) ) )
	EGmax = (( vel.DelEstream**2. ) + 2. * ( vel.Estream**2. ) + 2. * vel.Estream * np.sqrt( ( vel.Estream**2. ) + ( vel.DelEstream ** 2.)))/(4. * vel.Estream)
	Expon = 4.* vel.Estream * ( np.sqrt(EGmax) - vel.SqEstream )**2. / vel.DelEstream**2.
	Expfac = np.exp( -1.* Expon)

	vel.Gmax = EGmax / vel.FacNor * Expfac

	# Distribution is checked and an output is written
	if vel.output == True:
		Edist = open('VelocityDistribution.dat', 'w')
		Edist.write( '# Normalized Translational Energy Distribution. E in eV.' )
		StepIntegration = 100000
		DelE = vel.Emax / float( StepIntegration )
		Eaver = 0.
		Dnorm = 0.
		for i in range( StepIntegration ):
			E = float(i) * DelE
			Expon = 4. * vel.Estream * (np.sqrt( E ) - vel.SqEstream)**2. / vel.DelEstream **2.
			Expfac = np.exp( - Expon )
			GE = Expfac * E
			Eaver += GE * E * DelE
			Dnorm += GE * DelE
			Edist.write('{:f}	{:f}\n'.format( E, (GE / vel.FacNor) ) )

		Eaver = Eaver / vel.FacNor
		Dnorm = Dnorm / vel.FacNor

		if vel.shift[0]: print( ' Energy is shifted by {:f} eV'.format( vel.shift[1] ) )
		print( ' Energy corresponding to Stream Velocity:   E0 (eV) = {:f}'.format( vel.Estream ) )
		print( ' DeltaE (eV) = {:f}'.format( vel.DelEstream ) )
		print( ' Mass (amu) = {:f}'.format( vel.mass ) )
		print( ' Energy at Maximum Distribution (eV) = {:f}'.format( EGmax ) )
		print( ' The average collision energy is (eV) = {:f}'.format( Eaver ) )
		print( ' The average collision energy is (kJ/mol) = {:f}'.format(  Eaver * ev2kjmol ) )
		print( ' Verify normalization of the distribution: {:12.10f}'.format( Dnorm ) )

		if (abs( Dnorm - 1. ) > 1.E-4 ):
			print( 'WARNING: Increase Etransmax' )
			print( ' Normalization of the distribution: {:f}'.format( Dnorm ) )
			exit(1)

	return vel

def SelectVelocityFromDistribution( vel ):
	#v in m/s
	if vel.dist == False:
		Etrans = vel.Estream
	else:
		Gran = 99999.
		G = 0.
		while Gran > G:
			Etrans = np.random.random_sample() * vel.Emax
			Expon = 4.* vel.Estream * ( np.sqrt( Etrans ) - vel.SqEstream )**2. / vel.DelEstream**2.
			Expfac = np.exp( -1.* Expon)
			G = Etrans / vel.FacNor * Expfac
			Gran = np.random.random_sample() * vel.Gmax

	if vel.shift[0] == True: Etrans += vel.shift[1]

	VelTrans = np.sqrt( 2. * Etrans / J2eV / ( vel.mass * amu2kg ) )
	
	return VelTrans

def RotationalMatrix_ZYZ( phi, theta, psi ):
	"""
	Could have been done with scipy.spatial.transform.Rotation
	, however, this is only available in SciPy 1.2 and newer.
	To make it simpler, we do the rotation by hand as follows:

	Rotate a vector using three euler angles
	according to ZYZ convention:
	   counterclockwise rotation about z axis by phi (xyz -> x'y'z');
	   counterclockwise rotation about y' axis by theta (x'y'z' -> x''y''z'');
	   counterclockwise rotation about z'' axis by psi (x''y''z'' -> XYZ ).
	"""
	rot = np.zeros((3,3))

	rot[0][0] = np.cos(theta)*np.cos(phi)*np.cos(psi) - np.sin(phi)*np.sin(psi)
	rot[0][1] = np.cos(phi)*np.sin(psi) + np.cos(theta)*np.cos(psi)*np.sin(phi)
	rot[0][2] = -1.*np.cos(psi)*np.sin(theta)
	rot[1][0] = -1.*np.sin(phi)*np.cos(psi) - np.cos(theta)*np.sin(psi)*np.cos(phi)
	rot[1][1] = np.cos(psi)*np.cos(phi) - np.cos(theta)*np.sin(phi)*np.sin(psi)
	rot[1][2] = np.sin(psi)*np.sin(theta)
	rot[2][0] = np.sin(theta)*np.cos(phi)
	rot[2][1] = np.sin(theta)*np.sin(phi)
	rot[2][2] = np.cos(theta)

	return rot

def ComputeCOM( Pos, Masses ):
	"""
	Compute the center of mass.
	"""
	
	COM = np.zeros(3)
	for i in range(3):
		COM[i] = sum( Pos[:,i] * Masses )
	COM /= sum( Masses )
	
	return COM

def ComputeCOMVelocity( Vel, Masses ):
	"""
	Compute the velocity of the center of mass.
	The velocity returned is in the same units as supplied.
	"""
	
	vCOM = np.zeros(3)
	for i in range(3):
		vCOM[i] = sum( Vel[:,i] * Masses )
	vCOM /= sum( Masses )

	return vCOM

def ComputeInertiaTensor( positions, masses ):
	"""
	Compute Inertia tensor from Coordinates and Masses
	"""
	
	inertia = np.zeros((3,3))

	for i1 in range(3):
		for i2 in range(3):
			for i3 in range( len(masses) ):
				if (i1==i2):			# Diagonal elements
					n1 = (i1 + 1) % 3	# n1 and n2 are 1,2 for element (0,0)
					n2 = (i1 + 2) % 3	# 2,0 for element (1,1)
								# 0,1 for element (2,2)
					inertia[i1][i2] += masses[i3]*( positions[i3][n1]**2.0 + positions[i3][n2]**2.0 )
				else:				# Off-diagonal elements
					inertia[i1][i2] -= masses[i3]*positions[i3][i1]*positions[i3][i2]

	return inertia

def ComputeAngularMomentum( pos, vel, masses ):
	"""
	Compute angular momentum for a system of N atoms
	"""
 
	AngularMom = np.zeros(3)

	for i in range( len( masses ) ):
		AngularMom[0] += masses[i]*( pos[i][1]*vel[i][2] - pos[i][2]*vel[i][1] )
		AngularMom[1] += masses[i]*( pos[i][2]*vel[i][0] - pos[i][0]*vel[i][2] )
		AngularMom[2] += masses[i]*( pos[i][0]*vel[i][1] - pos[i][1]*vel[i][0] )

	AngularMom = AngularMom

	return AngularMom

def ComputeAngularVelocity( Inertia_, AngularMom ):
	AngularVel = np.zeros(3)
	
	for i in range(3):
		for j in range(3):
			AngularVel[i] += Inertia_[i,j] * AngularMom[j]
	
	return AngularVel

def ComputeRotationalVelocity( Pos, Vel, Masses ):
	AngularMom = ComputeAngularMomentum( Pos, Vel, Masses )
	Inertia = ComputeInertiaTensor( Pos, Masses )
	Inertia_ = np.linalg.inv( Inertia )
	
	AngularVel = ComputeAngularVelocity( Inertia_, AngularMom )

	RotationalVel = np.zeros((len(Masses),3))
	for i in range(len(Masses)):
		RotationalVel[i,0] = AngularVel[1]*Pos[i][2] - AngularVel[2]*Pos[i][1]
		RotationalVel[i,1] = AngularVel[2]*Pos[i][0] - AngularVel[0]*Pos[i][2]
		RotationalVel[i,2] = AngularVel[0]*Pos[i][1] - AngularVel[1]*Pos[i][0]
	
	return RotationalVel

vel = SetupVelocityDistribution(v0, alpha, Mass)	# Rather set up the velocity distribution only once

for i in range(xFIRSTJOBx, xLASTJOBx + 1):
	if not os.path.exists('{0:06d}'.format( i ) ):
		os.makedirs('{0:06d}'.format( i ) )

	f = open('{:06d}/simulation.pos_beads.xyz'.format( i ), 'w')
	g = open('{:06d}/simulation.vel_beads.xyz'.format( i ), 'w')

	# Random selection of the surface and molecular snapshots
	NSurface = random.randint( NSurf[0], NSurf[1] )
	tSurface = random.randint( tSurf[0], tSurf[1] )
	NMolecule = random.randint( NMol[0], NMol[1] )
	tMolecule = random.randint( tMol[0], tMol[1] )
	
	# A random rotation in the ZYZ convention according to J=0
	Phi = 2. * np.pi * np.random.random_sample()
	Psi = 2. * np.pi * np.random.random_sample()
	Theta = np.pi * np.random.random_sample()
	RotationalMatrix = RotationalMatrix_ZYZ( Phi, Theta, Psi )
	
	dXY = np.dot( [np.random.random_sample(), np.random.random_sample(), 0.], Cell )	# Make sure that the entire supercell is sampled
	
	# Ensure that the centroid COM contains only internal motion, i.e., the COM and rotational velocity are zero
	PosMolCentroid = []
	VelMolCentroid = []
	for j in range(5):
		PosMolCentroid.append( [ float(PosMolCentroidDat[NMolecule][tMolecule*52 + 47 + j].split()[1]), float(PosMolCentroidDat[NMolecule][tMolecule*52 + 47 + j].split()[2]), float(PosMolCentroidDat[NMolecule][tMolecule*52 + 47 + j].split()[3]) ] )
		VelMolCentroid.append( [ float(VelMolCentroidDat[NMolecule][tMolecule*52 + 47 + j].split()[1]), float(VelMolCentroidDat[NMolecule][tMolecule*52 + 47 + j].split()[2]), float(VelMolCentroidDat[NMolecule][tMolecule*52 + 47 + j].split()[3]) ] )
	PosMolCentroid = np.array(PosMolCentroid)
	VelMolCentroid = np.array(VelMolCentroid)
	VelMolCOM = ComputeCOMVelocity( VelMolCentroid, Masses )
	RotationalVelMol = ComputeRotationalVelocity( PosMolCentroid, VelMolCentroid, Masses )
	COM = ComputeCOM( PosMolCentroid, Masses )
	
	# Obtain the COM velocity from the velocity distribution
	vCOM = SelectVelocityFromDistribution( vel )
	
	for i2 in range(NBeads):		
		# Preparation of the molecule
		PosMol = []
		VelMol = []
		for j in range(5):
			PosMol.append( [ float(PosMolDat[NMolecule][i2][tMolecule*52 + 47 + j].split()[1]), float(PosMolDat[NMolecule][i2][tMolecule*52 + 47 + j].split()[2]), float(PosMolDat[NMolecule][i2][tMolecule*52 + 47 + j].split()[3]) ] )
			VelMol.append( [ float(VelMolDat[NMolecule][i2][tMolecule*52 + 47 + j].split()[1]), float(VelMolDat[NMolecule][i2][tMolecule*52 + 47 + j].split()[2]), float(VelMolDat[NMolecule][i2][tMolecule*52 + 47 + j].split()[3]) ] )
		PosMol = np.array( PosMol )
		VelMol = np.array( VelMol )
		
		# Ensure that the COM contains only internal motion, i.e., the velocity of the COM is zero
		VelMol -= VelMolCOM
		# Ensure that the molecule contains only vibrational motion, i.e., rotational motion is removed
		VelMol -= RotationalVelMol
		
		PosMol -= COM	# The centroid COM first need to be shifted to (0,0,0) before the rotation, and can be added back after the rotation
		PosMol = np.matmul( PosMol, RotationalMatrix )
		PosMol += COM + dXY
		VelMol = np.matmul( VelMol, RotationalMatrix )
		
		VelMol[:,2] -= vCOM

		# Start writing the i-PI input file
		f.write('50\n')
		f.write('# CELL(abcABC):    8.56093     8.56093    22.34082    90.00000    90.00000   120.00000  Step:           0  Bead:      {:2d} positions{{angstrom}}  cell{{angstrom}}\n'.format( i2 ))
		for j in range(45):
			f.write( "      Pt  {:.5e}  {:.5e}  {:.5e}\n".format( float(PosSurfDat[NSurface][i2][tSurface*52 + 2 + j].split()[1]), float(PosSurfDat[NSurface][i2][tSurface*52 + 2 + j].split()[2]), float(PosSurfDat[NSurface][i2][tSurface*52 + 2 + j].split()[3]) ) )
		for j in range(5):
			if j == 0:
				AtomType = ' C'
			else:
				AtomType = ' H'
			f.write( "      {:s}  {:.5e}  {:.5e}  {:.5e}\n".format( AtomType, *PosMol[j] ) )
		
		g.write('50\n')
		g.write('# CELL(abcABC):    8.56093     8.56093    22.34082    90.00000    90.00000   120.00000  Step:           0  Bead:      {:2d} velocities{{m/s}}  cell{{angstrom}}\n'.format(i2))
		for j in range(45):
			g.write( "      Pt  {:.5e}  {:.5e}  {:.5e}\n".format( float(VelSurfDat[NSurface][i2][tSurface*52 + 2 + j].split()[1]), float(VelSurfDat[NSurface][i2][tSurface*52 + 2 + j].split()[2]), float(VelSurfDat[NSurface][i2][tSurface*52 + 2 + j].split()[3]) ) )
		for j in range(5):
			if j == 0:
				AtomType = ' C'
			else:
				AtomType = ' H'
			g.write( "      {:s}  {:.5e}  {:.5e}  {:.5e}\n".format( AtomType, *VelMol[j] ) )
	
	f.close()
	g.close()
