#!/usr/bin/env python
import sys
from os import path, getcwd
from os import chdir, system, popen
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy.interpolate import Rbf
from ase import Atoms
from ase.io import read

Nrows = 4
Ncols = 2
mpl.rc("text", usetex=True)
#mpl.rcParams['axes.linewidth'] = 0. # change frame border thickness
fig, axs = plt.subplots(nrows=Nrows, ncols=Ncols, figsize=(6.69,5.4))

Density = [100, 50]
Density_interpolation = [100, 50]
minS_1 = 0.
maxS_1 = 0.0525
stepS_1 = 0.0025
minS_2 = 0.
maxS_2 = 0.31
stepS_2 = 0.01
legend_idx = 1
#folders = ['0.007_fs_TN','0.008_fs_TN','0.012_fs_TN','0.021_fs_TN','0.030_fs_TN','0.044_fs_TN','0.053_fs_TN','0.069_fs_TN','0.089_fs_TN','0.108_fs_TN','0.125_fs_TN','0.141_fs_TN','0.144_fs_TN']
#folders = ['0.007_fs_TN','0.012_fs_TN','0.030_fs_TN','0.089_fs_TN','0.125_fs_TN','0.144_fs_TN']
folders = ['0.007_fs_TN','0.030_fs_TN','0.144_fs_TN']
#folders = ['0.007_fs_TN']

titles = [r'$\left<E_i\right>=114$ kJ/mol', r'$\left<E_i\right>=124$ kJ/mol', r'$\left<E_i\right>=150$ kJ/mol', r'$\left<E_i\right>=174$ kJ/mol', r'$\left<E_i\right>=205$ kJ/mol', r'$\left<E_i\right>=247$ kJ/mol']

###########################################################

# def mirrorImage(x_, y_, x1_, y1_, x2_, y2_, Cell, Cell_):
	# pos = np.dot( [x_, y_], Cell )
	# pos1 = np.dot( [x1_, y1_], Cell )
	# pos2 = np.dot( [x2_, y2_], Cell )
	# x = pos[0]
	# x1 = pos1[0]
	# x2 = pos2[0]
	# y = pos[1]
	# y1 = pos1[1]
	# y2 = pos2[1]

	# A = y2 - y1
	# B = x1 - x2
	# C = -A*x1 - B*y1
	# M = np.sqrt( A**2 + B**2 )

	# A /= M
	# B /= M
	# C /= M

	# D = A * x + B * y + C

	# x3 = x - 2*A*D
	# y3 = y - 2*B*D
	# pos3 = np.dot( [x3, y3], Cell_ )
	# x3 = pos3[0] % 1.
	# y3 = pos3[1] % 1.

	# return x3, y3

def mirrorImage(x_, y_, x1, y1, x2, y2, Cell, Cell_):
	pos = np.dot( [x_, y_], Cell )
	x = pos[0]
	y = pos[1]

	A = y2 - y1
	B = x1 - x2
	C = -A*x1 - B*y1
	M = np.sqrt( A**2 + B**2 )

	A /= M
	B /= M
	C /= M

	D = A * x + B * y + C

	x3 = x - 2*A*D
	y3 = y - 2*B*D
	pos3 = np.dot( [x3, y3], Cell_ )
	x3 = pos3[0] % 1.
	y3 = pos3[1] % 1.

	return x3, y3

