#!/usr/bin/python

import numpy as np    
from scipy import linalg as LA 
from scipy import integrate 
from scipy.optimize import brenth
from random import uniform, random
np.set_printoptions(suppress=True, precision=5)

# ==============================================================================
# Universal constants  
# ==============================================================================

# for more information: http://docs.scipy.org/doc/scipy/reference/constants.html
from scipy.constants import physical_constants as phc
atmas = phc['atomic mass constant'][0]
eV2J  = phc['electron volt-joule relationship'][0]  
Bc    = phc['Boltzmann constant'][0]
hbar  = phc['Planck constant'][0] / (2*np.pi)
lsp   = phc['speed of light in vacuum'][0]
H2J   = phc['Hartree energy'][0]

# ==============================================================================
# INPUT ------------------------------------------------------------------------        
# ==============================================================================

# load input  
data=np.genfromtxt('Input/input.txt', comments="#", skip_header=0, invalid_raise=False, loose=True, dtype=float)

# required atomic properties
mass1     = data[0] * atmas                  # atomic mass units 
mass2     = data[1] * atmas                  # atomic mass units

Tn        = data[2]                          # nozzle temperature
M         = int(data[3])                     # number of grids for the finite difference
Nvs       = int(data[4])                     # maximum number of vibrational states 
Jmax      = int(data[5])                     # maximum angular momentum quantum number
et        = data[6]                          # Threshold energy for the harmonic fit in eV 
Ek        = data[7] * eV2J                   # initial kinetic energy conversion (eV to Joule)
NoCT      = int(data[8])                     # number of classical trajectories
dEoE      = data[9]                          # translational energy distribution broadening 
orpa      = int(data[10])                    # ortho-para state  1="Yes", 0="No"
wje       = data[11]                         # ortho-para ratio for even J
wjo       = data[12]                         # ortho-para ratio for odd J
vel_dist  = int(data[13])                    # velocity distribution in the molecular beam   0:"E_k=<E>", 1:"f(E_k)=max(f)"
STR       = int(data[14])                    # state-resolved calculations 1="Yes", 0="No"
v_STR     = int(data[15])                    # v (if state-resolved flag is on)
j_STR     = int(data[16])                    # j (if state-resolved flag is on)
vdis_YN   = int(data[17])                    # Are alpha and v0 specified? 1="Yes", 0="No"   
alphb     = data[18]                         # alpha (m/s)
vstr      = data[19]                         # stream velocity (m/s)
rotT      = data[20]                         # rotational temperature (as a fraction of Tn)

# load interaction potential  
data=np.loadtxt('Output/Intpot.txt',dtype=float)
rnew=data[:,0]
Vnew=data[:,1]

# load rovibrational energy eigenvalues
Erv=np.loadtxt('Output/Erov.txt',dtype=float) * eV2J

# load rovibrational energy eigenvalues
w0=np.loadtxt('Output/Syspar.txt',dtype=float)[0]

# ==============================================================================
# Define required functions     
# ==============================================================================

# clear screen
import os
def cls():
    os.system('cls' if os.name=='nt' else 'clear')

cls()

# Boltzmann distribution for rovibrational states
def Fb(v,J,Tn):
    global Erv
    bta1=Bc*Tn
    bta2=Bc*rotT*Tn
    if J%2 == 0:
        return wje*(2*J+1)*np.exp(-(Erv[v,J]-Erv[v,0])/bta2 )*np.exp(-Erv[v,0]/bta1 )
    else: 
        return wjo*(2*J+1)*np.exp(-(Erv[v,J]-Erv[v,0])/bta2 )*np.exp(-Erv[v,0]/bta1 )

# Match point finder
def Mpf(v,vec):
    ind=np.where(abs(vec-v)==min(abs(vec-v)))[0][0]
    return ind

# angle finder
def Angle(nx,ny):
    ang=np.arctan(float(ny)/float(nx))*180/np.pi     
    if np.sign(np.cos(ang*np.pi/180))!= np.sign(nx):
       ang=ang+180
    return ang

# function of velocity disribution in the beam
def veldist(x):
    global alphb, vstr, power
    c1, c2 = 0., 1.
    return c1 + c2 * x**power * np.exp( -(x-vstr)**2/alphb**2 )  

def rootvel(x):
    global velm_geuss, coeff
    return veldist(x) - coeff*veldist(velm_geuss)

def analytic_dist(x):
    global mass1, mass2, vstr, Ek, eV2J
    distavg = vstr*np.sqrt(np.pi)*x*(3./2.*x**2 + vstr**2)
    return 0.5*(mass1+mass2)/eV2J * ( (7./2.*x**2+vstr**2)*distavg - 3./2.*x**5*vstr*np.sqrt(np.pi) ) / distavg  -  Ek/eV2J

