--- title: "Saving and Using Saved Models" date: "`r format(Sys.Date(), '%Y-%m-%d')`" output: rmarkdown::html_vignette: toc: true toc_depth: 2 fig_width: 7 fig_height: 5 dpi: 600 vignette: > %\VignetteIndexEntry{"Saving and Using Saved Models"} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- To save a CISS-VAE model, use reticulate to save the model with 'torch': ```{r, eval=FALSE} library(reticulate) res = run_cissvae(data) # Import Python torch module torch <- import("torch") # Assume `model` is a Python object already available in the R session # (e.g., created earlier via reticulate) torch$save(res$model, "trained_vae.pt") ``` > **IMPORTANT** > > The python environment must be activated before trying to import the torch module. {.warning} To load a saved CISS-VAE model and use it to impute new data (called `data` in example below, with cluster vector `clusters`), ```{r, eval=FALSE} library(rCISSVAE) library(reticulate) ## Activate your virtual environment reticulate::use_virtualenv("cissvae_environment", required = TRUE) ## Use CISSVAE to load the model # Import the module so the class is registered (required for full-model loading) import("ciss_vae.classes.vae") # Load full model object model <- torch$load("trained_vae.pt", map_location = "cpu", weights_only = FALSE) model$eval() # Optional: get imputed dataset helpers <- import("ciss_vae.utils.helpers") DataLoader <- import("torch.utils.data")$DataLoader ## Convert your dataset to python ClusterDataset object CD_mod <- reticulate::import("ciss_vae.classes.cluster_dataset", convert = FALSE)$ClusterDataset np <- reticulate::import("numpy", convert = FALSE) pd <- reticulate::import("pandas", convert = FALSE) ## make sure NAs are python compatible data[is.na(data)] <- NaN ## Convert data and clusters into python objects data_py <- pd$DataFrame(data = data, dtype = "float64") clusters_py <- np$array(as.integer(clusters), dtype = "int64") ## Make ClusterDataset and DataLoader dataset = CD_mod( data = data_py, cluster_labels = clusters_py) data_loader <- DataLoader(dataset, batch_size = 4000L) ## Get Imputed Dataset imputed_df <- helpers$get_imputed_df(model, data_loader) ```