here=getcwd()
images = []
def plot(idx, legend_idx, folder, title):
	chdir(folder)

	Surface = read('000001/cfg_lammps.RuN.config', format='lammps-data', style='atomic')

	Cell = Surface.get_cell()
	Cell[1,:] /= 3.
	Cell_= np.linalg.inv(Cell)

	N = int( ( len( Surface ) - 2 ) / 5 / 3 )
	Surface_Top_layer = []
	for i in range( N ):
		Surface_Top_layer.append( Surface.positions[i*3] )
		POS_ = np.dot( Surface_Top_layer[-1], Cell_ )
		if (POS_[0] % 1) < 1e-2:
			POS_[0] += 1.
			POS = np.dot( POS_, Cell )
			Surface_Top_layer.append( POS )
			POS_[0] -= 1.
		if (POS_[1] % 1) < 1e-2:
			POS_[1] += 1.
			POS = np.dot( POS_, Cell )
			Surface_Top_layer.append( POS )
			POS_[1] -= 1.
		if (POS_[0] % 1) < 1e-2 and (POS_[1] % 1) < 1e-2:
			POS_[0] += 1.
			POS_[1] += 1.
			POS = np.dot( POS_, Cell )
			Surface_Top_layer.append( POS )
			POS_[0] -= 1.
			POS_[1] -= 1.
	Surface_Top_layer = np.array( Surface_Top_layer )
	Surface_Top_layer_d = np.dot( Surface_Top_layer, Cell_ )
	Surface_Top_layer_X = np.unique( Surface_Top_layer_d[:,0].round(decimals=4) )
	step = []
	x = (Surface_Top_layer_X[1] + Surface_Top_layer_X[2])/2.
	y = 0.
	step.append( np.dot( [x,y], Cell[:2,:2] ) )
	y = 1.
	step.append( np.dot( [x,y], Cell[:2,:2] ) )
	x = (Surface_Top_layer_X[-2] + Surface_Top_layer_X[-1])/2.
	y = 0.
	step.append( np.dot( [x,y], Cell[:2,:2] ) )
	y = 1.
	step.append( np.dot( [x,y], Cell[:2,:2] ) )
	step = np.array( step )
	step_size = (Surface_Top_layer_X[1] + Surface_Top_layer_X[2])/2. + 1. - (Surface_Top_layer_X[-2] + Surface_Top_layer_X[-1])/2.
	print('Statistical step ratio: {:6.5f}'.format( step_size ))

	POS_scat_initial = []
	POS_reac_initial = []
	POS_reaction = []
	SurfaceDivision_initial = np.zeros( len( Surface_Top_layer_X ) - 2 )
	SurfaceDivision_reac = np.zeros( len( Surface_Top_layer_X ) - 2 )
	Reac_directness = np.zeros( 4, dtype=int )
	
	here_tmp=getcwd()				#workdir
	Ntraj = int( popen('ls | egrep "^[0-9].....$" | wc -l').read() )
	for i in range(1,(Ntraj+1)):
		traj= '%.6d' % i		# define working directory
		newdir=path.join(here_tmp, traj)	# updating mypath

		if not path.exists(newdir): 
			print( "folder", newdir, "not found. exiting..." )
			quit()

		chdir(newdir)

		if path.exists('PostAnalysis.dat') and path.exists('outcome'):
			outcome = open('outcome', 'r').readline()

			if 'UNCLEAR' in outcome:
				continue

			post_dat = open('PostAnalysis.dat', 'r').readlines()

			if "SCATTERING" in outcome:
				POS = [ float(post_dat[28].split()[2]), float(post_dat[28].split()[3]) ]
				POS_ = np.dot( POS, Cell_[:2,:2] )
				POS_ %= 1.
				POS = np.dot( POS_, Cell[:2,:2] )
				POS_scat_initial.append( POS_ )
				
				POS_ = mirrorImage(POS_[0], POS_[1], 0., 0.5, 1., 0.5, Cell[:2,:2], Cell_[:2,:2])
				POS_scat_initial.append( POS_ )
			elif "REACTION" in outcome:
				POS = [ float(post_dat[31].split()[2]), float(post_dat[31].split()[3]) ]
				POS_ = np.dot( POS, Cell_[:2,:2] )
				POS_ %= 1.
				POS_reac_initial.append( POS_ )

				if Surface_Top_layer_X[0] < POS_[0] and POS_[0] < (Surface_Top_layer_X[1] + Surface_Top_layer_X[2])/2.:
					SurfaceDivision_initial[0] += 1
				elif (2*Surface_Top_layer_X[-2] + Surface_Top_layer_X[-1])/3. < POS_[0] and POS_[0] < Surface_Top_layer_X[-1]:
					SurfaceDivision_initial[0] += 1
				elif (Surface_Top_layer_X[-3] + Surface_Top_layer_X[-2])/2. < POS_[0] and POS_[0] < (2*Surface_Top_layer_X[-2] + Surface_Top_layer_X[-1])/3.:
					SurfaceDivision_initial[-1] += 1
				else:
					for j in range( 2, len( Surface_Top_layer_X ) - 1 ):
						if (Surface_Top_layer_X[j-1] + Surface_Top_layer_X[j])/2. < POS_[0] and POS_[0] < (Surface_Top_layer_X[j] + Surface_Top_layer_X[j+1])/2.:
							SurfaceDivision_initial[j-1] += 1
				
				POS_ = mirrorImage(POS_[0], POS_[1], Surface_Top_layer[2,0], Surface_Top_layer[2,1], Surface_Top_layer[-1,0], Surface_Top_layer[-1,1], Cell[:2,:2], Cell_[:2,:2])
				POS_reac_initial.append( POS_ )

				POS = [ float(post_dat[32].split()[2]), float(post_dat[32].split()[3]) ]
				POS_ = np.dot( POS, Cell_[:2,:2] )
				POS_ %= 1.
				POS_reaction.append( POS_ )
				
				if Surface_Top_layer_X[0] < POS_[0] and POS_[0] < (Surface_Top_layer_X[1] + Surface_Top_layer_X[2])/2.:
					SurfaceDivision_reac[0] += 1
					if float( post_dat[8].split()[0] ) < 3.:
						Reac_directness[0] += 1
					else:
						Reac_directness[1] += 1
				elif (2*Surface_Top_layer_X[-2] + Surface_Top_layer_X[-1])/3. < POS_[0] and POS_[0] < Surface_Top_layer_X[-1]:
					SurfaceDivision_reac[0] += 1
					if float( post_dat[8].split()[0] ) < 3.:
						Reac_directness[0] += 1
					else:
						Reac_directness[1] += 1
				elif (Surface_Top_layer_X[-3] + Surface_Top_layer_X[-2])/2. < POS_[0] and POS_[0] < (2*Surface_Top_layer_X[-2] + Surface_Top_layer_X[-1])/3.:
					SurfaceDivision_reac[-1] += 1
					if float( post_dat[8].split()[0] ) < 3.:
						Reac_directness[2] += 1
					else:
						Reac_directness[3] += 1
				else:
					for j in range( 2, len( Surface_Top_layer_X ) - 1 ):
						if (Surface_Top_layer_X[j-1] + Surface_Top_layer_X[j])/2. < POS_[0] and POS_[0] < (Surface_Top_layer_X[j] + Surface_Top_layer_X[j+1])/2.:
							SurfaceDivision_reac[j-1] += 1
							if float( post_dat[8].split()[0] ) < 3.:
								Reac_directness[2] += 1
							else:
								Reac_directness[3] += 1
							break
				
				#POS_ = mirrorImage(POS_[0], POS_[1], 0., 0.5, 1., 0.5, Cell[:2,:2], Cell_[:2,:2])
				POS_ = mirrorImage(POS_[0], POS_[1], Surface_Top_layer[2,0], Surface_Top_layer[2,1], Surface_Top_layer[-1,0], Surface_Top_layer[-1,1], Cell[:2,:2], Cell_[:2,:2])
				POS_reaction.append( POS_ )
	
	print(SurfaceDivision_initial)
	print(SurfaceDivision_reac)
	print(Reac_directness)
	
	NStepTerrace = [
	SurfaceDivision_initial[0],
	sum( SurfaceDivision_initial[1:] ),
	SurfaceDivision_reac[0],
	sum( SurfaceDivision_reac[1:] )
	]

	print(folder)
	print('Initial position at step, terrace, ratio step/total, and ratio of terrace atoms in the shadow/out of the shadow of the step atoms')
	R = float(NStepTerrace[0])/float(NStepTerrace[0]+NStepTerrace[1])
	RSE = np.sqrt( R * ( 1. - R ) / float(NStepTerrace[0]+NStepTerrace[1]) )
	NShadow = [0, 0]
	for i in range(1, len(SurfaceDivision_initial)):
		if i % 2 == 0:
			NShadow[0] += SurfaceDivision_initial[i]
		else:
			NShadow[1] += SurfaceDivision_initial[i]
	if (NShadow[0]+NShadow[1]) == 0:
		RShadow = 999.
		RSEShadow = 999.
	elif NShadow[0] == 0:
		RShadow = float(NShadow[0])/float(NShadow[0]+NShadow[1])
		RSEShadow = 1. - 0.68**(1./float(NShadow[0]+NShadow[1]))
	else:
		RShadow = float(NShadow[0])/float(NShadow[0]+NShadow[1])
		RSEShadow = np.sqrt( RShadow * ( 1. - RShadow ) / float(NShadow[0]+NShadow[1]) )
	print('{:d} , {:d} , {:5.4f} +/- {:5.4f} , {:5.4f} +/- {:5.4f}'.format( int(NStepTerrace[0]), int(NStepTerrace[1]), R, RSE, RShadow, RSEShadow ))
	
	print('Position at moment of reaction at step, terrace, and ratio')
	R = float(NStepTerrace[2])/float(NStepTerrace[2]+NStepTerrace[3])
	RSE = np.sqrt( R * ( 1. - R ) / float(NStepTerrace[2]+NStepTerrace[3]) )
	NShadow = [0, 0]
	for i in range(1, len(SurfaceDivision_reac)):
		if i % 2 == 0:
			NShadow[0] += SurfaceDivision_reac[i]
		else:
			NShadow[1] += SurfaceDivision_reac[i]
	if (NShadow[0]+NShadow[1]) == 0:
		RShadow = 999.
		RSEShadow = 999.
	elif NShadow[0] == 0:
		RShadow = float(NShadow[0])/float(NShadow[0]+NShadow[1])
		RSEShadow = 1. - 0.68**(1./float(NShadow[0]+NShadow[1]))
	else:
		RShadow = float(NShadow[0])/float(NShadow[0]+NShadow[1])
		RSEShadow = np.sqrt( RShadow * ( 1. - RShadow ) / float(NShadow[0]+NShadow[1]) )
	print('{:d} , {:d} , {:5.4f} +/- {:5.4f} , {:5.4f} +/- {:5.4f}'.format( int(NStepTerrace[2]), int(NStepTerrace[3]), R, RSE, RShadow, RSEShadow ))
	
	print('Number of direct and indirect reactions on the step, direct and indirect reactions on the terrace, and direct/total of both step and terrace for reaction')
	print('{:d} , {:d} , {:d} , {:d} , {:5.4f} , {:5.4f}'.format( *Reac_directness, float(Reac_directness[0])/(float(Reac_directness[1])+float(Reac_directness[0])), float(Reac_directness[2])/(float(Reac_directness[3])+float(Reac_directness[2])) ))
	print('\n')

	POS_scat_initial = np.array(POS_scat_initial)
	POS_reac_initial = np.array(POS_reac_initial)
	POS_reaction = np.array(POS_reaction)
	
	# tmp = POS_reac_initial
	# POS_reac_initial = np.append( POS_reac_initial, tmp + [0., 1.], axis=0 )
	# POS_reac_initial = np.append( POS_reac_initial, tmp - [0., 1.], axis=0 )
	
	# tmp = POS_reaction
	# POS_reaction = np.append( POS_reaction, tmp + [0., 1.], axis=0 )
	# POS_reaction = np.append( POS_reaction, tmp - [0., 1.], axis=0 )

	subax = plt.subplot(Nrows,Ncols,idx*2-1)
	#plt.scatter( POS_reac_initial[:,0]*Cell[0][0], POS_reac_initial[:,1]*Cell[1][1], s=0.5 )

	POS_reac_initial_binned, xedges, yedges = np.histogram2d( POS_reac_initial[:,0], POS_reac_initial[:,1], bins=Density, density=True )
	POS_reac_initial_binned /= sum( POS_reac_initial_binned )	# Ensures that an integration over the entire cell yields unity

	values = []
	for i in range( len(POS_reac_initial_binned) ):
		for j in range( len(POS_reac_initial_binned[0]) ):
			POS = np.dot( [(xedges[i+1] + xedges[i])/2., (yedges[j+1] + yedges[j])/2.], Cell[:2,:2] )
			values.append( [ POS[0], POS[1], POS_reac_initial_binned[i,j] ] )
	values = np.array( values )

	F = Rbf(values[:,0], values[:,1], values[:,2], function='linear', smooth=0. )
	Xnew = np.zeros(( Density_interpolation[0]+1, Density_interpolation[1]+1 ))
	Ynew = np.zeros(( Density_interpolation[0]+1, Density_interpolation[1]+1 ))
	for i in range( Density_interpolation[0]+1 ):
		for j in range( Density_interpolation[1]+1 ):
			POS = np.dot( [ float(i)/Density_interpolation[0], float(j)/Density_interpolation[1] ], Cell[:2,:2] )
			Xnew[i,j] = POS[0]
			Ynew[i,j] = POS[1]
	Znew = F(Xnew, Ynew).clip(min=0.)
	levels = np.arange( minS_1, maxS_1, stepS_1 )
	images.append( plt.contourf(Xnew, Ynew, Znew, zorder=-1, cmap='jet', levels=levels ) )
	#images.append( subax.contourf(Xnew, Ynew, Znew, zorder=-1, cmap='jet', levels=levels ) )

	#Surface_key = plt.scatter(* Surface_Top_layer[:,:2].T, s=200, marker='o', facecolors='none', linewidths=2, edgecolor='r', zorder=5)
	for POS in Surface_Top_layer[:,:2]:	# A bit dirty but it is for the paper
		if 3. < POS[0] < 15.:
			if POS[1] < 0.5 or POS[1] > 2.5:
				plt.scatter( POS[0], POS[1], s=200, marker='o', facecolors='none', linewidths=2, edgecolor='pink', zorder=5)
			else:
				plt.scatter( POS[0], POS[1], s=200, marker='o', facecolors='none', linewidths=2, edgecolor='r', zorder=5)
		else:
			plt.scatter( POS[0], POS[1], s=200, marker='o', facecolors='none', linewidths=2, edgecolor='k', zorder=5)
	plt.plot( [step[0,0], step[1,0]], [step[0,1], step[1,1]], color='k' )
	plt.plot( [step[2,0], step[3,0]], [step[2,1], step[3,1]], color='k' )
	edge = np.array(
	[np.dot( [0.,0.], Cell[:2,:2] ),
	np.dot( [0.,1.], Cell[:2,:2] ),
	np.dot( [1.,0.], Cell[:2,:2] ),
	np.dot( [1.,1.], Cell[:2,:2] )]
	)
	plt.fill_betweenx( [0., step[1,1]], [edge[0,0], edge[1,0]], [step[0,0], step[1,0]], hatch='/', facecolor='None', edgecolor='k', zorder=4, linewidth=0.0 )
	plt.fill_betweenx( [0., step[3,1]], [step[2,0], step[3,0]], [edge[2,0], edge[3,0]], hatch='/', facecolor='None', edgecolor='k', zorder=4, linewidth=0.0 )

	plt.xlabel(r'$X$ (\r{A})')
	plt.ylabel(r'$Y$ (\r{A})')
	plt.tick_params(length=3, width=1, direction='out', top=True, right=True)
	plt.axis('scaled')

	E = float( folder[:5] ) * 96.48531
	if idx == 1:
		plt.title(r'(a) $E_\mathrm{{i}}={:3.1f}$ kJ/mol'.format( E ), size=10)
	elif idx == 2:
		plt.title(r'(c) $E_\mathrm{{i}}={:3.1f}$ kJ/mol'.format( E ), size=10)
	elif idx == 3:
		plt.title(r'(e) $E_\mathrm{{i}}={:3.1f}$ kJ/mol'.format( E ), size=10)

	subax = plt.subplot(Nrows,Ncols,idx*2)
	#plt.scatter( POS_reaction[:,0]*Cell[0][0], POS_reaction[:,1]*Cell[1][1], s=0.1 )

	POS_reaction_binned, xedges, yedges = np.histogram2d( POS_reaction[:,0], POS_reaction[:,1], bins=Density, density=True )
	POS_reaction_binned /= sum( POS_reaction_binned )	# Ensures that an integration over the entire cell yields unity

	values = []
	for i in range( len(POS_reaction_binned) ):
		for j in range( len(POS_reaction_binned[0]) ):
			POS = np.dot( [(xedges[i+1] + xedges[i])/2., (yedges[j+1] + yedges[j])/2.], Cell[:2,:2] )
			values.append( [ POS[0], POS[1], POS_reaction_binned[i,j] ] )
	values = np.nan_to_num( values )

	F = Rbf(values[:,0], values[:,1], values[:,2], function='linear', smooth=0. )
	Xnew = np.zeros(( Density_interpolation[0]+1, Density_interpolation[1]+1 ))
	Ynew = np.zeros(( Density_interpolation[0]+1, Density_interpolation[1]+1 ))
	for i in range( Density_interpolation[0]+1 ):
		for j in range( Density_interpolation[1]+1 ):
			POS = np.dot( [ float(i)/Density_interpolation[0], float(j)/Density_interpolation[1] ], Cell[:2,:2] )
			Xnew[i,j] = POS[0]
			Ynew[i,j] = POS[1]
	Znew = F(Xnew, Ynew).clip(min=0.)
	levels = np.arange( minS_2, maxS_2, stepS_2 )
	images.append( plt.contourf(Xnew, Ynew, Znew, zorder=-1, cmap='jet', levels=levels ) )
	#images.append( plt.hist2d( POS_reaction[:,0], POS_reaction[:,1], bins=Density, density=True, cmap='jet' ) )

	#Surface_key = plt.scatter(* Surface_Top_layer[:,:2].T, s=200, marker='o', facecolors='none', linewidths=2, edgecolor='r', zorder=5)
	for POS in Surface_Top_layer[:,:2]:	# A bit dirty but it is for the paper
		if 3. < POS[0] < 15.:
			if POS[1] < 0.5 or POS[1] > 2.5:
				plt.scatter( POS[0], POS[1], s=200, marker='o', facecolors='none', linewidths=2, edgecolor='pink', zorder=5)
			else:
				plt.scatter( POS[0], POS[1], s=200, marker='o', facecolors='none', linewidths=2, edgecolor='r', zorder=5)
		else:
			plt.scatter( POS[0], POS[1], s=200, marker='o', facecolors='none', linewidths=2, edgecolor='k', zorder=5)
	plt.plot( [step[0,0], step[1,0]], [step[0,1], step[1,1]], color='k' )
	plt.plot( [step[2,0], step[3,0]], [step[2,1], step[3,1]], color='k' )
	plt.fill_betweenx( [0., step[1,1]], [edge[0,0], edge[1,0]], [step[0,0], step[1,0]], hatch='/', facecolor='None', edgecolor='k', zorder=4, linewidth=0.0 )
	plt.fill_betweenx( [0., step[3,1]], [step[2,0], step[3,0]], [edge[2,0], edge[3,0]], hatch='/', facecolor='None', edgecolor='k', zorder=4, linewidth=0.0 )

	plt.xlabel(r'$X$ (\r{A})')
	plt.ylabel(r'$Y$ (\r{A})')
	plt.tick_params(length=3, width=1, direction='out', top=True, right=True)
	plt.axis('scaled')

	if idx == 1:
		plt.title(r'(b) $E_\mathrm{{i}}={:3.1f}$ kJ/mol'.format( E ), size=10)
	elif idx == 2:
		plt.title(r'(d) $E_\mathrm{{i}}={:3.1f}$ kJ/mol'.format( E ), size=10)
	elif idx == 3:
		plt.title(r'(f) $E_\mathrm{{i}}={:3.1f}$ kJ/mol'.format( E ), size=10)

	chdir(here)

