#!/usr/bin/env python

#########################################
import matplotlib.pyplot as plt         #
import numpy as np                      #
import scipy.special as spc		#
from math import exp			#
#########################################

from matplotlib import rc               #
#########################################
rc('mathtext', fontset ='stix')


### DATA ######################################################
### Pt 211 ####################################################
#	     	      Average Ei       Expt S0		Expt err	SRP S0		  SRP err
Laser_off = np.array([[58.2384,         0.01580,         0.005,    	0.006,            0.002],            
             	     [69.2488,         0.02685,         0.005,    	0.019,     	  0.004],
             	     [79.532,          0.04285,         0.005,    	0.040,     	  0.009],
             	     [92.5326,         0.06100,         0.0061,    	0.058,     	  0.010],
             	     [96.8281,         0.07530,         0.00753,    	0.069,     	  0.008],              # Davide's values (0.07, 0.011) are for 500 trajectories
             	     [107.875,         0.09425,         0.00943,    	0.102,     	  0.014]])
 
                        # Average Ei FROM AIMD! Expt S0, Expt err
Laser_off_2018 = np.array([[55.5,  0.01761, 0.005],
                          [72.8,  0.03312, 0.005],
                          [79.9,  0.04563, 0.005],
                          [89.9,  0.06854, 0.007],
                          [98.5,  0.09861, 0.0098],
                          [107.8, 0.11363, 0.011]])

Laser_off_2018_v2 = np.array([98.5, 0.084, 0.0084])


vel = np.array([[298, 2454., 159.],
		[350, 2671., 194.],
		[400, 2856., 232.],
		[450, 3076., 266.],
		[500, 3151., 257.],
		[550, 3321., 288.]])

vel_2018 = np.array([[298, 2400.71667, 126.41333],
		     [350, 2744.56, 170.79556],
		     [400, 2874.23, 192.87667],
		     [450, 3042.09333, 227.39333],
		     [500, 3177.48, 263.80333],
		     [550, 3310.52333, 316.53]])

rot_surf = np.array([[0,      0.0753,  0.00753],                      # Angle, S0, err
                     [-10,    0.0751,  0.00751],
                     [-20,    0.0702,  0.00702],
                     [-30,    0.0574,  0.00574],
                     [-40,    0.0426,  0.005],
                     [-50,    0.0228,  0.005]])

rot_surf_2018 = np.array([[-50,	0.02743, 0.005],
			  [-40,	0.05977, 0.005977],
			  [-30,	0.06813, 0.006813],
			  [-20,	0.08561, 0.008561],
			  [-10,	0.09312, 0.009312],
			  [0,	0.10084, 0.010084],
			  [10,	0.09689, 0.009689],
			  [20,	0.08059, 0.008059],
			  [30,	0.04747, 0.005],
			  [40,	0.01734, 0.005]])

AIMD_rot_surf = np.array([[50,	0.008,	0.003],
			  [40,	0.021,	0.005],
			  [30,	0.037,	0.006],
			  [20,	0.039,	0.006],
			  [10,	0.060,	0.008],
			  [0,	0.069,	0.008],
			  [-10,	0.069,	0.008],
			  [-20,	0.053,	0.007],
			  [-30,	0.054,	0.007],
			  [-40,	0.050,	0.007],
			  [-50,	0.033,	0.006]])

AIMD_rot_surf_trap = np.array([[50,  0.142, 0.011],
			       [40,  0.099, 0.009],
			       [30,  0.072, 0.008],
			       [20,  0.057, 0.007],
			       [10,  0.067, 0.008],
			       [0,   0.073, 0.008],
			       [-10, 0.070, 0.008],
			       [-20, 0.058, 0.007],
			       [-30, 0.071, 0.008],
			       [-40, 0.097, 0.009],
			       [-50, 0.153, 0.011]])


Laser_off 	= np.transpose(Laser_off)
Laser_off_2018  = np.transpose(Laser_off_2018)

rot_surf = np.transpose(rot_surf)
rot_surf_2018 = np.transpose(rot_surf_2018)

AIMD_rot_surf = np.transpose(AIMD_rot_surf)
AIMD_rot_surf_trap = np.transpose(AIMD_rot_surf_trap)

#~~~S shape curve~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def S_shape(x,E0,A,W):
        return A/2.0*(1+spc.erf((x-E0)/W))
#~~~End~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#~~~Finding energy~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def Energy(S0,E0,A,W):
        return spc.erfinv(S0*2.0/A-1)*W+E0
