#!/usr/bin/env python
import	numpy as np
import	matplotlib.pyplot as plt 
import matplotlib
from	scipy.interpolate import interp1d
from 	matplotlib.patches import Circle
from	numpy.random import rand
matplotlib.rc("font", size=15)
fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(5,10))
####################################################################
def vector_plotter(start, end, color='blue'):
        plt.plot(*np.vstack([start,end]).T, c=color)
        plt.scatter(end[0], end[1], c=color)
####################################################################
def scatter_plotter(v, c):
	plt.scatter(v[0],v[1], color=c)
####################################################################
def derivate(x,y):
        x_prime, y_prime = [ ],  [ ]
        for i in range(1,len(y)-1):
                x_prime.append(x[i])
                y_prime.append(    (y[i+1] - y[i-1])/ ( x[i+1]-x[i-1] )  )

        return x_prime, y_prime
####################################################################
def fit_circle(A,B,C):
	'''fit a circle given three points A,B,C on it'''  

	# find gradient of the vector AB and BC -> m = (y2-y1)/(x2-x1)
	mAB = (B[1]-A[1]) / (B[0]-A[0])
	mBC = (C[1]-B[1]) / (C[0]-B[0])

	# find intercept
	qAB = mAB*A[0]
	qBC = mBC*B[0]

	# find the midpoint
	midAB = (A+B)/2
	midBC = (B+C)/2
	midCA = (C+A)/2

	# gradient of the perpendicular is m' = -1/m
	mAB_ = - 1./mAB
	mBC_ = - 1./mBC

	# find the intercept of the perpendicular
	qAB_ = midAB[1] - mAB_*midAB[0]
	qBC_ = midBC[1] - mBC_*midBC[0]

	# find center as crossing point if the perpendiculars
	xCenter = ( qBC_ - qAB_ ) / ( mAB_ - mBC_ )
	yCenter = mAB_ * ( xCenter ) + qAB_
	center = np.array( [ xCenter, yCenter ])

	# r = distance(center, circle )
	r =  np.linalg.norm(center-A) 
	circ = Circle(center, r, fill=False, color='gray')

	return r, center, circ

#############################################################################
### MAIN ####################################################################
#############################################################################

# read files

E_pbe = np.array( [ float(x.split()[2]) for x in open('PBE_MEP.dat') ])
r_pbe = np.array( [ float(x.split()[0]) for x in open('PBE_MEP.dat') ])
z_pbe = np.array( [ float(x.split()[1]) for x in open('PBE_MEP.dat') ])

E_srp = np.array( [ float(x.split()[2]) for x in open('SRP_MEP.dat') ])
r_srp = np.array( [ float(x.split()[0]) for x in open('SRP_MEP.dat') ])
z_srp = np.array( [ float(x.split()[1]) for x in open('SRP_MEP.dat') ])

# find the TS
print "TS found on MEP"
print "          Eb [ eV ]     r [ Ang ]    Z [ Ang ]"
print " PBE     %6.3f      %4.3f         %4.3f " %  ( max(E_pbe), r_pbe[np.argmax(E_pbe)], z_pbe[np.argmax(E_pbe)])
print " SRP     %6.3f      %4.3f         %4.3f " %  ( max(E_srp), r_srp[np.argmax(E_srp)], z_srp[np.argmax(E_srp)])

plt.subplot(2,1,1)

# MEP
plt.plot(r_pbe, z_pbe, color='green', linewidth ='2', zorder=0)
plt.plot(r_srp, z_srp, color='blue',  linewidth ='2', zorder = 0)
# TS
plt.scatter(r_pbe[np.argmax(E_pbe)], z_pbe[np.argmax(E_pbe)], color='red', label='transition state', zorder=2)
plt.scatter(r_srp[np.argmax(E_srp)], z_srp[np.argmax(E_srp)], color='red', zorder=2)

# compute and store the curvature of the MEPs
#			r 	Z	curvature
curvature_pbe = [   					]
curvature_srp = [   					]

for i in range(1, len(E_srp)-1):
	A_srp = np.array([ r_srp[i-1],  z_srp[i-1] ])
	B_srp = np.array([ r_srp[i],    z_srp[i]   ])
	C_srp = np.array([ r_srp[i+1],  z_srp[i+1] ])
	A_pbe = np.array([ r_pbe[i-1],  z_pbe[i-1] ])
	B_pbe = np.array([ r_pbe[i],    z_pbe[i]   ])
	C_pbe = np.array([ r_pbe[i+1],  z_pbe[i+1] ])

	raggio, centro, cerchio = fit_circle( A_pbe, B_pbe, C_pbe )
	curvature_pbe.append( [ B_pbe[0], B_pbe[1], 1./raggio, centro, cerchio  ] )
	raggio, centro, cerchio = fit_circle( A_srp, B_srp, C_srp ) 
	curvature_srp.append( [ B_srp[0], B_srp[1], 1./raggio, centro, cerchio  ] )

curvature_pbe = np.array(curvature_pbe)
curvature_srp = np.array(curvature_srp)


print "MAX CURVATURE"
PBEmax = np.argmax(curvature_pbe[:,2])
SRPmax = np.argmax(curvature_srp[:,2])
print "        r [Ang]  Z [Ang]    curv [1/Ang]"
print " PBE:   %4.3f     %4.3f     %6.1f     " % (curvature_pbe[PBEmax][0], curvature_pbe[PBEmax][1], curvature_pbe[PBEmax][2])
print " SRP:   %4.3f     %4.3f     %6.1f     " % (curvature_srp[SRPmax][0], curvature_srp[SRPmax][1], curvature_srp[SRPmax][2])

# plot elbow
plt.scatter(curvature_pbe[PBEmax][0], curvature_pbe[PBEmax][1], color='k', zorder=2, label='elbow')
plt.scatter(curvature_srp[SRPmax][0], curvature_srp[SRPmax][1], color='k', zorder=2)

# COMPUTE DERIVATIVES	
pbe_1st_r,  pbe_1st_z = derivate(r_pbe, z_pbe)
pbe_2nd_r,  pbe_2nd_z = derivate(pbe_1st_r,  pbe_1st_z)

srp_1st_r,  srp_1st_z = derivate(r_srp, z_srp)
srp_2nd_r,  srp_2nd_z = derivate(srp_1st_r,  srp_1st_z)

#### plot ###
ax=plt.gca()
ax.add_patch(curvature_pbe[PBEmax][-1])
ax.add_patch(curvature_srp[SRPmax][-1])

plt.axis('scaled')
plt.xlim(1.1,1.8)
plt.ylim(2.05, 2.75)
plt.ylabel( r'Z / [ $\AA$ ]' )
plt.xlabel( r'r / [ $\AA$ ]' )
plt.legend(loc='lower left', frameon=False, scatterpoints=1)
plt.text( 1.7, 2.61, 'A', fontsize=40)

# plot curvature
plt.subplot(2,1,2)

plt.plot(curvature_pbe[:,0], curvature_pbe[:,2], color='green', label='PBE', linewidth ='1' )
plt.plot(curvature_srp[:,0], curvature_srp[:,2], color='blue',  label='SRP32-vdW', linewidth ='1' )
plt.xlim(1.1,1.8)
plt.ylabel( r'curvature / [ $\AA^{-1}$ ]' )
plt.xlabel( r'r / [ $\AA$ ]' )
plt.legend(loc='center right', frameon=False)
plt.text( 1.7, 4.8, 'B', fontsize=40)
plt.tight_layout()
plt.savefig('curvature.png')
#plt.show()

