#!/usr/bin/env python
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.proj3d import proj_transform
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
#from mayavi import mlab

from matplotlib.patches import FancyArrowPatch
class Arrow3D(FancyArrowPatch):
    def __init__(self, x, y, z, dx, dy, dz, *args, **kwargs):
        super().__init__((0,0), (0,0), *args, **kwargs)
        self._xyz = (x,y,z)
        self._dxdydz = (dx,dy,dz)

    def draw(self, renderer):
        x1,y1,z1 = self._xyz
        dx,dy,dz = self._dxdydz
        x2,y2,z2 = (x1+dx,y1+dy,z1+dz)

        xs, ys, zs = proj_transform((x1,x2),(y1,y2),(z1,z2), renderer.M)
        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))
        super().draw(renderer)

def _arrow3D(ax, x, y, z, dx, dy, dz, *args, **kwargs):
    '''Add an 3d arrow to an `Axes3D` instance.'''

    arrow = Arrow3D(x, y, z, dx, dy, dz, *args, **kwargs)
    ax.add_artist(arrow)
setattr(Axes3D,'arrow3D',_arrow3D)

from matplotlib import rcParams
rcParams['axes.labelpad'] = -15
mpl.rc("text", usetex=True)
Nrows = 2
Ncols = 2
fig = plt.figure(figsize=(5.,4.))
#fig = mlab.figure(figsize=(5.,5.5))

def spherical2cartesian(r, theta, phi):
	return r*np.array( [ r*np.sin( np.radians( theta ) )*np.cos( np.radians( phi ) ), r*np.sin( np.radians( theta ) )*np.sin( np.radians( phi ) ), r*np.cos( np.radians( theta ) ) ] )

def Rx( theta, pos ):
	THETA = np.radians( theta )
	R = np.array(
	[[1., 0., 0.],
	[0., np.cos(THETA), -np.sin(THETA)],
	[0., np.sin(THETA), np.cos(THETA)]]
	)
	
	POS = np.dot( pos, R )
	
	return POS

def Ry( theta, pos ):
	THETA = np.radians( theta )
	R = np.array(
	[[np.cos(THETA), 0., np.sin(THETA)],
	[0., 1., 0.],
	[-np.sin(THETA), 0., np.cos(THETA)]]
	)
	
	POS = np.dot( pos, R )
	
	return POS

def Rz( phi, pos ):
	PHI = np.radians( phi )
	R = np.array(
	[[np.cos(PHI), -np.sin(PHI), 0.],
	[np.sin(PHI), np.cos(PHI), 0.],
	[0., 0., 1.]]
	)
	
	POS = np.dot( pos, R )
	
	return POS

hcl = np.zeros((2,3))
hcl[1] = spherical2cartesian( 1., 90., 0.)

L = spherical2cartesian( 1.5, 0., 0.)
c = spherical2cartesian( 1.2, 90., 0.)

ax = fig.add_subplot(Nrows, Ncols, 1, projection='3d')

ax.scatter( hcl[0,0], hcl[0,1], hcl[0,2], marker='o', s=200, c='g', edgecolors='k', alpha=1. )
ax.scatter( hcl[1:,0], hcl[1:,1], hcl[1:,2], marker='o', s=30, c='w', edgecolors='k', alpha=1. )
ax.arrow3D(hcl[0,0], hcl[0,1], hcl[0,2], L[0], L[1], L[2], color='r', mutation_scale=5)
ax.plot(hcl[:,0], hcl[:,1], hcl[:,2], color='k', linewidth=3.)
ax.text(L[0], L[1], L[2], 'L', size=12, c='r')
ax.text(hcl[0,0]+0.2, hcl[0,1], hcl[0,2]+0.5, 'Cl', size=12, c='g')
ax.text(hcl[1,0], hcl[1,1], hcl[1,2]+0.3, 'H', size=12, c='k')
ax.text( -1.8, -1.7, -0.8, '(a)', size=12, c='k' )

plt.xticks([])
plt.yticks([])
ax.set_zticks([])
xlim = [-2,2]
ylim = [-2,2]
zlim = [-1,2]
plt.xlim(xlim)
plt.ylim(ylim)
ax.set_zlim(zlim)
#plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$X$')
plt.ylabel(r'$Y$')
ax.set_zlabel(r'$Z$')

ax = fig.add_subplot(Nrows, Ncols, 2, projection='3d')

hcl = Ry( 30., hcl )
c = Ry( 30., c )
L = Ry( 30., L )

ax.scatter( hcl[0,0], hcl[0,1], hcl[0,2], marker='o', s=200, c='g', edgecolors='k', alpha=1. )
ax.scatter( hcl[1:,0], hcl[1:,1], hcl[1:,2], marker='o', s=30, c='w', edgecolors='k', alpha=1. )
ax.arrow3D(hcl[0,0], hcl[0,1], hcl[0,2], L[0], L[1], L[2], color='r', mutation_scale=5)
ax.plot(hcl[:,0], hcl[:,1], hcl[:,2], color='k', linewidth=3.)
#ax.text(L[0]+0.1, L[1], L[2], 'L', size=12, c='r')
ax.text( -1.8, -1.7, -0.8, '(b)', size=12, c='k' )
ax.text( -0.4, 0., 1.5, r'$\beta$', size=12, c='g' )

