#!/usr/bin/env python
import sys
from os import path, getcwd, popen
from os import chdir, system
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from random import choice, random
from scipy import interpolate

here=getcwd()
def collectangles(folder):
	angles = []
	angles_stat = []
	chdir(folder)

	data = open('DynamicsINPUTS/initial_pos_momenta.info', 'r').readlines()
	data2 = open('DynamicsINPUTS/init_vals.dat', 'r').readlines()
	Ntraj = int( popen('ls | egrep "^[0-9].....$" | wc -l').read() )
	here_tmp=getcwd()				#workdir
	for i in range(1,(Ntraj+1)):
		traj= '%.6d' % i		# define working directory
		newdir=path.join(here_tmp, traj)	# updating mypath

		print( newdir )
		if not path.exists(newdir): 
			print( "folder", newdir, "not found. exiting..." )
			quit()

		chdir(newdir)

		if not path.exists('outcome') or not path.exists('PostAnalysis.dat'):
			continue

		outcome = open('outcome', 'r').readlines()
		if not 'UNCLEAR' in outcome[0]:
			theta = np.degrees( float( data[i-1].split()[4] ) )
			ptheta = float( data[i-1].split()[10] ) * 0.529177249 / 1822.888486209 * 21.876913 # angstrom to bohr, electron mass to amu, velocity in Hartree atomic units to angstrom / fs
			weight = 1. / float( np.sin( np.radians( theta ) ) )

			angles_stat.append( np.array([theta, ptheta, weight]) )
			if 'REACTION ' in outcome[0]:
				angles.append( np.array([theta, ptheta, weight]) )

	angles_stat = np.array( angles_stat )
	angles_stat_binned, xedges, yedges = np.histogram2d( angles_stat[:,1], angles_stat[:,0], range=[[min(angles_stat[:,1]), max(angles_stat[:,1])], [20., 160.]], bins=36, density=True, weights=angles_stat[:,2] )
	# for i in range(edges[1].size - 1):
		# if np.sum(angles_stat_binned[:,i]) == 0: continue
		# angles_stat_binned[:,i] /= np.sum(angles_stat_binned[:,i])

	angles = np.array( angles )
	angles_binned, xedges, yedges = np.histogram2d( angles[:,1], angles[:,0], range=[[min(angles[:,1]), max(angles[:,1])], [20., 160.]], bins=[xedges, yedges], density=True, weights=angles[:,2] )
	#for i in range(yedges.size - 1):
		#if np.sum(angles_binned[i,:]) == 0: continue
		#angles_binned[i,:] /= np.sum(angles_binned[i,:])

	angles_binned -= angles_stat_binned
	angles_binned += 1.

	chdir(here)
	
	#return angles_binned, edges
	return angles_binned, [xedges, yedges]

s_limit = [-8.,0.]
def collectangles_heatmap(folder):
	Nangle = 100
	Ns = 40
	here=getcwd()

	angles = []
	chdir(folder)

	Ntraj = int( popen('ls | egrep "^[0-9].....$" | wc -l').read() )
	here_tmp=getcwd()				#workdir
	for i in range(1,(Ntraj+1)):
		traj= '%.6d' % i		# define working directory
		newdir=path.join(here_tmp, traj)	# updating mypath

		print( newdir )
		if not path.exists(newdir): 
			print( "folder", newdir, "not found. exiting..." )
			quit()

		chdir(newdir)

		if not path.exists('outcome') or not path.exists('PostAnalysis.dat'):
			continue

		outcome = open('outcome', 'r').readlines()
		if 'REACTION ' in outcome[0]:
			t0 = round( float( outcome[1].split()[1] ) / 0.4 )
			
			data = open('analysis_rcom.dat', 'r').readlines()
			
			r = []
			Z = []
			theta = []
			#for j in range(1, t0+1):
			for j in range(1, len(data)):
				r.append( float( data[j].split()[1] ) )
				Z.append( float( data[j].split()[2] ) )
				theta.append( float( data[j].split()[8] ) )
			s = reactioncoordinate_traj( r, Z, t0 )
			for j in range( len(s) ):
				angles.append( [s[j], theta[j]] )
	
	angles = np.array(angles)
	angles_binned, xedges, yedges = np.histogram2d( angles[:,1], angles[:,0], range=[[0., 180.], [s_limit[0], s_limit[1]]], bins=(Nangle, Ns), density=True )
	
	#for j in range( Ns ):
	#	angles_binned[:,j] /= sum( angles_binned[:,j] )

	chdir(here)
	
	return angles_binned, [xedges, yedges]

def plotangles(idx, title, folder):
	ax2 = fig.add_subplot(gs[idx[0],idx[1]])

	angles, edges = collectangles(folder)

	vlimit = max( abs( np.amin(angles) ), abs( np.amax(angles) ) )
	plt.pcolormesh(edges[1], edges[0], angles, vmin=1.-(vlimit-1.), vmax=vlimit, cmap='bwr')
	cbar = plt.colorbar(label='Relative intensity', ticks=[1.-(vlimit-1.),1.,vlimit])
	cbar.ax.set_yticklabels(['Low', 'Avg.', 'High'])

	y = ( plt.ylim()[1] - plt.ylim()[0] ) * 0.05 + plt.ylim()[0]
	plt.annotate(title, xy=(0.05, 0.87), xycoords='axes fraction', size=10, color='k', bbox=dict(boxstyle='round', fc='w'))

	plt.xlabel(r'$\theta_\textrm{i}$ angle (degrees)')
	plt.ylabel(r'$p_{\theta_\textrm{i}}$ (amu\r{A}$^2$/fs)')

	#plt.xticks([0., 45., 90., 135., 180.])
	plt.xticks([20., 55., 90., 125., 160.])

	return

