Validate high-dimensional Cox models with time-dependent AUC
Usage
validate(
x,
time,
event,
model.type = c("lasso", "alasso", "flasso", "enet", "aenet", "mcp", "mnet", "scad",
"snet"),
alpha,
lambda,
pen.factor = NULL,
gamma,
lambda1,
lambda2,
method = c("bootstrap", "cv", "repeated.cv"),
boot.times = NULL,
nfolds = NULL,
rep.times = NULL,
tauc.type = c("CD", "SZ", "UNO"),
tauc.time,
seed = 1001,
trace = TRUE
)
Arguments
- x
Matrix of training data used for fitting the model; on which to run the validation.
- time
Survival time. Must be of the same length with the number of rows as
x
.- event
Status indicator, normally 0 = alive, 1 = dead. Must be of the same length with the number of rows as
x
.- model.type
Model type to validate. Could be one of
"lasso"
,"alasso"
,"flasso"
,"enet"
,"aenet"
,"mcp"
,"mnet"
,"scad"
, or"snet"
.- alpha
Value of the elastic-net mixing parameter alpha for
enet
,aenet
,mnet
, andsnet
models. Forlasso
,alasso
,mcp
, andscad
models, please setalpha = 1
.alpha=1
: lasso (l1) penalty;alpha=0
: ridge (l2) penalty. Note that formnet
andsnet
models,alpha
can be set to very close to 0 but not 0 exactly.- lambda
Value of the penalty parameter lambda to use in the model fits on the resampled data. From the fitted Cox model.
- pen.factor
Penalty factors to apply to each coefficient. From the fitted adaptive lasso or adaptive elastic-net model.
- gamma
Value of the model parameter gamma for MCP/SCAD/Mnet/Snet models.
- lambda1
Value of the penalty parameter lambda1 for fused lasso model.
- lambda2
Value of the penalty parameter lambda2 for fused lasso model.
- method
Validation method. Could be
"bootstrap"
,"cv"
, or"repeated.cv"
.- boot.times
Number of repetitions for bootstrap.
- nfolds
Number of folds for cross-validation and repeated cross-validation.
- rep.times
Number of repeated times for repeated cross-validation.
- tauc.type
Type of time-dependent AUC. Including
"CD"
proposed by Chambless and Diao (2006).,"SZ"
proposed by Song and Zhou (2008).,"UNO"
proposed by Uno et al. (2007).- tauc.time
Numeric vector. Time points at which to evaluate the time-dependent AUC.
- seed
A random seed for resampling.
- trace
Logical. Output the validation progress or not. Default is
TRUE
.
References
Chambless, L. E. and G. Diao (2006). Estimation of time-dependent area under the ROC curve for long-term risk prediction. Statistics in Medicine 25, 3474–3486.
Song, X. and X.-H. Zhou (2008). A semiparametric approach for the covariate specific ROC curve with survival outcome. Statistica Sinica 18, 947–965.
Uno, H., T. Cai, L. Tian, and L. J. Wei (2007). Evaluating prediction rules for t-year survivors with censored regression models. Journal of the American Statistical Association 102, 527–537.
Examples
data(smart)
x <- as.matrix(smart[, -c(1, 2)])[1:500, ]
time <- smart$TEVENT[1:500]
event <- smart$EVENT[1:500]
y <- survival::Surv(time, event)
fit <- fit_lasso(x, y, nfolds = 5, rule = "lambda.1se", seed = 11)
# Model validation by bootstrap with time-dependent AUC
# Normally boot.times should be set to 200 or more,
# we set it to 3 here only to save example running time.
val.boot <- validate(
x, time, event,
model.type = "lasso",
alpha = 1, lambda = fit$lambda,
method = "bootstrap", boot.times = 3,
tauc.type = "UNO", tauc.time = seq(0.25, 2, 0.25) * 365,
seed = 1010
)
#> Start bootstrap sample 1
#> Start bootstrap sample 2
#> Start bootstrap sample 3
# Model validation by 5-fold cross-validation with time-dependent AUC
val.cv <- validate(
x, time, event,
model.type = "lasso",
alpha = 1, lambda = fit$lambda,
method = "cv", nfolds = 5,
tauc.type = "UNO", tauc.time = seq(0.25, 2, 0.25) * 365,
seed = 1010
)
#> Start fold 1
#> Start fold 2
#> Start fold 3
#> Start fold 4
#> Start fold 5
# Model validation by repeated cross-validation with time-dependent AUC
val.repcv <- validate(
x, time, event,
model.type = "lasso",
alpha = 1, lambda = fit$lambda,
method = "repeated.cv", nfolds = 5, rep.times = 3,
tauc.type = "UNO", tauc.time = seq(0.25, 2, 0.25) * 365,
seed = 1010
)
#> Start repeat round 1 fold 1
#> Start repeat round 1 fold 2
#> Start repeat round 1 fold 3
#> Start repeat round 1 fold 4
#> Start repeat round 1 fold 5
#> Start repeat round 2 fold 1
#> Start repeat round 2 fold 2
#> Start repeat round 2 fold 3
#> Start repeat round 2 fold 4
#> Start repeat round 2 fold 5
#> Start repeat round 3 fold 1
#> Start repeat round 3 fold 2
#> Start repeat round 3 fold 3
#> Start repeat round 3 fold 4
#> Start repeat round 3 fold 5
# bootstrap-based discrimination curves has a very narrow band
print(val.boot)
#> High-Dimensional Cox Model Validation Object
#> Random seed: 1010
#> Validation method: bootstrap
#> Bootstrap samples: 3
#> Model type: lasso
#> glmnet model alpha: 1
#> glmnet model lambda: 0.0466592
#> glmnet model penalty factor: not specified
#> Time-dependent AUC type: UNO
#> Evaluation time points for tAUC: 91.25 182.5 273.75 365 456.25 547.5 638.75 730
summary(val.boot)
#> Time-Dependent AUC Summary at Evaluation Time Points
#> 91.25 182.5 273.75 365 456.25 547.5 638.75
#> Mean 0.6076768 0.6770503 0.7310150 0.7669546 0.7678797 0.7805710 0.7592774
#> Min 0.5741414 0.6582779 0.7163194 0.7479257 0.7491532 0.7607504 0.7495020
#> 0.25 Qt. 0.5818182 0.6594434 0.7209654 0.7549788 0.7560774 0.7678429 0.7512654
#> Median 0.5894949 0.6606089 0.7256113 0.7620318 0.7630017 0.7749355 0.7530288
#> 0.75 Qt. 0.6244444 0.6864366 0.7383628 0.7764691 0.7772430 0.7904813 0.7641652
#> Max 0.6593939 0.7122642 0.7511143 0.7909064 0.7914843 0.8060271 0.7753015
#> 730
#> Mean 0.7503529
#> Min 0.7465142
#> 0.25 Qt. 0.7485221
#> Median 0.7505299
#> 0.75 Qt. 0.7522723
#> Max 0.7540147
plot(val.boot)
#> 91.25 182.5 273.75 365 456.25 547.5 638.75
#> Mean 0.6076768 0.6770503 0.7310150 0.7669546 0.7678797 0.7805710 0.7592774
#> Min 0.5741414 0.6582779 0.7163194 0.7479257 0.7491532 0.7607504 0.7495020
#> 0.25 Qt. 0.5818182 0.6594434 0.7209654 0.7549788 0.7560774 0.7678429 0.7512654
#> Median 0.5894949 0.6606089 0.7256113 0.7620318 0.7630017 0.7749355 0.7530288
#> 0.75 Qt. 0.6244444 0.6864366 0.7383628 0.7764691 0.7772430 0.7904813 0.7641652
#> Max 0.6593939 0.7122642 0.7511143 0.7909064 0.7914843 0.8060271 0.7753015
#> 730
#> Mean 0.7503529
#> Min 0.7465142
#> 0.25 Qt. 0.7485221
#> Median 0.7505299
#> 0.75 Qt. 0.7522723
#> Max 0.7540147
# k-fold cv provides a more strict evaluation than bootstrap
print(val.cv)
#> High-Dimensional Cox Model Validation Object
#> Random seed: 1010
#> Validation method: k-fold cross-validation
#> Cross-validation folds: 5
#> Model type: lasso
#> glmnet model alpha: 1
#> glmnet model lambda: 0.0466592
#> glmnet model penalty factor: not specified
#> Time-dependent AUC type: UNO
#> Evaluation time points for tAUC: 91.25 182.5 273.75 365 456.25 547.5 638.75 730
summary(val.cv)
#> Time-Dependent AUC Summary at Evaluation Time Points
#> 91.25 182.5 273.75 365 456.25 547.5 638.75
#> Mean 0.3594933 0.5725705 0.7359625 0.7898166 0.7913584 0.7995129 0.7911661
#> Min 0.0050000 0.0050000 0.5871406 0.5874762 0.5874762 0.5888388 0.5204137
#> 0.25 Qt. 0.0050000 0.4276249 0.6262626 0.7323377 0.7400465 0.7499517 0.7992726
#> Median 0.3724490 0.7473958 0.7653061 0.8174583 0.8174583 0.8437509 0.8211214
#> 0.75 Qt. 0.5867347 0.7653061 0.7747872 0.8854951 0.8854951 0.8902918 0.8902918
#> Max 0.8282828 0.9175258 0.9263158 0.9263158 0.9263158 0.9247312 0.9247312
#> 730
#> Mean 0.7617612
#> Min 0.5204137
#> 0.25 Qt. 0.7854386
#> Median 0.7926118
#> 0.75 Qt. 0.8200499
#> Max 0.8902918
plot(val.cv)
#> 91.25 182.5 273.75 365 456.25 547.5 638.75
#> Mean 0.3594933 0.5725705 0.7359625 0.7898166 0.7913584 0.7995129 0.7911661
#> Min 0.0050000 0.0050000 0.5871406 0.5874762 0.5874762 0.5888388 0.5204137
#> 0.25 Qt. 0.0050000 0.4276249 0.6262626 0.7323377 0.7400465 0.7499517 0.7992726
#> Median 0.3724490 0.7473958 0.7653061 0.8174583 0.8174583 0.8437509 0.8211214
#> 0.75 Qt. 0.5867347 0.7653061 0.7747872 0.8854951 0.8854951 0.8902918 0.8902918
#> Max 0.8282828 0.9175258 0.9263158 0.9263158 0.9263158 0.9247312 0.9247312
#> 730
#> Mean 0.7617612
#> Min 0.5204137
#> 0.25 Qt. 0.7854386
#> Median 0.7926118
#> 0.75 Qt. 0.8200499
#> Max 0.8902918
# repeated cv provides similar results as k-fold cv
# but more robust than k-fold cv
print(val.repcv)
#> High-Dimensional Cox Model Validation Object
#> Random seed: 1010
#> Validation method: repeated cross-validation
#> Cross-validation folds: 5
#> Cross-validation repeated times: 3
#> Model type: lasso
#> glmnet model alpha: 1
#> glmnet model lambda: 0.0466592
#> glmnet model penalty factor: not specified
#> Time-dependent AUC type: UNO
#> Evaluation time points for tAUC: 91.25 182.5 273.75 365 456.25 547.5 638.75 730
summary(val.repcv)
#> Note: for repeated CV, we evaluated quantile statistic tables for
#> each CV repeat, then calculated element-wise mean across all tables.
#> Time-Dependent AUC Summary at Evaluation Time Points
#> 91.25 182.5 273.75 365 456.25 547.5
#> Mean of Mean 0.3610550 0.6420535 0.7360920 0.7818087 0.7829022 0.7915698
#> Mean of Min 0.0050000 0.2655556 0.5455362 0.5748210 0.5748210 0.5937272
#> Mean of 0.25 Qt. 0.0050000 0.5541139 0.6359196 0.7380272 0.7416286 0.7373083
#> Mean of Median 0.4021086 0.7194473 0.7891285 0.8121328 0.8139987 0.8129568
#> Mean of 0.75 Qt. 0.6086546 0.7929182 0.8198284 0.8606522 0.8606522 0.8881416
#> Mean of Max 0.7845118 0.8782325 0.8900470 0.9234104 0.9234104 0.9257152
#> 638.75 730
#> Mean of Mean 0.7713189 0.7498721
#> Mean of Min 0.5495460 0.5453652
#> Mean of 0.25 Qt. 0.6638718 0.6775897
#> Mean of Median 0.8185457 0.7557037
#> Mean of 0.75 Qt. 0.8941499 0.8607948
#> Mean of Max 0.9304813 0.9099070
plot(val.repcv)
#> 91.25 182.5 273.75 365 456.25 547.5
#> Mean of Mean 0.3610550 0.6420535 0.7360920 0.7818087 0.7829022 0.7915698
#> Mean of Min 0.0050000 0.2655556 0.5455362 0.5748210 0.5748210 0.5937272
#> Mean of 0.25 Qt. 0.0050000 0.5541139 0.6359196 0.7380272 0.7416286 0.7373083
#> Mean of Median 0.4021086 0.7194473 0.7891285 0.8121328 0.8139987 0.8129568
#> Mean of 0.75 Qt. 0.6086546 0.7929182 0.8198284 0.8606522 0.8606522 0.8881416
#> Mean of Max 0.7845118 0.8782325 0.8900470 0.9234104 0.9234104 0.9257152
#> 638.75 730
#> Mean of Mean 0.7713189 0.7498721
#> Mean of Min 0.5495460 0.5453652
#> Mean of 0.25 Qt. 0.6638718 0.6775897
#> Mean of Median 0.8185457 0.7557037
#> Mean of 0.75 Qt. 0.8941499 0.8607948
#> Mean of Max 0.9304813 0.9099070
# # Test fused lasso, SCAD, and Mnet models
#
# data(smart)
# x = as.matrix(smart[, -c(1, 2)])[1:500,]
# time = smart$TEVENT[1:500]
# event = smart$EVENT[1:500]
# y = survival::Surv(time, event)
#
# set.seed(1010)
# val.boot = validate(
# x, time, event, model.type = "flasso",
# lambda1 = 5, lambda2 = 2,
# method = "bootstrap", boot.times = 10,
# tauc.type = "UNO", tauc.time = seq(0.25, 2, 0.25) * 365,
# seed = 1010)
#
# val.cv = validate(
# x, time, event, model.type = "scad",
# gamma = 3.7, alpha = 1, lambda = 0.05,
# method = "cv", nfolds = 5,
# tauc.type = "UNO", tauc.time = seq(0.25, 2, 0.25) * 365,
# seed = 1010)
#
# val.repcv = validate(
# x, time, event, model.type = "mnet",
# gamma = 3, alpha = 0.3, lambda = 0.05,
# method = "repeated.cv", nfolds = 5, rep.times = 3,
# tauc.type = "UNO", tauc.time = seq(0.25, 2, 0.25) * 365,
# seed = 1010)
#
# print(val.boot)
# summary(val.boot)
# plot(val.boot)
#
# print(val.cv)
# summary(val.cv)
# plot(val.cv)
#
# print(val.repcv)
# summary(val.repcv)
# plot(val.repcv)