#!/usr/bin/env python
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
from matplotlib.patches import FancyArrowPatch
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

# class Arrow3D(FancyArrowPatch):
#     def __init__(self, x, y, z, dx, dy, dz, *args, **kwargs):
#         super().__init__((0,0), (0,0), *args, **kwargs)
#         self._xyz = (x,y,z)
#         self._dxdydz = (dx,dy,dz)
        
#     # def __init__(self, xs, ys, zs, *args, **kwargs):
#     #     super().__init__((0,0), (0,0), *args, **kwargs)
#     #     self._verts3d = xs, ys, zs

#     def draw(self, renderer):
#         x1,y1,z1 = self._xyz
#         dx,dy,dz = self._dxdydz
#         x2,y2,z2 = (x1+dx,y1+dy,z1+dz)

#         xs, ys, zs = proj3d.proj_transform((x1,x2),(y1,y2),(z1,z2), renderer.M)
#         self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))
#         super().draw(renderer)

#     def do_3d_projection(self, renderer=None):
#         xs3d, ys3d, zs3d = self._xyz
#         xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
#         print(xs)
#         self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))

#         return np.min(zs)


# def _arrow3D(ax, x, y, z, dx, dy, dz, *args, **kwargs):
#     '''Add an 3d arrow to an `Axes3D` instance.'''

#     arrow = Arrow3D(x, y, z, dx, dy, dz, *args, **kwargs)
#     ax.add_artist(arrow)
    
# setattr(Axes3D,'arrow3D',_arrow3D)

class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))

        return np.min(zs)
    
    def draw(self, renderer=None):
        FancyArrowPatch.draw(self, renderer)

def draw_arrow(origin, c, L):
    a2 = Arrow3D([origin[0,0], origin[0,0]+c[0]], [origin[0,1], origin[0,1]+c[1]], [origin[0,2], origin[0,2]+c[2]], mutation_scale=10, lw=2, arrowstyle="-|>", color="k", connectionstyle="arc3,rad=0.", zorder=100)
    ax.add_artist(a2)
    a2 = Arrow3D([origin[0,0], origin[0,0]+L[0]], [origin[0,1], origin[0,1]+L[1]], [origin[0,2], origin[0,2]+L[2]], mutation_scale=10, lw=2, arrowstyle="-|>", color="r", connectionstyle="arc3,rad=0.", zorder=50)
    ax.add_artist(a2)
    return

from matplotlib import rcParams
rcParams['axes.labelpad'] = -15
mpl.rc("text", usetex=True)
Nrows = 3
Ncols = 2
fig = plt.figure(figsize=(5.,5.5))
#fig = mlab.figure(figsize=(5.,5.5))

def spherical2cartesian(r, theta, phi):
	return r*np.array( [ r*np.sin( np.radians( theta ) )*np.cos( np.radians( phi ) ), r*np.sin( np.radians( theta ) )*np.sin( np.radians( phi ) ), r*np.cos( np.radians( theta ) ) ] )

def Rx( theta, pos ):
	THETA = np.radians( theta )
	R = np.array(
	[[1., 0., 0.],
	[0., np.cos(THETA), -np.sin(THETA)],
	[0., np.sin(THETA), np.cos(THETA)]]
	)
	
	POS = np.dot( pos, R )
	
	return POS

def Ry( theta, pos ):
	THETA = np.radians( theta )
	R = np.array(
	[[np.cos(THETA), 0., np.sin(THETA)],
	[0., 1., 0.],
	[-np.sin(THETA), 0., np.cos(THETA)]]
	)
	
	POS = np.dot( pos, R )
	
	return POS

def Rz( phi, pos ):
	PHI = np.radians( phi )
	R = np.array(
	[[np.cos(PHI), -np.sin(PHI), 0.],
	[np.sin(PHI), np.cos(PHI), 0.],
	[0., 0., 1.]]
	)
	
	POS = np.dot( pos, R )
	
	return POS