def plot_traj(name, r_TS):
	data = open(name, 'r').readlines()
	
	r = []
	Z = []
	theta = []
	for i in range(1, len(data)):
		r.append( float( data[i].split()[1] ) )
		Z.append( float( data[i].split()[2] ) )
		theta.append( float( data[i].split()[8] ) )
		
		if r[-1] > r_TS:
			try:
				idx
			except NameError:
				idx = i-1
	
	xaxis = reactioncoordinate_traj( r, Z, idx )
	plt.plot( xaxis, theta )

def reactioncoordinate( r, Z, E ):
	if len(r) != len(Z) != len(E):
		print('lengths of r, Z and E are not equal')
		exit()
	
	s = [ 0. ]
	for i in range( 1, len( r ) ):
		dZ = Z[i] - Z[i-1]
		dr = r[i] - r[i-1]
		s.append( s[-1] + np.sqrt( dZ**2 + dr**2 ) )
	ds = s[np.argmax(E)]
	for i in range( len(s) ):
		s[i] -= ds
	
	return s

def reactioncoordinate_traj( r, Z, idx ):
	if len(r) != len(Z):
		print('lengths of r and Z are not equal')
		exit()
	
	s = [ 0. ]
	for i in range( 1, len( r ) ):
		dZ = Z[i] - Z[i-1]
		dr = r[i] - r[i-1]
		s.append( s[-1] + np.sqrt( dZ**2 + dr**2 ) )
	ds = s[idx]
	for i in range( len(s) ):
		s[i] -= ds
	
	return s

def ptheta( j, mj, theta, direction, atomicunits=True ):
	'''
	|L_theta| = hbar * [ j(j+1) - (mJ / sin(theta))^2 ]^0.5
	|L_phi| = hbar * mJ
	Note that the momentum has a direction, i.e., positive or negative
	'''
	if atomicunits:
		unitconv = 1.
	else:
		unitconv = 0.529177249 / 1822.888486209 * 21.876913	# angstrom to bohr, electron mass to amu, velocity in Hartree atomic units to angstrom / fs
	if mj == 0:
		pt = unitconv * direction * np.sqrt( j*(j+1) )
	else:
		pt = unitconv * direction * np.sqrt( j*(j+1) - ( mj / np.sin(theta) )**2 )
	
	return pt

def rotperiod( j, mu, r0 ):
	'''
	I=mu*R**2
	omega=2*pi*frequency
	L=I*omega
	T=1/frequency=2*pi*mu*R**2/L
	'''
	L = abs( ptheta(j, 0, 0, 1, False) )
	period = 2.*np.pi*mu*r0**2 / L

	return period

def theta_( j, mj, phase ):
	cos_beta = 0.
	if mj != 0:
		cos_beta = mj / np.sqrt( j*(j+1) )
	BETA = np.arccos( cos_beta )
	sin_beta = np.sin( BETA )

	gamma = 2.*np.pi*random()
	cos_gamma = np.cos( gamma )
	cos_theta = -sin_beta * cos_gamma
	THETA_i = np.arccos( cos_theta )
	if gamma < np.pi:
		PT_i = ptheta(j, mj, THETA_i, 1, False)
	else:
		PT_i = ptheta(j, mj, THETA_i, -1, False)

	gamma = ( gamma + phase ) % ( 2.*np.pi )
	cos_gamma = np.cos( gamma )
	cos_theta = -sin_beta * cos_gamma
	THETA_f = np.arccos( cos_theta )
	if gamma < np.pi:
		PT_f = ptheta(j, mj, THETA_f, 1, False)
	else:
		PT_f = ptheta(j, mj, THETA_f, -1, False)
	
	return THETA_i, THETA_f, PT_i, PT_f

#****** V E L O C I T Y   D I S T R I B U T I O N *************************************************
#
#      F. Nattino, 05/2011
#
#     Modified to python by N. Gerrits, 01/2019
#
#     Select incidence energy according to flux weighted velocity distribution as 
#      in Michelsen and al. J Chem. Phys. 94, 7502 (1991) 
#    
#     G(E)dE = ( 1 / N ) * E * exp[ - 4Es * ( SQRT(E) - SQRT(Es) ) ** 2 / DelEs**2 ] dE
#
#     DelEs / Es = 2 DelVs / Vs = 2 / S
#
#     N = (DelEs / 2 S)**2 * { ( 1 + S**2 ) * exp( -S**2) + 
#         SQRT(PI) * S * ( 1.5 + S**2 )[ 1 + ERF(S) ] }
#
#     Units: Mass = kg, Vs and DelVs in m/sec, Etransmax in eV
#
#     USE: Constants (J2eV and ev2kjmol) and Random

