######################
# helper functions   #
######################
ncat              <- function(myvec){length(unique(myvec[is.na(myvec)==F]))}
is.wholenumber    <- function(myvec, tol = .Machine$double.eps^0.5){if(is.numeric(myvec)){abs(myvec - round(myvec)) < tol}else{rep(FALSE,length(myvec))}}
test.binary       <- function(myvec){if(ncat(myvec)==2){return(TRUE)}else{return(FALSE)}}
test.poisson      <- function(myvec){if(ncat(myvec)%in%seq(3,20) & all(is.wholenumber(myvec[is.na(myvec)==F])) & is.numeric(myvec[is.na(myvec)==F])){return(TRUE)}else{return(FALSE)}}
test.multinomial  <- function(myvec){if(ncat(myvec)%in%seq(3,9)  & is.factor(myvec[is.na(myvec)==F])){return(TRUE)}else{return(FALSE)}}
is.binary         <- function(myvec){all(myvec[is.na(myvec)==F]%in%c(0,1))}
   
assign.family     <- function(mymat,families=NULL){
                                     if(is.null(families)==TRUE){
                                     bin.ind  <- pois.ind <- mult.ind <- vector(mode = "logical", length = ncol(mymat))
                                     ncat.mat <- apply(mymat,2,ncat)
                                     for(i in 1:ncol(mymat)){bin.ind[i]  <- test.binary(mymat[,i])
                                                             pois.ind[i] <- test.poisson(mymat[,i])
                                                             mult.ind[i] <- test.multinomial(mymat[,i])}
                                     fam <- rep("gaussian",ncol(mymat))
                                     fam[bin.ind]  <-  "binomial"
                                     fam[pois.ind] <-  "poisson"
                                     fam[mult.ind] <-  paste0("multinom(K=",ncat.mat[mult.ind]-1,")")
                                     }else{fam<-families}
                                     return(fam)
}

make.formula <- function(dat,approach,index,fam,surv=F,survnodes=NULL){
    if(approach[1]%in%c("GLM","GAM")){
    outc <- colnames(dat)[length(colnames(dat))]
    dat2 <- subset(dat,select=-length(colnames(dat)))
    if(surv==TRUE){
      dat2 <- dat2[,!(names(dat2) %in% survnodes),drop=F]
    }
    if(length(colnames(dat))==1){covar <- "1"}else{
      if(approach=="GLM"){covar <- paste(colnames(dat2),collapse="+")}
      if(approach=="GAM"){
        cts.x <- apply(dat2, 2, function(x) (length(unique(x)) > 10))
        if(sum(!cts.x) > 0 & sum(cts.x) > 0){
        covar <-paste(paste(paste("s(",
            colnames(dat2[, cts.x, drop = FALSE]),
            ")", sep = ""), collapse = "+"),
            "+", paste(colnames(dat2[, !cts.x, drop = FALSE]),
                collapse = "+"))
        }
        if(sum(!cts.x) > 0 & sum(cts.x) == 0){
        covar <- paste(colnames(dat2[, !cts.x, drop = FALSE]),
                collapse = "+")
        }
        if(sum(!cts.x) == 0 & sum(cts.x) > 0){
        covar <- paste(paste(paste("s(",
            colnames(dat2[, cts.x, drop = FALSE]),
            ")", sep = ""), collapse = "+"))
         }
                        }
      }
      gam.formula <- paste(outc,"~",covar); gam.formula <- as.formula(gam.formula)
      if(substr(fam,1,4)=="mult"){
      gam.formula.2 <- paste("~",as.character(gam.formula)[3])
      gam.formula.3 <- paste0("as.numeric(as.character(",as.character(gam.formula)[2],"))","~",as.character(gam.formula)[3])
      gam.formula <- lapply(c(gam.formula.3,rep(list(gam.formula.2),as.numeric(substr(fam,12,12))-1)),as.formula)  }
      }else{gam.formula<-eval(parse(text=approach[index]))}
      return(gam.formula)
}

