import numpy as np
import matplotlib.pyplot as plt


def schem_expl(y0,T,a,M,N):
    dx, dt = 2*a/(N+1), T/(M+1)
    A=2*np.diag(np.ones(N))-np.diag(np.ones(N-1),-1)-np.diag(np.ones(N-1),1)
    G=np.eye(N)-dt/(dx**2)*A
    Y=np.zeros((M+2,N+2))
    tpt=np.linspace(0,T,M+2)
    xpt=np.linspace(-a,a,N+2)
    Y[0,:]=y0(xpt)
    Y[:,0]=1
    F=np.zeros(N)
    for i in range(M+1):
        F=Y[i,1:-1]*(1-Y[i,1:-1])
        F[0]=F[0]+1/(dx**2)
        Y[i+1,1:-1]=np.dot(G,Y[i,1:-1])+dt*F
    return tpt, xpt, Y

def y1(x):
    return np.ones(np.size(x))*[x<-4]

tpt, xpt, Y = schem_expl(y1,10,10,10000,200)



for i in range(0,tpt.size):
    if i%200==0:
        plt.plot(xpt,Y[i,:])
#plt.plot(Y[201,:])
plt.show()