r = 2.2
phi = np.radians(180.)
theta = np.linspace(0., np.radians(30.), 201)
x0 = 0.
y0 = 0.
z0 = 0.3
x = r*np.sin(theta)*np.cos(phi) + x0
y = r*np.sin(theta)*np.sin(phi) + y0
z = r*np.cos(theta) + z0

ax.plot(x, y, z, c='g')
dx = x[-1]-x[-5]
dy = y[-1]-y[-5]
dz = z[-1]-z[-5]
ax.arrow3D(x[-1]+3*dx, y[-1]+3*dy, z[-1]+3*dz, dx, dy, dz, color='g', mutation_scale=5)

plt.xticks([])
plt.yticks([])
ax.set_zticks([])
plt.xlim(xlim)
plt.ylim(ylim)
ax.set_zlim(zlim)
#plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$X$')
plt.ylabel(r'$Y$')
ax.set_zlabel(r'$Z$')

ax = fig.add_subplot(Nrows, Ncols, 3, projection='3d')

hcl = Ry( -30., hcl )
c = Ry( -30., c )
hcl = Rz( 180., hcl )
c = Rz( 180., c )
hcl = Ry( 30., hcl )
c = Ry( 30., c )
#L = Ry( 30., L )

ax.scatter( hcl[0,0], hcl[0,1], hcl[0,2], marker='o', s=200, c='g', edgecolors='k', alpha=1. )
ax.scatter( hcl[1:,0], hcl[1:,1], hcl[1:,2], marker='o', s=30, c='w', edgecolors='k', alpha=1. )
ax.arrow3D(hcl[0,0], hcl[0,1], hcl[0,2], L[0], L[1], L[2], color='r', mutation_scale=5)
ax.plot(hcl[:,0], hcl[:,1], hcl[:,2], color='k', linewidth=3.)
#ax.text(L[0], L[1], L[2]+0.2, 'L', size=12, c='r')
ax.text( -1.8, -1.7, -0.8, '(c)', size=12, c='k' )
ax.text( -0.6, 0., 1.4, r'$\gamma$', size=12, c='g' )

r = 2.
phi = np.linspace(np.radians(90.), np.radians(180.), 201)
theta = -np.sin(np.radians(30.))*np.cos(phi)
x0 = 0.
y0 = 0.
z0 = -1.
x = r*np.sin(theta)*np.cos(phi) + x0
y = r*np.sin(theta)*np.sin(phi) + y0
z = r*np.cos(theta) + z0

ax.plot(x, y, z, c='g')
dx = x[-1]-x[-5]
dy = y[-1]-y[-5]
dz = z[-1]-z[-5]
ax.arrow3D(x[-1]+3*dx, y[-1]+3*dy, z[-1]+3*dz, dx, dy, dz, color='g', mutation_scale=5)

plt.xticks([])
plt.yticks([])
ax.set_zticks([])
plt.xlim(xlim)
plt.ylim(ylim)
ax.set_zlim(zlim)
#plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$X$')
plt.ylabel(r'$Y$')
ax.set_zlabel(r'$Z$')

ax = fig.add_subplot(Nrows, Ncols, 4, projection='3d')

ax.scatter( hcl[0,0], hcl[0,1], hcl[0,2], marker='o', s=200, c='g', edgecolors='k', alpha=1. )
ax.scatter( hcl[1:,0], hcl[1:,1], hcl[1:,2], marker='o', s=30, c='w', edgecolors='k', alpha=1. )
ax.arrow3D(hcl[0,0], hcl[0,1], hcl[0,2], L[0], L[1], L[2], color='r', mutation_scale=5)
ax.plot(hcl[:,0], hcl[:,1], hcl[:,2], color='k', linewidth=3.)
#ax.text(L[0], L[1], L[2], 'L', size=12, c='r')
ax.text( -1.8, -1.7, -0.8, '(d)', size=12, c='k' )
ax.text( -1.5, 0., 0.7, r'$\theta$', size=12, c='g' )

r = 2.2
phi = np.radians(180.)
theta = np.linspace(0., np.radians(120.), 201)
x0 = 0.
y0 = 0.
z0 = 0.3
x = r*np.sin(theta)*np.cos(phi) + x0
y = r*np.sin(theta)*np.sin(phi) + y0
z = r*np.cos(theta) + z0

ax.plot(x, y, z, c='g')
dx = x[-1]-x[-2]
dy = y[-1]-y[-2]
dz = z[-1]-z[-2]
ax.arrow3D(x[-1]+3*dx, y[-1]+3*dy, z[-1]+3*dz, dx, dy, dz, color='g', mutation_scale=5)

plt.xticks([])
plt.yticks([])
ax.set_zticks([])
plt.xlim(xlim)
plt.ylim(ylim)
ax.set_zlim(zlim)
#plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$X$')
plt.ylabel(r'$Y$')
ax.set_zlabel(r'$Z$')

plt.tight_layout()
#plt.subplots_adjust(wspace=0, hspace=0)
#ax.view_init(elev=90., azim=0.)
plt.savefig('rotationalstate.pdf')
#plt.show()