calculate.support <- function(dat,A,intervention,projection=projection_linear,...){
 if(length(A)==1){size<-1; dat$placeholder <- dat[,A]; A <- c(A,"placeholder"); intervention <- cbind(intervention,intervention)}else{size<-2} # this fixes the problem when we only have one A column
 dat2 <- dat3 <-dat
 if(dim(dat[,A])[2]!=ncol(intervention)){stop("Number of columns of A and intervention don't match")}
 my.cuts <- matrix(NA,ncol=ncol(intervention),nrow=nrow(intervention)-1)
 for(i in 1:dim(my.cuts)[2]){my.cuts[,i] <- (intervention[,i][-length(intervention[,i])] + intervention[,i][-1])/2}  # definition of intervals (and epsilon indirectly)
 for(i in 1:length(A)){
 dat[,A][,i] <- cut(subset(dat,select=A)[,i],breaks=c(-Inf,my.cuts[,i],Inf))  # needs adaption if epsilon would be controlled directly
 }
 cut.intervals <- rep(list(NULL),length(A))
 support.model.names  <-   paste("sm_",as.vector(outer(A, c(seq(1:length(intervention[,1]))), paste, sep="_")),sep="")
 followed_A <-  rep(list(rep(list(NULL),length(A))),nrow(intervention))
 followed_A_mat <- rep(list(NA),nrow(intervention))
 followed_A_consec <- rep(list(NA),length(A))
 for(i in 1:length(cut.intervals)){cut.intervals[[i]] <- cbind(c(-Inf,my.cuts[,i],Inf)[-length(c(-Inf,my.cuts[,i],Inf))] , c(-Inf,my.cuts[,i],Inf)[-1] )}
 for(i in 1:length(A)){
  for(j in 1:nrow(intervention)){
  followed_A[[j]][[i]] <- as.numeric((dat2[,A][,i] > cut.intervals[[i]][j,1]) & (dat2[,A][,i] <  cut.intervals[[i]][j,2]))  # crude
 }}
 for(i in 1:length(followed_A)){followed_A_mat[[i]]<-matrix(unlist(followed_A[[i]]),ncol=length(followed_A[[i]]))} # crude 
 followed_A_consec <- followed_A_mat
 itt <- function(myvec){myvec[is.na(myvec)]<-0;for(i in 2:length(myvec)){if(myvec[i-1]==0){myvec[i]<-0}}
                        return(myvec)}
 itt2 <- function(mymatrix){t(apply(mymatrix,1,itt))}
 followed_A_consec <- lapply(followed_A_consec,itt2) # -> if you don't follow rule at time t, you can't follow in future
 include_time0 <- function(mymat){return(cbind(1,mymat))}
 include<- lapply(followed_A_consec,include_time0)   # at time 0, before intervention, everyone follows the rule 
                                                     # CRUCIAL: who to include in support model. Currently, only those who follow "rule"
 for(i in 1:length(A)){
  for(j in 1:nrow(intervention)){
  dat3[,A][,i]    <- as.numeric((dat2[,A][,i] > cut.intervals[[i]][j,1]) & (dat2[,A][,i] <  cut.intervals[[i]][j,2]))
  support.model   <- suppressWarnings(try(glm(as.formula(paste(A[i],"~.")),data=dat3[as.logical(include[[j]][,i]),1:(which(colnames(dat3)%in%A[i]))],family=binomial),silent=TRUE))  # support given FULL past, adapt as needed if required
  assign(paste("sm_",A[i],"_",c(seq(1:length(intervention[,1])))[j],sep=""),support.model)
  dat3[,A][,i]  <- dat2[,A][,i]
 }}
 include_list <- lapply(lapply(include,as.data.frame),as.list)  # list, to define subset which follows the "rule" of interest
 all.predictions  <- all.predictions.2 <- rep(list(rep(list(rep(0,nrow(dat))),length(A))),nrow(intervention))
  for(i in 1:length(A)){
  for(j in 1:nrow(intervention)){
  pred.model   <- try(get(paste("sm_",A[i],"_",c(seq(1:length(intervention[,1])))[j],sep="")),silent=TRUE) # TO DO: this model, or variable screening?
  if(!inherits(pred.model, "try-error")){all.predictions[[j]][[i]][as.logical(include_list[[j]][[i]])]<-suppressWarnings(try(predict(pred.model,type="response",newdata=pred.model$data),silent=TRUE))}
  #all.predictions[[j]][[i]][followed_A_consec[[j]][,i]==0]<- 0 #check? Should predictions for those that don't follow rule be zero? -> 13.2.21: ?
  }}
  mymean <- function(mymat){apply(mymat,2,mean)}
  crude_support <- matrix(unlist(lapply(followed_A_consec,mymean)),nrow=nrow(intervention),ncol=ncol(intervention),byrow=T) # CRUDE SUPPORT, i.e. ga
  exp.support <- function(vec){mean(vec)}
  mnz.list      <- function(lis){lapply(lis,exp.support)}
  cond_support  <- suppressWarnings(matrix(unlist(lapply(all.predictions,mnz.list)),nrow=nrow(intervention),ncol=ncol(intervention),byrow=T)) # suppress Warnings, as min(empty)=INf -> warning
  cond_support[is.na(cond_support) | is.infinite(cond_support)] <-0
  cond_support <- round(t(apply(cond_support,1,cumprod)),digits=6)
  crude_weights <- apply(crude_support,2,projection,...)
  cond_weights <- apply(cond_support,2,projection,...)
  if(size==1){crude_support <- crude_support[,1,drop=F];cond_support <- cond_support[,1,drop=F]}
  return(list(crude_weights=crude_weights,cond_weights=cond_weights,crude_support=crude_support,cond_support=cond_support, gal = all.predictions))
}

projection_linear <- function(x,c1=0.1,c2=0.1){
 if(c1<0 | c1>1){stop("c1 needs to be in [0, 1]")}
 if(c2<0 | c2>1){stop("c2 needs to be in [0, 1]")}
 w <- x; w[x>=c2] <- 1; w[x<c2] <- c1 + ((1-c1)/(c2))*x[x<c2]; w[x==0] <- c1
 return(w)
 }


require.package <- function(package, message = paste("loading required package (", 
    package, ") failed; please install", sep = "")){
    if (!requireNamespace(package, quietly = FALSE)) {
        stop(message, call. = FALSE)
    }
    invisible(TRUE)
}

screen.cramersv <- function(dat, form, nscreen=4, cts.num=10, ...){
    if(length(all.vars(formula(form))[-1])>1){
    dat <- na.omit(dat[,all.vars(formula(form))] )
    var_cont <- apply(dat, 2, function(x) (length(unique(x)) > cts.num))
    cutf <- function(x){cut(x, unique(quantile(x, prob = c(0, 0.2, 0.4, 0.6, 0.8, 1))),include.lowest=T)}
    if(any(var_cont)){dat[, var_cont] <- apply(dat[, var_cont, drop = FALSE], 2, cutf)}
    Y <- dat[,all.vars(formula(form))[1]]; X <- dat[,all.vars(formula(form))[-1]]
    calc_cram_v <- function(x_var, y_var) cramer(table(y_var, x_var))
    cramers_v <- apply(X, 2, calc_cram_v, y_var = Y)
    whichVariable <- colnames(X)[unname(rank(-cramers_v, ties.method = "random") <= nscreen)]}else{whichVariable <- NULL}
    return(whichVariable)
}

