#!/usr/bin/env python
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from math import erf
from scipy.stats import linregress
from pylab import *
from scipy.optimize import curve_fit

mpl.rc("text", usetex=True)
fig = plt.figure(figsize=(3.69,4.))
#plt.rcParams.update({'font.size': 10})

# delta J, E_s, standard deviation and error
NN_v21 = np.array(
[[0 , 41.72439015561618 , 9.485805763465775 , 0.7296773664204442],
[1 , 42.42073875907346 , 9.065015701905772 , 0.6016662241374643],
[2 , 42.902445888578946 , 10.082495441505088 , 0.6428364303070906],
[3 , 44.815613173152286 , 11.937525244727995 , 0.8526803746234284],
[4 , 44.25173956460955 , 12.195588174731927 , 0.9166765410695522],
[5 , 44.718068439717264 , 12.85240142127313 , 0.9945486813635908],
[6 , 43.3207975066849 , 13.208289612969727 , 0.9557182943292565],
[7 , 44.36679219273018 , 12.650799985481948 , 0.8415186973684327],
[8 , 43.534464394935874 , 12.460540160428605 , 0.8234156709728868]]
)

NN_v22 = np.array(
[[0 , 40.82687347777624 , 6.786912939272291 , 0.1820391831590776],
[1 , 40.35738109402743 , 6.8524125493398 , 0.1741638528505864],
[2 , 40.085896389222924 , 7.057698684938841 , 0.18547220370622455],
[3 , 39.35597438838326 , 8.14618053327639 , 0.24190648068514037],
[4 , 37.36930721406284 , 8.982496505240706 , 0.3304264752869901],
[5 , 35.99289560301496 , 9.011280190530263 , 0.3728894274764357],
[6 , 33.52809466484295 , 7.866678869518778 , 0.3048276144458186],
[7 , 33.53010118886324 , 7.612504509096008 , 0.2679728502149074],
[8 , 33.29644537916573 , 6.496484179055161 , 0.24312480999168518]]
)

# delta Erot, E_s
exp_v21 = np.array(
[[-8.326672684688674e-17, 0.3041326145191908],
[0.00489539748953971, 0.30026308820618886],
[0.012175732217573204, 0.28959544404148974],
[0.03401673640167363, 0.2813100331996242],
[0.0218410041841004, 0.2733288825580715],
[0.04857740585774062, 0.27329919523664836],
[0.06564853556485362, 0.2839398002447462],
[0.08510460251046026, 0.2663299223112067],
[0.1068200836820084, 0.2806962164693551]]
)
exp_v21[:,:] *= 96.48530749926

exp_v22 = np.array(
[[0.00489539748953971, 0.2253796771468951],
[0.01217573221757326, 0.22537159327533388],
[0.0218410041841004, 0.21563401247146252],
[0.034016736401673575, 0.20456119908903128],
[0.04870292887029287, 0.20161351288819998],
[0.06564853556485362, 0.1963981613373511],
[0.08510460251046031, 0.1899808217116064],
[0.10694560669456066, 0.1742337186645444]]
)
exp_v22[:,:] *= 96.48530749926

# delta Erot, delta J according to experiment
Erot_exp = np.array(
[[-5.551115123125783e-17, 0],
[0.00489539748953971, 1],
[0.012175732217573149, 2],
[0.0218410041841004, 3],
[0.03401673640167363, 4],
[0.04857740585774062, 5],
[0.06564853556485356, 6],
[0.08510460251046031, 7],
[0.1068200836820084, 8]]
)
Erot_exp[:,0] *= 96.48530749926

# delta Erot, delta J according to NN for v=2
Erot_NN = np.array(
[[0., 0 ],
[0.234290361100104, 1],
[0.702720903919177, 2],
[1.40499145301706, 3],
[2.34065212635048, 4],
[3.50910388523928, 5],
[4.90959914222389, 6],
[6.54124258119044, 7],
[8.40299214152086, 8],
[10.4936601373225, 9]]
)

def baule(E, mass):
	mu = (35.453 + 1.008) / mass

	ET = E * 2.4 * mu / (1 + mu)**2
#	stderr = ( ET / 2.4 * 4. - ET )

	return ET

def baule_limit(E, mass):
	mu = (35.453 + 1.008) / mass

	ET = E * 4. * mu / (1 + mu)**2

	return ET
	
plt.errorbar( Erot_exp[:,0], NN_v21[:,1], yerr=NN_v21[:,3], linestyle='None', capsize=4, marker='o', label=r'HD-NNP ($\nu=2,j=1 \rightarrow \nu=1$)', color='orange' )
plt.errorbar( Erot_exp[:,0], NN_v22[:,1], yerr=NN_v22[:,3], linestyle='None', capsize=4, marker='o', label=r'HD-NNP ($\nu=2,j=1 \rightarrow \nu=2$)', color='b' )

#slope, intercept, r_value, p_value, std_err = linregress( Erot_exp[:,0], NN_v21[:,1] )
#plt.plot( [0., 11.], [0.*slope + intercept, 11.*slope + intercept], linestyle='-', color='k' )

def func(x, a, c):
	return a*x + c

nstd = 2. # to draw 5-sigma intervals

