##We assume that a list of data-frame is provided with columns Chr; Start; N; Alt; Depth; 
##Weight;Genotype;Number of chromosomes; Number of copies; id;
##And that the number of clusters is known; contamination is also known.
#' Expectation step calculation
#'
#' 
#' @param Schrod A list of dataframes (one for each sample), generated by the Patient_schrodinger_cellularities() function.
#' @param centers Coordinates of the clones: a list of numeric vectors (1 per sample), with coordinates between 0 and 1.
#' @param weights Proportion of mutation in a clone
#' @param adj.factor Factor to compute the probability: makes transition between the cellularity of the clone and the frequency observed
#' @keywords E-Step
e.step<-function(Schrod,centers,weights,adj.factor){
  f<-eval.fik(Schrod = Schrod,centers = centers,weights = weights,
              adj.factor = adj.factor)
  for(k in 1:length(weights)){ ##k corresponds to a clone
    f[,k]<-f[,k]*weights[k]
  }
  ### Normalize fik by mutations
  f_0<-matrix(0,nrow = nrow(f), ncol = ncol(f))
  Id<-Schrod[[1]]$id
  for(m in unique(Id)){
    test<-Id==m
    tot<-sum(f[test,])
    if(tot == 0){
      f_0[test,]<-1/(sum(test)*ncol(f))  
    }
    else{
      f_0[test,]<-f[test,]/tot
    }
  }
  
  f_0
}

