#!/usr/bin/env python
import sys
from os import path, getcwd
from os import chdir, system, popen
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy.interpolate import Rbf

Nrows = 3
Ncols = 2
mpl.rc("text", usetex=True)
fig, ax = plt.subplots(nrows=Nrows, ncols=Ncols, figsize=(6.69,6.4))

density = 50
minS = 0.24
maxS = 0.43
stepS = 0.005
legend_idx = 1
#folders = [ '1.18_ms_TN', '1.29_ms_TN', '1.55_ms_TN', '1.80_ms_TN', '2.12_ms_TN', '2.56_ms_TN' ]
folders = [ '1.18', '1.29_ms_TN', '1.55_ms_TN', '1.80_ms_TN', '2.12_ms_TN', '2.56_ms_TN' ]
titles = [r'$\left<E_i\right>=114$ kJ/mol', r'$\left<E_i\right>=124$ kJ/mol', r'$\left<E_i\right>=150$ kJ/mol', r'$\left<E_i\right>=174$ kJ/mol', r'$\left<E_i\right>=205$ kJ/mol', r'$\left<E_i\right>=247$ kJ/mol']
Ei = [114, 124, 150, 174, 205, 247]

def stattop(ax, Cell):
	for idx1 in range(2):
		for idx2 in range(2):
			if idx1 == 0 and idx2 == 0:
				top = np.array(
					[[1./3. / 2., 1./3. / 2.],
					[0., 1./2. / 2.],
					[0., 0.],
					[1./2. / 2., 0.],
					[1./3. / 2., 1./3. / 2.]]
					) 
			elif idx1 == 0 and idx2 == 1:
				top = np.array(
					[[1./2. / 2., 0.],
					[0., 0.],
					[0., -1./2. / 2.],
					[1./3. / 2., -2./3. / 2.],
					[2./3. / 2., -1./3. / 2.],
					[1./2. / 2., 0.]]
					)
			elif idx1 == 1 and idx2 == 0:
				top = np.array(
					[[0., 1./2. / 2.],
					[-1./3. / 2., 2./3. / 2.],
					[-2./3. / 2., 1./3. / 2.],
					[-1./2. / 2., 0.],
					[0., 0.],
					[0., 1./2. / 2.]]
					)
			elif idx1 == 1 and idx2 == 1:
				top = np.array(
					[[0., 0.],
					[-1./2. / 2., 0.],
					[-1./3. / 2., -1./3. / 2.],
					[0., -1./2. / 2.],
					[0., 0.]]
					)

			top[:,0] += idx1
			top[:,1] += idx2
			top = np.dot( top, Cell )

#			ax.fill( top[:,0], top[:,1], color='b', edgecolor='k')
			plt.plot( top[:,0], top[:,1], color='b' )

def statfcc(ax, Cell):
	fcc = np.array(
	[[1./3. / 2., 1./3. / 2.],
	[1./3. / 2., 2./3.],
	[2./3., 1./3. / 2.],
	[1./3. / 2., 1./3. / 2.]]
	)

	fcc = np.dot( fcc, Cell )

#	ax.fill( fcc[:,0], fcc[:,1], color='g')
	plt.plot( fcc[:,0], fcc[:,1], color='g')

def stathcp(ax, Cell):
	hcp = np.array(
	[[1./3., 5./6.],
	[5./6., 5./6.],
	[5./6., 1./3.],
	[1./3., 5./6.]]
	)

	hcp = np.dot( hcp, Cell )

#	ax.fill( hcp[:,0], hcp[:,1], color='r')
	plt.plot( hcp[:,0], hcp[:,1], color='r')

def statbridge(ax, Cell):
	for idx in range(2):
		if idx == 0:
			bridge = np.array(
			[[1./2. / 2., 0.],
			[1./3. / 2., 1./3. / 2.],
			[2./3., 1./3. / 2.],
			[3./4., 0.],
			[1./2. / 2., 0.]]
			)
		elif idx == 1:
			bridge = np.array(
			[[1./3., -1./3. / 2.],
			[1./4., 0.],
			[3./4., 0.],
			[5./6., -1./3. / 2.],
			[1./3., -1./3. / 2.]]
			)

		bridge[:,1] += idx
		bridge = np.dot( bridge, Cell )

#		ax.fill( bridge[:,0], bridge[:,1], color='k')
		plt.plot( bridge[:,0], bridge[:,1], color='k', linestyle='--')

	bridge = np.array(
	[[2./3., 1./3. / 2.],
	[5./6., 1./3.],
	[1./3., 5./6.],
	[1./6., 2./3.],
	[2./3., 1./3. / 2.]]
	)

	bridge = np.dot( bridge, Cell )