screen.glmnet.cramer <- function(dat, form, alpha = 1, pw=T, nfolds = 10, nlambda = 150, ...){
    require.package("glmnet")
    if(substr(form,1,4)=="list"){form<-paste(strsplit(strsplit(form,"))")[[1]][1],"\\(")[[1]][4],"~",strsplit(strsplit(form,",")[[1]][1],"~")[[1]][2])} 
    if(length(all.vars(formula(form)))>2){
    dat <- na.omit(dat[,all.vars(formula(form))] )
    Y <- as.vector(as.matrix(dat[,all.vars(formula(form))[1]])); X <- as.data.frame(dat[,all.vars(formula(form))[-1]])
    savedat<-dat; saveX<-X
    myfamily <- assign.family(data.frame(dat[,all.vars(formula(form))[1]])) 
    if(substr(myfamily,1,4)=="mult"){myfamily<-"multinomial"}
    # needed for factor variables later on
    if (ncol(X) > 26 * 27) stop("Too many variables for this screening algorithm.\n Contact Michael Schomaker for solution.")
    let <- c(letters, sort(do.call("paste0", expand.grid(letters, letters[1:26]))))
    names(X) <- let[1:ncol(X)]
    # factors are coded as dummies which are standardized in cv.glmnet()
    # intercept is not in model.matrix() because its already in cv.glmnet()
    is_fact_var <- sapply(X, is.factor)
    X <- try(model.matrix(~ -1 + ., data = X), silent = FALSE)
    successfulfit <- FALSE
    cvIndex <- rep(1:nfolds,trunc(nrow(X)/nfolds)+1)[1:nrow(X)]
    fitCV <- try(glmnet::cv.glmnet(
      x = X, y = Y, lambda = NULL, type.measure = "deviance",
      nfolds = nfolds, family = myfamily, alpha = alpha,
      nlambda = nlambda, keep = T, foldid=cvIndex
    ), silent = TRUE)
    # if no variable was selected, penalization might have been too strong, try log(lambda)
     if(!inherits(fitCV,"try-error")){if (all(fitCV$nzero == 0) | all(is.na(fitCV$nzero))) {
      fitCV <- try(glmnet::cv.glmnet(
        x = X, y = Y, lambda = log(fitCV$glmnet.fit$lambda + 1), type.measure = "deviance",
        nfolds = nfolds, family = myfamily, alpha = alpha, keep = T, foldid=cvIndex
      ), silent = TRUE)
    }}
    if(inherits(fitCV,"try-error")){successfulfit <- FALSE}else{successfulfit <- TRUE}
    whichVariable <- NULL
    if(successfulfit==TRUE){
    coefs <- coef(fitCV$glmnet.fit, s = fitCV$lambda.min)
    if(myfamily!="multinomial"){if(all(coefs[-1]==0)){whichVariable<-NULL}}else{
       min1<-function(vec){vec[-1]}
       if(all(unlist(lapply(coefs,min1))==0)){whichVariable<-NULL}}
    #}else{  
    if(myfamily!="multinomial"){var_nms <- coefs@Dimnames[[1]]}else{
                                var_nms <- coefs[[1]]@Dimnames[[1]]
                                allzero<-function(vec){all(vec==0)}  
                                sel<-apply(do.call("cbind",coefs)[-1,],1,allzero)
                                coefs <- coefs[[1]][-1];coefs[sel]<-0;coefs[!sel]<-0.1
                                coefs <- c(0,coefs)
                                }
    # Instead of Group Lasso:
    # If any level of a dummy coded factor is selected, the whole factor is selected
    if (any(is_fact_var)) {
      nms_fac <- names(which(is_fact_var))
      is_selected <- coefs[-1] != 0 # drop intercept
      # model.matrix adds numbers to dummy coded factors which we need to get rid of
      var_nms_sel <- gsub("[^::a-z::]", "", var_nms[-1][is_selected])
      sel_fac <- nms_fac[nms_fac %in% var_nms_sel]
      sel_numer <- var_nms_sel[!var_nms_sel %in% sel_fac]
      all_sel_vars <- c(sel_fac, sel_numer)
      whichVariable <- names(is_fact_var) %in% all_sel_vars
      } else {
      # metric variables only
        whichVariable <- coefs[-1] != 0
      }
      whichVariable <- colnames(saveX)[whichVariable]
      }
    #}  
    if(is.null(whichVariable)){whichVariable<-screen.cramersv(dat=savedat,form=form)
    if(pw==T){cat("Lasso failed and screening was based on Cramer's V (for ",form,")\n")}}
    }else{whichVariable<-NULL}
    return(whichVariable)
} 



censor <- function(vec,C.index){
   start.cens <- (which(vec==1)[which(vec==1)%in%C.index])[1]
   if(is.na(start.cens)==FALSE){if(start.cens!=length(vec)){vec[(start.cens+1):length(vec)] <- NA}}
   return(vec)
 }
 
adjust.sim.surv <- function(mat,Yn){  #improve for-loop
  mymin <- function(vec){if(length(vec)>0){return(min(vec))}else{return(NA)}}
  find.first <- function(vec){mymin(which(vec==1))}
  first.event <- apply(mat[,Yn,drop=F],1,find.first)
  censor.at <- Yn[first.event]
  position <- rep(NA,nrow(mat))
  for(i in 1:nrow(mat)){if(length(which(censor.at[i]==colnames(mat)))>0){position[i]<-which(censor.at[i]==colnames(mat))}}
  for(i in 1:nrow(mat)){
    if(is.na(position[i])==FALSE){
      if(position[i]+1<=ncol(mat)){
        mat[i,(position[i]+1):ncol(mat)]<-NA
        mat[i,Yn[(min(first.event[i]+1,length(Yn))):length(Yn)]]<-1
        } 
    }}
return(mat)
}


