# remove unwanted covariates from the imported data set and to separate the remaining parts
vartype.f <- function(tempdat, t){
  tempdat  <- tempdat[, !tempdat[1,]==6]  
  vartype <- tempdat[, tempdat[1,]==t, drop=F]   
  return(vartype)
}

#reorganize the imported original data set by ordering the columns by their data property and add a new column called "id" to uniquely number the subjects. 
df <- function(dat){
  newdf <- data.frame(id=c(0, seq_len(nrow(dat)-1)))
  nvartype <- vector(length=5)  
  for (r in 1:5){
    type <- vartype.f(dat, r)   
    nvartype[r] <- ncol(type)
    newdf <- data.frame(newdf, type)
  }
  if(nvartype[1]==0 | nvartype[2]==0) stop("Missing treatment and/or response variables")
  typelist <- cumsum(nvartype) + 1
  return(list(newdf=newdf, typelist=typelist, nvartype=nvartype))
}

#obtain factors of each nominal covariate (nominalfactor)
nominal.f <- function(dat, allnomcol){
  nominalfactor <- lapply(allnomcol, function(u)  sort(unique(dat[, u]))) 
  nnominalfactor <- sapply(nominalfactor, length)
  combnominalfactor <- sapply(seq_along(nominalfactor), function(u) {2^(nnominalfactor[u]-1)-1})
  return(list(nominalfactor=nominalfactor, nnominalfactor=nnominalfactor, combnominalfactor=combnominalfactor))
}

#consider all allowable splits for a nominal covariate
nomcomb.f <- function(dat, nomvarcol){
  nlevel <- length(sort(unique(dat[ , nomvarcol])))
  comb <-  as.matrix(do.call(expand.grid, rep(list(seq(0,1)),  nlevel)))
  dimnames(comb) <- NULL 
  subcomb <- seq(from=1, to=2^(nlevel), by=2)
  comb <- comb[subcomb, ]  
  comb <- comb[-1,]   
  return(comb)
}

#replace original nominal data points with the values (either 0 or 1) corresponding to an allowable split of a nominal covariate that are obtained from the nomcomb.f function.
combsub.f <- function (nomsub, level, combchoice, m) {
  for ( v in seq_along(level) ){
    nomsub <- replace(nomsub, nomsub==level[v], combchoice[m, v])
  }
  names(nomsub) <- paste(names(nomsub), m, sep = "")    
  return(nomsub)
}

#add new columns to the existing data set, where each new column that is generated by the combsub.f
replacenominal.f <- function(dat, nomvarcol){    
  nomvar <- dat[, nomvarcol, drop=F]    
  level <- sort(unique(dat[ , nomvarcol]))
  combchoice <- nomcomb.f(dat, nomvarcol)    ############################
  tempnomcov <- sapply(seq_len(nrow(combchoice)), function (s)  combsub.f (nomvar, level , combchoice, s))    ############################
  return(as.data.frame(tempnomcov))
}

#provide information regarding the number of variables in the final data set
new.df<-function(dat){
  datainfo <- df(dat)
  tempdata <- as.data.frame(datainfo$newdf)
  newtempdata <- tempdata[-1, ]
  datainfo$nvar <- ncol(tempdata) 
  sep.f <- function(sepdat, seppoint) {unname(split(sepdat, cumsum(seq_along(sepdat) %in% seppoint)))}
  if (datainfo$nvar > 2)  datainfo$varlist <- sep.f(c(2 : datainfo$nvar), datainfo$typelist)   
  if (datainfo$nvartype[3]>=1) datainfo$orderedcol <- which(tempdata[1,]==3)
  if (datainfo$nvartype[4]>=1) datainfo$binarycol <- which(tempdata[1,]==4)
  if (datainfo$nvartype[5]>=1) {  
    datainfo$nominalcol <- which(tempdata[1,]==5)                  
    nominalall <- do.call(data.frame, lapply(seq_along(datainfo$nominalcol), function(d) replacenominal.f(newtempdata, datainfo$nominalcol[d])))
    newtempdata <- data.frame(newtempdata, nominalall)
    if (length(datainfo$nominalcol)==1) {
      nominalfactor <- sort(unique(newtempdata[, datainfo$nominalcol]))
      datainfo$combnomlevel <- 2^(length(nominalfactor)-1)-1
      datainfo$nominallist <- c((datainfo$nvar + 1) : ncol(newtempdata))
    } else{ 
      datainfo$combnomlevel <- nominal.f(newtempdata, datainfo$nominalcol)$combnominalfactor ######
      datainfo$replacenom <- c((datainfo$nvar + 1) : ncol(newtempdata))
      datainfo$nominallist <- sep.f(datainfo$replacenom, (cumsum(datainfo$combnomlevel) + 1))
    }
  }
  datainfo$mydata <- newtempdata
  return(datainfo)
}