#~~~End~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#~~~MB velocity distribution~~~~~~~~~~~~~~~~~~
def MB_vel(v, v0, alpha):
	return v*v*v*exp(-((v-v0)/alpha)*((v-v0)/alpha))
#~~~End~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

xfit = np.array([i for i in range(50,150)])
vv   = np.array([i for i in range(1500,4000,10)])

LO_fit_expt = np.array([102.97779996, 0.16960127, 48.01642888])
LO_fit_expt_2018 = np.array([125.3667177, 0.35523554, 57.55801635])
LO_fit_expt_2018_Davide = np.array([106.161503831, 0.228161871789, 46.069353565])

E_2018_LO_expt = Energy(Laser_off_2018[1], LO_fit_expt[0], LO_fit_expt[1], LO_fit_expt[2])
E_SRP_LO_2018 = Energy(Laser_off[3], LO_fit_expt_2018[0], LO_fit_expt_2018[1], LO_fit_expt_2018[2])

E_SRP_LO_2018_Davide = Energy(Laser_off[3], LO_fit_expt_2018_Davide[0], LO_fit_expt_2018_Davide[1], LO_fit_expt_2018_Davide[2])
#print Laser_off[0] - E_SRP_LO_2018_Davide # Agrees with the values of Davide's shifts given below
E_shift_Davide = np.array([-15.22, -8.15, -3.78, -7.91, -7.12, -6.05])

shift_log = np.array([0.005, 0.005, 0.005, 0.002, 0.003, 0.003])
shift_lin = np.array([-2, 4, 1, -2, 2.5, 1])

shift_SRP_x = np.array([5, -2, 2, 0, -4, -3])
shift_SRP_y = np.array([0, 0, 0, -0.02, -0.02, -0.03])
### PLOTS ###################################################
# Expt 2016 vs Expt 2018 - normal incidence
# ~~~ LO ~~~
# Linear
plt.subplot(2,2,1)
plt.errorbar(Laser_off[0],Laser_off[1],yerr = Laser_off[2], color='red', marker='^', mec = 'red', markersize = 5, linestyle='None', label =   'Expt 2016')
plt.errorbar(Laser_off_2018[0],Laser_off_2018[1],yerr = Laser_off_2018[2], color='blue', marker='^', mec = 'blue', markersize = 5, linestyle='None', label = 'Expt 2018 A')
plt.errorbar(Laser_off_2018_v2[0],Laser_off_2018_v2[1],yerr = Laser_off_2018_v2[2], color='k', marker='^', mec = 'k', markersize = 5, linestyle='None', label = 'Expt 2018 B')
plt.ylim([0.00,0.15])
plt.xlim([55,115])
plt.xticks(np.arange(60, 120, 10.0))
plt.xlabel('E$_{\\rm{i}}$ (kJ/mol)', fontsize = 12)
plt.ylabel('S$_0$', fontsize =12)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
lg = plt.legend(loc = 'best', numpoints = 1, prop={'size': 12})
lg.draw_frame(False)
plt.savefig('SI_Fig5_Norm_expt.png',bbox_inches='tight', dpi=300)
plt.show()

angfit = np.array([float(i) for i in range(-60,60)])
cosfit = rot_surf[1][0]*np.cos(np.radians(angfit))*np.cos(np.radians(angfit))

# Scaled expt and AIMD comparison - only theta = 0
plt.subplot(2,2,1)
scl = rot_surf[1][0]/rot_surf_2018[1][5]
plt.plot(angfit,cosfit, 'k--')
plt.errorbar(-rot_surf[0],rot_surf[1],yerr = rot_surf[2], color='red', marker='o', mec = 'red', markersize = 4, linestyle='None', label =   'Expt 2016')
plt.errorbar(-rot_surf_2018[0],scl*rot_surf_2018[1],yerr = rot_surf_2018[2], color='white', marker='o', mec = 'black', ecolor = 'black', markersize = 4, mew =1, linestyle='None', label =   'Scaled Expt 2018 A')
plt.errorbar(-AIMD_rot_surf[0],AIMD_rot_surf[1],yerr = AIMD_rot_surf[2], color='blue', marker='o', mec = 'blue', markersize = 4, linestyle='None', label =   'SRP32-vdW')
plt.annotate('', xy=(-20, 0.07), xytext=(-20, 0.09),
            arrowprops=dict(facecolor='black', shrink=0.1))
plt.annotate('Terrace', xy = (-31,0.09), xytext=(-31,0.09), fontsize = 12)
plt.annotate('', xy=(40, 0.07), xytext=(40, 0.09),
            arrowprops=dict(facecolor='black', shrink=0.1))