multiResultClass <- function(result1=NULL,result2=NULL)
{
  me <- list(
    result1 = result1,
    result2 = result2
  )
  class(me) <- append(class(me),"multiResultClass")
  return(me)
}

 
multi.help <- function(vec){which(t(rmultinom(1,1,vec))==1)-1}
rmulti <- function(probmat){apply(probmat,1,multi.help)}
prop <- function(vec,categ=0){vec <- na.omit(vec);sum(vec==categ)/length(vec)}
rmean <- function(vec){
vec <- na.omit(vec)
if(!is.factor(vec)){mean(vec)}else{mean(as.numeric(as.character(vec)))}}
lrmean <- function(mmat,ind){sapply(subset(mmat,select=ind),rmean)}

factor.to.numeric <- function(vec,verb){nf <- levels(na.omit(vec))
nums <- 0:(length(nf)-1)
code <- data.frame(nf,nums)
recoding <- paste(apply(code,1,paste,collapse=" replaced by "),collapse=" ; ")
vec2 <- rep(NA,length(vec))
for(i in 1:length(nums)){vec2[vec==code$nf[i]]<-code$nums[i]}
if(verb==TRUE){cat(paste(recoding,"\n"))}
return(as.numeric(vec2))
}

recode_to_factor <- function(vec, verb = TRUE) {
  vec <- as.factor(vec)
  nf <- levels(na.omit(vec))           
  nums <- 0:(length(nf) - 1)            
  code <- data.frame(nf = nf, nums = nums)  
  recoding <- paste(apply(code, 1, function(row) paste(row, collapse = " replaced by ")),
                    collapse = " ; ")
  vec2 <- as.numeric(vec) - 1         
  if(verb){cat(paste(recoding, "\n")) }
  return(factor(vec2, levels = nums))   
}


binary.to.zeroone<-function(vec,verb){nf <- unique(na.omit(vec))
nums <- c(0,1)
code <- data.frame(nf,nums)
recoding <- paste(apply(code,1,paste,collapse=" replaced by "),collapse=" ; ")
vec2 <- rep(NA,length(vec))
for(i in 1:length(nums)){vec2[vec==code$nf[i]]<-code$nums[i]}
if(verb==TRUE){cat(paste(recoding,"\n"))}
return(vec2)
}

right.coding<-function(vec){
uv <- unique(na.omit(vec))
pc <- factor(0:(length(uv)-1))
if(identical(levels(uv),levels(pc))){return(TRUE)}else{return(FALSE)}
}

extract.families <- function(forms=NULL,fdata){
 if(is.null(forms)==FALSE){
 fams <- matrix(assign.family(fdata),nrow=1,dimnames=list(NULL,colnames(fdata)))
 outcomes <- model.fams <-  rep(NA,length(forms))
 for(i in 1:length(forms)){outcomes[i]<- strsplit(forms[i],"~")[[1]][1]}
 for(i in 1:length(outcomes)){if(nchar(outcomes[i])>4){if(substr(outcomes[i],1,4)=="list"){outcomes[i]<-strsplit(strsplit(outcomes[i],"\\(")[[1]][4],"\\)")[[1]][1]}}}
 #outcomes <- gsub('.{1}$', '', outcomes)
 for(i in 1:length(forms)){model.fams[i] <- fams[which(colnames(fams)%in%outcomes[i])]}
 return(model.fams)
 }else{return(NULL)}
}

missing.data <- function(ml){any(lapply(ml,is.matrix)==FALSE)}

cramer<-function(mt){
    allterms <- matrix(NA,nrow=nrow(mt),ncol=ncol(mt))
    n <- sum(mt)
     for(i in 1:nrow(mt)){
      for(j in 1:ncol(mt)){
        allterms[i,j]<- ((mt[i,j]- ((sum(mt[i,])*sum(mt[,j]))/(n)) )^2)/(((sum(mt[i,])*sum(mt[,j]))/(n)))
     }}
     cramer <- sqrt(sum(allterms)/(n*(min(nrow(mt),ncol(mt))-1)))
     cramer
}

make.interval <- function(vec){
ni <- length(vec)
intvs <- rep(list(NA),ni)
nv <- c(0,vec)
for(i in 1:ni){intvs[[i]] <- seq(nv[i],nv[i+1])}
intvs
}

.onAttach <- function(libname = find.package("CICI"), pkgname = "CICI") {
packageStartupMessage("The manual for this package will be available soon. \n Type ?CICI for a first overview.")
}

mi.inference <- function (est, std.err, confidence = 0.95){
    qstar <- est[[1]]
    for (i in 2:length(est)) {
        qstar <- cbind(qstar, est[[i]])
    }
    qbar <- apply(qstar, 1, mean)
    u <- std.err[[1]]
    for (i in 2:length(std.err)) {
        u <- cbind(u, std.err[[i]])
    }
    u <- u^2
    ubar <- apply(u, 1, mean)
    bm <- apply(qstar, 1, var)
    m <- dim(qstar)[2]
    tm <- ubar + ((1 + (1/m)) * bm)
    rem <- (1 + (1/m)) * bm/ubar
    nu <- (m - 1) * (1 + (1/rem))^2
    alpha <- 1 - (1 - confidence)/2
    low <- qbar - qt(alpha, nu) * sqrt(tm)
    up <- qbar + qt(alpha, nu) * sqrt(tm)
    result <- list(est = qbar, std.err = sqrt(tm), df = nu, 
        lower = low, upper = up, r = rem)
    result
}

