#!/usr/bin/env python
import sys
from os import path, getcwd
from os import chdir, system, popen
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy.interpolate import Rbf
from ase import Atoms
from ase.io import read

Nrows = 2
Ncols = 6
mpl.rc("text", usetex=True)
#mpl.rcParams['axes.linewidth'] = 0. # change frame border thickness
fig, axs = plt.subplots(nrows=Nrows, ncols=Ncols, figsize=(6.69,4.))

Density = [30, 40]
Density_interpolation = [30, 40]
minS_1 = 0.
maxS_1 = 0.031
stepS_1 = 0.001
minS_2 = 0.
maxS_2 = 0.21
stepS_2 = 0.01
legend_idx = 1
#folders = ['0.007_fs_TN','0.008_fs_TN','0.012_fs_TN','0.021_fs_TN','0.030_fs_TN','0.044_fs_TN','0.053_fs_TN','0.069_fs_TN','0.089_fs_TN','0.108_fs_TN','0.125_fs_TN','0.141_fs_TN','0.144_fs_TN']
#folders = ['0.007_fs_TN','0.012_fs_TN','0.030_fs_TN','0.089_fs_TN','0.125_fs_TN','0.144_fs_TN']
folders = ['533-D2/0.007_fs_TN', '322-D2/0.007_fs_TN', '755-D2/0.007_fs_TN', '433-D2/0.007_fs_TN', '977-D2/0.007_fs_TN', '544-D2/0.007_fs_TN', '1199-D2/0.007_fs_TN', '655-D2/0.007_fs_TN', '131111-D2/0.007_fs_TN', '766-D2/0.007_fs_TN', '151313-D2/0.007_fs_TN', '877-D2/0.007_fs_TN']
#folders = ['0.007_fs_TN']

#titles = [r'$\left<E_i\right>=114$ kJ/mol', r'$\left<E_i\right>=124$ kJ/mol', r'$\left<E_i\right>=150$ kJ/mol', r'$\left<E_i\right>=174$ kJ/mol', r'$\left<E_i\right>=205$ kJ/mol', r'$\left<E_i\right>=247$ kJ/mol']
titles = ['(533)', '(322)', '(755)', '(433)', '(977)', '(544)', '(1199)', '(655)', '(131111)', '(766)', '(151313)', '(877)']

###########################################################

def mirrorImage(x_, y_, x1, y1, x2, y2, Cell, Cell_):
	pos = np.dot( [x_, y_], Cell )
	x = pos[0]
	y = pos[1]

	A = y2 - y1
	B = x1 - x2
	C = -A*x1 - B*y1
	M = np.sqrt( A**2 + B**2 )

	A /= M
	B /= M
	C /= M

	D = A * x + B * y + C

	x3 = x - 2*A*D
	y3 = y - 2*B*D
	pos3 = np.dot( [x3, y3], Cell_ )
	x3 = pos3[0] % 1.
	y3 = pos3[1] % 1.

	return x3, y3