# ==============================================================================
# Boltzmann distribution - random sampling of reactive states
# ==============================================================================
if np.size(Erv)==1:
   Erv=np.array([[Erv]])
elif np.size(Erv)==np.shape(Erv)[0]:
   Erv=np.array([Erv])
   Erv=Erv.T

# ortho-para states for N2 which is similar to D2; see p. 5 in J. Chem. Phys. 140, 084702 (2014). 
if orpa==0:
   wje = 1.
   wj0 = 1. 

if STR==0:

    Fbsum=0
    for v in range(0,Nvs):
       for J in range(0,Jmax+1):
           Fbsum=Fbsum+Fb(v,J,Tn)

    PB=np.zeros((np.shape(Erv)))
    jmat=np.zeros((np.size(Erv)))
    vmat=np.zeros((np.size(Erv)))
    randsam=np.zeros((np.size(Erv)+1))      # random sampling between (0 1)
    randsam[0]=0
    ind=1
    for v in range(0,Nvs):
       for J in range(0,Jmax+1):
           jmat[ind-1]=J
           vmat[ind-1]=v 
           PB[v,J]=Fb(v,J,Tn)/Fbsum
           randsam[ind]=randsam[ind-1]+PB[v,J]
           ind=ind+1 

#---------------------------------------------------------------------------------

if STR==1:

    PB=np.zeros((np.shape(Erv)))
    jmat=np.zeros((np.size(Erv)))
    vmat=np.zeros((np.size(Erv)))
    probr=np.zeros((np.size(Erv)))
    ind=0
    for v in range(0,Nvs):
       for J in range(0,Jmax+1):
           jmat[ind]=J
           vmat[ind]=v
           if v==v_STR and J==j_STR:
               PB[v,J]=1.
               probr[ind]=NoCT
           ind=ind+1 

else:

    probr=np.zeros((np.size(randsam)))
    for i in range(0,NoCT):
        s=uniform(0,1)
        ind=max(np.where(randsam<=s)[0])
        probr[ind]=probr[ind]+1     


print "vib qn., rot qn., expected Boltzmann weight, randomly-generated BW: \n"   
for i in range(0,len(jmat)):
     print ("\t {0:5d} {1:5d} \t {2:12.10f} {3:12.10f}". format(int(vmat[i]), int(jmat[i]), PB[int(vmat[i]),int(jmat[i])], probr[i]/NoCT))  

# ==============================================================================
# translational energy - velocity
# ==============================================================================