#'Maximization step
#'
#' Optimization of clone positions and proportion of mutations in each clone, based on the previously calculated expectation
#' @param fik Matrix giving the probability of each mutation to belong to a specific clone
#' @param Schrod A list of dataframes (one for each sample), generated by the Patient_schrodinger_cellularities() function.
#' @param previous.weights Weights from the previous optimization step (used as priors for this step)
#' @param previous.centers Clone coordinates from previous optimization step (used as priors for this step)
#' @param adj.factor Factor to compute the probability: makes transition between the cellularity of the clone and the frequency observed
#' @param contamination Numeric vector with the fraction of normal cells contaminating the sample
#' @param optim use L-BFS-G optimization from R ("default"), or from optimx ("optimx"), or Differential Evolution ("DEoptim")
#' @keywords EM Maximization
m.step<-function(fik,Schrod,previous.weights,
                 previous.centers,contamination,adj.factor,
                 optim ="default"
){
  if(!is.null(fik)){
    weights<-apply(X = fik,MARGIN = 2,FUN = function(z) sum(z)/length(unique(Schrod[[1]]$id)))
  }
  else{
    weights<-rep(1/length(previous.weights),times = length(previous.weights))
  }
  # weights<-weights/sum(weights) # Overkill
  cur.cent<-list()
  
  if(optim == "default" | optim == "optimx" | optim == "exact"){
    Alt<-matrix(nrow = nrow(Schrod[[1]]),ncol = length(Schrod))
    Depth<-matrix(nrow = nrow(Schrod[[1]]),ncol = length(Schrod))
    
    for(i in 1:length(Schrod)){
      Alt[,i]<-Schrod[[i]]$Alt
      Depth[,i]<-Schrod[[i]]$Depth
    }
  }
  
  ### Function for maximization step
  fnx<-compiler::cmpfun(function(x) {
    r<--fik*eval.fik.m(Schrod = Schrod,centers = x,adj.factor = adj.factor,
                       weights = weights,
                       log = TRUE)
    
    r[fik==0]<-0
    sum(r,
        na.rm = TRUE
    )},
    options = list(optimize = 3)
  )
  
  ### Exact function with recomputation of fik
  efnx<-compiler::cmpfun(function(x) {
    fik<-e.step(Schrod = Schrod, centers = x, weights = weights,adj.factor = adj.factor)
    r<--fik*eval.fik.m(Schrod = Schrod,centers = x,adj.factor = adj.factor,
                       weights = weights,
                       log = TRUE)
    
    
    r[fik==0]<-0
    PI<-matrix(nrow = nrow(fik),ncol = ncol(fik))
    
    for(i in 1:length(weights)){
      PI[,i]<-fik[,i]*log(weights[i])
    }
    PI[fik==0]<-0
    sum(r-PI,
        na.rm = TRUE
    )},
    options = list(optimize = 3)
  )
  
  if(optim == "default"){
    spare<-tryCatch(optim(par = unlist(previous.centers),
                          fn = fnx ,
                          gr= function(x){grbase(fik = fik,adj.factor = adj.factor,centers = x,Alt = Alt,Depth=Depth)},
                          method = "L-BFGS-B",
                          lower = rep(.Machine$double.eps,times = length(unlist(previous.centers))),
                          upper=rep(1,length(unlist(previous.centers)))),
                    #### IF FAILS DUE TO INFINITE VALUE:
                    ####################################
                    error = function(e){
                      message("Gradient failed")
                      optim(par = unlist(previous.centers),
                            fn = fnx ,
                            method = "L-BFGS-B",
                            lower = rep(.Machine$double.eps,times = length(unlist(previous.centers))),
                            upper=rep(1,length(unlist(previous.centers)))
                      )
                    }
    )
    if(!is.list(spare)){
      return(NA)
    }
    return(list(weights=weights,centers=spare$par,val=spare$val))
  }
  else if(optim =="optimx"){
    
    spare<-optimx::optimx(par = unlist(previous.centers),
                          fn = fnx,
                          method = "L-BFGS-B",
                          lower = rep(.Machine$double.eps,times = length(unlist(previous.centers))),
                          upper=rep(1,length(unlist(previous.centers))))
    
    return(list(weights=weights,centers=spare[1:length(unlist(previous.centers))],val=spare$value))
  }
  else if(optim =="DEoptim"){
    spare<-suppressWarnings(DEoptim::DEoptim(fn = efnx,
                                             lower = rep(0,times = length(unlist(previous.centers))),
                                             upper=rep(1,length(unlist(previous.centers))),
                                             control = DEoptim::DEoptim.control(
                                               NP = min(10*length(unlist(previous.centers)),40),
                                               strategy= 3,
                                               itermax = 200,
                                               initialpop = NULL,
                                               CR = 0.9
                                             )
    )
    )
    
    return(list(weights = weights, centers = spare$optim$bestmem,val = spare$optim$bestval, 
                initialpop = spare$member$pop,itermax = 200)
    )
  }
  else if(optim == "exact"){
    new.centers<-grzero(fik,adj.factor,Alt,Depth)
    val<-fnx(new.centers)
    return(list(weights = weights,centers = new.centers, val = val))
  }
  return(list(weights=weights,centers=spare[1:length(unlist(previous.centers))],val=spare$value))
}