correct.models <- list(
  "L"=c("adherence.1 ~ comorbidity.0 + efv.0",
        "weight.1 ~ sex + log_age",
        "comorbidity.1 ~ log_age + weight.0 + comorbidity.0",
        "list(as.numeric(as.character(dose.1)) ~  I(sqrt(weight.1)) + dose.0,  ~   I(sqrt(weight.1)) + dose.0, ~ I(sqrt(weight.1)) + dose.0)",
        "adherence.2 ~ comorbidity.1 + adherence.1 + efv.1",
        "weight.2 ~ weight.1 + comorbidity.1",
        "comorbidity.2 ~ log_age + weight.1 + comorbidity.1",
        "list(as.numeric(as.character(dose.2)) ~ I(sqrt(weight.2)) + dose.1, ~ I(sqrt(weight.2)) + dose.1, ~  I(sqrt(weight.2)) + dose.1)",
        "adherence.3 ~ comorbidity.2 + adherence.2 + efv.2",
        "weight.3 ~ weight.2 + comorbidity.2",
        "comorbidity.3 ~ log_age + weight.2 + comorbidity.2",
        "list(as.numeric(as.character(dose.3)) ~ I(sqrt(weight.3)) + dose.2, ~ I(sqrt(weight.3)) + dose.2, ~ I(sqrt(weight.3)) + dose.2)",
        "adherence.4 ~ comorbidity.3 + adherence.3 + efv.3",
        "weight.4 ~ weight.3 + comorbidity.3",
        "comorbidity.4 ~ log_age + weight.2 + comorbidity.2",
        "list(as.numeric(as.character(dose.4)) ~ I(sqrt(weight.4)) + dose.3, ~ I(sqrt(weight.4)) + dose.3, ~ I(sqrt(weight.4)) + dose.3)"
  ),
  "A"=c("efv.0 ~ log_age + metabolic*dose.0",
        "efv.1 ~ log_age + dose.0 + metabolic*adherence.1",
        "efv.2 ~ log_age + dose.0 + metabolic*adherence.2*dose.1",
        "efv.3 ~ log_age + dose.0 + metabolic*adherence.3*dose.2",
        "efv.4 ~ log_age + dose.0 + metabolic*adherence.4*dose.3"
  ),
  "Y"=c("VL.0 ~ I(sqrt(efv.0))",
        "VL.1 ~ I(sqrt(efv.1)) + comorbidity.0",
        "VL.2 ~ I(sqrt(efv.2)) + comorbidity.1",
        "VL.3 ~ I(sqrt(efv.3)) + comorbidity.2",
        "VL.4 ~ I(sqrt(efv.4)) + comorbidity.3"
  )
)

excessRelativeRisk <- function(k, l){(l - k) / k}
numberNeededToTreat <- function(k, l){1 / (l - k)}
relativeRiskReduction <- function(k, l){(k - l) / k}
survivalRatio <- function(k, l){(1 - l) / (1 - k)}
relativeSusceptibility <- function(k, l){(l - k) / (1 - k)}
exposureAttributableFraction <- function(k, l){(l - k) / l}


hazardbinning <- function(form.n,form.d,X,Anodes,abar,SL.library=NULL,verbose=FALSE,...){
  #
  if(length(abar)<2){stop("abar needs to have >= 2 values to apply binning")}
  contin_var <- apply(subset(X, select=Anodes), 2, function(var) length(unique(var)))
  if(any(contin_var<10)){warning("Some of your intervention variables have less than 10 unique values. Are you sure A is continuous?")}
  #
  g.d <- g.n <- rep(list(matrix(NA,nrow=nrow(X),ncol=length(abar),dimnames=list(NULL,paste(abar)))),
                    length(form.n))
  #
  cuts <- (head(abar, -1) + tail(abar, -1)) / 2
  cuts <- c(cuts[1] - mean(diff(abar)), cuts, cuts[length(cuts)] + mean(diff(abar)))
  #
  result_list <- lapply(seq_along(form.n), function(i) {
    A <- X[, Anodes[i]]
    fml.n <- as.formula(sub("~\\s*", "~ s(bin_id) + ", form.n[[i]]));vars.n <- all.vars(fml.n[[3]]);y.n <- all.vars(fml.n[[2]])
    fml.d <- as.formula(sub("~\\s*", "~ s(bin_id) + ", form.d[[i]]));vars.d <- all.vars(fml.d[[3]]);y.d <- all.vars(fml.d[[2]])
    W.n <- X[, vars.n[-1], drop = FALSE]
    W.d <- X[, vars.d[-1], drop = FALSE]
    y_bins0 <- findInterval(A, cuts, rightmost.closed = TRUE)
    #
    dat.n <- cbind(data.frame(y = y_bins0, status = 1), W.n)
    dat.d <- cbind(data.frame(y = y_bins0, status = 1), W.d)
    #
    long_data_s.n <- as_ped(survival::Surv(y, status) ~ ., data = dat.n, cut = c(0:length(cuts)))
    long_data_s.d <- as_ped(survival::Surv(y, status) ~ ., data = dat.d, cut = c(0:length(cuts)))
    names(long_data_s.n)[names(long_data_s.n) == "ped_status"] <- Anodes[i]
    names(long_data_s.d)[names(long_data_s.d) == "ped_status"] <- Anodes[i]
    if(verbose){
      message(paste0("t = ", i, ", effective sample size (hazardbinning) = ", nrow(long_data_s.n)))
    }
    #
    use_superlearner_n <- !is.null(SL.library) & length(vars.n) > 1
    use_superlearner_d <- !is.null(SL.library) & length(vars.d) > 1
    #
    if (use_superlearner_n) {
      if(verbose){
        message(paste0("Fitting SuperLearner with modeling library: ", paste(SL.library, collapse = ", "), " for numerator models"))
      }
      fit_n <- SuperLearner(
        Y = long_data_s.n[, y.n],
        X = long_data_s.n[, vars.n, drop = FALSE],
        id = long_data_s.n$id,
        family = binomial(),
        verbose = FALSE,
        SL.library = SL.library,
        ...
      )
    } else {
      fit_n <- mgcv::gam(fml.n, data = long_data_s.n, family = "binomial")
    }
    if (use_superlearner_d) {
      if(verbose){
        message(paste0("Fitting SuperLearner with modeling library: ", paste(SL.library, collapse = ", "), " for denominator models"))
      }
      fit_d <- SuperLearner(
        Y = long_data_s.d[, y.d],
        X = long_data_s.d[, vars.d, drop = FALSE],
        id = long_data_s.d$id,
        family = binomial(),
        verbose = FALSE,
        SL.library = SL.library,
        ...
      )
    } else {
      fit_d <- mgcv::gam(fml.d, data = long_data_s.d, family = "binomial")
    }
    #
    g_n_hazard <- sapply(seq_along(abar), function(j) {
      pX <- cbind(bin_id = j, W.n)
      if (use_superlearner_n) as.numeric(predict(fit_n, pX)$pred)
      else as.numeric(predict(fit_n, type = "response", newdata = pX))
    })
    
    g_d_hazard <- sapply(seq_along(abar), function(j) {
      pX <- cbind(bin_id = j, W.d)
      if (use_superlearner_d) as.numeric(predict(fit_d, pX)$pred)
      else as.numeric(predict(fit_d, type = "response", newdata = pX))
    })
    
    g_n_density <- t(apply(g_n_hazard, 1, hazard_to_density))
    g_d_density <- t(apply(g_d_hazard, 1, hazard_to_density))
    
    list(g_n = g_n_density, g_d = g_d_density)
  })
  #
  g.n <- lapply(result_list, `[[`, "g_n")
  g.d <- lapply(result_list, `[[`, "g_d")
  #
  return(list(g.n,g.d))
  #
}