def plot_ammonia(ax, ammonia):
    ax.scatter( ammonia[0,0], ammonia[0,1], ammonia[0,2], marker='o', s=50, c='k', edgecolors='k', alpha=1., zorder=200 )
    ax.scatter( ammonia[1,0], ammonia[1,1], ammonia[1,2], marker='o', s=30, c='r', edgecolors='k', alpha=1., zorder=201 )
    ax.scatter( ammonia[2:,0], ammonia[2:,1], ammonia[2:,2], marker='o', s=30, c='w', edgecolors='k', alpha=1., zorder=201 )
    return

ammonia = np.zeros((5,3))
ammonia[1] = spherical2cartesian( 1., 0., 0.)
ammonia[2] = spherical2cartesian( 1., 109.5, 30.)
ammonia[3] = spherical2cartesian( 1., 109.5, 150.)
ammonia[4] = spherical2cartesian( 1., 109.5, 270.)

L = spherical2cartesian( 1.8, 0., 0.)
c = spherical2cartesian( 1.5, 0., 0.)

ax = fig.add_subplot(Nrows, Ncols, 1, projection='3d')
ax.computed_zorder = False

plot_ammonia(ax, ammonia)
draw_arrow(ammonia, c, L)
ax.text(L[0], L[1], L[2], 'L', size=12, c='r')
ax.text(c[0]+0.2, c[1], c[2]*0.8, 'c', size=12, c='k')
ax.text( -1.8, -1.7, -0.8, '(a)', size=12, c='k' )

plt.xticks([])
plt.yticks([])
ax.set_zticks([])
xlim = [-2,2]
ylim = [-2,2]
zlim = [-1,2]
plt.xlim(xlim)
plt.ylim(ylim)
ax.set_zlim(zlim)
#plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$X$')
plt.ylabel(r'$Y$')
ax.set_zlabel(r'$Z$')

ax = fig.add_subplot(Nrows, Ncols, 2, projection='3d')
ax.computed_zorder = False

ammonia = Rz( 90., ammonia )
c = Rz( 90., c )
L = Rz( 90., L )

plot_ammonia(ax, ammonia)
draw_arrow(ammonia, c, L)
draw_arrow(ammonia, c, L)
ax.text(L[0], L[1], L[2]+0.1, 'L', size=12, c='r')
ax.text(c[0]+0.2, c[1], c[2]*0.8, 'c', size=12, c='k')
ax.text( -1.8, -1.7, -0.8, '(b)', size=12, c='k' )
ax.text( -0.8, -0.2, 1.5, r'$\alpha$', size=12, c='b' )

r = 1.8
phi = np.radians(180.)
x0 = r*np.cos(phi)
y0 = r*np.sin(phi)
z0 = 1.3

theta = np.linspace(phi, phi+3*np.pi/2., 201)

x = r*np.sin(theta) # x=0
y = r*np.cos(theta) + y0 # y - y0 = r*cos(theta)
z = np.full(len(theta), z0) # z - z0 = r*sin(theta)

ax.plot(x, y, z, c='b')
dx = x[-1]-x[-2]
dy = y[-1]-y[-2]
dz = z[-1]-z[-2]
############################################ FIX THIS:
#ax.arrow3D(x[-1]+3*dx, y[-1]+3*dy, z[-1]+3*dz, dx, dy, dz, color='g', mutation_scale=5)
offset = 10
a2 = Arrow3D([x[-1], x[-1]+offset*dx], [y[-1], y[-1]+offset*dy], [z[-1], z[-1]+offset*dz], mutation_scale=10, lw=1, arrowstyle="-|>", color="b", connectionstyle="arc3,rad=0.")
ax.add_artist(a2)

plt.xticks([])
plt.yticks([])
ax.set_zticks([])
plt.xlim(xlim)
plt.ylim(ylim)
ax.set_zlim(zlim)
#plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$X$')
plt.ylabel(r'$Y$')
ax.set_zlabel(r'$Z$')

ax = fig.add_subplot(Nrows, Ncols, 3, projection='3d')
ax.computed_zorder = False

ammonia = Ry( 90., ammonia )
c = Ry( 90., c )
#L = Ry( 60., L )