plt.annotate('Step', xy = (33,0.0885), xytext=(33,0.0885), fontsize = 12)
plt.ylim([0.0,0.1])
plt.xlim([-55,55])
plt.xticks(np.arange(-40, 50, 20.0))
plt.xlabel('$\\rm{\\theta_i}$ ('u'\xb0)', fontsize = 12)
plt.ylabel('S$_0$', fontsize =12)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.annotate('$\\rm{\\phi_i} = 0$'u'\xb0', xy = (-53,0.0885), xytext=(-53,0.0885), fontsize = 12)
lg = plt.legend(loc = 'lower center', numpoints = 1, prop={'size': 10})
lg.draw_frame(False)
plt.savefig('Fig4.png',bbox_inches='tight', dpi = 300)
plt.show()

# Including trapping for SI
plt.subplot(2,2,1)
scl = rot_surf[1][0]/rot_surf_2018[1][5]
plt.errorbar(-rot_surf[0],rot_surf[1],yerr = rot_surf[2], color='red', marker='o', mec = 'red', markersize = 4, linestyle='None', label =   'Expt 2016')
plt.errorbar(-rot_surf_2018[0],scl*rot_surf_2018[1],yerr = rot_surf_2018[2], color='white', marker='o', mec = 'blue', ecolor = 'blue', markersize = 4, mew =1, linestyle='None', label =   'Scaled Expt 2018 A')
plt.errorbar(-AIMD_rot_surf[0],AIMD_rot_surf[1],yerr = AIMD_rot_surf[2], color='black', marker='o', mec = 'black', markersize = 4, linestyle='None', label =   'SRP32-vdW')
plt.errorbar(-AIMD_rot_surf_trap[0],AIMD_rot_surf_trap[1],yerr = AIMD_rot_surf_trap[2], color='green', marker='o', mec = 'green', markersize = 4, linestyle='None', label =   'SRP32-vdW incl. trapping')
plt.ylim([0.0,0.18])
plt.xlim([-55,55])
plt.xticks(np.arange(-40, 50, 20.0))
plt.yticks(np.arange(0, 0.18, 0.04))
plt.xlabel('$\\rm{\\theta_i}$ ('u'\xb0)', fontsize = 12)
plt.ylabel('S$_0$', fontsize =12)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.annotate('$\\rm{\\phi_i} = 0$'u'\xb0', xy = (-10,0.005), xytext=(-10,0.005), fontsize = 12)
lg = plt.legend(loc = 'upper center', numpoints = 1, prop={'size': 10})
lg.draw_frame(False)
plt.savefig('SI_Fig6_S0_trap.png',bbox_inches='tight', dpi=300)
plt.show()

# ~~~ LO ~~~
# Linear
plt.subplot(2,2,3)
plt.errorbar(Laser_off[0],Laser_off[1],yerr = Laser_off[2], color='red', marker='^', mec = 'red', markersize = 5, linestyle='None', label =   'Expt 2016')
plt.errorbar(Laser_off_2018[0][4],Laser_off_2018[1][4],yerr = Laser_off_2018[2][4], color='blue', marker='o', mec = 'blue', markersize = 5, linestyle='None', label = 'Expt 2018 A')
plt.errorbar(Laser_off_2018[0],Laser_off_2018[1]/Laser_off_2018[1][4]*Laser_off_2018_v2[1],yerr = Laser_off_2018[2]/Laser_off_2018[1][4]*Laser_off_2018_v2[1], color='white', marker='^', mec = 'blue', ecolor = 'blue', markersize = 5, linestyle='None', label = 'Scaled Expt 2018 A')
plt.errorbar(Laser_off_2018_v2[0],Laser_off_2018_v2[1],yerr = Laser_off_2018_v2[2], color='k', marker='^', mec = 'k', markersize = 5, linestyle='None', label = 'Expt 2018 B')
plt.ylim([0.00,0.15])
plt.xlim([55,115])
plt.xticks(np.arange(60, 120, 10.0))
plt.xlabel('E$_{\\rm i}$ (kJ/mol)', fontsize = 12)
plt.ylabel('S$_0$', fontsize =12)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
lg = plt.legend(loc = 'best', numpoints = 1, prop={'size': 11})
lg.draw_frame(False)
plt.subplots_adjust(left=0.1, bottom=0.08, right=0.95, top=0.99, wspace=0.15, hspace=0.15)
plt.savefig('Fig3.png',bbox_inches='tight', dpi = 300)
plt.show()