def SetupVelocityDistribution():
	from math import erf
	J2eV=6.24181 * 10**18
	ev2kjmol=96.4853365
	amu2kg=1.66053892 * 10**-27
	class vel:
		"""Translational information"""
		pass
	# Translational parameters
	vel.dist	= True	# Use a Velocity distribution for normal translational energy
	vel.stream	= 3616. # Stream Velocity (m/s), if vel.dist is false then this is used as the initial translational energy
	vel.width	= 371.	# Velocity width (m/s)
	vel.Emax	= 4.5	# Maximum translational energy considered (eV)
	vel.shift	= [False, 0.0]	# Shift velocity distribution. If true, by how much (eV)
	vel.Etrans0	= 0.	# Translational energy (eV) if no velocity distribution is used, i.e. a monocromatic beam
	vel.output	= False	# If True, the velocity distribution will be written to a file (VelocityDistribution.dat) and a summary to the screen,  and a normalization check is performed
	vel.mass	= 36.458# Mass of HCl (amu)

	vel.Estream = 0.5 * vel.mass*amu2kg * vel.stream**2. * J2eV
	vel.DelEstream = 2. * vel.Estream * vel.width / vel.stream
	vel.SqEstream = np.sqrt( vel.Estream )

	srma = vel.stream / vel.width
	# Normalization factor
	vel.FacNor = ( vel.DelEstream /( 2. * srma) )**2. * (( 1. + srma**2. )* np.exp( -1. * srma**2. ) + np.sqrt( np.pi ) * srma * ( 1.5 + srma**2. ) * ( 1. + erf( srma ) ) )
	EGmax = (( vel.DelEstream**2. ) + 2. * ( vel.Estream**2. ) + 2. * vel.Estream * np.sqrt( ( vel.Estream**2. ) + ( vel.DelEstream ** 2.)))/(4. * vel.Estream)
	Expon = 4.* vel.Estream * ( np.sqrt(EGmax) - vel.SqEstream )**2. / vel.DelEstream**2.
	Expfac = np.exp( -1.* Expon)

	vel.Gmax = EGmax / vel.FacNor * Expfac

	# Distribution is checked and an output is written
	if vel.output == True:
		Edist = open('VelocityDistribution.dat', 'w')
		Edist.write( '# Normalized Translational Energy Distribution. E in eV.' )
		StepIntegration = 100000
		DelE = vel.Emax / float( StepIntegration )
		Eaver = 0.
		Dnorm = 0.
		for i in range( StepIntegration ):
			E = float(i) * DelE
			Expon = 4. * vel.Estream * (np.sqrt( E ) - vel.SqEstream)**2. / vel.DelEstream **2.
			Expfac = np.exp( - Expon )
			GE = Expfac * E
			Eaver += GE * E * DelE
			Dnorm += GE * DelE
			Edist.write('{:f}	{:f}'.format( E, (GE / vel.FacNor) ) )

		Eaver = Eaver / vel.FacNor
		Dnorm = Dnorm / vel.FacNor

		print( ' Energy corresponding to Stream Velocity:   E0 (eV) = ', vel.Estream )
		print( ' DeltaE (eV) = ', vel.DelEstream )
		print( ' Mass (amu) = ', vel.mass )
		print( ' Energy at Maximum Distribution (eV) = ', EGmax )
		print( ' The average collision energy is (eV) = '    ,  Eaver )
		print( ' The average collision energy is (kJ/mol) = ' ,  Eaver * ev2kjmol )
		print( ' Verify normalization of the distribution: {:12.10f}\n'.format( Dnorm ) )

		if (abs( Dnorm - 1. ) > 1.E-4 ):
			print( 'WARNING: Increase Etransmax' )
			print( ' Normalization of the distribution: ',  Dnorm )
			exit(1)

	return vel

def SelectVelocityFromDistribution( vel ):
	J2eV=6.24181 * 10**18
	amu2kg=1.66053892 * 10**-27
	Gran = 99999.
	G = 0.
	while Gran > G:
		Etrans = np.random.random_sample() * vel.Emax
		Expon = 4.* vel.Estream * ( np.sqrt( Etrans ) - vel.SqEstream )**2. / vel.DelEstream**2.
		Expfac = np.exp( -1.* Expon)
		G = Etrans / vel.FacNor * Expfac
		Gran = np.random.random_sample() * vel.Gmax

	VelTrans = np.sqrt( 2. * Etrans / J2eV / ( vel.mass * amu2kg) ) * 1.E-5

	return VelTrans

def theta_final( j, mj, vel ):
	d = 5.5	# Distance between initial Z and Z_TS in Angstrom
	#v0 = 3500. * 10**-5	# 10**10 / 10**15, m/s to Angstrom/fs --> Velocity in Angstrom/fs
	if vel.dist == True:
		v0 = SelectVelocityFromDistribution( vel )
	else:
		v0 = vel.stream * 10**-5
	mu = 0.97	# Reduced mass in amu
	r0 = 1.27	# Equilibrium gas phase HCl bond length in Angstrom
	period = rotperiod( j, mu, r0 )
	phase = -( ( d / v0 / period ) % 1. ) * 2.*np.pi
	
	theta_i, theta_f, pt_i, pt_f = theta_( j, mj, phase )
	theta_i_DEG = np.degrees(theta_i)

	return [ theta_i_DEG, pt_i, theta_weight(theta_i_DEG), np.degrees(theta_f) ]

def theta_weight( THETA ):
	'''
	This was how trajectories were weighted in the original plots.
	This relies on the fact that the statistical initial distribution is a sin(theta) distribution.
	Note that the input expects the theta angle to be in degrees instead of radians.
	'''
	weight = 1. / np.sin( np.radians( THETA ) )
	#weight = 0.5 * np.exp( -(np.radians( THETA ) - 80.)**2 / 1000. ) + 2*np.exp( -(np.radians( THETA ) - 170.)**2 / 8000. ) - 1.	# A function derived by Jan Geweke to describe reactivity

	return weight

def theta_reactivity( THETA, nu ):
	'''
	This function yields the reactivity of a particular theta value determined approximately via the reactivity of v=0, J=0.
	Note that the input expects the theta angle to be in degrees instead of radians.
	'''
	if nu==0:
		theta_sticking = [0.0000, 0.0019, 0.0132, 0.1494, 0.3149, 0.3544, 0.3224, 0.3913, 0.3926, 0.4464, 0.4814, 0.4561]
	elif nu==1:
		theta_sticking = [0.0000, 0.0000, 0.0327, 0.1857, 0.4057, 0.4968, 0.4798, 0.5634, 0.7256, 0.7260, 0.7364, 0.7748]
	elif nu==2:
		theta_sticking = [0.0048, 0.0141, 0.0503, 0.2008, 0.4619, 0.5459, 0.5474, 0.6561, 0.8207, 0.8576, 0.8716, 0.8416]
	
	theta = np.arange(10,180,15)
	f = interpolate.UnivariateSpline(theta, theta_sticking, s=0)
	
	return f( THETA )

