# coding=Big5
from visual import *
import matplotlib.pyplot as plt
import numpy as np

def V_pot(x):
    global V0,sqa,sqb
    if(sqa<=x<=sqb): V=0.
    else: V=V0
    return V


def Shoot(a,b,Nh,Ea,Eb,psi_00,psi_01,V):
    E1=Ea; E2=Eb; Nst=120
    for j in range(Nst):
        ps1=psi_r(a,b,Nh,E1,psi_00,psi_01,V)
        ps2=psi_r(a,b,Nh,E2,psi_00,psi_01,V)
        if((ps1-psi_00)*(ps2-psi_00) > 0.):
            return E1,E2
        else:
            pst=psi_r(a,b,Nh,(E1+E2)/2.,psi_00,psi_01,V)
            if((pst-psi_00)*(ps1-psi_00) < 0.):
                E2=(E1+E2)/2.
                ps2=pst
            else:
                E1=(E1+E2)/2.
                ps1=pst
    return E1,E2
        
        

def psi_r(a,b,Nh,E,psi_00,psi_01,V):
    h=(b-a)/Nh
    h2=h**2
    psi_0=psi_00
    psi_1=psi_01
    for i in range(Nh):
        x=a+h*i
        ksq=2.*(E-V(x))
        psi_2=2.*psi_1-psi_0-ksq*psi_1*h2
        psi_0=psi_1
        psi_1=psi_2
    DP=abs(psi_1-psi_00)
    return psi_1


def PsiE(a,b,Nh,E,psi_00,psi_01,V):
    h=1.*(b-a)/Nh
    h2=h**2
    psi_0=psi_00
    psi_1=psi_01
    Psi=[]; xa=[]; Psi2=[]
    Psi.append(psi_0); xa.append(a)
    Psi.append(psi_1); xa.append(a+h)
    norm=0.
    for i in range(Nh):
        norm+=(psi_0)**2*h
        x=a+h*i
        ksq=2.*(E-V(x))
        psi_2=2.*psi_1-psi_0-ksq*psi_1*h2
        psi_0=psi_1
        psi_1=psi_2
        Psi.append(psi_2)
        xa.append(x+h)       
    DP=abs(psi_1-psi_00)
    print '%13.5e %13.5e %13.5e %13.5e' %(DP,psi_0,psi_1,psi_00)
    #---------------normalization---------
    A=sqrt(norm)
    for i in range(len(Psi)):
        Psi[i]=(Psi[i]/A)
        Psi2.append(Psi[i]**2)
    return Psi,xa,Psi2

###############################----main---------------

V0,sqa,sqb=250,0.,1.
a,b=sqa-1.00,sqb+1.00
Nh=1000
h=(b-a)/Nh
pot='1D finite square well V0='+str(V0)
psi_00,psi_01=1.E-10*exp(-(a)**2),1.E-10*exp(-(a+h)**2)

#======shoot eigen-energies====
N=1200
EN=[0 for i in range(N)]
for i in range(N):
    EN[i]=0.+0.5*i

ENN=[]; NEIG=8
print '=====Shoot======='
accuracy=1.e-3; KK=0
for i in range(N):
    if(len(ENN)==NEIG): break
    E1,E2=Shoot(a,b,Nh,EN[i],EN[i+1],psi_00,psi_01,V_pot)
    if(KK == 0) and (abs((E1-E2)/E2) < accuracy):
        ENN.append((E1+E2)/2.)
        Ep=(E1+E2)/2.
        print '%6d %6d %12.6f' %(i,len(ENN),Ep)
        KK=1
        continue

    if(abs((E1-E2)/E2) < 1.e-3) and (E1+E2)/2.-Ep > 0.1:
        ENN.append((E1+E2)/2.)
        Ep=(E1+E2)/2.
        print '%6d %6d %12.6f' %(i,len(ENN),Ep)
    
print '==== ENN======'
Ex0=4.*pi**2/8.
N=len(ENN)
strE=[]
for i in range(N):
    strE.append(str(ENN[i]))
    Exi=Ex0*(i+1.)**2       #----exact eigenenergies for inf Well----
    print '%6d %12.5f %12.5f' %(i+1,ENN[i],Exi)

print '=========wave function Psi========='
Psi=[[] for i in range(N)]; xa=[[] for i in range(N)];Psi2=[[] for i in range(N)]
for i in range(N):
    Psi[i],xa[i],Psi2[i]=PsiE(a,b,Nh,ENN[i],psi_00,psi_01,V_pot)
print'-------------check the first 3 norms-----------'
for j in range(3):
    sum=0.
    for i in range(Nh):
        sum+=Psi2[j][i]*h
    print j,' norm=',sum

# matplot µe¹Ï  ----Psi----
plt.figure()
plt.subplot(421)
#------------------eigen energy---------
for i in range(7):
    plt.text(3, -4.5-i*0.3,r'E'+str(i+1)+'='+strE[i][0:6])
plt.title('$\Psi_n(x)$'+pot,fontsize=15)
plt.plot(xa[0], Psi[0], label='n=1'); plt.legend()
plt.subplot(422)
plt.plot(xa[1], Psi[1], label='n=2'); plt.legend()
plt.subplot(423)
plt.plot(xa[2], Psi[2], label='n=3'); plt.legend()
plt.subplot(424)
plt.plot(xa[3], Psi[3], label='n=4'); plt.legend()
plt.subplot(425)
plt.plot(xa[4], Psi[4], label='n=5'); plt.legend()
plt.subplot(426)
plt.plot(xa[5], Psi[5], label='n=6'); plt.legend()
plt.subplot(427)
plt.plot(xa[6], Psi[6], label='n=7'); plt.legend()

# matplot µe¹Ï -----Psi^2------
plt.figure()
plt.subplot(421)
plt.title('$\Psi_n^2(x)$'+pot,fontsize=15)
plt.plot(xa[0], Psi2[0])
plt.subplot(422)
plt.plot(xa[1], Psi2[1])
plt.subplot(423)
plt.plot(xa[2], Psi2[2])
plt.subplot(424)
plt.plot(xa[3], Psi2[3])
plt.subplot(425)
plt.plot(xa[4], Psi2[4])
plt.subplot(426)
plt.plot(xa[5], Psi2[5])
plt.subplot(427)
plt.plot(xa[6], Psi2[6])
plt.subplot(428)
plt.plot(xa[7], Psi2[7])

plt.figure()
plt.title('$\Psi_n^2(x)$ vs. x\n'+pot,fontsize=15)
for i in range(3):
    plt.plot(xa[i], Psi2[i], label='n='+str(i+1),linewidth=i+1) 
plt.legend()
plt.show()