as_ped<- function(formula, data, cut) {
  # Check inputs
  if (!inherits(formula, "formula")) stop("`formula` must be a formula.")
  if (!is.data.frame(data)) stop("`data` must be a data.frame.")
  if (anyNA(cut) || !is.numeric(cut)) stop("`cut` must be a numeric vector without NAs.")
  #
  mf <- model.frame(formula, data)
  surv_obj <- mf[[1]]
  if (!inherits(surv_obj, "Surv")) stop("Left-hand side of formula must be a Surv object.")
  #
  y <- surv_obj[, "time"]
  status <- surv_obj[, "status"]
  XX <- mf[, -1, drop = FALSE]
  #
  intervals <- data.frame(
    start = head(c(-1, cut), -1),
    stop  = cut
  )
  #
  ped_list <- lapply(seq_along(y), function(i) {
    t_i <- y[i]
    X_i <- XX[i, , drop = FALSE]
    active_rows <- which(intervals$start < t_i)
    if (length(active_rows) == 0) return(NULL)
    
    int_i <- intervals[active_rows, , drop = FALSE]
    int_i$stop <- pmin(int_i$stop, t_i)
    
    # Add ped_status: 1 if event occurs in last interval
    int_i$ped_status <- 0
    if (status[i] == 1) {
      int_i$ped_status[nrow(int_i)] <- 1
    }
    
    cbind(
      id = i,
      bin_id = int_i$stop,
      ped_status = int_i$ped_status,
      X_i[rep(1, nrow(int_i)), , drop = FALSE]
    )
  })
  #
  ped_data <- do.call(rbind, ped_list)
  rownames(ped_data) <- NULL
  #
  return(ped_data)
}


hazard_to_density <- function(h) {
  d <- numeric(length(h))
  d[1] <- h[1]
  if (length(h) > 1) {
    for (k in 2:length(h)) {
      d[k] <- prod(1 - h[1:(k - 1)]) * h[k]
    }
  }
  return(d)
}


# 1) binning
binning <- function(form.n,form.d,X,Anodes,abar,...){
  #
  if(length(abar)<2){stop("abar needs to have >= 2 values to apply binning")}
  contin_var <- apply(subset(X, select=Anodes), 2, function(var) length(unique(var)))
  if(any(contin_var<10)){warning("Some of your intervention variables have less than 10 unique values. Are you sure A is continuous?")}
  #
  fitted.n <- fitted.d <- rep(list(NULL),length(form.n))
  g.d <- g.n <- rep(list(matrix(NA,nrow=nrow(X),ncol=length(abar),dimnames=list(NULL,paste(abar)))),
                    length(form.n))
  cuts <- rep(NA,length(abar)-1)
  for(i in 1:(length(cuts))){cuts[i] <- (abar[i] + abar[i+1])/2}  
  #cuts <- c(-Inf,cuts,Inf)
  cuts <- c(cuts[min(order(cuts))]-mean(diff(abar)),cuts,cuts[max(order(cuts))]+mean(diff(abar)))
  #
  for(i in 1:length(form.n)){
    for(j in 1:length(abar)){
      cutX <- X
      cutX[,Anodes] <- apply((subset(cutX, select=Anodes)>cuts[j]) & (subset(cutX, select=Anodes) < cuts[j+1]),2,as.numeric)
      fitted.n[[i]] <- mgcv::gam(as.formula(form.n[i]),data=cutX,family="binomial")
      fitted.d[[i]] <- mgcv::gam(as.formula(form.d[i]),data=cutX,family="binomial")
      pX <- cutX; pX[,Anodes] <- 1
      g.n[[i]][,j] <- dbinom(1,size=1,predict(fitted.n[[i]], type="response", newdata=pX))
      g.d[[i]][,j] <- dbinom(1,size=1,predict(fitted.d[[i]], type="response", newdata=pX))
    }
  }
  #
  return(list(g.n,g.d))
  #
}