#Calculate DIFF for a node (can be any node, either a parent node or a daughter node)
dif.f <- function(finaldata, id){
  subdata <- finaldata[id , 2:4]
  if(nrow(subdata[subdata$tr == 1,])!=0 & nrow(subdata[subdata$tr == 0,])!=0){
    # require(survival)
    fit = summary(survfit(Surv(y, status) ~ tr,data=subdata))$table
    avertime0 = fit[1,"*rmean"] 
    avertime1 = fit[2,"*rmean"]
    
    diff = (avertime1 - avertime0)^2
    sum1 = nrow(subdata[subdata$tr == 0,])
    sum2 = nrow(subdata[subdata$tr == 1,])
    return(c(diff,sum1,sum2,avertime0,avertime1))
  }
  else stop("Insufficient subjects within the node")
}

#Calculate DIFF for offspring nodes 
diffsplit.f <- function(finaldata, id, cov, c, minsize){ 
  subdata <- finaldata[id, c(1, cov)]   
  idl <- subdata[subdata[ , 2]<=c,]$id             
  idr <- subdata[subdata[ , 2]>c,]$id                    
  nidl <- length(idl)
  nidr <- length(idr)
  lediff <- dif.f(finaldata, idl)
  ridiff <- dif.f(finaldata, idr)
  if (lediff[2]   >= minsize && lediff[3] >= minsize && ridiff[2]  && ridiff[2] >= minsize && ridiff[3] >= minsize){   
    diffidl <- lediff[1]                     
    diffidr <- ridiff[1]           
    diffsplit <- (diffidl*nidl + diffidr*nidr)/(nidl + nidr)   
    return( list(idleft=idl, idright=idr, nidl=nidl, nidr=nidr, splittingpoint=c, diffidl=diffidl, diffidr=diffidr, diffsplit=diffsplit) )	
  } else stop ("The sample size per treatment assignment within the node is too small")
  #} else stop("Insufficient subjects in each node")
}

#obtain distinct values of each ordered covariate (choice) and the corresponding weighted squared difference in the  within-node response rates for the two treatments after splitting (diffsplit).
orddiffsplit.f <- function(finaldata, id, ordcov, minsize){
  distinct <- sort(unique(finaldata[id, ordcov]))  
  ndistinct <- length(distinct)
  choice <- distinct[-ndistinct]     
  ordsplitinfo <- lapply(choice, function(c) tryCatch(diffsplit.f(finaldata, id, ordcov, c, minsize), error = function(e) NA))  
  diffsplit <- sapply(seq_along(choice), function(i) tryCatch(ordsplitinfo[[i]]$diffsplit, error = function(e) NA))    
  return(list(choice=choice, diffsplit=diffsplit))
}