Compute.adj.fact<-function(Schrod){ ##Factor used to compute the probability of the binomial distribution
  n<-length(Schrod)
  adj.factor<-matrix(ncol = n,nrow=nrow(Schrod[[1]]))
  for(i in 1:n){
    adj.factor[,i]<-Schrod[[i]]$NC/Schrod[[i]]$NCh
  }
  return(adj.factor)
}
#'Expectation Maximization algorithm
#'
#' Optimization of clone positions and proportion of mutations in each clone.
#' @param Schrod A list of dataframes (one for each sample), generated by the Patient_schrodinger_cellularities() function.
#' @param nclust Number of clones to look for (mandatory if prior_center or prior_weight are null)
#' @param prior_center Clone coordinates (from another analysis) to be used 
#' @param prior_weight Prior on the fraction of mutation in each clone
#' @param contamination Numeric vector with the fraction of normal cells contaminating the sample
#' @param optim use L-BFS-G optimization from R ("default"), or from optimx ("optimx")
#' @param epsilon Stop value: maximal admitted value of the difference in cluster position and weights 
#' between two optimization steps. If NULL, will take 1/(median depth). 
#' @keywords EM
EM.algo<-function(Schrod, nclust=NULL,
                  prior_center=NULL,prior_weight=NULL,
                  contamination, epsilon=10**(-2),
                  optim = "default"
){
  if(is.null(prior_weight)){
    prior_weight<-rep(1/nclust,times = nclust)
    cur.weight<-rep(1/nclust,times = nclust)
  }
  else{
    cur.weight<-prior_weight
  }
  if(is.null(prior_center)){
    prior_center<-c(runif(n = (nclust-1)*length(Schrod),min = 0,max = 1),rep(1,times = length(Schrod)))
  }
  else{
    cur.center<-prior_center
  }
  prior_center<-unlist(cur.center)
  cur.val<-NULL
  eval<-1
  adj.factor<-Compute.adj.fact(Schrod = Schrod)
  if(grepl(pattern = optim,x = "compound",ignore.case = TRUE)){
    if(is.matrix(adj.factor) && ncol(adj.factor)>1){
      unicity_test<-TRUE
      for(i in 1:ncol(adj.factor)){
        if(length(unique(adj.factor[,i]))>1){
          unicity_test<-FALSE
        }
      }
      if(unicity_test){
        optim<-"exact"
      }
      else{
        optim<-"default"
      }
    }
    else{
      if(length(unique(adj.factor))==1){
        message("EM evaluation...")
        optim<-"exact"
      }
      else{
        message("default use...")
        optim<-"default"
      }
    }
  }
  if(optim!="DEoptim"){
    iters<-0
    while(eval>epsilon){
      if(optim == "exact"){
        iters<-iters+1 ### exact can be stuck with meta stable values
        ### Contradictory with convergence of EM...
      }
      tik<-e.step(Schrod = Schrod,centers = cur.center,weights = cur.weight,
                  adj.factor = adj.factor)
      m<-m.step(fik = tik,Schrod = Schrod,previous.weights = cur.weight,
                previous.centers =cur.center,
                adj.factor=adj.factor,optim = optim)
      
      if(!is.list(m)){
        test<-create_priors(nclust = 2,nsample = 2)
        eval_1<-max(abs(prior_center-unlist(test)))
        break      
      }
      else{
        n.weights<-unlist(m$weights)
        n.centers<-list()
        n.val<-m$val
        
        for(i in 1:length(cur.center)){
          n.centers[[i]]<-m$centers[((i-1)*length(cur.center[[1]])+1):((i)*length(cur.center[[1]]))]
        }
        eval<-max(abs(c(n.weights,unlist(n.centers))-c(cur.weight,unlist(cur.center))))
        cur.weight<-n.weights
        #prior_center<-c(prior_center,unlist(n.centers))
        cur.center<-n.centers
        ### Add fik*log(weights) if EM not direct optimization
        
      }
      PI<-matrix(nrow = nrow(tik),ncol= ncol(tik))
      for(i in 1:length(cur.weight)){
        PI[,i]<-tik[,i]*log(cur.weight[i])
      }
      PI[PI==0]<-0
      cur.val<-n.val - sum(PI)
    }
    fik<-e.step(Schrod = Schrod,
                centers = cur.center,
                weights = cur.weight,
                adj.factor = adj.factor)
    
    return(list(fik=fik,weights=cur.weight,centers=cur.center,val=cur.val))
  }
  else{
    ### DIRECT EVALUATION WITH DEoptim
    m<-m.step(fik = NULL,Schrod = Schrod,previous.weights = rep(1,times = length(prior_weight)),
              previous.centers =unlist(prior_center),
              adj.factor=adj.factor,optim = "DEoptim")
    fik<-e.step(Schrod = Schrod,
                centers = m$centers,
                adj.factor = adj.factor,
                rep(1,times = length(prior_weight))
    )
    cur.val<-sum(fik * eval.fik.m(Schrod= Schrod,
                                  centers = m$centers,
                                  weights = cur.weight,
                                  adj.factor = adj.factor,
                                  log = TRUE))
    return(list(fik=fik,weights=cur.weight,centers=m$centers,
                val=cur.val,initialpop = m$itialpop))
  }
}