# 2) dnorm, dpois, dbinom

parametric <- function(form.n,form.d,X,Anodes,abar,...){
  fams <- assign.family(as.data.frame(X[,Anodes]))
  if(any(substr(fams,1,4)=="mult")){stop("multinomial intervention currently not supported")}
  fitted.n <- fitted.d <- rep(list(NULL),length(fams))
  g.d <- g.n <- rep(list(matrix(NA,nrow=nrow(X),ncol=length(abar),dimnames=list(NULL,paste(abar)))),
                    length(fams))
  for(i in 1:length(form.n)){
    fitted.n[[i]] <- mgcv::gam(as.formula(form.n[i]),data=X,family=fams[i])
    fitted.d[[i]] <- mgcv::gam(as.formula(form.d[i]),data=X,family=fams[i])
    for(j in 1:length(abar)){
      XA <- X; XA[,Anodes] <- abar[j] 
      if(fams[i]=="gaussian"){
        g.n[[i]][,j] <- dnorm(abar[j],mean=predict(fitted.n[[i]], newdata=XA),sd=sqrt(fitted.n[[i]]$sig2))
        g.d[[i]][,j] <- dnorm(abar[j],mean=predict(fitted.d[[i]], newdata=XA),sd=sqrt(fitted.d[[i]]$sig2))
      }else{
        if(fams[i]=="poisson"){
          g.n[[i]][,j] <- dpois(abar[j],lambda=predict(fitted.n[[i]], type="response", newdata=XA))
          g.d[[i]][,j] <- dpois(abar[j],lambda=predict(fitted.d[[i]], type="response", newdata=XA))
        }else{
          g.n[[i]][,j] <- dbinom(abar[j],size=1,prob=predict(fitted.n[[i]], type="response", newdata=XA))
          g.d[[i]][,j] <- dbinom(abar[j],size=1,prob=predict(fitted.d[[i]], type="response", newdata=XA))
        }
      }
    }}
  return(list(g.n,g.d))
}

# 3) haldensify