#implement a tree-based method
tree.f <- function(node, nodeinfo, finaldata, id, nlayer, nnode, npredictor, orderedcol, binarydcol, nominalcol, nomlist, combnom, minsize, maxlay, maxlayer=FALSE, printroot=TRUE){   
  if((!maxlayer))   
  {
    nparent <- length(id)
    parentdiff <- dif.f(finaldata, id)
    diffbefore <- parentdiff[1]
    node[nnode, c(2, 3, 8:9)] <- c(1, nlayer, parentdiff[4:5])
    if (printroot) {
      cat("node), split,  n,  c(n1, n2),  c(avertime0,avertime1)",  "\n", sep=" ")
      cat("+ denotes a nominal covariate\n", sep=" ")
      cat("1 ) root", nparent,  deparse(parentdiff[c(2:3)]), deparse(parentdiff[c(4:5)]),  "\n", sep=" ")
      nodeinfo[nnode, 1] <- paste(round(parentdiff[2], 2), ":", round(parentdiff[3], 2))
      nodeinfo[nnode, 2] <- paste(round(parentdiff[4], 2), ":", round(parentdiff[5], 2))
    }                                                 
    if (!is.null(orderedcol)) {			
      orderedsplitinfo <- lapply(orderedcol, function(j) tryCatch(orddiffsplit.f(finaldata, id, j, minsize), error = function(e) NA)) 
      ordereddiffsplit <- unlist(lapply(seq_along(orderedsplitinfo), function(i) tryCatch(orderedsplitinfo[[i]]$diffsplit, error = function(e) NA)))
      ordereddiff <- abs(ordereddiffsplit - diffbefore)
      whichorderedmax <- which.max(ordereddiff)
      orderedmax <- ordereddiff[whichorderedmax]             
    } else {orderedmax <- NA}
    if (!is.null(binarydcol)) {
      binarysplitinfo <- lapply(binarydcol, function(j) tryCatch(diffsplit.f(finaldata, id, j, 0, minsize), error = function(e) NA)) 
      binarydiffsplit <-  unlist(lapply(seq_along(binarysplitinfo), function(i) tryCatch(binarysplitinfo[[i]]$diffsplit, error = function(e) NA)))
      binarydiff <- abs(binarydiffsplit - diffbefore)
      whichbinarymax <- which.max(binarydiff)
      binarymax <- binarydiff[whichbinarymax]
    } else {binarymax <- NA} 
    if (!is.null(nominalcol)) { 
      nomsplitinfo <- lapply(unlist(nomlist), function(j) tryCatch(diffsplit.f(finaldata, id, j, 0, minsize), error = function(e) NA)) 
      nomdiffsplit <- unlist(lapply(seq_along(nomsplitinfo), function(i) tryCatch(nomsplitinfo[[i]]$diffsplit, error = function(e) NA)))
      nomdiff <- abs(nomdiffsplit - diffbefore)
      whichnommax <- which.max(nomdiff)
      nommax <- nomdiff[whichnommax]      
    } else {nommax <- NA}     
    
    if (!all(is.na(c(orderedmax, binarymax, nommax))))   
    {       
      diffmax <- c(orderedmax, binarymax, nommax)
      whichmax <- which.max(diffmax)
      maxdiffsplit <- diffmax[whichmax]                  
      if (maxdiffsplit > 0){
        if (any(maxdiffsplit==orderedmax, na.rm = T)) {	
          ordchoice <- lapply(seq_along(orderedsplitinfo), function(i) tryCatch(orderedsplitinfo[[i]]$choice, error = function(e) NA)) 
          ordcov <-  rep(orderedcol, unlist(lapply(ordchoice, length))) 				   
          if (length(orderedcol)==1)  {maxcov <-  orderedcol} else {maxcov <- ordcov[whichorderedmax]}
          maxgroup <- unlist(ordchoice)[whichorderedmax]       
          offspringsplit <- diffsplit.f(finaldata, id, maxcov, maxgroup, minsize) 
        } else 
          if (any(maxdiffsplit==binarymax, na.rm = T)) {
            if (length(binarydcol)==1)  {maxcov <- binarydcol} else {maxcov <- binarydcol[whichbinarymax]}
            maxgroup <- 0
            offspringsplit <- diffsplit.f(finaldata, id, maxcov, 0, minsize) 
          }  else if (any(maxdiffsplit==nommax, na.rm = T)){
            nomgrouplist <- unlist(lapply(combnom, seq))
            maxgroup <- nomgrouplist[whichnommax]   
            if (length(nominalcol)==1) {maxcov <- nominalcol; maxnomcov <- nomlist[whichnommax]} else {
              nomcovlist <-  rep(nominalcol, combnom)  
              maxcov <- nomcovlist[whichnommax]				
              maxnomcov <- nomlist[[which(maxcov==nominalcol)]][maxgroup]
            } 
            offspringsplit <- diffsplit.f(finaldata, id, maxnomcov, 0, minsize) 
          }
        else stop ("stop")
        
        idl <- offspringsplit$idleft           
        idr <- offspringsplit$idright                       
        nleft <- offspringsplit$nidl     
        nright <- offspringsplit$nidr 
        nnodeleft <- 2*nnode
        nnoderight <- nnodeleft + 1
        leftdiff <- dif.f(finaldata, idl)               
        rightdiff <- dif.f(finaldata, idr)
        
        if (any(maxdiffsplit==nommax, na.rm = T)) {  
          nomprint <- cbind(sort(unique(finaldata[ , maxcov])), nomcomb.f (finaldata, maxcov)[maxgroup, ])
          leftnomprint <-  nomprint[nomprint[,2]==0, 1]
          rightnomprint <- nomprint[nomprint[,2]==1, 1]
          stringleftnomprint <- paste(paste(leftnomprint,";", sep=""), collapse = " ")
          stringrightnomprint <- paste(paste(rightnomprint,";", sep=""), collapse = " ")
          
          cat(nnodeleft,") ", "+", names(finaldata[maxcov]),  "={", stringleftnomprint,"}",  nleft, deparse(leftdiff[c(2:3)]),  deparse(leftdiff[c(4:5)]),  "\n", sep=" ")
          cat(nnoderight, ") ", "+", names(finaldata[maxcov]),  "={", stringrightnomprint,"}",   nright, deparse(rightdiff[c(2:3)]), deparse(rightdiff[c(4:5)]),  "\n", sep=" ")                          
          nodeinfo[nnodeleft, 3] <- paste(names(finaldata[maxcov]))
          nodeinfo[nnodeleft, 4] <- paste("= {", substr(stringleftnomprint, 1, nchar(stringleftnomprint) - 1), "}")                                       
          nodeinfo[nnoderight, 3] <- paste(names(finaldata[maxcov]))
          nodeinfo[nnoderight, 4] <- paste("= {", substr(stringrightnomprint, 1, nchar(stringrightnomprint) - 1), "}")  
          nodeinfo[c(nnodeleft, nnoderight), 5] <- paste("nominal")       
        } else {
          cat(nnodeleft,") ", names(finaldata[maxcov]),  "<=", maxgroup, nleft, deparse(leftdiff[c(2:3)]),  deparse(leftdiff[c(4:5)]), "\n", sep=" ")
          cat(nnoderight, ") ", names(finaldata[maxcov]),  ">", maxgroup,  nright, deparse(rightdiff[c(2:3)]), deparse(rightdiff[c(4:5)]), "\n", sep=" ")  
          nodeinfo[nnodeleft, 3] <- paste(names(finaldata[maxcov]))
          nodeinfo[nnodeleft, 4] <- paste("<=", maxgroup)
          nodeinfo[nnoderight, 3] <- paste(names(finaldata[maxcov]))
          nodeinfo[nnoderight, 4] <- paste(">", maxgroup)
          nodeinfo[c(nnodeleft, nnoderight), 5] <- "0"
        }  
        node[c(nnodeleft, nnoderight), 2] <- 1
        node[c(nnodeleft, nnoderight), 3] <- nlayer + 1
        nodeinfo[nnodeleft, 1] <- paste(round(leftdiff[2], 2), ":", round(leftdiff[3], 2))
        nodeinfo[nnodeleft, 2] <- paste(round(parentdiff[4], 2), ":", round(parentdiff[5], 2))                   
        
        nodeinfo[nnoderight, 1] <- paste(round(rightdiff[2], 2), ":", round(rightdiff[3], 2))
        nodeinfo[nnoderight, 2] <- paste(round(parentdiff[4], 2), ":", round(parentdiff[5], 2))
        
        left <- tree.f(node, nodeinfo, finaldata, idl, (nlayer + 1), nnodeleft, npredictor, orderedcol, binarydcol, nominalcol, nomlist, combnom, minsize, maxlay, ((nlayer + 1) > maxlay), printroot=FALSE)
        node <- left$node
        nodeinfo <- left$nodeinfo
        right <- tree.f(node, nodeinfo, finaldata, idr, (nlayer + 1), nnoderight, npredictor, orderedcol, binarydcol, nominalcol, nomlist, combnom, minsize, maxlay, ((nlayer + 1) > maxlay), printroot=FALSE)         
        node <- right$node
        nodeinfo <- right$nodeinfo
      }
    } 
  }
  return(list(node=node, nodeinfo=nodeinfo))
}


