#-- AI4.py

import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
import os 
import math
import time
from time import sleep
from motorm import turn

Lx=800; Ly=600; F=2; Lx2=Lx//F; Ly2=Ly//F;
def cw(k): cv2.waitKey(k)
def sh0(w,img): cv2.imshow(w,img)
def cntr(img):
    cntsT,_=cv2.findContours(img,cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)
    return cntsT

##############################################################
def takepic():    
    cap = cv2.VideoCapture(0)
    cap.set(cv2.CAP_PROP_FRAME_WIDTH,Lx)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT,Ly)
    sleep(2)  #wait 2 s for camera warm-up
    for k in LAB:
        #if(k=='T' or k=='H'): continue
        jt=0; NP=0;  
        while(jt < 80 and cap.isOpened()):
            ret,img=cap.read(); jt+=1; 
            fn='PNG/'+k+'1_'+str(NP).zfill(3)+'.png'  
            (x,y,w,h),(cx,cy),masked,hst1=CNTX(img) 
            x1=x; x2=x+w; y1=y; y2=y+h
            x1=x1*F; y1=y1*F; x2=x2*F; y2=y2*F
            img3=img[y1:y2,x1:x2,:]; img3=cv2.cvtColor(img3,cv2.COLOR_BGR2GRAY)
            cv2.imshow('img3',img3)
            sh0('hst1',hst1); #cw(0)
            cv2.imwrite(fn,img3)
            NP+=1; print(jt,NP,fn,(x,y,w,h)); 
            if cv2.waitKey(10) & 0xFF == ord('q'): break
            km=jt%5
            if(km==0): turn(0.3,0.3,0.07); turn(0,0,0.3) 
            if(km==1): turn(0.3,-0.3,0.07); turn(0,0,0.5) 
            if(km==2 or km==3): turn(-0.3,0.3,0.07); turn(0,0,0.5) 
            if(km==4): turn(0.3,-0.3,0.07); turn(0,0,0.5) 
        input('change card...')
    cv2.destroyAllWindows()

###########################################################
def ANADAT(DAT):
    LD=len(DAT); SKIP=[]
    for j in range(LD):
        #print(j,DAT[j])
        if(j in SKIP): continue
        for i in range(LD):
            if(i==j): continue
            if(i in SKIP): continue
            r=np.sqrt((DAT[j][0]-DAT[i][0])**2+(DAT[j][0]-DAT[i][0])**2)
            if(r<25):
                if(DAT[j][4]>DAT[i][4]): SKIP.append(i)
                else: SKIP.append(j)
                #print(j,i,r,SKIP)
                #continue
    #print('SKIP=',SKIP)
    return SKIP


###############################################################
def CNTS(img):
    imgin=np.copy(img); 
    img1=cv2.resize(imgin,(Lx2,Ly2))
    Ac=Lx2*Ly2*0.0001
    hsv=cv2.cvtColor(img1, cv2.COLOR_BGR2HSV)
    mask1=cv2.inRange(hsv,lower1,upper1)
    masked=cv2.bitwise_and(img1, img1, mask=mask1)
    hst1=np.vstack((img1,masked))
    cnts=cntr(mask1)
    if(len(cnts)==0): return [],masked,hst1
    DAT=[]
    for j,C in enumerate(cnts):
        (x,y,w,h)=cv2.boundingRect(C);
        A=w*h; cx=int(x+w/2); cy=int(y+h/2); rat=w/h
        if(A<Ac): continue
        if(w<30 or h <30): continue
        #cv2.rectangle(masked,(x,y),(x+w,y+h),(200,200,200),2)
        #cv2.circle(masked,(cx,cy),5,(255,255,255),-1)
        #hst1=np.vstack((img1,masked))
        DAT.append([cx,cy,x,y,w,h])
    SKIP=ANADAT(DAT)
    DAT2=[]
    for j in range(len(DAT)):
        if(j in SKIP): continue
        DAT2.append(DAT[j])
        cx,cy,x,y,w,h=DAT[j][0],DAT[j][1],DAT[j][2],DAT[j][3],DAT[j][4],DAT[j][5]
        cv2.rectangle(masked,(x,y),(x+w,y+h),(200,200,200),2)
        cv2.circle(masked,(cx,cy),5,(255,255,255),-1)
        hst1=np.vstack((img1,masked))
    return DAT2,masked,hst1


