library(testthat)
library(survival)
library(dplyr)

# Cox model for testing
cox_fit <- coxph(Surv(AVAL, EVENT) ~ TRT01P, ties = "exact", data = codebreak200)

test_that("tipping_point_model_based returns correct structure for model-based imputations", {
  res <- tipping_point_model_based(
    dat = codebreak200,
    reason = "Early dropout",
    impute = "docetaxel",
    imputation_model = "weibull",
    J = 2,
    tipping_range = seq(0.1, 1, by = 0.1),
    cox_fit = cox_fit,
    verbose = FALSE,
    seed = 123
  )

  expect_s3_class(res, "tipse")
  expect_true(all(c(
    "original_data", "original_HR", "reason_to_impute", "arm_to_impute",
    "method_to_impute", "imputation_results", "imputation_data"
  ) %in% names(res)))

  # Check imputation_results structure
  expect_true(all(c("HR", "HR_upperCI", "HR_lowerCI", "parameter", "tipping_point") %in% names(res$imputation_results)))

  # imputation_data should be a list with names equal to tipping_range
  expect_true(is.list(res$imputation_data))
  expect_true(all(names(res$imputation_data) %in% as.character(seq(0.1, 1, by = 0.1))))
})

test_that("tipping_point_model_based throws error for invalid cox_fit", {
  expect_error(
    tipping_point_model_based(
      dat = codebreak200,
      reason = "Early dropout",
      impute = "docetaxel",
      imputation_model = "weibull",
      J = 2,
      tipping_range = seq(0.1, 1, by = 0.1),
      cox_fit = "not_a_coxph"
    ),
    "Argument 'cox_fit' must be a valid cox model object"
  )
})

test_that("tipping_point_model_based throws error for invalid impute arm", {
  expect_error(
    tipping_point_model_based(
      dat = codebreak200,
      reason = "Early dropout",
      impute = "invalid_arm",
      imputation_model = "weibull",
      J = 2,
      tipping_range = seq(0.1, 1, by = 0.1),
      cox_fit = cox_fit
    ),
    "Argument 'impute' must be one of the arms"
  )
})

test_that("tipping_point_model_based throws error when reason is empty", {
  expect_error(
    tipping_point_model_based(
      dat = codebreak200,
      reason = character(0),
      impute = "sotorasib",
      imputation_model = "weibull",
      J = 2,
      tipping_range = seq(0.1, 1, by = 0.1),
      cox_fit = cox_fit
    ),
    "Argument 'reason' must specify at least one censoring reason"
  )
})

test_that("tipping_point_model_based validates distribution argument", {
  expect_error(
    tipping_point_model_based(
      dat = codebreak200,
      reason = "Early dropout",
      impute = "sotorasib",
      imputation_model = "invalid_dist",
      J = 2,
      tipping_range = seq(0.1, 1, by = 0.1),
      cox_fit = cox_fit
    ),
    "'arg' should be one of \"weibull\", \"exponential\""
  )
})

test_that("tipping_point_model_based sets tipping_point flag correctly", {
  res <- tipping_point_model_based(
    dat = codebreak200,
    reason = "Early dropout",
    impute = "docetaxel",
    imputation_model = "weibull",
    J = 2,
    tipping_range = seq(0.1, 1, by = 0.1),
    cox_fit = cox_fit,
    verbose = FALSE,
    seed = 123
  )

  expect_true("tipping_point" %in% names(res$imputation_results))
  expect_true(any(res$imputation_results$tipping_point %in% c(TRUE, FALSE)))
})

test_that("tipping_point_model_based respects seed for reproducibility", {
  res1 <- tipping_point_model_based(
    dat = codebreak200,
    reason = "Early dropout",
    impute = "sotorasib",
    imputation_model = "weibull",
    J = 2,
    tipping_range = seq(0.1, 1, by = 0.1),
    cox_fit = cox_fit,
    verbose = FALSE,
    seed = 123
  )

  res2 <- tipping_point_model_based(
    dat = codebreak200,
    reason = "Early dropout",
    impute = "sotorasib",
    imputation_model = "weibull",
    J = 2,
    tipping_range = seq(0.1, 1, by = 0.1),
    cox_fit = cox_fit,
    verbose = FALSE,
    seed = 123
  )

  expect_equal(res1$imputation_results, res2$imputation_results)
})

test_that("tipping_point_model_based detects method correctly based on tipping_range", {
  # All <= 1 -> hazard deflation
  res_deflation <- tipping_point_model_based(
    dat = codebreak200,
    reason = "Early dropout",
    impute = "docetaxel",
    imputation_model = "weibull",
    J = 2,
    tipping_range = seq(0.05, 1, by = 0.05),
    cox_fit = cox_fit,
    verbose = FALSE,
    seed = 123
  )
  expect_equal(res_deflation$method_to_impute, "hazard deflation")

  # Some > 1 -> hazard inflation
  res_inflation <- tipping_point_model_based(
    dat = codebreak200,
    reason = "Early dropout",
    impute = "sotorasib",
    imputation_model = "weibull",
    J = 2,
    tipping_range = c(1.1, 1.2, 1.5),
    cox_fit = cox_fit,
    verbose = FALSE,
    seed = 123
  )
  expect_equal(res_inflation$method_to_impute, "hazard inflation")
})