#	ax.fill( bridge[:,0], bridge[:,1], color='k')
	plt.plot( bridge[:,0], bridge[:,1], color='k', linestyle='--')

	for idx in range(2):
		if idx == 0:
			bridge = np.array(
			[[1./3. / 2., 1./3. / 2.],
			[0., 1./2. / 2.],
			[0., 3./4.],
			[1./3. / 2., 2./3.],
			[1./3. / 2., 1./3. / 2.]]
			)
		elif idx == 1:
			bridge = np.array(
			[[0., 1./2. / 2.],
			[-1./3. / 2., 1./3.],
			[-1./3. / 2., 5./6.],
			[0., 3./4.],
			[0., 1./2. / 2.]]
			)

		bridge[:,0] += idx
		bridge = np.dot( bridge, Cell )

#		ax.fill( bridge[:,0], bridge[:,1], color='k')
		plt.plot( bridge[:,0], bridge[:,1], color='k', linestyle='--')

###########################################################

EPSILON = 0.001
EPSILON_SQUARE = EPSILON**2

def side(x1, y1, x2, y2, x, y):
	return (y2 - y1)*(x - x1) + (-x2 + x1)*(y - y1)

def naivePointInTriangle(x1, y1, x2, y2, x3, y3, x, y):
	checkSide1 = side(x1, y1, x2, y2, x, y) >= 0
	checkSide2 = side(x2, y2, x3, y3, x, y) >= 0
	checkSide3 = side(x3, y3, x1, y1, x, y) >= 0
	return checkSide1 and checkSide2 and checkSide3

def pointInTriangleBoundingBox(x1, y1, x2, y2, x3, y3, x, y):
	xMin = np.amin( [ x1, x2, x3 ] ) - EPSILON
	xMax = np.amax( [ x1, x2, x3 ] ) + EPSILON
	yMin = np.amin( [ y1, y2, y3 ] ) - EPSILON
	yMax = np.amax( [ y1, y2, y3 ] ) + EPSILON

	if ( x < xMin ) and ( xMax < x ) and ( y < yMin ) and ( yMax < y ):
		return False
	else:
		return True

def distanceSquarePointToSegment(x1, y1, x2, y2, x, y):
	p1_p2_squareLength = (x2 - x1)*(x2 - x1) + (y2 - y1)*(y2 - y1)
	dotProduct = ((x - x1)*(x2 - x1) + (y - y1)*(y2 - y1)) / p1_p2_squareLength
	if ( dotProduct < 0 ):
		return (x - x1)*(x - x1) + (y - y1)*(y - y1)
	elif ( dotProduct <= 1 ):
		p_p1_squareLength = (x1 - x)*(x1 - x) + (y1 - y)*(y1 - y)
		return p_p1_squareLength - dotProduct * dotProduct * p1_p2_squareLength
	else:
		return (x - x2)*(x - x2) + (y - y2)*(y - y2)

def accuratePointInTriangle(x1, y1, x2, y2, x3, y3, x, y):
	if not pointInTriangleBoundingBox(x1, y1, x2, y2, x3, y3, x, y):
		return False

	if naivePointInTriangle(x1, y1, x2, y2, x3, y3, x, y):
		return True

	if (distanceSquarePointToSegment(x1, y1, x2, y2, x, y) <= EPSILON_SQUARE):
		return True
	if (distanceSquarePointToSegment(x2, y2, x3, y3, x, y) <= EPSILON_SQUARE):
		return True
	if (distanceSquarePointToSegment(x3, y3, x1, y1, x, y) <= EPSILON_SQUARE):
		return True
 
	return False

def mirrorImage(x, y, x1, y1, x2, y2, Cell, Cell_):
	pos = np.dot( [x, y], Cell )
	pos1 = np.dot( [x1, y1], Cell )
	pos2 = np.dot( [x2, y2], Cell )
	x = pos[0]
	x1 = pos1[0]
	x2 = pos2[0]
	y = pos[1]
	y1 = pos1[1]
	y2 = pos2[1]

	A = y2 - y1
	B = x1 - x2
	C = -A*x1 - B*y1
	M = np.sqrt( A**2 + B**2 )

	A /= M
	B /= M
	C /= M

	D = A * x + B * y + C

	x3 = x - 2*A*D
	y3 = y - 2*B*D
	pos3 = np.dot( [x3, y3], Cell_ )
	x3 = pos3[0] % 1.
	y3 = pos3[1] % 1.

	return x3, y3

here=getcwd()
images = []
def plot(idx, legend_idx, folder, title, Ei):
	chdir(folder)

	subax = plt.subplot(Nrows,Ncols,idx)

	Cell = [[8.6930937287162351,    0.0000000000000000],
		[4.3465468643581175,    7.5284400065474486]]

	Cell = np.array( Cell )
	Cell2 = Cell / 3.
	Cell_= np.linalg.inv(Cell)
	Cell2_= np.linalg.inv(Cell2)

	statfcc(subax, Cell2)
	stathcp(subax, Cell2)
	stattop(subax, Cell2)