###############################################################
def CNTX(img):
    imgin=np.copy(img); 
    img1=cv2.resize(imgin,(Lx2,Ly2))
    Ac=Lx2*Ly2*0.0001
    hsv=cv2.cvtColor(img1, cv2.COLOR_BGR2HSV)
    mask1=cv2.inRange(hsv,lower1,upper1)
    masked=cv2.bitwise_and(img1, img1, mask=mask1)
    hst1=np.vstack((img1,masked))
    cnts=cntr(mask1)
    if(len(cnts)==0): return (0,0,0,0),(0,0),masked,hst1
    C=max(cnts,key=cv2.contourArea)  #max() max contours
    (x,y,w,h)=cv2.boundingRect(C);
    A=w*h; cx=int(x+w/2); cy=int(y+h/2); rat=w/h
    if(A<Ac): return (0,0,0,0),(0,0),masked,hst1
    cv2.rectangle(masked,(x,y),(x+w,y+h),(200,200,200),2)
    cv2.circle(masked,(cx,cy),5,(255,255,255),-1)
    hst1=np.vstack((img1,masked))
    return (x,y,w,h),(cx,cy),masked,hst1



#################################################################
def Sload_images_from_folder(iszx,iszy,folder):
    nitems=os.listdir(folder); nsitems=sorted(nitems)
    images=[]; img5=[]
    for filename in nsitems:
        img = cv2.imread(os.path.join(folder,filename), cv2.IMREAD_GRAYSCALE)
        if img is not None:
            img1=cv2.resize(img,(iszx,iszy))/255
            images.append(img1); img5.append(img)
    lens=len(nsitems)
    return lens,nsitems,np.array(images[:lens]),img5
##################################################
def load_images_from_folder(folder,nfile):
    images = []
    for filename in os.listdir(folder):
        img = cv2.imread(os.path.join(folder,filename), cv2.IMREAD_GRAYSCALE)
        if img is not None:
            img = cv2.resize(img,(isx,isy))/255
            images.append(img)
    return np.array(images[:nfile])

##################################################
def AITEST(pathT):
    # read images
    #test = load_images_from_folder(pathT,ntest)
    ntest,fimgs,test,img5=Sload_images_from_folder(isx,isy,pathT)
    print('load is done path=',pathT)
    #print('fimgs=',fimgs)
    X_test = test
    X_test = X_test.reshape(X_test.shape[0], isy,isx, 1).astype("float32")
    print('X_test:type,len,shape=',type(X_test),len(X_test),X_test.shape)
    pred = np.argmax(model.predict(X_test),axis=1)
    #print()
    #print(X_test)
    print('pred=',pred)
    if(1==2):
        for i in range(ntest):
            img = X_test[i,:,:,0]
            print(i,'pred[i]=',pred[i],' LAB=',LAB[pred[i]])
            cv2.imshow('img',img)
            cv2.waitKey(0)

##################################################
def VERIFY():
    for j in range(5):
        if(j==0): pathT=path+'/'+LAB[0]+'/'
        if(j==1): pathT=path+'/'+LAB[1]+'/'
        if(j==2): pathT=path+'/'+LAB[2]+'/'
        if(j==3): pathT=path+'/'+LAB[3]+'/'
        if(j==4): pathT=path+'/TEST/'
        if(j==5): pathT=path+'/TEST2/'
        AITEST(pathT)

