suppressPackageStartupMessages({
library(tidyverse)
library(tidymodels)
library(survival)
library(cowplot)
library(future)
library(censored)
})
theme_set(theme_cowplot())
options(repr.plot.width = 15, repr.plot.height = 9)
set.seed(42)
# the multi model fitting process will use multiple cores
plan(multicore, workers = 6)

Survival prediction models, part 1#

When it comes to predicting survival, we can’t just do a simple regression because the dataset often includes censored data points,

the most common model used in these cases is the Cox proportional hazards model,

there are other options, but when considering them it is always good to check if they out-perform the cox model on your data.

if you are not familiar with survival analysis, it’s better to start here: http://www.sthda.com/english/wiki/survival-analysis-basics

Dataset preprocessing#

To illustrate we will use a dataset from the survival package, check ?survival::pbc for more info

head(pbc)
A data.frame: 6 × 20
idtimestatustrtagesexasciteshepatospidersedemabilicholalbumincopperalk.phosasttrigplateletprotimestage
<int><int><int><int><dbl><fct><int><int><int><dbl><dbl><int><dbl><int><dbl><dbl><int><int><dbl><int>
11 4002158.7652292950034f1111.014.52612.601561718.0137.9517219012.24
2245000156.4462696783025f0110.0 1.13024.14 547394.8113.52 8822110.63
3310122170.0725530458590m0000.5 1.41763.48210 516.0 96.10 5515112.04
4419252154.7405886379192f0110.5 1.82442.54 646121.8 60.63 9218310.34
5515041238.1054072553046f0110.0 3.42793.53143 671.0113.15 7213610.93
6625032266.2587268993840f0100.0 0.82483.98 50 944.0 93.00 63 NA11.03
data <- 
    # remove rows with missing data
    na.omit(pbc) |> 
    # the surv variable will be our outcome
    mutate(surv=Surv(time, status==2)) 
dim(data)
  1. 276
  2. 21
# separate 20% of the dataset for testing the models later
data_split <- initial_split(data, prop = 0.8, strata=time)
# create a recipe, to prepare the data for fitting
rec <- recipe(surv ~ ., data=data) |>
    # remove uneeded variables
    step_rm(time, status, id) |> 
    # remove variables with zero variance
    step_zv(all_predictors()) |> 
    # transform factors into integers
    step_integer(all_nominal_predictors()) 

Model definitions#

These are the models we will be testing, they are provided by the censored package.

note that some models, such as random forest, have hyper-parameters,

these parameters are tunable, we will tune some of them to illustrate.

models <- list(
    cox_ph_survival = proportional_hazards(engine='survival'),
    cox_ph_glmnet = proportional_hazards(penalty=tune(), mixture=tune(), engine='glmnet'),
    survreg_flexsurv = survival_reg(engine='flexsurv'),
    rand_forest_partykit = rand_forest(trees = tune(), engine='partykit'),
    rand_forest_aorsf = rand_forest(trees = tune(), engine='aorsf'),
    decision_tree_partykit = decision_tree(engine='partykit'),
    boost_tree_mboost = boost_tree(trees = tune(), engine='mboost')
) |>
map(~set_mode(.x,'censored regression'))
wsets <- workflow_set(
    preproc=list(rec),
    models=models
)

Model fitting#

Now, we fit all the models with our training data,

to evaluate which model is best, we will use cross-validation resampling, so each model will be fitted multiple times

if there are tunable parameters, multiple fits will be done for each candidate parameter set,

the number of candidates tested is controlled by the grid argument below.

# 5-fold cross-validation resampling for the tunning and comparisons
folds <- vfold_cv(training(data_split), v=5, strata=time)
res <- workflow_map(
    wsets,
    "tune_grid",
    seed = 42,
    grid = 10,
    resamples = folds,
    # this metric will be calculated for each fit
    metrics=metric_set(concordance_survival)
)

Metrics#

During the fit, the concordance index was calculated for each combination,

now we can gather these results and build a table:

# collect metrics from the result object
m <- collect_metrics(res)
# remove a few unrelated columns
mb <- mutate(m,
    label = gsub('recipe_', '', wflow_id),
    preproc=NULL,
    .estimator=NULL,
    n=NULL,
    .metric=NULL
) |>
# select the best parameter set for each model, according to the metric
slice_max(mean, by=label, n=1, with_ties=FALSE)
mb
A tibble: 7 × 6
wflow_id.configmodelmeanstd_errlabel
<chr><chr><chr><dbl><dbl><chr>
recipe_cox_ph_survival Preprocessor1_Model1 proportional_hazards0.8092024269202040.0246510956539134cox_ph_survival
recipe_cox_ph_glmnet Preprocessor1_Model02proportional_hazards0.8273744787298700.0260384841178290cox_ph_glmnet
recipe_survreg_flexsurv Preprocessor1_Model1 survival_reg 0.8127989873944120.0254010738957239survreg_flexsurv
recipe_rand_forest_partykit Preprocessor1_Model03rand_forest 0.8324112197289440.0195804127464225rand_forest_partykit
recipe_rand_forest_aorsf Preprocessor1_Model08rand_forest 0.8488197701772590.0249418414304294rand_forest_aorsf
recipe_decision_tree_partykitPreprocessor1_Model1 decision_tree 0.7253019410412250.0341123982952786decision_tree_partykit
recipe_boost_tree_mboost Preprocessor1_Model08boost_tree 0.8200257515873630.0252241065143019boost_tree_mboost

Note that since we are using cross-validation, each model was fitted multiple times, so we have a mean and std_err.

we can also fit these models with the 20% test data we separated earlier:

final_fits <- 
    # map across each model in the table above
    set_names(mb$wflow_id) |>
    map(function(wi) {
        # select the best model config, in case of tunnable parameters
        bwr <- extract_workflow_set_result(res, wi)
        best_params <- select_best(bwr, metric='concordance_survival')
        bw <- finalize_workflow(
            extract_workflow(res, wi), 
            best_params
        )
    
        # re-fit the best config with the initial 80% train / 20% test
        last_fit(
            bw,
            split = data_split,
            metrics = metric_set(concordance_survival)
        )
    })
test_m <-
    # collect the metrics of the final model
    map(final_fits, collect_metrics) |>
    bind_rows(.id='wflow_id') |>
    select(wflow_id, test_metric=.estimate)

test_m
A tibble: 7 × 2
wflow_idtest_metric
<chr><dbl>
recipe_cox_ph_survival 0.870346598202824
recipe_cox_ph_glmnet 0.884467265725289
recipe_survreg_flexsurv 0.871630295250321
recipe_rand_forest_partykit 0.889602053915276
recipe_rand_forest_aorsf 0.889602053915276
recipe_decision_tree_partykit0.819640564826701
recipe_boost_tree_mboost 0.875481386392811

To make it easier to visualize all these results, we can make a plot

the validation mean and std_err will be plotted in black as a point and error bar,

and for comparison, we added a cross to indicate the score of the fully fitted model on the test set.

inner_join(mb, test_m, by='wflow_id') |>
ggplot(aes(y=fct_reorder(label,mean), x=mean)) +
geom_point(size=4) +
geom_point(aes(x=test_metric), shape=3, size=3, stroke=3, color='darkred') +
geom_pointrange(aes(xmin=mean-std_err, xmax=mean+std_err)) +
labs(x='concordance index', y='models')

Needless to say, you shouldn’t assume the best model here will also be the best in other datasets, always test on your data.

Variable importance#

Once we decide which model we want to use, we can extract the model object, and do further analyses,

such as ranking the dataset variables by how important the model scored them.

names(final_fits)
  1. 'recipe_cox_ph_survival'
  2. 'recipe_cox_ph_glmnet'
  3. 'recipe_survreg_flexsurv'
  4. 'recipe_rand_forest_partykit'
  5. 'recipe_rand_forest_aorsf'
  6. 'recipe_decision_tree_partykit'
  7. 'recipe_boost_tree_mboost'
aorsf_obj <-
    final_fits[['recipe_rand_forest_aorsf']] |>
    extract_fit_engine()
aorsf_obj
---------- Oblique random survival forest

     Linear combinations: Accelerated Cox regression
          N observations: 220
                N events: 89
                 N trees: 84
      N predictors total: 17
   N predictors per node: 5
 Average leaves per tree: 17.3928571428571
Min observations in leaf: 5
      Min events in leaf: 1
          OOB stat value: 0.83
           OOB stat type: Harrell's C-index
     Variable importance: anova

-----------------------------------------
aorsf_obj$get_importance_clean() |>
enframe() |>
ggplot(aes(y=fct_reorder(name, value), x=value)) +
geom_col() +
labs(x='aorsf model variable importance', y='variable')