#'Data filter
#'
#' Keep one possibility per position and ajust weight accordingly
#' @param Schrod A list of dataframes (one for each sample), generated by the Patient_schrodinger_cellularities.
#' @param fik matrix of probability of each possibility to belong to a clone
#' @keywords filter
filter_on_fik<-function(Schrod,fik){
  tmp<-unique(Schrod[[1]]$id)
  keep<-numeric(length = length(tmp))
  
  for(i in 1:length(tmp)){
    u<-Schrod[[1]]$id==tmp[i]
    if(sum(u)>1){
      spare<-fik[u,]
      M<-max(spare)
      if(sum(spare==M)==1){
        l<-which(apply(X = spare,MARGIN = 1,FUN = function(z) sum(z== M)>0))
      }
      else{
        l<-which(apply(X = spare,MARGIN = 1,FUN = function(z) sum(z== M)>0))
        if(length(l)>1){
          l<-l[which.max(apply(X = spare[l,],MARGIN = 1,FUN = sum))]
        }
      }
      keep[i]<-which(u)[l]
    }
    else{
      keep[i]<-which(u)
    }
  }
  result<-Schrod
  for(l in 1:length(Schrod)){
    result[[l]]<-result[[l]][keep,]
  }
  return(result)
}

#'Expectation Maximization algorithm
#'
#' Optimization of clone positions and proportion of mutations in each clone followed 
#' by filtering on most likely possibility for each mutation and a re-optimization.
#' @param Schrod A list of dataframes (one for each sample), generated by the Patient_schrodinger_cellularities() function.
#' @param nclust Number of clones to look for (mandatory if prior_center or prior_weight are null)
#' @param prior_center Clone coordinates (from another analysis) to be used 
#' @param prior_weight Prior on the fraction of mutation in each clone
#' @param contamination Numeric vector with the fraction of normal cells contaminating the sample
#' @param epsilon Stopping condition for the algorithm: what is the minimal tolerated difference of position 
#' or weighted between two steps
#' @param optim use L-BFS-G optimization from R ("default"), or from optimx ("optimx"), or Differential Evolution ("DEoptim")
#' @keywords EM

FullEM<-function(Schrod, nclust, prior_center, prior_weight=NULL, 
                 contamination, epsilon=5*10**(-3),
                 optim = "default"
){
  if(length(prior_weight!=nclust)){
    prior_weight<-rep(1/nclust,times = nclust)
  }
  E_out<-EM.algo(Schrod = Schrod, nclust = nclust,
                 prior_center = prior_center, prior_weight = prior_weight, 
                 contamination = contamination, epsilon = epsilon,
                 optim = optim)
  if(is.list(E_out)){
    F_out<-filter_on_fik(Schrod = Schrod,fik = E_out$fik)
  }
  return(list(EM.output = E_out, filtered.data=F_out))
}

#'Clonal fraction prior creation
#'
#' Semi-random generation of clonal priors
#' @param nclust Number of clones to look for.
#' @param nsample Number of samples
#' @param prior Possible priors known (the position of each element in a list corresponds to 1 clone)
#' @keywords EM

create_priors<-function(nclust,nsample,prior=NULL){
  result<-list()
  if(is.null(prior)){
    for(i in 1:nsample){
      result[[i]]<-c(runif(n = nclust-1,min = 0,max = 1),1)
    }
    return(result)
  }
  else if(length(prior[[1]])<nclust){## Need to complete the list
    if(sum(list_prod(prior)==1)>0){ ## there is an ancestral clone in the priors given
      for(i in 1:nsample){
        result[[i]]<-c(prior[[i]],runif(n = nclust-length(prior[[i]])))
      }
      return(result)
    }
    else{##need to add ancestral clone
      for(i in 1:nsample){
        result[[i]]<-c(prior[[i]],runif(n = nclust-1-length(prior[[i]])),1)
      }
      return(result)
    }
  }
  else if(length(prior[[1]])==nclust){
    return(prior)
  }
  else{## need to remove elements
    lp<-list_prod(prior)
    if(sum(lp>0.95**nsample)>0){ ## there is an ancestral clone in the priors given
      w<-which.max(lp>0.95**nsample)
      for(i in 1:nsample){
        result[[i]]<-c(sample(x = prior[[i]],size = nclust-1,replace = F),prior[[i]][w])   
      }
      return(result)
    }
    else{
      for(i in 1:nsample){
        result[[i]]<-c(sample(x = prior[[i]],size = nclust-1,replace = F),1)
      }
      return(result)
    }
  }
}