#	statbridge(subax, Cell2)

	Surface_d = np.array(	[[ 0.0000000000000,  0.0000000000000],
				[  0.0000000000000,  1.0000000000000],
				[  1.0000000000000,  0.0000000000000],
				[  1.0000000000000,  1.0000000000000]] )

	Surface_c = np.dot(Surface_d, Cell2)
	Surface = np.zeros((len(Surface_d),2))
	for i in range(len(Surface_d)):
		for j in range(2):
			Surface[i][j] = Surface_c[i][j]

	Pd_key = plt.scatter(* Surface.T, c='grey', s=50, marker='o', facecolors='none', linewidths=2, edgecolor='grey', zorder=5)

	Xnew = np.zeros((density+1,density+1))
	Ynew = np.zeros((density+1,density+1))
	for Xd in range(density+1):
		for Yd in range(density+1):
			Xnew[Yd][Xd] = ( Xd / float(density) * Cell2[0][0] + Yd / float(density) * Cell2[1][0] )
			Ynew[Yd][Xd] = ( Yd / float(density) * Cell2[1][1] )

	ET = np.zeros((density+1,density+1))
	ET_N = np.zeros((density+1,density+1))

	fcc_triangle = np.dot( [[0.,0.],[1.,0.],[0.,1.]], Cell2 )

	here_tmp=getcwd()				#workdir
	Ntraj = int( popen('ls | egrep "^[0-9].....$" | wc -l').read() )
	for i in range(1,(Ntraj+1)):
		traj= '%.6d' % i		# define working directory
		newdir=path.join(here_tmp, traj)	# updating mypath
#		print(newdir)

		if not path.exists(newdir): 
			print( "folder", newdir, "not found. exiting..." )
			quit()

		chdir(newdir)

		if path.exists('PostAnalysis.dat') and path.exists('outcome'):
			outcome = open('outcome', 'r').readline()

			if not 'SCATTERING\n' in outcome: continue

			post_dat = open('PostAnalysis.dat', 'r').readlines()

			x = float(post_dat[28].split()[2])
			y = float(post_dat[28].split()[3])

			pos = np.dot( [x,y], Cell_ )
			pos *= 3.
			pos %= 1.
			pos_c = np.dot( pos, Cell2 )