##################################################
def buildmodel():
    print('load path=',path)
    train = load_images_from_folder(path+'/'+LAB[0]+'/',nfile)
    print('type(train)=',type(train),train.shape)
    Y_train = np.zeros(train.shape[0])
    X_train = train
    train = load_images_from_folder(path+'/'+LAB[1]+'/',nfile)
    Y_train = np.append(Y_train, np.ones(train.shape[0])*1)
    X_train = np.append(X_train, train, axis=0)
    if(ncat>2):
        train = load_images_from_folder(path+'/'+LAB[2]+'/',nfile)
        Y_train = np.append(Y_train, np.ones(train.shape[0])*2)
        X_train = np.append(X_train, train, axis=0)
    if(ncat>3):
        for nc in range(3,ncat):
            #train = load_images_from_folder(path+'/N6/',nfile)
            train = load_images_from_folder(path+'/'+LAB[nc]+'/',nfile)
            Y_train = np.append(Y_train, np.ones(train.shape[0])*nc)
            X_train = np.append(X_train, train, axis=0)
    print('X_train.shape=',X_train.shape); 
    print('y_train.shape=',Y_train.shape)
    # One-hot code
    Y_train = to_categorical(Y_train); print('Y_train.shape2=',Y_train.shape)
    # Transform images into 4D tensors
    X_train = X_train.reshape(X_train.shape[0], isy,isx, 1).astype("float32")
    X_train,X_test,Y_train,Y_test=train_test_split(X_train,Y_train, \
            test_size=0.1, random_state=42)
    # define the model
    model = Sequential()
    model.add(Conv2D(16,(3,3),activation = 'relu', input_shape = (isy,isx,1)))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Conv2D(32,(3,3),activation = 'relu'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Flatten())
    model.add(Dense(64,activation='relu'))
    model.add(Dropout(0.2))
    model.add(Dense(32,activation='relu'))
    model.add(Dense(ncat,activation='softmax'))
    model.summary()
    model.compile(loss = 'categorical_crossentropy',
                 optimizer = "adam",
                 metrics = ['accuracy'])
    #print(X_train.shape); print(Y_train.shape)
    history = model.fit(X_train,Y_train,shuffle=True, validation_split=0.1,
                                    batch_size=128, epochs=60)
    # training and validation accuracy
    acc = history.history["accuracy"]
    epochs = range(1, len(acc)+1)
    val_acc = history.history["val_accuracy"]
    loss = history.history["loss"]
    epochs = range(1, len(loss)+1)
    val_loss = history.history["val_loss"]
    if(1==2):
        plt.plot(epochs, acc, "bo-", label="Training Acc")
        plt.plot(epochs, val_acc, "ro--", label="Validation Acc")
        plt.title("Training and Validation Accuracy")
        plt.xlabel("Epochs")
        plt.ylabel("Accuracy")
        plt.legend()
        plt.show()
        # -----------loss-------------
        plt.plot(epochs, loss, "bo-", label="Training Loss")
        plt.plot(epochs, val_loss, "ro--", label="Validation Loss")
        plt.title("Training and Validation Loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()
    print("\nTesting ...")
    loss, accuracy = model.evaluate(X_train, Y_train, verbose=0)
    print("accuracy of training data={:.2f}".format(accuracy))
    loss, accuracy = model.evaluate(X_test, Y_test, verbose=0)
    print("accuracy of testing data={:.2f}".format(accuracy))
    model.save(mod1)
    print('model ',mod1,' is saved...')

#################### model prediction for a single image #############
def modelPred(img):
    ntest,fimgs,test,img5=Sload_images_from_folder(isx,isy,pathT)
    images=[]
    ntest=1; img1=cv2.resize(img,(isx,isy))/255; images.append(img1) 
    test=np.array(images)
    X_test = test
    X_test = X_test.reshape(X_test.shape[0], isy,isx, 1).astype("float32")
    #print('X_test:type,len,shape=',type(X_test),len(X_test),X_test.shape)
    pred = np.argmax(model.predict(X_test),axis=1)      
    #print('pred=',pred)
    if(1==2):
        for i in range(ntest):
            img = X_test[i,:,:,0]
            #print(i,'pred[i]=',pred[i],' lab=',LAB[pred[i]])
            cv2.imshow('img',img)
            cv2.waitKey(0)
    return pred[0]

################################################################
def move1():
    jt=0; NP=0;  
    while(jt < 60 and cap.isOpened()):
        ret,img=cap.read(); jt+=1; 
        (x,y,w,h),(cx,cy),masked,hst1=CNTX(img) 
        x1=x; x2=x+w; y1=y; y2=y+h; x1=x1*F; y1=y1*F; x2=x2*F; y2=y2*F
        img3=img[y1:y2,x1:x2,:]; img3=cv2.cvtColor(img3,cv2.COLOR_BGR2GRAY)
        cv2.imshow('img3',img3); sh0('hst1',hst1); #cw(0)
        KPred=modelPred(img3); LB=LABEL[KPred]
        NP+=1; print(jt,img3.shape,KPred,LB); 
        if cv2.waitKey(1) & 0xFF == ord('q'): break
        if(KPred==0): turn(0.3,0.3,0.1); turn(0,0,0.5) 
        if(KPred==1): turn(-0.3,0.3,0.1); turn(0,0,0.5) 
        if(KPred==2): turn(0.3,-0.3,0.1); turn(0,0,0.5) 
        if(KPred==3): turn(-0.3,-0.3,0.1); turn(0,0,0.5) 
        cw(0)
    cv2.destroyAllWindows()

##################################################
def move2():
    jt=0; NP=0; TARG=0
    while(jt < 60 and cap.isOpened()):
        ret,img=cap.read(); jt+=1; 
        DAT,masked,hst1=CNTS(img) 
        lenD=len(DAT); print('lenD=',lenD,DAT)
        sh0('hst1',hst1); #cw(0)
        for j in range(lenD):
            xc=DAT[j][0]; yc=DAT[j][1]; x=DAT[j][2]; y=DAT[j][3]; w=DAT[j][4]; h=DAT[j][5];
            x1=x; x2=x+w; y1=y; y2=y+h; x1=x1*F; y1=y1*F; x2=x2*F; y2=y2*F
            img3=img[y1:y2,x1:x2,:]; img3=cv2.cvtColor(img3,cv2.COLOR_BGR2GRAY)
            KPred=modelPred(img3); LB1=LAB[KPred]; LB2=LABEL[KPred]
            print(jt,j,img3.shape,KPred,LB1,LB2); 
            if(KPred != TARG): continue
            if(xc<Lx2/3): turn(-0.3,0.3,0.1); turn(0,0,0.1);
            elif(xc>Lx2/3*2): turn(0.3,-0.3,0.1); turn(0,0,0.1);
            else: turn(0.3,0.3,0.1); turn(0,0,0.1);
            sh0('img3',img3); cw(0)
    cv2.destroyAllWindows()



#-----------------------------------------------------------
#------------------------------------- main ----------------
color='green'; h=60
lower1=np.array([50, 30, 90])        # green
upper1=np.array([80, 255, 255])      # green
LAB=['A','L','P','S']; #A=應 L=用 P=物 S=理
LABEL=['App','Lied','Phy','Sics']

if(1==2): takepic(); input('takepic done....')

if(1==1):
    print('importing tensorflow...')
    t0=time.time()
    from tensorflow.keras.preprocessing import image
    from tensorflow.keras.optimizers import RMSprop
    from tensorflow import keras
    import tensorflow as tf
    import shutil
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense
    from tensorflow.keras.layers import Flatten
    from tensorflow.keras.layers import Conv2D
    from tensorflow.keras.layers import MaxPooling2D
    from tensorflow.keras.layers import Dropout
    from tensorflow.keras.utils import to_categorical
    from sklearn.model_selection import train_test_split
    t1=time.time(); TEN=round(t1-t0,3); print('importing done. TEN=',TEN)

path='PNG'; ncat=4; nfile=60; mod1='model_ALPS_60A.h5'
ntest=20; pathT=path+'/TEST/'
print('load path=',path)
isx=64; isy=64; seed=7; np.random.seed(seed)
if(1==2):
    buildmodel()
    t2=time.time(); time_BM=round(t2-t1,2); print('time_BM=',time_BM)
model=keras.models.load_model(mod1)
print('load ',mod1,' ... done')
if(1==2): VERIFY(); input('AI2-build model done...')

#======================= Now apply the model in AI-car movements =====
cap=cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH,Lx)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT,Ly)
sleep(2)  #wait 2 s for camera warm-up
#move1()
move2()
cv2.destroyAllWindows()