def plotangles_analytical(idx, title, nu, J, N, vel):
	ax2 = plt.subplot(Nrows,Ncols,idx)

	mJ = np.arange( -J, J+1, 1 )
	angles = []
	for i in range( N ):
		angles.append( theta_final( J, choice( mJ ), vel ) )
	angles = np.array( angles )

	angles_stat_binned, xedges, yedges = np.histogram2d( angles[:,1], angles[:,0], range=[[min(angles[:,1]), max(angles[:,1])], [20., 160.]], bins=36, density=True, weights=angles[:,2] )
	angles_binned, xedges, yedges = np.histogram2d( angles[:,1], angles[:,0], range=[[min(angles[:,1]), max(angles[:,1])], [20., 160.]], bins=36, density=True, weights=angles[:,2]*theta_reactivity(angles[:,3], nu) )

	angles_binned -= angles_stat_binned	# In order to get the relative distribution of reactive trajectories vs the initial distribution
	angles_binned += 1.			# Arbitrary shift so that if reacted/statistical=1 it also yields a probability density of 1
	
	vlimit = max( abs( np.amin(angles_binned) ), abs( np.amax(angles_binned) ) )
	plt.pcolormesh(yedges, xedges, angles_binned, vmin=1.-(vlimit-1.), vmax=vlimit, cmap='bwr')
	#plt.pcolormesh(yedges, xedges, angles_binned, cmap='jet')
	#plt.pcolormesh(yedges, xedges, angles_stat_binned, cmap='jet')
	#cbar = plt.colorbar(label='Relative intensity')
	cbar = plt.colorbar(label='Relative intensity', ticks=[1.])
	cbar.ax.set_yticklabels(['Avg.'])

	plt.annotate(title, xy=(0.05, 0.87), xycoords='axes fraction', size=10, color='k')

	plt.xlabel(r'$\theta_\textrm{i}$ angle (degrees)')
	plt.ylabel(r'$p_{\theta_\textrm{i}}$ (amu\r{A}$^2$/fs)')

	plt.xticks([20., 55., 90., 125., 160.])

	return

#r [ Ang ]   Z [ Ang ]   Theta [ Deg ]   E [ kJ/mol ]
top = np.array(
[[1.2944, 2.9713, 101.71,  17.96],
[1.2949, 2.9603, 101.19,  18.39],
[1.2954, 2.9492, 100.70,  18.88],
[1.2958, 2.9381, 100.24,  19.41],
[1.2963, 2.9269, 99.80,  19.98],
[1.2967, 2.9155, 99.38,  20.57],
[1.2970, 2.9041, 98.92,  21.18],
[1.2974, 2.8925, 98.39,  21.81],
[1.2976, 2.8808, 97.81,  22.44],
[1.2979, 2.8688, 97.20,  23.08],
[1.2980, 2.8568, 96.62,  23.72],
[1.2982, 2.8446, 96.08,  24.37],
[1.2985, 2.8322, 95.60,  25.02],
[1.2990, 2.8197, 95.18,  25.69],
[1.2997, 2.8072, 94.79,  26.40],
[1.3006, 2.7946, 94.39,  27.18],
[1.3016, 2.7818, 94.00,  28.02],
[1.3025, 2.7688, 93.64,  28.92],
[1.3033, 2.7555, 93.33,  29.87],
[1.3042, 2.7420, 93.05,  30.86],
[1.3051, 2.7282, 92.77,  31.90],
[1.3061, 2.7143, 92.44,  33.00],
[1.3075, 2.7003, 91.99,  34.18],
[1.3091, 2.6862, 91.36,  35.44],
[1.3106, 2.6718, 90.63,  36.79],
[1.3120, 2.6569, 89.86,  38.22],
[1.3135, 2.6417, 89.08,  39.73],
[1.3152, 2.6264, 88.33,  41.32],
[1.3172, 2.6110, 87.61,  43.03],
[1.3199, 2.5957, 86.91,  44.88],
[1.3229, 2.5804, 86.23,  46.88],
[1.3259, 2.5647, 85.63,  49.02],
[1.3291, 2.5487, 85.09,  51.32],
[1.3327, 2.5327, 84.56,  53.77],
[1.3369, 2.5168, 83.98,  56.40],
[1.3419, 2.5014, 83.31,  59.24],
[1.3476, 2.4865, 82.61,  62.28],
[1.3540, 2.4719, 81.96,  65.54],
[1.3614, 2.4581, 81.46,  69.02],
[1.3700, 2.4454, 81.23,  72.73],
[1.3804, 2.4349, 81.50,  76.66],
[1.3939, 2.4284, 82.92,  80.79],
[1.4134, 2.4303, 87.38,  85.04],
[1.4389, 2.4415, 96.52,  89.18],
[1.4662, 2.4568, 107.57,  92.96],
[1.5744, 2.5986, 131.15,  95.26],
[1.5919, 2.6070, 132.18,  96.76],
[1.6062, 2.6110, 132.76,  98.16],
[1.6198, 2.6147, 133.26,  99.49],
[1.6330, 2.6184, 133.75, 100.76],
[1.6460, 2.6222, 134.23, 101.97],
[1.6587, 2.6264, 134.72, 103.11],
[1.6712, 2.6309, 135.21, 104.19],
[1.6836, 2.6359, 135.68, 105.20],
[1.6956, 2.6409, 136.10, 106.14],
[1.7072, 2.6457, 136.47, 107.01],
[1.7183, 2.6502, 136.79, 107.82],
[1.7289, 2.6543, 137.04, 108.57],
[1.7391, 2.6578, 137.24, 109.27],
[1.7488, 2.6607, 137.40, 109.91],
[1.7580, 2.6628, 137.50, 110.51],
[1.7668, 2.6643, 137.57, 111.06],
[1.7753, 2.6650, 137.61, 111.58],
[1.7834, 2.6649, 137.62, 112.06],
[1.7913, 2.6641, 137.60, 112.51],
[1.7990, 2.6624, 137.56, 112.93],
[1.8065, 2.6601, 137.50, 113.32],
[1.8139, 2.6570, 137.41, 113.68],
[1.8212, 2.6532, 137.30, 114.01],
[1.8285, 2.6487, 137.18, 114.31],
[1.8359, 2.6434, 137.03, 114.57],
[1.8432, 2.6375, 136.86, 114.81],
[1.8507, 2.6309, 136.67, 115.02],
[1.8584, 2.6237, 136.46, 115.19],
[1.8662, 2.6160, 136.23, 115.32],
[1.8742, 2.6077, 135.99, 115.42],
[1.8825, 2.5991, 135.73, 115.48],
[1.8911, 2.5901, 135.46, 115.49],
[1.9092, 2.5717, 134.91, 115.40],
[1.9188, 2.5625, 134.64, 115.29],
[1.9287, 2.5535, 134.37, 115.13],
[1.9388, 2.5448, 134.12, 114.93],
[1.9493, 2.5366, 133.88, 114.69],
[1.9600, 2.5289, 133.67, 114.42],
[1.9710, 2.5219, 133.47, 114.11],
[1.9822, 2.5155, 133.29, 113.76],
[1.9935, 2.5100, 133.14, 113.40],
[2.0049, 2.5051, 132.99, 113.01],
[2.0164, 2.5013, 132.86, 112.61],
[2.0279, 2.4985, 132.74, 112.20],
[2.0393, 2.4966, 132.62, 111.80],
[2.0508, 2.4955, 132.49, 111.41],
[2.0622, 2.4949, 132.35, 111.03],
[2.0736, 2.4947, 132.16, 110.67],
[2.0850, 2.4949, 131.93, 110.35],
[2.0965, 2.4955, 131.64, 110.06],
[2.1080, 2.4959, 131.25, 109.80],
[2.1200, 2.4957, 130.72, 109.59],
[2.1324, 2.4946, 130.04, 109.41],
[2.1458, 2.4917, 129.13, 109.24],
[2.1612, 2.4847, 127.78, 109.08],
[2.1811, 2.4686, 125.53, 108.89],
[2.2016, 2.4532, 123.36, 108.63],
[2.2196, 2.4448, 121.89, 108.32],
[2.2359, 2.4409, 120.84, 107.98],
[2.2510, 2.4400, 120.07, 107.64],
[2.2652, 2.4411, 119.48, 107.31]]
)