add.to.list<-function(...){
  c(as.list(...))
}

#'Expectation Maximization algorithm
#'
#' Optimization of clone positions and proportion of mutations in each clone followed 
#' by filtering on most likely possibility for each mutation and a re-optimization. Then gives out the possibility with maximal likelihood
#' Relies on foreach
#' @param Schrod A list of dataframes (one for each sample), generated by the Patient_schrodinger_cellularities() function.
#' @param nclust Number of clones to look for (mandatory if prior_center or prior_weight are null)
#' @param prior_center Clone coordinates (from another analysis) to be used 
#' @param prior_weight Prior on the fraction of mutation in each clone
#' @param contamination Numeric vector with the fraction of normal cells contaminating the sample
#' @param epsilon Stopping condition for the algorithm: what is the minimal tolerated difference of position or weighted between two steps
#' @param Initializations Maximal number of independant initial condition tests to be tried
#' @param optim use L-BFS-G optimization from R ("default"), or from optimx ("optimx"), or Differential Evolution ("DEoptim")
#' @param keep.all.models Should the function output the best model (default; FALSE), or all models tested (if set to true)
#' @import foreach
#' @importFrom doParallel registerDoParallel
#' @importFrom parallel makeCluster stopCluster
#' @keywords EM
parallelEM<-function(Schrod,nclust,epsilon,contamination,
                     prior_center=NULL,prior_weight=NULL,
                     Initializations=1,
                     optim = "default",
                     keep.all.models = FALSE
){
  result<-list()
  for(i in 1:Initializations){
    result[[i]]<-FullEM(Schrod = Schrod,nclust = nclust,
                        prior_weight = prior_weight,
                        contamination = contamination,epsilon = epsilon,
                        prior_center = create_priors(nclust = nclust,
                                                     nsample = length(Schrod),
                                                     prior = prior_center),
                        optim = optim
    )
  } 
  if(keep.all.models){
    if(Initializations>1){
      return(result)
    }
    else{
      return(result[[1]])
    }
  }
  else{
    M<-result[[1]]$EM.output$val
    Mindex<-1
    if(length(result)>1){
      for(i in 2:length(result)){
        if(result[[i]]$EM.output$val<M){
          M<-result[[i]]$EM.output$val
          Mindex<-i
        }
      }
    }
    return(result[[Mindex]])
  }
}

