import sys
sys.path.insert(1, '/Users/lucagravina/dptqm/Strucutral codes/Liouvillian block diagonalization')

import numpy as np
import matplotlib.pyplot as plt
from qutip import * 

import scipy.sparse as spr
import block_diagonalize
import time
import os
import pickle


class Parameters():
    """Parameter class for Kerr resonator (used to study 1st order dissipative PT)
    """
    def __init__(self, Δ, N, Grat, Utilde, ηtilde=1):
        """Initializing parameters

        Args:
            N (Int): Scaling factor for TDL. Morally equivalent to the number of resonators or the scaling of the non-linearity.
            Delta (Float): Detuning.
            Grat (Float): Two-photon drive in units of the semiclassical critical drive.
            Utilde (Float): Normalized Kerr-coupling.
            etatilde (Float): Normalized two-photon dissipation. Defaults to 1.
        """
        self.Δ = Δ
        self.N = N 
        self.ηtilde = ηtilde 
        self.Utilde = Utilde
        self.U = self.Utilde/self.N
        self.η = self.ηtilde/self.N
        
        self.Gc = np.abs(self.η*self.Δ/np.sqrt(self.η**2+self.U**2))
        self.G = Grat*self.Gc
        self.n_sc = lambda g : self.U*self.Δ/(self.U**2+self.η**2)*(1+np.sqrt((g**2-1)*np.heaviside(g-1,1)))*np.heaviside(g-1,1) #Semiclassical approximation of the photon number

class Kerr_2γ():
    def __init__(self, Nfock, p):
        """Dissipative Kerr resonator class

        Args:
            Nfock (Int): Hilbert space truncation.
            p (Parameters): Configuration of the Parameters class.
        """
        self.Nfock = Nfock
        self.a = destroy(Nfock)
        self.parity = 1.j*np.pi*self.a.dag()*self.a
        self.parity = self.parity.expm()
        
        self.p = p
        self.c_ops = [np.sqrt(p.η)*self.a**2]
        self.H = -p.Δ*self.a.dag()*self.a + p.G/2 *(self.a.dag()**2 +self.a**2) + p.U/2 *self.a.dag()**2*self.a**2
        self.LL = liouvillian(self.H, self.c_ops)
    
    def bd(self):
        """Block diagonalization of the Liouvillian matrix

        Returns:
            Tuple: Tuple containing 
            - the photon number in the two parity sectors with allowing a steady state;
            - the Liouvillian gap characterizing the aformentioned symmetry sectors;
            - the even-odd coherence;
        """
        P, block_bfs, block_sizes = block_diagonalize.PermMat(self.LL)
        self.bd_L = np.dot(P, np.dot(self.LL.data, np.transpose(P)))
        self.num_blocks, self.blocks_list, self.bl_indices = block_diagonalize.get_blocks(self.bd_L)   

        done=False
        for i in range(int(len(self.blocks_list))):
            block = self.blocks_list[i]
            evals, evecs = Qobj(block).eigenstates()    

            ss_block_form = evecs[-1]
            evec2 = spr.dok_matrix((self.LL.shape[0], 1), dtype='complex')
            evec2[self.bl_indices[i]:self.bl_indices[i]+self.blocks_list[i].shape[0]] = ss_block_form.data
            evec2 = evec2.tocsr()
            evec2 = np.dot(np.transpose(P),evec2)
            ss= Qobj(evec2, dims=[self.LL.dims[0], [1]])
            ss=vector_to_operator(ss)

            if np.real(ss.tr())>0.05:
                ss=ss+ss.dag()
                ss/=(ss.tr())       
                if expect(self.parity, ss)>0.5:
                    num_even = expect(ss, self.a.dag()*self.a)
                    gap_even = evals[-2]
                else:
                    num_odd = expect(ss, self.a.dag()*self.a)
                    gap_odd = evals[-2]
            elif done==False:
                gap_even_odd = np.real(evals[-1])+1.j*np.abs(np.imag(evals[-1]))
                gap_odd_even = np.real(evals[-1])-1.j*np.abs(np.imag(evals[-1]))       
                done=True
        return num_even, num_odd, gap_even, gap_odd, gap_even_odd, gap_odd_even
    
    