TS = np.array(
[[1.2987, 2.9685, 113.74,  18.80],
[1.2994, 2.9370, 112.57,  20.15],
[1.3001, 2.9054, 111.36,  21.75],
[1.3007, 2.8736, 110.10,  23.42],
[1.3012, 2.8415, 108.83,  25.04],
[1.3020, 2.8091, 107.67,  26.67],
[1.3037, 2.7765, 106.85,  28.56],
[1.3055, 2.7435, 106.30,  30.69],
[1.3075, 2.7100, 105.85,  33.02],
[1.3104, 2.6762, 105.45,  35.67],
[1.3133, 2.6417, 104.98,  38.61],
[1.3168, 2.6068, 104.56,  41.86],
[1.3222, 2.5718, 104.45,  45.59],
[1.3283, 2.5365, 104.50,  49.79],
[1.3361, 2.5012, 104.78,  54.54],
[1.3471, 2.4671, 105.45,  59.96],
[1.3617, 2.4345, 106.60,  66.03],
[1.3838, 2.4069, 108.83,  72.70],
[1.4295, 2.3980, 113.54,  79.51],
[1.4927, 2.4065, 118.18,  85.49],
[1.5581, 2.4220, 121.06,  90.28],
[1.6335, 2.4529, 123.90,  93.69],
[1.6923, 2.4743, 125.32,  95.88],
[1.7373, 2.4861, 125.93,  97.35],
[1.7757, 2.4944, 126.20,  98.37],
[1.8084, 2.4987, 126.04,  99.09],
[1.8375, 2.5010, 125.54,  99.60],
[1.8641, 2.5019, 124.86,  99.97],
[1.8883, 2.5012, 124.12, 100.24],
[1.9110, 2.4993, 123.36, 100.43],
[1.9325, 2.4969, 122.64, 100.57],
[1.9533, 2.4942, 121.95, 100.65],
[1.9734, 2.4910, 121.28, 100.70],
[1.9929, 2.4873, 120.63, 100.71],
[2.0119, 2.4833, 119.99, 100.69],
[2.0307, 2.4788, 119.37, 100.64],
[2.0492, 2.4740, 118.76, 100.56],
[2.0676, 2.4689, 118.15, 100.46],
[2.0860, 2.4636, 117.54, 100.33],
[2.1044, 2.4579, 116.93, 100.18],
[2.1230, 2.4521, 116.32, 100.01],
[2.1418, 2.4461, 115.71,  99.81],
[2.1609, 2.4403, 115.11,  99.59],
[2.1803, 2.4350, 114.54,  99.34]]
)

fcc = np.array(
[[1.3075, 2.9688, 113.40,  24.41],
[1.3095, 2.9377, 113.90,  26.15],
[1.3115, 2.9066, 114.67,  28.11],
[1.3138, 2.8754, 115.60,  30.16],
[1.3159, 2.8440, 116.22,  32.22],
[1.3177, 2.8124, 116.36,  34.30],
[1.3207, 2.7807, 116.62,  36.58],
[1.3246, 2.7489, 117.18,  39.12],
[1.3291, 2.7169, 117.92,  41.86],
[1.3351, 2.6850, 118.71,  44.82],
[1.3422, 2.6532, 119.36,  47.99],
[1.3507, 2.6217, 120.11,  51.43],
[1.3633, 2.5917, 121.14,  55.21],
[1.3792, 2.5634, 121.35,  59.24],
[1.3949, 2.5349, 119.80,  63.39],
[1.4092, 2.5056, 117.12,  67.68],
[1.4245, 2.4766, 114.93,  72.27],
[1.4415, 2.4486, 113.57,  77.19],
[1.4623, 2.4233, 113.60,  82.47],
[1.4925, 2.4060, 115.52,  88.00],
[1.6892, 2.4322, 117.04, 106.77],
[1.7413, 2.4529, 117.52, 110.17],
[1.7890, 2.4734, 117.78, 112.99],
[1.8316, 2.4924, 117.94, 115.36],
[1.8698, 2.5099, 118.09, 117.35],
[1.9039, 2.5255, 118.27, 119.04],
[1.9342, 2.5389, 118.44, 120.49],
[1.9613, 2.5503, 118.55, 121.76],
[1.9858, 2.5599, 118.61, 122.87],
[2.0081, 2.5679, 118.63, 123.86],
[2.0285, 2.5743, 118.62, 124.75],
[2.0471, 2.5788, 118.59, 125.55],
[2.0644, 2.5813, 118.53, 126.28],
[2.0804, 2.5815, 118.41, 126.95],
[2.0954, 2.5788, 118.23, 127.56],
[2.1096, 2.5726, 117.95, 128.12],
[2.1232, 2.5621, 117.56, 128.63],
[2.1366, 2.5461, 117.02, 129.07],
[2.1504, 2.5242, 116.32, 129.45],
[2.1653, 2.4969, 115.52, 129.75],
[2.1818, 2.4663, 114.67, 129.93]]
)