# curve fit [with only y-error]
popt, pcov = curve_fit(func, Erot_exp[:,0], NN_v21[:,1], sigma=NN_v21[:,3])
perr = np.sqrt(np.diag(pcov))

#print fit parameters and 1-sigma estimates
#print('fit parameter 1-sigma error')
#print('———————————–')
#for i in range(len(popt)):
#	print(str(popt[i])+' +- '+str(perr[i]))

# prepare confidence level curves
popt_up = popt + nstd * perr
popt_dw = popt - nstd * perr

fit = func(Erot_exp[:,0], *popt)
fit_up = func(Erot_exp[:,0], *popt_up)
fit_dw = func(Erot_exp[:,0], *popt_dw)
plt.fill_between(Erot_exp[:,0], fit_up, fit_dw, alpha=.25, color='orange')
plot(Erot_exp[:,0], fit, 'k')

#slope, intercept, r_value, p_value, std_err = linregress( Erot_exp[:,0], NN_v22[:,1] )
#plt.plot( [0., 11.], [0.*slope + intercept, 11.*slope + intercept], linestyle='-', color='k' )

# curve fit [with only y-error]
popt, pcov = curve_fit(func, Erot_exp[:,0], NN_v22[:,1], sigma=NN_v22[:,3])
perr = np.sqrt(np.diag(pcov))

#print fit parameters and 1-sigma estimates
#print('fit parameter 1-sigma error')
#print('———————————–')
#for i in range(len(popt)):
#	print(str(popt[i])+' +- '+str(perr[i]))

# prepare confidence level curves

popt_up = popt + nstd * perr
popt_dw = popt - nstd * perr

fit = func(Erot_exp[:,0], *popt)
fit_up = func(Erot_exp[:,0], *popt_up)
fit_dw = func(Erot_exp[:,0], *popt_dw)
plt.fill_between(Erot_exp[:,0], fit_up, fit_dw, alpha=.25, color='b')
plot(Erot_exp[:,0], fit, 'k')

plt.scatter( exp_v21[:,0], exp_v21[:,1], marker='s', label=r'Exp. ($\nu=2,j=1 \rightarrow \nu=1$)', color='orange' )
plt.scatter( exp_v22[:,0], exp_v22[:,1], marker='s', label=r'Exp. ($\nu=2,j=1 \rightarrow \nu=2$)', color='b' )

#slope, intercept, r_value, p_value, std_err = linregress( exp_v21[:,0], exp_v21[:,1] )
#plt.plot( [0., 11.], [0.*slope + intercept, 11.*slope + intercept], linestyle='-', color='k' )

# curve fit [with only y-error]
popt, pcov = curve_fit(func, exp_v21[:,0], exp_v21[:,1])
perr = np.sqrt(np.diag(pcov))

#print fit parameters and 1-sigma estimates
#print('fit parameter 1-sigma error')
#print('———————————–')
#for i in range(len(popt)):
#	print(str(popt[i])+' +- '+str(perr[i]))

# prepare confidence level curves
popt_up = popt + nstd * perr
popt_dw = popt - nstd * perr

fit = func(Erot_exp[:,0], *popt)
fit_up = func(Erot_exp[:,0], *popt_up)
fit_dw = func(Erot_exp[:,0], *popt_dw)
plt.fill_between(Erot_exp[:,0], fit_up, fit_dw, alpha=.25, color='orange')
plot(Erot_exp[:,0], fit, 'k')

#slope, intercept, r_value, p_value, std_err = linregress( exp_v22[:,0], exp_v22[:,1] )
#plt.plot( [0., 11.], [0.*slope + intercept, 11.*slope + intercept], linestyle='-', color='k' )

# curve fit [with only y-error]
popt, pcov = curve_fit(func, exp_v22[:,0], exp_v22[:,1])
perr = np.sqrt(np.diag(pcov))

#print fit parameters and 1-sigma estimates
#print('fit parameter 1-sigma error')
#print('———————————–')
#for i in range(len(popt)):
#	print(str(popt[i])+' +- '+str(perr[i]))

# prepare confidence level curves
popt_up = popt + nstd * perr
popt_dw = popt - nstd * perr

fit = func(Erot_exp[:,0], *popt)
fit_up = func(Erot_exp[:,0], *popt_up)
fit_dw = func(Erot_exp[:,0], *popt_dw)
plt.fill_between(Erot_exp[:,0], fit_up, fit_dw, alpha=.25, color='b')
plot(Erot_exp[:,0], fit, 'k')

xlim = plt.xlim()
plt.plot( [-5., 15.], [50.-baule(50., 196.966), 50.-baule(50., 196.966)], color='black', linestyle=':', label='' )
plt.plot( [-5., 15.], [50.-baule_limit(50., 196.966), 50.-baule_limit(50., 196.966)], color='black', linestyle='--', label='' )
plt.xlim(xlim)

plt.legend(loc='best', numpoints=1, frameon=True)
#plt.xlim(70., 270.)
plt.ylim(15., 60.)
#plt.ylim(1e-7, 1.)
#plt.yscale('log')
plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.ylabel(r'$\left<E_s\right>$ (kJ/mol)')
plt.xlabel(r'$\Delta J$')
plt.xticks( Erot_exp[:,0], [int(x) for x in Erot_exp[:,1]] )

plt.tight_layout()
plt.savefig('translationJ.pdf')
#plt.show()