plot_ammonia(ax, ammonia)
draw_arrow(ammonia, c, L)
ax.text(L[0]+0.1, L[1], L[2], 'L', size=12, c='r')
ax.text(c[0]*0.8, c[1]*0.8, c[2]+0.4, 'c', size=12, c='k')
ax.text( -1.8, -1.7, -0.8, '(c)', size=12, c='k' )
ax.text( -0.8, 0., 1.5, r'$\beta$', size=12, c='b' )

r = 2.2
phi = np.radians(180.)
theta = np.linspace(0., np.radians(120.), 201)
x0 = 0.
y0 = 0.
z0 = 0.3
x = r*np.sin(theta)*np.cos(phi) + x0
y = r*np.sin(theta)*np.sin(phi) + y0
z = r*np.cos(theta) + z0

ax.plot(x, y, z, c='b')
dx = x[-1]-x[-2]
dy = y[-1]-y[-2]
dz = z[-1]-z[-2]
############################### FIX THIS:
#ax.arrow3D(x[-1]+3*dx, y[-1]+3*dy, z[-1]+3*dz, dx, dy, dz, color='g', mutation_scale=5)
offset = 10
a2 = Arrow3D([x[-1], x[-1]+offset*dx], [y[-1], y[-1]+offset*dy], [z[-1], z[-1]+offset*dz], mutation_scale=10, lw=1, arrowstyle="-|>", color="b", connectionstyle="arc3,rad=0.")
ax.add_artist(a2)

plt.xticks([])
plt.yticks([])
ax.set_zticks([])
plt.xlim(xlim)
plt.ylim(ylim)
ax.set_zlim(zlim)
#plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$X$')
plt.ylabel(r'$Y$')
ax.set_zlabel(r'$Z$')

ax = fig.add_subplot(Nrows, Ncols, 4, projection='3d')
ax.computed_zorder = False

#ammonia = Ry( 30., ammonia )
ammonia = Rx(20., ammonia )
#ammonia = Ry(-30., ammonia )

plot_ammonia(ax, ammonia)
draw_arrow(ammonia, c, L)
ax.text(L[0], L[1], L[2], 'L', size=12, c='r')
ax.text(c[0]*0.8, c[1]*0.8, c[2]+0.4, 'c', size=12, c='k')
ax.text( -1.8, -1.7, -0.8, '(d)', size=12, c='k' )
ax.text( -1.5, -1.1, 1.55, r'$\gamma$', size=12, c='b' )

r = 0.6
phi = np.radians(90.)
x0 = c[0]*1.05 - 0.2
y0 = c[1]*1.05
z0 = c[2]*1.05
theta = np.linspace(np.radians(270.), np.radians(-30.), 201)
x = r*np.sin(theta)*np.cos(phi) + x0
y = r*np.sin(theta)*np.sin(phi) + y0
z = r*np.cos(theta) + z0

ax.plot(x, y, z, c='b')
dx = x[-1]-x[-2]
dy = y[-1]-y[-2]
dz = z[-1]-z[-2]
################################################## Fix this:
#ax.arrow3D(x[-1]+3*dx, y[-1]+3*dy, z[-1]+3*dz, dx, dy, dz, color='g', mutation_scale=5)
offset = 10
a2 = Arrow3D([x[-1], x[-1]+offset*dx], [y[-1], y[-1]+offset*dy], [z[-1], z[-1]+offset*dz], mutation_scale=10, lw=1, arrowstyle="-|>", color="b", connectionstyle="arc3,rad=0.")
ax.add_artist(a2)

plt.xticks([])
plt.yticks([])
ax.set_zticks([])
plt.xlim(xlim)
plt.ylim(ylim)
ax.set_zlim(zlim)
#plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$X$')
plt.ylabel(r'$Y$')
ax.set_zlabel(r'$Z$')

ax = fig.add_subplot(Nrows, Ncols, 5, projection='3d')
ax.computed_zorder = False

ammonia = Ry( 30., ammonia )
c = Ry( 30., c )
L = Ry( 30., L )