bridge = np.array(
[[1.2991, 2.9685, 116.86,  20.31],
[1.2996, 2.9370, 114.99,  21.78],
[1.3000, 2.9054, 113.45,  23.48],
[1.3005, 2.8736, 112.20,  25.25],
[1.3008, 2.8414, 111.07,  26.96],
[1.3015, 2.8090, 109.81,  28.69],
[1.3030, 2.7763, 108.18,  30.68],
[1.3042, 2.7431, 106.28,  32.92],
[1.3056, 2.7094, 104.39,  35.32],
[1.3077, 2.6752, 102.90,  38.00],
[1.3095, 2.6402, 101.70,  40.91],
[1.3118, 2.6045, 100.72,  44.04],
[1.3157, 2.5687, 99.99,  47.59],
[1.3194, 2.5318, 99.19,  51.48],
[1.3247, 2.4946, 98.44,  55.81],
[1.3319, 2.4575, 97.95,  60.69],
[1.3399, 2.4199, 97.88,  66.04],
[1.3515, 2.3835, 99.33,  71.90],
[1.3699, 2.3515, 103.08,  78.23],
[1.4077, 2.3351, 110.70,  84.82],
[1.4736, 2.3460, 116.96,  90.61],
[1.5250, 2.3482, 119.09,  95.32],
[1.5778, 2.3557, 121.54,  99.22],
[1.6360, 2.3735, 123.58, 102.16],
[1.6861, 2.3875, 124.35, 104.21],
[1.7275, 2.3952, 124.42, 105.68],
[1.7642, 2.4002, 124.19, 106.75],
[1.7973, 2.4029, 123.76, 107.54],
[1.8274, 2.4037, 123.17, 108.12],
[1.8552, 2.4028, 122.46, 108.56],
[1.8814, 2.4007, 121.66, 108.87],
[1.9061, 2.3973, 120.81, 109.09],
[1.9297, 2.3928, 119.95, 109.23],
[1.9525, 2.3874, 119.08, 109.30],
[1.9748, 2.3811, 118.22, 109.31],
[1.9966, 2.3739, 117.37, 109.24],
[2.0182, 2.3659, 116.54, 109.12],
[2.0397, 2.3570, 115.70, 108.93],
[2.0613, 2.3472, 114.86, 108.68],
[2.0830, 2.3366, 114.00, 108.37],
[2.1052, 2.3253, 113.14, 108.00],
[2.1279, 2.3135, 112.29, 107.57],
[2.1512, 2.3013, 111.47, 107.07],
[2.1752, 2.2889, 110.65, 106.50]]
)

# This part is with the correct mJ sampling:
S0_v0 = np.array(
[[0, 0.0000, 0.0019, 0.0132, 0.1494, 0.3149, 0.3544, 0.3224, 0.3913, 0.3926, 0.4464, 0.4814, 0.4561, 0.0000, 0.0019, 0.0040, 0.0112, 0.0131, 0.0132, 0.0129, 0.0141, 0.0159, 0.0182, 0.0249, 0.0466],
[2, 0.3333, 0.3773, 0.2808, 0.3733, 0.2962, 0.2363, 0.2861, 0.2859, 0.2693, 0.2901, 0.3945, 0.3223, 0.0348, 0.0191, 0.0182, 0.0134, 0.0128, 0.0125, 0.0134, 0.0125, 0.0139, 0.0178, 0.0234, 0.0425],
[4, 0.5320, 0.4973, 0.4565, 0.4700, 0.4513, 0.3648, 0.2913, 0.2104, 0.1478, 0.0745, 0.0188, 0.0000, 0.0350, 0.0211, 0.0170, 0.0175, 0.0127, 0.0139, 0.0133, 0.0113, 0.0114, 0.0102, 0.0062, 0.0000],
[6, 0.4468, 0.4037, 0.3911, 0.4099, 0.4374, 0.4344, 0.3788, 0.3240, 0.2309, 0.1517, 0.0954, 0.0455, 0.0363, 0.0202, 0.0174, 0.0150, 0.0137, 0.0142, 0.0138, 0.0136, 0.0136, 0.0126, 0.0163, 0.0199],
[8, 0.2019, 0.2032, 0.2257, 0.2565, 0.2746, 0.3245, 0.3855, 0.4247, 0.5000, 0.5467, 0.5871, 0.6379, 0.0275, 0.0169, 0.0145, 0.0136, 0.0122, 0.0135, 0.0137, 0.0147, 0.0162, 0.0185, 0.0246, 0.0446]]
)