rpst <- function(data, 
                 datapath,
                 maxlay = 12,
                 minsize = 5)
  {
  if (is.null(data)) {data <- read.table(datapath, header = T)} else {data <-data}
  newdf <- new.df(data)
  maxnodenum <- 2^(maxlay+2) - 1
  node <- matrix(rep(0, maxnodenum*9), nrow=maxnodenum)
  nodenum <- 1:maxnodenum 
  node[, 1] <- nodenum
  node[, 4] <- floor(nodenum/2)
  node[, 5] <- ifelse(nodenum%%2==0, nodenum+1, nodenum-1)
  node[, 6] <- ifelse(nodenum*2 > maxnodenum, 0, nodenum*2)
  node[, 7] <- ifelse( (nodenum*2+1) > maxnodenum, 0, (nodenum*2+1))
  colnames(node) <- c("Node number", "Node existence", "Node layer", "Upper node number", "Paired node number", "Left daughter node number", "Right daughter node number","avertime0","avertime1")
  nodeinfo <- matrix(rep(0, maxnodenum*5), nrow=maxnodenum)
  colnames(nodeinfo) <- c("n1 : n2", "p-value", "Split 1", "Split 2", "Nominal covariate or not")
  existingnode <- tree.f(node, nodeinfo, finaldata=newdf$mydata, id=newdf$mydata$id, 1, 1, npredictor=newdf$nvar, orderedcol=newdf$orderedcol, binarydcol=newdf$binarycol, nominalcol=newdf$nominalcol, 
                         nomlist=newdf$nominallist, combnom=newdf$combnomlevel, minsize, maxlay, maxlayer=FALSE, printroot=TRUE)  
  node <- existingnode$node
  nodeinfo <- existingnode$nodeinfo    
  existingnode <- existingnode$node[existingnode$node[, 2]==1, ]   
  # plottree(node, nodeinfo, existingnode, xlength=1, ylength=3, xshift=0.5 , ysegment=6)
  out=list(data = data, maxlay = maxlay, minsize = minsize,
           node=node, nodeinfo=nodeinfo, existingnode=existingnode)
  class(out)="rpst"
  return(out)
}