plot_ammonia(ax, ammonia)
draw_arrow(ammonia, c, L)
ax.text(L[0], L[1], L[2]+0.2, 'L', size=12, c='r')
ax.text(c[0]+0.6, c[1], c[2]-0.4, 'c', size=12, c='k')
ax.text( -1.8, -1.7, -0.8, '(e)', size=12, c='k' )
ax.text( -0.6, 0., 1.4, r'$\theta$', size=12, c='b' )

r = 2.
phi = np.radians(180.)
theta = np.linspace(0., np.radians(120.), 201)
x0 = 0.
y0 = 0.
z0 = 0.3
x = r*np.sin(theta)*np.cos(phi) + x0
y = r*np.sin(theta)*np.sin(phi) + y0
z = r*np.cos(theta) + z0

ax.plot(x, y, z, c='b')
dx = x[-1]-x[-2]
dy = y[-1]-y[-2]
dz = z[-1]-z[-2]
########################################### FIX THIS:
#ax.arrow3D(x[-1]+3*dx, y[-1]+3*dy, z[-1]+3*dz, dx, dy, dz, color='g', mutation_scale=5)
offset = 10
a2 = Arrow3D([x[-1], x[-1]+offset*dx], [y[-1], y[-1]+offset*dy], [z[-1], z[-1]+offset*dz], mutation_scale=10, lw=1, arrowstyle="-|>", color="b", connectionstyle="arc3,rad=0.")
ax.add_artist(a2)

plt.xticks([])
plt.yticks([])
ax.set_zticks([])
plt.xlim(xlim)
plt.ylim(ylim)
ax.set_zlim(zlim)
#plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$X$')
plt.ylabel(r'$Y$')
ax.set_zlabel(r'$Z$')

ax = fig.add_subplot(Nrows, Ncols, 6, projection='3d')
ax.computed_zorder = False

phi = 150.
ammonia = Rz( phi, ammonia )
c = Rz( phi, c )
L = Rz( phi, L )

plot_ammonia(ax, ammonia)
draw_arrow(ammonia, c, L)
ax.text(L[0], L[1], L[2], 'L', size=12, c='r')
ax.text(c[0], c[1], c[2]+0.2, 'c', size=12, c='k')
ax.text( -1.8, -1.7, -0.8, '(f)', size=12, c='k' )
ax.text( 0.2, 0., 2.6, r'$\phi$', size=12, c='b' )

r = 1.8
phi = np.radians(180.)
x0 = r*np.cos(phi)
y0 = r*np.sin(phi)
z0 = 1.2
theta = np.linspace(phi, phi+3*np.pi/2., 201)
x = r*np.sin(theta) # x=0
y = r*np.cos(theta) + y0 # y - y0 = r*cos(theta)
z = np.full(len(theta), z0) # z - z0 = r*sin(theta)

ax.plot(x, y, z, c='b')
dx = x[-1]-x[-2]
dy = y[-1]-y[-2]
dz = z[-1]-z[-2]
################################################## FIX THIS:
#ax.arrow3D(x[-1]+3*dx, y[-1]+3*dy, z[-1]+3*dz, dx, dy, dz, color='g', mutation_scale=5)
offset = 10
a2 = Arrow3D([x[-1], x[-1]+offset*dx], [y[-1], y[-1]+offset*dy], [z[-1], z[-1]+offset*dz], mutation_scale=10, lw=1, arrowstyle="-|>", color="b", connectionstyle="arc3,rad=0.")
ax.add_artist(a2)

plt.xticks([])
plt.yticks([])
ax.set_zticks([])
plt.xlim(xlim)
plt.ylim(ylim)
ax.set_zlim(zlim)
#plt.tick_params(length=6, width=1, direction='in', top=True, right=True)
plt.xlabel(r'$X$')
plt.ylabel(r'$Y$')
ax.set_zlabel(r'$Z$')

plt.tight_layout()
#plt.subplots_adjust(wspace=0, hspace=0)
#ax.view_init(elev=90., azim=0.)
plt.savefig('rotationalstate.pdf')
plt.savefig('rotationalstate.tiff', dpi=300)
#plt.show()