#' Expectation Maximization
#'
#' Maximization of the likelihood given a mixture of binomial distributions
#' @param Schrod List of dataframes, output of the Schrodinger function or the EM algorithm
#' @param contamination The fraction of normal cells in the sample
#' @param prior_weight If known a list of priors (fraction of mutations in a clone) to be used in the clustering
#' @param nclone_range Number of clusters to look for
#' @param Initializations Maximal number of independant initial condition tests to be tried
#' @param epsilon Stop value: maximal admitted value of the difference in cluster position and weights between two optimization steps.
#' @param ncores Number of CPUs to be used
#' @param clone_priors If known a list of priors (cell prevalence) to be used in the clustering
#' @param FLASH should it use FLASH algorithm to create priors
#' @param optim use L-BFS-G optimization from R ("default"), or from optimx ("optimx"), or Differential Evolution ("DEoptim")
#' @param keep.all.models Should the function output the best model (default; FALSE), or all models tested (if set to true)
#' @param model.selection The function to minimize for the model selection: can be "AIC", "BIC", or numeric. In numeric, the function
#'uses a variant of the BIC by multiplication of the k*ln(n) factor. If >1, it will select models with lower complexity.
#' @keywords EM clustering number
EM_clustering<-function(Schrod,contamination,prior_weight=NULL, clone_priors=NULL, Initializations=1,
                        nclone_range=2:5, epsilon=0.01,ncores = 2,
                        model.selection = "BIC",optim = "default",keep.all.models = FALSE,
                        FLASH = FALSE){
  list_out_EM<-list()
  if(FLASH){
    tree<-Cellular_preclustering(Schrod)$tree
  }
  if(ncores >1){
    cl <- parallel::makeCluster( ncores )
    doParallel::registerDoParallel(cl)
    
    list_out_EM<-foreach::foreach(i=paste(rep(nclone_range,each = Initializations),c("",rep("_jit",times = Initializations-1))),
                                  ### jitter around priors if more than 1
                                  .export = c("parallelEM","FullEM","EM.algo","create_priors",
                                              "add.to.list","e.step","m.step","list_prod",
                                              "Compute.adj.fact","eval.fik","eval.fik.m",
                                              "filter_on_fik","Create_prior_cutTree","grbase")) %dopar% {
                                                
                                                if(FLASH){
                                                  if(grepl(pattern= "_",x= i)){
                                                    i<-as.numeric(unlist(strsplit(x = i,split = "_"))[1])
                                                    jitter <- TRUE
                                                  }
                                                  else{
                                                    i<-as.numeric(i)
                                                    jitter <- FALSE
                                                  }
                                                  priors<-Create_prior_cutTree(tree,Schrod,i,jitter)
                                                  return(parallelEM(Schrod = Schrod,nclust = i,epsilon = epsilon,
                                                                    contamination = contamination,prior_center = priors$centers,
                                                                    prior_weight = priors$weights,Initializations = 1,
                                                                    optim = optim,keep.all.models = keep.all.models                                                  )
                                                  )
                                                }
                                                else{
                                                  i<-as.numeric(unlist(strsplit(x = i,split = "_"))[1])
                                                  return(parallelEM(Schrod = Schrod,nclust = i,epsilon = epsilon,
                                                                    contamination = contamination,prior_center = clone_priors,
                                                                    prior_weight = prior_weight,Initializations = 1,
                                                                    optim = optim,keep.all.models = keep.all.models                                                  )
                                                  )
                                                }
                                              }
    #doParallel::stopImplicitCluster()
    parallel::stopCluster(cl)
  }
  else{
    index<-0
    for(i in 1:length(nclone_range)){
      for(init in 1:Initializations){
        if(FLASH){
          if(init == 1){
            priors<-Create_prior_cutTree(tree,Schrod,nclone_range[i],jitter = FALSE)
          }
          else{
            priors<-Create_prior_cutTree(tree,Schrod,nclone_range[i],jitter = TRUE)
            
          }
          index<-index+1
          list_out_EM[[index]]<-parallelEM(Schrod = Schrod,nclust = nclone_range[i],
                                           epsilon = epsilon,
                                           contamination = contamination,
                                           prior_center = priors$centers,
                                           prior_weight = priors$weights,
                                           Initializations = 1,
                                           optim = optim,
                                           keep.all.models = keep.all.models)
        }
        else{
          index<-index+1
          list_out_EM[[index]]<-parallelEM(Schrod = Schrod,nclust = nclone_range[i],epsilon = epsilon,
                                           contamination = contamination,prior_center = clone_priors,
                                           prior_weight = prior_weight,Initializations = Initializations,
                                           optim = optim,
                                           keep.all.models = keep.all.models)
        }
      }
    }
  }
  if(!keep.all.models){
    ### 
    # Criterion
    #
    result<-list_out_EM[[which.min(BIC_criterion(EM_out_list = list_out_EM, model.selection = model.selection))]]
    
    
    result$cluster<-hard.clustering(EM_out = result$EM.output)
    
    
    return(result)
  }
  else{
    Crit<-BIC_criterion(EM_out_list = list_out_EM, model.selection = model.selection)
    
    for(i in 1:length(list_out_EM)){
      list_out_EM[[i]]$cluster<-hard.clustering(EM_out = list_out_EM[[i]]$EM.output)
      list_out_EM[[i]]$Crit<-Crit[i]
    }
    
    return(list_out_EM)
  }
}