here=getcwd()
images = []
def plot(idx, legend_idx, folder, title):
	chdir(folder)

	Surface = read('000001/cfg_lammps.RuN.config', format='lammps-data', style='atomic')

	Cell = Surface.get_cell()
	Cell[1,:] /= 3.
	Cell_= np.linalg.inv(Cell)

	N = int( ( len( Surface ) - 2 ) / 5 / 3 )
	Surface_Top_layer = []
	for i in range( N ):
		if Surface.positions[i*3][0] > 3.: continue
		Surface_Top_layer.append( Surface.positions[i*3] )
		POS_ = np.dot( Surface_Top_layer[-1], Cell_ )
		if (POS_[1] % 1) < 1e-2:
			POS_[1] += 1.
			POS = np.dot( POS_, Cell )
			Surface_Top_layer.append( POS )
			POS_[1] -= 1.
		
	Surface_Top_layer = np.array( Surface_Top_layer )
	Surface_Top_layer_d = np.dot( Surface_Top_layer, Cell_ )

	Surface_Top_layer_full = []
	for i in range( N ):
		Surface_Top_layer_full.append( Surface.positions[i*3] )
		POS_ = np.dot( Surface_Top_layer_full[-1], Cell_ )
		if (POS_[0] % 1) < 1e-2:
			POS_[0] += 1.
			POS = np.dot( POS_, Cell )
			Surface_Top_layer_full.append( POS )
			POS_[0] -= 1.
		if (POS_[1] % 1) < 1e-2:
			POS_[1] += 1.
			POS = np.dot( POS_, Cell )
			Surface_Top_layer_full.append( POS )
			POS_[1] -= 1.
		if (POS_[0] % 1) < 1e-2 and (POS_[1] % 1) < 1e-2:
			POS_[0] += 1.
			POS_[1] += 1.
			POS = np.dot( POS_, Cell )
			Surface_Top_layer_full.append( POS )
			POS_[0] -= 1.
			POS_[1] -= 1.
	Surface_Top_layer_full = np.array( Surface_Top_layer_full )

	POS_reaction = []
	
	here_tmp=getcwd()				#workdir
	Ntraj = int( popen('ls | egrep "^[0-9].....$" | wc -l').read() )
	for i in range(1,(Ntraj+1)):
		traj= '%.6d' % i		# define working directory
		newdir=path.join(here_tmp, traj)	# updating mypath

		if not path.exists(newdir): 
			print( "folder", newdir, "not found. exiting..." )
			quit()

		chdir(newdir)

		if path.exists('PostAnalysis.dat') and path.exists('outcome'):
			outcome = open('outcome', 'r').readline()

			if 'UNCLEAR' in outcome:
				continue

			post_dat = open('PostAnalysis.dat', 'r').readlines()

			if 'SCATTERING' in outcome:
				continue
			elif "REACTION" in outcome:
				POS = [ float(post_dat[32].split()[2]), float(post_dat[32].split()[3]) ]
				POS_ = np.dot( POS, Cell_[:2,:2] )
				POS_ %= 1.
				if POS_[0] > max( Surface_Top_layer_d[:,0] ): continue
				POS_reaction.append( POS_ )
				
				POS_ = mirrorImage(POS_[0], POS_[1], Surface_Top_layer_full[2,0], Surface_Top_layer_full[2,1], Surface_Top_layer_full[-1,0], Surface_Top_layer_full[-1,1], Cell[:2,:2], Cell_[:2,:2])
				POS_reaction.append( POS_ )
	
	POS_reaction = np.array(POS_reaction)
	
	subax = plt.subplot(Nrows,Ncols,idx)
	#plt.scatter( POS_reaction[:,0]*Cell[0][0], POS_reaction[:,1]*Cell[1][1], s=0.1 )

	POS_reaction_binned, xedges, yedges = np.histogram2d( POS_reaction[:,0], POS_reaction[:,1], bins=Density, density=True )
	POS_reaction_binned /= sum( POS_reaction_binned )	# Ensures that an integration over the entire cell yields unity

	values = []
	for i in range( len(POS_reaction_binned) ):
		for j in range( len(POS_reaction_binned[0]) ):
			POS = np.dot( [(xedges[i+1] + xedges[i])/2., (yedges[j+1] + yedges[j])/2.], Cell[:2,:2] )
			values.append( [ POS[0], POS[1], POS_reaction_binned[i,j] ] )
	values = np.nan_to_num( values )

	F = Rbf(values[:,0], values[:,1], values[:,2], function='linear', smooth=0. )
	Xnew = np.zeros(( Density_interpolation[0]+1, Density_interpolation[1]+1 ))
	Ynew = np.zeros(( Density_interpolation[0]+1, Density_interpolation[1]+1 ))
	idx1 = 0
	idx2 = 0
	for i in np.linspace( min(Surface_Top_layer_d[:,0]), max(Surface_Top_layer_d[:,0]), Density_interpolation[0]+1 ):
		idx2 = 0
		for j in np.linspace( min(Surface_Top_layer_d[:,1]), max(Surface_Top_layer_d[:,1]), Density_interpolation[1]+1 ):
			POS = np.dot( [ i, j ], Cell[:2,:2] )
			Xnew[idx1,idx2] = POS[0]
			Ynew[idx1,idx2] = POS[1]
			idx2 += 1
		idx1 += 1
	Znew = F(Xnew, Ynew).clip(min=0.)
	levels = np.arange( minS_2, maxS_2, stepS_2 )
	images.append( plt.contourf(Xnew, Ynew, Znew, zorder=-1, cmap='jet', levels=levels ) )
	#images.append( plt.hist2d( POS_reaction[:,0], POS_reaction[:,1], bins=Density, density=True, cmap='jet' ) )

	Surface_key = plt.scatter(* Surface_Top_layer[:,:2].T, s=50, marker='o', facecolors='none', linewidths=1, edgecolor='grey', zorder=5)

	plt.xlabel(r'$X$ (\r{A})')
	plt.ylabel(r'$Y$ (\r{A})')
	plt.tick_params(length=3, width=1, direction='out', top=True, right=True)
	plt.axis('scaled')

	#plt.title(r'$E_\mathrm{{i}}={:3.1f}$ kJ/mol'.format( E ), size=10)
	plt.title('{:s}'.format( title ), size=10)

	chdir(here)

for i in range(len(folders)):
	plot(i+1, legend_idx, folders[i], titles[i])

#plt.subplots_adjust(wspace=0.0, hspace=-0.05)

plt.tight_layout()

cbar2_ax = fig.add_axes([0.25, 0.85, 0.5, 0.02]) #xmin, ymin, xwidth, ywidth
fig.colorbar(images[1], cax=cbar2_ax, orientation='horizontal', ticks=np.linspace(minS_2, maxS_2-stepS_2, 6))
cbar2_ax.xaxis.set_label_position('top')
cbar2_ax.xaxis.set_ticks_position('top')
cbar2_ax.set_xlabel('Dissociation location probability density')

fig.subplots_adjust(top=0.81)

#plt.text(0.25, 0.975, 'Laser off', fontsize=12, transform=plt.gcf().transFigure)
#plt.text(0.7, 0.975, r'$\nu_1=1$', fontsize=12, transform=plt.gcf().transFigure)


plt.savefig('site_density_SI.pdf')
