#!/usr/bin/env python
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

eV2kJ = 96.4853075
# DFT values were with MP smearing instead of Fermi-Dirac
# Z, 1H, 2H, 3H
DFT = np.array(
[[2,	-128.27688,	-139.32078,	-139.42756],
[2.25,	-137.27648,	-140.61445,	-140.62328],
[2.5,	-140.24511,	-141.19300,	-141.20520],
[2.75,	-141.25500,	-141.49727,	-141.52012],
[3,	-141.60124,	-141.66230,	-141.68597],
[3.25,	-141.72308,	-141.74289,	-141.76188],
[3.5,	-141.76144,	-141.77211,	-141.78491],
[3.75,	-141.76406,	-141.77150,	-141.77908],
[4,	-141.75061,	-141.75528,	-141.75951],
[4.25,	-141.72959,	-141.73227,	-141.73469],
[4.5,	-141.70642,	-141.70780,	-141.70896],
[4.75,	-141.68429,	-141.68456,	-141.68455],
[5,	-141.66439,	-141.66369,	-141.66282],
[5.25,	-141.64713,	-141.64580,	-141.64456],
[5.5,	-141.63269,	-141.63121,	-141.63007],
[5.75,	-141.62132,	-141.62002,	-141.61924],
[6,	-141.61322,	-141.61223,	-141.61184],
[6.25,	-141.60834,	-141.60771,	-141.60765],
[6.5,	-141.60650,	-141.60621,	-141.60646]]
)
DFT[:,1] -= DFT[-1,1]
DFT[:,2] -= DFT[-1,2]
DFT[:,3] -= DFT[-1,3]
DFT[:,1:] *= eV2kJ

# This needs to be updated to values of the final NN
NN = np.array(
[[	2.5	,	156.983270578861	,	39.4597630465862	,	43.2202924352098	],
[	2.75	,	42.3626992329711	,	10.5350273433831	,	10.990892845786	],
[	3	,	4.29568498013963	,	-4.9234419439399	,	-6.30463853927935	],
[	3.25	,	-9.40960692292108	,	-12.8241742240523	,	-14.659792296991	],
[	3.5	,	-14.322100533363	,	-16.1733928391543	,	-17.6600032450186	],
[	3.75	,	-15.1318308771743	,	-16.5526987720683	,	-17.4868515437959	],
[	4	,	-14.075854922178	,	-15.2281867139447	,	-15.8129642486467	],
[	4.25	,	-12.1619181949864	,	-12.9497518718072	,	-13.5200891584922	],
[	4.5	,	-9.70623587218974	,	-10.242205368594	,	-11.0134983976833	],
[	4.75	,	-7.37618370714135	,	-7.80592548812972	,	-8.50383580229713	],
[	5	,	-5.47445555327579	,	-5.81204229692249	,	-6.15346227003858	],
[	5.25	,	-3.83341325861505	,	-4.08716779877236	,	-4.07220245082669	],
[	5.5	,	-2.63881092265161	,	-2.53927823131981	,	-2.36539139019488	],
[	5.75	,	-1.60147601509566	,	-1.29048558279592	,	-1.03192637830078	],
[	6	,	-0.775388808528528	,	-0.418242092569217	,	-0.153014118988096	],
[	6.25	,	-0.17664361574455	,	0.019481207323872	,	0.117701148836934	],
[	6.5	,	0	,	0	,	0	]]
)

mpl.rc("text", usetex=True)
plt.style.use('tableau-colorblind10')
fig = plt.figure(figsize=(3.69,3.))
#plt.rcParams.update({'font.size': 12})
				
plt.plot( DFT[:,0], DFT[:,1], c='C0', ls='--' )
plt.plot( DFT[:,0], DFT[:,2], c='C1', ls='--' )
plt.plot( DFT[:,0], DFT[:,3], c='C2', ls='--' )

plt.plot( NN[:,0], NN[:,1], label='1H', c='C0' )
plt.plot( NN[:,0], NN[:,2], label='2H', c='C1' )
plt.plot( NN[:,0], NN[:,3], label='3H', c='C2' )

plt.plot( [0.,7.5], [0,0], c='black')
plt.legend(numpoints=1, loc='lower right')
plt.xlim(2.5,6.5)
plt.ylim(-20,10)
plt.xlabel(r'$Z_\textrm{C}$ (\r{A})')
plt.ylabel('Potential energy (kJ/mol)')
plt.tick_params(labelbottom=True, length=6, width=1, direction='in', top=True, right=True)

plt.tight_layout()
#plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig('vdwwell.pdf')
#plt.show()