hal_density <- function(form.n, form.d, X, Anodes, abar,
                        n_bins = max(10, sqrt(nrow(X))),
                        lambda_seq = exp(seq(-0.1, -10, length = 100)),
                        grid_type = c("equal_range", "equal_mass"),
                        max_degree = NULL,
                        smoothness_orders = NULL,
                        hal.verbose = FALSE,
                        verbose = FALSE,
                        runtime = c("fast", "very_fast", "fairly_fast",
                                    "somewhat_fast", "regular"),
                        ...) {
  
  if (length(abar) < 2) {
    stop("abar needs to have >= 2 values to apply haldensify")
  }
  if (!requireNamespace("haldensify", quietly = TRUE)) {
    stop("Package 'haldensify' is required for this function.")
  }
  
  grid_type <- match.arg(grid_type)
  runtime   <- match.arg(runtime)
  
  contin_var <- apply(X[, Anodes, drop = FALSE], 2, function(z) length(unique(z)))
  if (any(contin_var < 10)) {
    warning("Some of your intervention variables have fewer than 10 unique values. Are you sure A is continuous?")
  }
  
  very_fast     <- list(c(50, 25),   c(50, 25, 10),   c(25, 10),   c(25, 10, 5))
  fast          <- list(c(100, 50),  c(100, 50, 25),  c(40, 15),   c(40, 15, 10))
  fairly_fast   <- list(c(200, 100), c(200, 100, 50), c(50, 25),   c(50, 25, 15))
  somewhat_fast <- list(c(400, 200), c(400, 200, 100),c(100, 75),  c(100, 75, 50))
  regular       <- list(c(500, 200), c(500, 200, 50), c(200, 100), c(200, 100, 50))
  
  runtime_defaults <- list(
    very_fast     = list(max_degree = 1, smoothness_orders = 0),
    fast          = list(max_degree = 1, smoothness_orders = 1),
    fairly_fast   = list(max_degree = 2, smoothness_orders = 1),
    somewhat_fast = list(max_degree = 2, smoothness_orders = 2),
    regular       = list(max_degree = 3, smoothness_orders = 2)
  )
  
  user_deg    <- !is.null(max_degree)
  user_smooth <- !is.null(smoothness_orders)
  
  if (!user_deg || !user_smooth) {
    defaults <- runtime_defaults[[runtime]]
    if (!user_deg) {
      max_degree <- defaults$max_degree
    }
    if (!user_smooth) {
      smoothness_orders <- defaults$smoothness_orders
    }
  }
  
  if (hal.verbose) {
    msg_suffix <- if (user_deg || user_smooth) {
      " (user-specified max_degree / smoothness_orders override runtime defaults where provided)"
    } else {
      ""
    }
    message(
      "hal_density configuration: runtime = ", runtime,
      "; max_degree = ", max_degree,
      "; smoothness_orders = ", paste(smoothness_orders, collapse = ","),
      msg_suffix
    )
  }
  
  sel_rt <- get(runtime)
  n_obs  <- nrow(X)
  K      <- length(abar)
  abar_names <- paste(abar)
  
  parse_formula <- function(f) {
    parts     <- strsplit(f, "~", fixed = TRUE)[[1]]
    lhs       <- gsub(" ", "", parts[1])
    rhs_raw   <- gsub(" ", "", parts[2])
    rhs_terms <- strsplit(rhs_raw, "+", fixed = TRUE)[[1]]
    rhs_terms <- rhs_terms[rhs_terms != ""]
    list(lhs = lhs, rhs = rhs_terms)
  }
  
  comp_list <- lapply(seq_along(form.n), function(i) {
    info_n <- parse_formula(form.n[i])
    info_d <- parse_formula(form.d[i])
    
    tc_n <- info_n$rhs
    null_model_n <- identical(tc_n, "1") || (length(tc_n) == 1 && tc_n[1] == "1")
    
    if (null_model_n) {
      W.n <- matrix(1, n_obs, 1)
      colnames(W.n) <- "(Intercept)"
    } else {
      covars_n <- setdiff(info_n$rhs, "1")
      W.n <- X[, covars_n, drop = FALSE]
    }
    
    nknots_n <- num_knots_generator(
      max_degree        = max_degree,
      smoothness_orders = smoothness_orders,
      base_num_knots_0  = if (ncol(W.n) >= 20) sel_rt[[1]] else sel_rt[[2]],
      base_num_knots_1  = if (ncol(W.n) >= 20) sel_rt[[3]] else sel_rt[[4]]
    )
    
    if (!null_model_n) {
      A_n <- as.numeric(X[, info_n$lhs])
      fit_n <- haldensify::haldensify(
        A          = A_n,
        W          = W.n,
        n_bins     = n_bins,
        grid_type  = grid_type,
        lambda_seq = lambda_seq,
        num_knots  = nknots_n,
        ...
      )
      if (hal.verbose) {
        message(
          "Numerator density (t = ", i, "): ",
          "max_degree = ", max_degree,
          "; smoothness_orders = ", paste(smoothness_orders, collapse = ","),
          "; num_knots = ", paste(nknots_n, collapse = ",")
        )
      }
      if (verbose) {
        message(
          "t = ", i,
          "; Effective sample size for numerator density (haldensify) = ",
          fit_n$hal_fit$lasso_fit$nobs
        )
      }
    } else {
      fit_n <- NULL
    }
    
    tc_d <- info_d$rhs
    null_model_d <- identical(tc_d, "1") || (length(tc_d) == 1 && tc_d[1] == "1")
    
    if (null_model_d) {
      W.d <- matrix(1, n_obs, 1)
      colnames(W.d) <- "(Intercept)"
    } else {
      covars_d <- setdiff(info_d$rhs, "1")
      W.d <- X[, covars_d, drop = FALSE]
    }
    
    nknots_d <- num_knots_generator(
      max_degree        = max_degree,
      smoothness_orders = smoothness_orders,
      base_num_knots_0  = if (ncol(W.d) >= 20) sel_rt[[1]] else sel_rt[[2]],
      base_num_knots_1  = if (ncol(W.d) >= 20) sel_rt[[3]] else sel_rt[[4]]
    )
    
    if (!null_model_d) {
      A_d <- as.numeric(X[, info_d$lhs])
      fit_d <- haldensify::haldensify(
        A          = A_d,
        W          = W.d,
        n_bins     = n_bins,
        grid_type  = grid_type,
        lambda_seq = lambda_seq,
        num_knots  = nknots_d,
        ...
      )
      if (hal.verbose) {
        message(
          "Denominator density (t = ", i, "): ",
          "max_degree = ", max_degree,
          "; smoothness_orders = ", paste(smoothness_orders, collapse = ","),
          "; num_knots = ", paste(nknots_d, collapse = ",")
        )
      }
      if (verbose) {
        message(
          "t = ", i,
          "; Effective sample size for denominator density (haldensify) = ",
          fit_d$hal_fit$lasso_fit$nobs
        )
      }
    } else {
      fit_d <- NULL
    }
    
    if (null_model_n) {
      gn_mat <- matrix(1 / n_obs, nrow = n_obs, ncol = K)
      colnames(gn_mat) <- abar_names
    } else {
      covars_n_full <- setdiff(info_n$rhs, "1")
      gn_cols <- lapply(seq_along(abar), function(j) {
        WA.n <- X[, covars_n_full, drop = FALSE]
        sel_n <- colnames(WA.n) %in% Anodes
        if (any(sel_n)) {
          WA.n[, sel_n] <- abar[j]
        }
        haldensify::predict.haldensify(
          fit_n,
          new_A = rep(abar[j], n_obs),
          new_W = WA.n,
          trim  = FALSE,
          ...
        )
      })
      gn_mat <- do.call(cbind, gn_cols)
      colnames(gn_mat) <- abar_names
    }
    
    if (null_model_d) {
      gd_mat <- matrix(1 / n_obs, nrow = n_obs, ncol = K)
      colnames(gd_mat) <- abar_names
    } else {
      covars_d_full <- setdiff(info_d$rhs, "1")
      gd_cols <- lapply(seq_along(abar), function(j) {
        WA.d <- X[, covars_d_full, drop = FALSE]
        sel_d <- colnames(WA.d) %in% Anodes
        if (any(sel_d)) {
          WA.d[, sel_d] <- abar[j]
        }
        haldensify::predict.haldensify(
          fit_d,
          new_A = rep(abar[j], n_obs),
          new_W = WA.d,
          trim  = FALSE,
          ...
        )
      })
      gd_mat <- do.call(cbind, gd_cols)
      colnames(gd_mat) <- abar_names
    }
    
    list(gn = gn_mat, gd = gd_mat)
  })
  
  g.n <- lapply(comp_list, `[[`, "gn")
  g.d <- lapply(comp_list, `[[`, "gd")
  
  list(g.n, g.d)
}

num_knots_generator <- function(max_degree, smoothness_orders, base_num_knots_0 = 500, 
                                base_num_knots_1 = 200) 
{
  if (all(smoothness_orders > 0)) {
    return(sapply(seq_len(max_degree), function(d) {
      round(base_num_knots_1/2^(d - 1))
    }))
  }
  else {
    return(sapply(seq_len(max_degree), function(d) {
      round(base_num_knots_0/2^(d - 1))
    }))
  }
}