S0_v1 = np.array(
[[0, 0.0000, 0.0000, 0.0327, 0.1857, 0.4057, 0.4968, 0.4798, 0.5634, 0.7256, 0.7260, 0.7364, 0.7748, 0.0000, 0.0000, 0.0062, 0.0123, 0.0141, 0.0140, 0.0141, 0.0147, 0.0149, 0.0169, 0.0224, 0.0396],
[2, 0.5562, 0.4976, 0.4092, 0.5189, 0.4213, 0.3741, 0.4264, 0.4481, 0.4173, 0.4340, 0.6329, 0.7168, 0.0372, 0.0200, 0.0203, 0.0140, 0.0142, 0.0144, 0.0149, 0.0138, 0.0157, 0.0198, 0.0234, 0.0424],
[4, 0.7111, 0.7249, 0.7202, 0.7058, 0.6551, 0.5141, 0.4157, 0.3274, 0.2482, 0.1312, 0.0605, 0.0602, 0.0338, 0.0195, 0.0157, 0.0164, 0.0123, 0.0146, 0.0145, 0.0131, 0.0140, 0.0133, 0.0111, 0.0261],
[6, 0.7088, 0.6837, 0.6974, 0.7261, 0.6604, 0.6288, 0.5178, 0.3798, 0.2778, 0.2101, 0.1957, 0.1518, 0.0337, 0.0195, 0.0166, 0.0138, 0.0133, 0.0141, 0.0144, 0.0143, 0.0146, 0.0143, 0.0221, 0.0339],
[8, 0.5023, 0.5486, 0.5716, 0.5065, 0.4744, 0.5098, 0.5582, 0.5909, 0.6397, 0.6839, 0.6572, 0.6842, 0.0343, 0.0211, 0.0172, 0.0158, 0.0138, 0.0146, 0.0143, 0.0149, 0.0159, 0.0176, 0.0241, 0.0435]]
)

S0_v2 = np.array(
[[0, 0.0048, 0.0141, 0.0503, 0.2008, 0.4619, 0.5459, 0.5474, 0.6561, 0.8207, 0.8576, 0.8716, 0.8416, 0.0048, 0.0053, 0.0077, 0.0128, 0.0142, 0.0141, 0.0144, 0.0147, 0.0134, 0.0137, 0.0175, 0.0363],
[2, 0.5511, 0.5353, 0.4468, 0.5759, 0.4619, 0.4276, 0.4799, 0.4940, 0.4387, 0.5000, 0.6959, 0.8100, 0.0375, 0.0200, 0.0211, 0.0142, 0.0146, 0.0153, 0.0153, 0.0141, 0.0159, 0.0203, 0.0234, 0.0392],
[4, 0.8708, 0.8110, 0.8256, 0.7961, 0.7402, 0.5667, 0.4737, 0.3738, 0.3257, 0.2436, 0.1290, 0.0759, 0.0251, 0.0177, 0.0136, 0.0146, 0.0116, 0.0148, 0.0150, 0.0137, 0.0154, 0.0171, 0.0159, 0.0298],
[6, 0.8580, 0.7877, 0.7964, 0.7885, 0.7738, 0.7345, 0.6076, 0.4787, 0.3220, 0.2507, 0.2466, 0.2400, 0.0263, 0.0176, 0.0150, 0.0131, 0.0122, 0.0132, 0.0144, 0.0149, 0.0155, 0.0157, 0.0251, 0.0427],
[8, 0.7551, 0.7685, 0.7402, 0.7331, 0.6331, 0.6016, 0.6152, 0.6482, 0.7130, 0.6816, 0.7180, 0.8051, 0.0307, 0.0184, 0.0156, 0.0145, 0.0136, 0.0148, 0.0144, 0.0146, 0.0153, 0.0177, 0.0230, 0.0365]]
)

theta = np.arange(10,180,15)

Nrows=3
Ncols=2
#fig, ax = plt.subplots(nrows=Nrows, ncols=Ncols, figsize=(6.69,7.))
fig = plt.figure(figsize=(6.69,7.))
gs = fig.add_gridspec(Nrows, Ncols)
mpl.rc("text", usetex=True)

prop_cycle = plt.rcParams['axes.prop_cycle']
color_cycle = prop_cycle.by_key()['color']

#################### Plot S0(theta)

#plt.subplot(Nrows,Ncols,1)
fig.add_subplot(gs[0,0])

plt.errorbar( theta, S0_v0[0,1:len(theta)+1], S0_v0[0,len(theta)+1:], label=r'$J=0$', capsize=4, ls='' )
plt.errorbar( theta, S0_v0[1,1:len(theta)+1], S0_v0[1,len(theta)+1:], label=r'$J=2$', capsize=4, ls='' )
#plt.errorbar( theta, S0_v0[2,1:len(theta)+1], S0_v0[2,len(theta)+1:], label=r'$J=4$', capsize=4, ls='' )
#plt.errorbar( theta, S0_v0[3,1:len(theta)+1], S0_v0[3,len(theta)+1:], label=r'$J=6$', capsize=4, ls='' )
plt.errorbar( theta, S0_v0[4,1:len(theta)+1], S0_v0[4,len(theta)+1:], label=r'$J=8$', capsize=4, ls='' )

xaxis = np.linspace( 0., 180., 181 )
plt.plot( xaxis, theta_reactivity( xaxis, 0 ), c='C0' )

nu = 0
N = 100000
vel = SetupVelocityDistribution()	# Rather set up the velocity distribution only once
for idx in range(2):
	if idx == 0:
		J = 2
	else:
		J = 8
	mJ = np.arange( -J, J+1, 1 )
	angles = []
	for i in range( N ):
		angles.append( theta_final( J, choice( mJ ), vel ) )
	angles = np.array( angles )

	angles_stat_binned, xedges, yedges = np.histogram2d( angles[:,1], angles[:,0], range=[[min(angles[:,1]), max(angles[:,1])], [5., 175.]], bins=170, density=False )
	angles_binned, xedges, yedges = np.histogram2d( angles[:,1], angles[:,0], range=[[min(angles[:,1]), max(angles[:,1])], [5., 175.]], bins=170, density=False, weights=theta_reactivity(angles[:,3], nu) )

	sticking = []
	theta = []
	for i in range( len(angles_binned) ):
		sticking.append( sum(angles_binned[:,i]) / sum(angles_stat_binned[:,i]) )
		theta.append( (yedges[i] + yedges[i+1])/2. )
	plt.plot( theta, sticking, color='C{:d}'.format( idx+1 ) )