for i in range(len(folders)):
	plot(i+1, legend_idx, folders[i], titles[i])

xlim = plt.xlim()
ylim = plt.ylim()

subax = plt.subplot(Nrows,Ncols,7)

plt.title(r'(g) Side view of Pt(433)', size=10)

plt.xlabel(r'$X$ (\r{A})')
plt.ylabel(r'$Z$ (\r{A})')
plt.tick_params(length=3, width=1, direction='out', top=True, right=False, labelleft=False, left=False)
plt.axis('scaled')
plt.xlim(xlim)
plt.ylim(ylim)

subax = plt.subplot(Nrows,Ncols,8)

plt.title(r'(h) Top view of Pt(433)', size=10)

plt.xlabel(r'$X$ (\r{A})')
plt.ylabel(r'$Y$ (\r{A})')
plt.tick_params(length=3, width=1, direction='out', top=True, right=True)
plt.axis('scaled')
plt.xlim(xlim)
plt.ylim(ylim)


#plt.subplots_adjust(wspace=0.0, hspace=-0.05)

plt.tight_layout()

cbar_ax = fig.add_axes([0.095, 0.87, 0.385, 0.02]) #xmin, ymin, xwidth, ywidth
fig.colorbar(images[0], cax=cbar_ax, orientation='horizontal', ticks=np.linspace(minS_1, maxS_1-stepS_1, 6))
cbar_ax.xaxis.set_label_position('top')
cbar_ax.xaxis.set_ticks_position('top')
cbar_ax.set_xlabel('Initial location probability density')

cbar2_ax = fig.add_axes([0.585, 0.87, 0.385, 0.02]) #xmin, ymin, xwidth, ywidth
fig.colorbar(images[1], cax=cbar2_ax, orientation='horizontal', ticks=np.linspace(minS_2, maxS_2-stepS_2, 7))
cbar2_ax.xaxis.set_label_position('top')
cbar2_ax.xaxis.set_ticks_position('top')
cbar2_ax.set_xlabel('Dissociation location probability density')

fig.subplots_adjust(top=0.84)

#plt.text(0.25, 0.975, 'Laser off', fontsize=12, transform=plt.gcf().transFigure)
#plt.text(0.7, 0.975, r'$\nu_1=1$', fontsize=12, transform=plt.gcf().transFigure)


plt.savefig('site_density.pdf')