#			if accuratePointInTriangle(fcc_triangle[0][0], fcc_triangle[0][1], fcc_triangle[1][0], fcc_triangle[1][1], fcc_triangle[2][0], fcc_triangle[2][1], pos_c[0], pos_c[1]):	# fcc
			if float(post_dat[16].split()[0]) < float(post_dat[18].split()[0]):	# fcc
				offset_d = 1./3.

				# Mirror in y = 0*x + c line
				x, y = mirrorImage(pos[0], pos[1], 0., 0.5, 1., 0., Cell2, Cell2_)
				xd = int( round( x*density ) )
				yd = int( round( y*density ) )
				ET[yd][xd] += float( post_dat[-1].split()[-1] )*96.4853075
				ET_N[yd][xd] += 1.

				# Mirror in the line that only mirrors in the direct y coordinate
				x, y = mirrorImage(pos[0], pos[1], 0., 1., 0.5, 0., Cell2, Cell2_)
				xd = int( round( x*density ) )
				yd = int( round( y*density ) )
				ET[yd][xd] += float( post_dat[-1].split()[-1] )*96.4853075
				ET_N[yd][xd] += 1.
			else:	# hcp
				offset_d = 2./3.

				# Mirror in y = 0*x + c line
				x, y = mirrorImage(pos[0], pos[1], 0., 1., 1., 0.5, Cell2, Cell2_)
				xd = int( round( x*density ) )
				yd = int( round( y*density ) )
				ET[yd][xd] += float( post_dat[-1].split()[-1] )*96.4853075
				ET_N[yd][xd] += 1.

				# Mirror in the line that only mirrors in the direct y coordinate
				x, y = mirrorImage(pos[0], pos[1], 0.5, 1., 1., 0., Cell2, Cell2_)
				xd = int( round( x*density ) )
				yd = int( round( y*density ) )
				ET[yd][xd] += float( post_dat[-1].split()[-1] )*96.4853075
				ET_N[yd][xd] += 1.

			offset = np.dot( [offset_d, offset_d], Cell2 ) # required for the rotations and mirror planes

			# Rotations (120 and 240 degrees):
			THETA = np.radians( 120. )
			x = (pos_c[0]-offset[0])*np.cos( THETA ) - (pos_c[1]-offset[1])*np.sin( THETA ) + offset[0]
			y = (pos_c[0]-offset[0])*np.sin( THETA ) + (pos_c[1]-offset[1])*np.cos( THETA ) + offset[1]
			tmp_d = np.dot( [x,y], Cell2_ ) % 1.
			xd = int( round( tmp_d[0]*density ) )
			yd = int( round( tmp_d[1]*density ) )
			ET[yd][xd] += float( post_dat[-1].split()[-1] )*96.4853075
			ET_N[yd][xd] += 1.

			THETA = np.radians( 240. )
			x = (pos_c[0]-offset[0])*np.cos( THETA ) - (pos_c[1]-offset[1])*np.sin( THETA ) + offset[0]
			y = (pos_c[0]-offset[0])*np.sin( THETA ) + (pos_c[1]-offset[1])*np.cos( THETA ) + offset[1]
			tmp_d = np.dot( [x,y], Cell2_ ) % 1.
			xd = int( round( tmp_d[0]*density ) )
			yd = int( round( tmp_d[1]*density ) )
			ET[yd][xd] += float( post_dat[-1].split()[-1] )*96.4853075
			ET_N[yd][xd] += 1.

			# Mirror in xy line
			xd = int( round( pos[1]*density ) )
			yd = int( round( pos[0]*density ) )
			ET[yd][xd] += float( post_dat[-1].split()[-1] )*96.4853075
			ET_N[yd][xd] += 1.

			# Original point:
			xd = int( round( pos[0]*density ) )
			yd = int( round( pos[1]*density ) )
			ET[yd][xd] += float( post_dat[-1].split()[-1] )*96.4853075
			ET_N[yd][xd] += 1.

	Znew = np.zeros((density+1,density+1))
	for i in range(density+1):
		for j in range(density+1):
			Znew[i][j] = ET[i][j] / ET_N[i][j] / float(Ei)

	Xnew2 = np.zeros((density*2+1,density*2+1))
	Ynew2 = np.zeros((density*2+1,density*2+1))
	for Xd in range(density*2+1):
		for Yd in range(density*2+1):
			Xnew2[Yd][Xd] = ( Xd / float(density*2) * Cell2[0][0] + Yd / float(density*2) * Cell2[1][0] )
			Ynew2[Yd][Xd] = ( Yd / float(density*2) * Cell2[1][1] )
	F 	= Rbf(Xnew, Ynew, Znew, function='cubic', smooth=0. )
	Znew2	= F(Xnew2, Ynew2)	

	levels = np.arange( minS, maxS, stepS )

#	areas 	 = plt.contourf(Xnew, Ynew, Znew, levels, zorder=-1, cmap='jet'  )			# coloured areas
#	areas 	 = plt.contourf(Xnew, Ynew, Znew, levels, zorder=-1, cmap='nipy_spectral'  )			# coloured areas
	areas 	 = plt.contourf(Xnew2, Ynew2, Znew2, levels, zorder=-1, cmap='nipy_spectral'  )			# coloured areas
#	areas 	 = plt.contourf(Xnew, Ynew, Znew, levels, zorder=-1, cmap='viridis'  )			# coloured areas
#	contours = plt.contour( Xnew, Ynew, Znew, levels, colors='k', zorder=0 )	# black level lines

#	plt.colorbar(areas)
	images.append(areas)

	if idx > (Nrows-1)*Ncols:
		plt.xlabel(r'x (\r{A})')
	else:
		plt.tick_params(labelbottom=False)
	if idx % Ncols == 0:
		plt.tick_params(labelleft=False)
	if idx % 2 == 1:
		plt.ylabel(r'y (\r{A})')
	plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
	plt.axis('scaled')
	plt.axis([-0.5, 4.5, -0.4, 3.1])

	plt.annotate(title, xy=(0., 2.6), size=12)

	chdir(here)

for i in range(len(folders)):
	plot(i+1, legend_idx, folders[i], titles[i], Ei[i])

plt.tight_layout()
plt.subplots_adjust(wspace=0.0, hspace=-0.06)

cbar_ax = fig.add_axes([0.89, 0.1, 0.025, 0.85]) #xmin, ymin, xwidth, ywidth
#fig.colorbar(images[1], cax=cbar_ax, ticks=np.arange(0.0,0.7,0.1))
fig.colorbar(images[1], cax=cbar_ax)
cbar_ax.set_ylabel('Fraction of energy transferred')
fig.subplots_adjust(right=0.87, left=0.08)

#plt.text(0.25, 0.975, 'Laser off', fontsize=12, transform=plt.gcf().transFigure)
#plt.text(0.7, 0.975, r'$\nu_1=1$', fontsize=12, transform=plt.gcf().transFigure)

plt.savefig('ET_site.pdf')