if vdis_YN == 0:

     velm_geuss = np.sqrt( 2.*Ek/(mass1+mass2) )  
    
     if vel_dist == 1:
         power = 3
         vstr = 0.7*velm_geuss                                 # initial guess for the stream velocity
         dvstr = 1e-5*velm_geuss                               # step size for determining the stream velocity 
         alphb = np.sqrt( (velm_geuss**2 - vstr*velm_geuss)/(power/2.) )  
       
         coeff = 0.5                                       # FWHM 
         vela = brenth( rootvel, 0., velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
         velb = brenth( rootvel, velm_geuss, 3.*velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
         sign = np.sign ( (velb**2 - vela**2)/velm_geuss**2 - dEoE  )
       
         swchv = 0
         while vstr < 1.3*velm_geuss and swchv == 0:
            vstr = vstr + dvstr
            alphb = np.sqrt( (velm_geuss**2 - vstr*velm_geuss)/(power/2.) )
            vela = brenth( rootvel, 0., velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
            velb = brenth( rootvel, velm_geuss, 3.*velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
            if np.sign( (velb**2 - vela**2)/velm_geuss**2 - dEoE ) != sign:
               print "[alpha, vmax, v0, vb, va] = ", [alphb/100, velm_geuss/100, vstr/100, velb/100, vela/100], "(A/ps)"
               swchv = 1
     else:
         power = 3
         vstr = 0.7*velm_geuss                                 # initial guess for the stream velocity
         dvstr = 1e-5*velm_geuss                               # step size for determining the stream velocity 
         alph_guess = np.sqrt( (velm_geuss**2 - vstr*velm_geuss)/(power/2.) )  
         alphb = brenth( analytic_dist, 0.5*alph_guess, 2.*alph_guess, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )   
    
         coeff = 0.5                                       # FWHM 
         vela = brenth( rootvel, 0., velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
         velb = brenth( rootvel, velm_geuss, 3.*velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
         sign = np.sign ( (velb**2 - vela**2)/velm_geuss**2 - dEoE  )
       
         swchv = 0
         while vstr < 1.3*velm_geuss and swchv == 0:
            vstr = vstr + dvstr
            alph_guess = np.sqrt( (velm_geuss**2 - vstr*velm_geuss)/(power/2.) )
            alphb = brenth( analytic_dist, 0.5*alph_guess, 2.*alph_guess, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )   
            vela = brenth( rootvel, 0., velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
            velb = brenth( rootvel, velm_geuss, 3.*velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
            if np.sign( (velb**2 - vela**2)/velm_geuss**2 - dEoE ) != sign:
               print "[alpha, vmax, v0, vb, va] = ", [alphb/100, velm_geuss/100, vstr/100, velb/100, vela/100], "(A/ps)"
               swchv = 1

else: 
      power = 3
      velm_geuss = vstr/2. + 1./2.*np.sqrt(vstr**2 + 2*power*alphb**2 )  
      dvstr = 1e-5*velm_geuss                           # step size for determining vb and va 
      coeff = 0.5                                       # FWHM 
      vela = brenth( rootvel, 0., velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
      velb = brenth( rootvel, velm_geuss, 3.*velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
   
coeff = 0.005
v_min = brenth( rootvel, 0., velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )
v_max = brenth( rootvel, velm_geuss, 3.*velm_geuss, args=(), xtol=dvstr*1e-9, rtol=dvstr*1e-11, maxiter=100, full_output=False, disp=True )    
velrange = np.mgrid[v_min:v_max:int(min(NoCT/10,1000))*1j]
velfun = veldist(velrange)/sum( veldist(velrange) )
Randsamp = [velfun[0]]
for i in range(1,len(velfun)):
   Randsamp.append( Randsamp[i-1] + velfun[i] )

Randsamp = np.array( Randsamp )
Vtr = []
for i in range(0,NoCT):
   Vtr.append ( velrange[ min(np.where( 0 < Randsamp-uniform(0,1) )[0]) ]    )   
   #Vtr.append ( np.sqrt( 2.*Ek/(mass1+mass2) )  )   # for only exact velocity (in A/ps)

Vtr = 0.01*np.array( Vtr )                           # in A/ps units
print 1.*len(Vtr[velm_geuss/100.<Vtr])/len(Vtr), (v_max-velm_geuss)/100., (velm_geuss-v_min)/100.
print "Average energy is : ",  0.5*(mass1+mass2)*sum(velrange**2*velfun)/sum(velfun)/eV2J, "(eV)" 

if vdis_YN == 1: 
     print "Energy broadening (Delta E/<E>) is :", (velb**2-vela**2)/(sum(velrange**2*velfun)/sum(velfun))    # (vb^2-va^2)/<v>^2    


# ==============================================================================
# classical motion due to vibrational degree of freedom
# ==============================================================================

Force=np.zeros((len(Vnew)))
Force[1:len(Vnew)]=-eV2J/1.0e-10*(Vnew[1:len(Vnew)]-Vnew[0:len(Vnew)-1])/(rnew[1:len(Vnew)]-rnew[0:len(Vnew)-1])
Force[0]=Force[1]

Redm=mass1*mass2/(mass1+mass2)     # reduced mass
dt=2*np.pi/(w0*1000)
Rvst=[]
for v in range(0,Nvs):
   ind   = Mpf( Erv[v,0]/eV2J,Vnew )
   rr    = [rnew[ind]*1e-10]
   rr.append( Force[ind]*dt**2/mass1+rr[0] )  # this equation is correct for only homogeneous diatomic molecules    
   i    = 1
   swch = 0
   while swch <= 2:
        ind  = Mpf( rr[i]*1.0e10,rnew )
        rr.append(  Force[ind]*dt**2/Redm+2*rr[i]-rr[i-1]  )
        i = i+1
        swch = swch+abs(np.sign(rr[i]-rr[i-1])-np.sign(rr[i-1]-rr[i-2]))    
   Rvst.append( rr[0:len(rr)-1] )

# ==============================================================================
# initial velocity - rotational degree of freedom
# ==============================================================================

# coordinates and velocity
Vx=[]    # Vx=[0]*NoCT
Vy=[]
Vz=[]
re=[]
theta=[]
phi=[]
# information of the j and v quantum numbers
JQN=[]
VQN=[]
MQN=[]
ERV=[]

count=0       # counter       
for i in range(0,len(vmat)):
    v=int(vmat[i])
    j=int(jmat[i])
    for k in range(0,int(probr[i])):    
        m= -j + int(uniform(0,1)*(2*j+1))
        p=uniform(0,2.*np.pi)                         # angle phi for precessional motion of J vector
        pp=uniform(0,2.*np.pi)                        # angle phi-prime for the orentation of diatomic molecule perpendicular to J            
        t=np.arccos( m/(1e-100+np.sqrt(j*(j+1.))) )   # angle theta for J vector
        # rotation matrix for trasnforming (x',y',z') in polar coordinates to (x,y,z)         
        RotMat=np.array([[np.cos(t)*np.cos(p),np.cos(t)*np.sin(p),-np.sin(t)],\
                         [-np.sin(p),np.cos(p),0],\
                         [np.sin(t)*np.cos(p),np.sin(t)*np.sin(p),np.cos(t)]])
        MO=np.array([[np.cos(pp)],[-np.sin(pp)],[0]]) # diatomic molecule orientation in (x',y',z') coordinates 
        MOxyz=np.dot(LA.inv(RotMat),MO)               # diatomic molecule orientation in (x,y,z) coordinates                      
        if j==0:
             #swtch=0
             #while swtch==0:
             #    nx=uniform(-1.0,1.0) 
             #    ny=uniform(-1.0,1.0) 
             #    nz=uniform(-1.0,1.0)
             #    if np.sqrt(nx**2+ny**2+nz**2)<=1 :
             #       swtch=1
             #NormalizationF=np.sqrt(nx**2+ny**2+nz**2)
             #nx=nx/NormalizationF 
             #ny=ny/NormalizationF
             #nz=nz/NormalizationF
             THETA=np.arccos( uniform(-1.0,1.0) )     # kind of weighted random number sampling
             PHI=uniform(0,2.*np.pi)
             nx=np.sin(THETA)*np.cos(PHI)
             ny=np.sin(THETA)*np.sin(PHI)
             nz=np.cos(THETA)
        else:
             nx=MOxyz[0][0]                     
             ny=MOxyz[1][0]
             nz=MOxyz[2][0]
        Jx=np.sqrt(j*(j+1.))*np.sin(t)*np.cos(p)      # in hbar units 
        Jy=np.sqrt(j*(j+1.))*np.sin(t)*np.sin(p)
        Jz=m                
        # add vibrational velocity
        rr=Rvst[v]
        ind = int(uniform(1,len(rr)-2) + 0.5)
        Vvib=(rr[ind+1]-rr[ind-1])/(2*dt)
        Vvec=np.dot([[0,-Jz,Jy],[Jz,0,-Jx],[-Jy,Jx,0]],[[nx],[ny],[nz]]) * hbar/ (rr[ind]*mass1)
        re.append( rr[ind]/1e-10 )
        theta.append( np.arccos(nz)*180/np.pi )      # arccos ranges between [0 pi]  
        phi.append( Angle(nx,ny) )                   # Angle(nx,ny) is between [-pi/2 3pi/2]
        Vx.append( 0.01*(Vvec[0][0] + Vvib*nx/2.) )  # in A/ps units
        Vy.append( 0.01*(Vvec[1][0] + Vvib*ny/2.) )  # in A/ps units
        Vz.append( 0.01*(Vvec[2][0] + Vvib*nz/2.) )  # in A/ps units
        JQN.append( j )
        VQN.append( v )
        MQN.append( m )
        ERV.append( Erv[v,j]/eV2J )  

# ==============================================================================
# Output files
# ==============================================================================

g=open('Output/ICON.txt', 'w')

for ii in range(0,len(Vx)):
    g.write( "%5.6E \t %5.6E \t %5.6E \t %5.6E \t %+5.6E \t %+5.6E \t %+5.6E \t %+5.6E \t %+5.6E \t %+5.6E \t %+5.6E\n" \
    %tuple([random(),random(),re[ii],theta[ii],phi[ii],Vx[ii],Vy[ii],Vz[ii]-Vtr[ii],-Vx[ii],-Vy[ii],-Vz[ii]-Vtr[ii]]) )

g.close()

g=open('Output/jv_QN_Inf.txt', 'w')

for ii in range(0,len(Vx)):
    g.write( "%5.0f \t %5.0f \t %+5.0f \t %+5.6E\n" \
    %tuple(np.array([VQN[ii],JQN[ii],MQN[ii],ERV[ii]]) )  )

g.close()

g=open('Output/veldist.txt', 'w')

for ii in range(0,len(Vx)):
    g.write( "%5.8f \n" %tuple(np.array([ Vtr[ii] ]) )  )

g.close()