def get_cutoff(p, c0=5, cmax=100, step=5, precision=0.005):
    """Estimator of the an appropriate Hilbert space cutoff for a given set of parameters of the Kerr resonator. In each iteration the algorithm increments a trial cutoff by STEP starting from C0.
    It compares successive evaluations of the even photon number and states convergence when their normalized difference is below a user-defined PRECISION.

    Args:
        p (Parameters): Configuration of the Kerr resonator
        c0 (Int, optional): Lower bound for cutoff estimation. Defaults to 5.
        cmax (Int, optional): Upper bound for cutoff estimation. If the selected bound does not lead to convergence, an infinite cutoff is returned. Defaults to 100.
        step (Int, optional): Step-like increment of the trial cutoff in the search. Defaults to 5.
        precision (Float, optional): threshold for convergence. Defaults to 0.005.

    Returns:
        Int: Optimal Hilbert space cutoff 
    """
    metric = lambda x1,x2 : ((np.abs(x1-x2)/np.abs(np.min([x1,x2])))<precision)
    k = Kerr_2γ(Nfock=c0,p=p)
    nprev = k.bd()[0]
    c = c0+step
    while c<cmax:
        k = Kerr_2γ(Nfock=c,p=p)
        n = k.bd()[0]
        if metric(n,nprev):
            c -= step
            while 1:
                step = np.max([step//2,1])
                k = Kerr_2γ(Nfock=c-step,p=p)
                if metric(k.bd()[0],n): 
                    c -= step
                elif step==1: return c 
        else:
            nprev = n
            c += step
    return np.inf


def map_cutoff(Utilderange, Δrange, Grange):
    """Create map with optimal truncations for a user-defined parameter range.

    Args:
        Utilderange (ndarray): Array with selected values of U
        Deltarange (ndarray): Array with selected values of Δ
        Grange (ndarray): Array with selected values of Grat

    Returns:
        ndarray: Array containing the optimal cutoffs for each combination of the aforementioned parameters.
    """
    C = np.zeros((Utilderange.size, Δrange.size, Grange.size),dtype='float')
    for i in range(Utilderange.size):
        Utilde=Utilderange[i]
        for j in range(Δrange.size):
            Δ=Δrange[j]
            for k in range(Grange.size):
                Grat=Grange[k]
                p = Parameters(Δ=Δ, N=10, Grat=Grat, ηtilde=1, Utilde=Utilde)
                C[i,j,k] = get_cutoff(p, c0=5, cmax=100, step=5, precision=0.005)
    return C





#Previous iterations 
"""     def get_cutoff_MK1(cuts, p, precision=0.01):
    metric = lambda x1,x2 : ((np.abs(x1-x2)/np.abs(np.min([x1,x2])))<precision)
    q = deque()
    
    k = Kerr_2γ(Nfock=cuts[0], p=p)
    q.append(0)
    q.append(k.bd()[0])
    
    for i in range(1,cuts.size):
        k = Kerr_2γ(Nfock=cuts[i], p=p)
        ni = k.bd()[0]
        μi = metric(ni,q.pop())
        q.append(μi)
        print(q[0],q[1],i)
        if (q.popleft())*(q[0])==True : 
            return i-1
        q.append(ni)
    
    if q[0]==True:
        print("Carefull: Convergence on border")
        return i
    else: return np.inf
    
    
    
    
def get_cutoff_MK2(p, c0=5, cmax=100, step=5, precision=0.005):
    metric = lambda x1,x2 : ((np.abs(x1-x2)/np.abs(np.min([x1,x2])))<precision)
    k = Kerr_2γ(Nfock=c0,p=p)
    nprev = k.bd()[0]
    c = c0+step
    while c<cmax:
        k = Kerr_2γ(Nfock=c,p=p)
        n = k.bd()[0]
        if metric(n,nprev):
            c -= step
            while 1:
                k = Kerr_2γ(Nfock=c-1,p=p)
                if metric(k.bd()[0],n): c -= 1
                else: return c
        else:
            nprev = n
            c += step
    return np.inf """