plt.annotate(r'(a) $\nu=0$', xy=(0.05, 0.87), xycoords='axes fraction', bbox=dict(boxstyle='round', fc='w'))

plt.legend(loc='upper center', numpoints=1, frameon=True, fontsize=8, ncol=1, columnspacing=1.)
plt.xlim(0., 180.)
plt.xticks(np.linspace(0.,180.,5))
plt.ylim(0.0, 0.8)
plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$\theta_\textrm{i}$ (degrees)')
plt.ylabel('Reaction probability')

############ Plot theta vs p_theta

nu = 0
plotangles( (0,1), r'(b) $\nu={0:d},J={1:d}$'.format(nu, 2), 'MD-data/2.56_ms_v{0:d}j{1:d}'.format(nu, 2) )
plotangles( (1,0), r'(c) $\nu={0:d},J={1:d}$'.format(nu, 8), 'MD-data/2.56_ms_v{0:d}j{1:d}'.format(nu, 8) )

############ Plot heatmap of theta along S

fig.add_subplot(gs[1,1])

#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_002087.dat', 1.89)
#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_002988.dat', 1.89)
#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_003197.dat', 1.89)
#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_003815.dat', 1.89)
#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_004117.dat', 1.89)
#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_005410.dat', 1.89)
#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_006953.dat', 1.89)
#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_007355.dat', 1.89)
#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_008116.dat', 1.89)
#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_008191.dat', 1.89)
#plot_traj('MEP_theta_MD/2.56ms_v2j8_angles/analysis_rcom_009262.dat', 1.89)

vmin = 0.
vmax = 0.003

angles, edges = collectangles_heatmap('MD-data/2.56_ms_v0j8_angles')
areas = plt.pcolormesh(edges[1], edges[0], angles, cmap='jet', vmin=vmin, vmax=vmax)

plt.colorbar(areas, label='Probability density')

plt.plot(s_limit, [117.,117.], c='k', ls='--')

#xaxis = reactioncoordinate( top[:,0], top[:,1], top[:,3] )
#plt.scatter( xaxis, top[:,2], marker='o', label='MEP', facecolor='None', color=color_cycle[1], zorder=100, alpha=0.5 )

#plt.plot( [0.,0.], [0.,180.], color='k', ls='--' )

plt.annotate(r'(d) $\nu=0, J=8$', xy=(0.05,0.87), xycoords='axes fraction', bbox=dict(boxstyle='round', fc='w'))

#plt.legend(loc='upper center', numpoints=1, frameon=True, handletextpad=0.)
plt.xlim(-8.,0.)
plt.ylim(0.,180.)
plt.yticks([0.,45.,90.,135.,180.])
plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.ylabel(r'$\theta$ (degrees)')
plt.xlabel(r'Reaction coordinate (\r{A})')

############ Plot MEP vs theta or E

#plt.subplot(Nrows,Ncols,3)
fig.add_subplot(gs[2,0])

xaxis = reactioncoordinate( TS[:,0], TS[:,1], TS[:,3] )
plt.plot( xaxis, TS[:,2], marker='o', label='TS' )

xaxis = reactioncoordinate( top[:,0], top[:,1], top[:,3] )
plt.plot( xaxis, top[:,2], marker='o', label='Top' )

xaxis = reactioncoordinate( fcc[:,0], fcc[:,1], fcc[:,3] )
plt.plot( xaxis, fcc[:,2], marker='o', label='Fcc' )

xaxis = reactioncoordinate( bridge[:,0], bridge[:,1], bridge[:,3] )
plt.plot( xaxis, bridge[:,2], marker='o', label='Bridge' )

plt.plot( [0.,0.], [80.,140.], color='k', ls='--' )

plt.annotate('(e) MEP', xy=(0.05,0.85), xycoords='axes fraction', bbox=dict(boxstyle='round', fc='w'))

plt.legend(loc='best', numpoints=1, frameon=True)
#plt.legend(loc='upper left', numpoints=1, handletextpad=0.5, borderaxespad=0.2, frameon=False)
plt.xlim(-1.5,0.5)
plt.ylim(80.,140.)
plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.ylabel(r'$\theta$ (degrees)')
plt.xlabel(r'Reaction coordinate (\r{A})')

#plt.subplot(Nrows,Ncols,4)
fig.add_subplot(gs[2,1])

xaxis = reactioncoordinate( TS[:,0], TS[:,1], TS[:,3] )
plt.plot( xaxis, TS[:,3], marker='o', label='TS', markerfacecolor='None' )

xaxis = reactioncoordinate( top[:,0], top[:,1], top[:,3] )
plt.plot( xaxis, top[:,3], marker='o', label='Top', markerfacecolor='None' )

xaxis = reactioncoordinate( fcc[:,0], fcc[:,1], fcc[:,3] )
plt.plot( xaxis, fcc[:,3], marker='o', label='Fcc', markerfacecolor='None' )

xaxis = reactioncoordinate( bridge[:,0], bridge[:,1], bridge[:,3] )
plt.plot( xaxis, bridge[:,3], marker='o', label='Bridge', markerfacecolor='None' )

plt.plot( [0.,0.], [0.,140.], color='k', ls='--' )

plt.annotate('(f) MEP', xy=(0.05,0.85), xycoords='axes fraction', bbox=dict(boxstyle='round', fc='w'))

plt.legend(loc='best', numpoints=1, frameon=True)
#plt.legend(loc='upper left', numpoints=1, handletextpad=0.5, borderaxespad=0.2, frameon=False)
plt.xlim(-1.5,0.5)
plt.ylim(0.,140.)
plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.ylabel(r'$E$ (kJ/mol)')
plt.xlabel(r'Reaction coordinate (\r{A})')

plt.tight_layout()

plt.savefig('explanation_rotational_mechanism.pdf')

