Suppose we have a dataset spanning ten cancers and we want to fit a lasso logistic regression model to predict 10-year survival.
Some cancer classes in our dataset are large (e.g. breast), some are small (e.g. head and neck).
We suppose that some features are predictive across all classes. Combining all of our data gives us the best chance to discover these features.
We also suppose that some features are predictive in just one (or a few) classes. Discovering them is easiest when we fit a separate model for each class.
Pretraining gives us the best of both worlds.
Pretraining fits models in two steps: the first step discovers features that are predictive across all classes, and the second discovers features for individual classes.
Pretraining is a general method to pass information from one model to another, with many more use cases including:
These cases are described in detail and with examples in Part 2 of this series.
Before we describe pretraining in more detail, we will do a quick review of the lasso.
For the Gaussian family with data \((x_i,y_i), i=1,2,\ldots n\), the lasso has the form \[ {\rm argmin}_{\beta_0, \beta} \frac{1}{2} \sum_{i=1}^n\left(y_i- \beta_0 - x_i^T \beta\right)^2 + \lambda \sum_{j=1}^p |\beta_j |. \]
This is composed of:
Though I show here a lasso linear regression problem, the lasso can be applied to the entire GLM family (including logistic and Poisson regression for example), as well as Cox’s proportional hazards model.
There are many software packages that implement the lasso, a popular
one is glmnet in R.
Varying the regularization parameter \(\lambda \ge 0\) yields a path of solutions:
an optimal value \(\hat\lambda\) is
usually chosen by cross-validation using for example the
cv.glmnet function from the package
glmnet.
\[ {\rm argmin}_{\beta_0, \beta} \frac{1}{2} \sum_{i=1}^n\left(y_i- O_i - \beta_0 - x_i^T \beta\right)^2 + \lambda \sum_{j=1}^p |\beta_j |. \]
\[ {\rm argmin}_{\beta_0, \beta} \frac{1}{2} \sum_{i=1}^n\left(y_i- O_i - \beta_0 - x_i^T \beta\right)^2 + \lambda \sum_{j=1}^p {\rm pf}_j |\beta_j |. \]
At the extremes:
- penalty factor = 0 \(\rightarrow\) the feature will always be included in the model;
- a penalty factor = \(+\infty\) \(\rightarrow\) the feature is discarded.
Both options are standard in generalized linear models (and in the
software package glmnet).
For the input grouped setting, pretraining model fitting happens in two steps.
Fit a lasso penalized model using the full dataset \(X, y\) to get coefficients \(\hat{\beta}\).
Choose the hyperparameter \(\alpha\) either by trying a few values, or using cross-validation (details later).
Fit a model for each group \(X_k, y_k\) using
- Offset: \((1 - \alpha) X_k \hat{\beta}\)
- Penalty factor: \(1\) (features selected by the overall model) or \(\frac{1}{\alpha}\) (other features)
Note that when \(\alpha = 0\), this fits a fine-tuned version of the overall model for each group: the pretrained model can only use features that were selected by the overall model, and it must use the full prediction from the overall model as an offset.
At the other extreme when \(\alpha = 1\), this approach completely ignores the overall model and fits a separate model for each group.
ptLasso is a package that fits pretrained models
using the glmnet package, including lasso, elasticnet and
ridge models.
All model fitting in ptLasso is done with
cv.glmnet.
The first step of pretraining is a straightforward call to
cv.glmnet; the second step is done by calling
cv.glmnet with an offset and penalty factor.
Additionally, one call to ptLasso fits an overall
model, pretrained class specific models, and class specific models for
each group (without pretraining).
The ptLasso package also includes methods for
prediction and plotting, and a function that performs K-fold
cross-validation.
A typical modeling pipeline looks like this:
# Data: features X, response y and group identifier "groups"
# Fit a model:
fit = ptLasso(X, y, groups)
# ...optionally using cross-validation for parameter selection:
cvfit = cv.ptLasso(X, y, groups)
# Visualize the model:
plot(fit)
plot(cvfit)
# And make predictions:
predict(fit, Xtest, groupstest)
predict(cvfit, Xtest, groupstest)
# Inspect the coefficients
coef(fit)
coef(cvfit)
We’ll walk through these steps using ptLasso with a
simple simulated dataset.
First, we load the ptLasso package:
require(ptLasso)
Now, we simulate data with \(3\)
groups and a continuous response using the helper function
gaussian.example.data:
set.seed(1234)
out = gaussian.example.data(k = 3)
x = out$x; y = out$y; groups = out$groups
outtest = gaussian.example.data(k = 3)
xtest = outtest$x; ytest = outtest$y; groupstest = outtest$groups
Sanity check:
dim(x)
## [1] 600 80
head(y)
## [1] -7.593043389 10.842131973 -9.845550692 -36.527392301 0.003346178
## [6] 18.835316767
We are ready to fit a model using ptLasso. We’ll choose
the pretraining parameter \(\alpha =
0.5\) (randomly chosen).
Here, we’ll go through three steps: fitting a model, predicting with a validation set, and extracting the model coefficients.
fit <- ptLasso(x, y, groups, alpha = 0.5)
The function ptLasso used cv.glmnet to fit
\(7\) models:
A call to plot displays the cross validation curves for
each model. The top row shows the overall model, the middle row the
pretrained models, and the bottom row the individual models.
plot(fit)
predict makes predictions from all \(7\) models. It returns a list
containing:
yhatoverall (predictions from the overall model),yhatpre (predictions from the pretrained models)
andyhatind (predictions from the individual models).By default, predict uses lambda.min – the
value of the regularization parameter \(\lambda\) that minimized the
cross-validated mean squared error – for all \(7\) cv.glmnet models.
preds = predict(fit, xtest, groupstest=groupstest)
If you also provide ytest (for model validation),
predict will additionally compute a measure of
performance.
preds = predict(fit, xtest, groupstest=groupstest, ytest=ytest)
preds
##
## Call:
## predict.ptLasso(object = fit, xtest = xtest, groupstest = groupstest,
## ytest = ytest)
##
##
## alpha = 0.5
##
## Performance (Mean squared error):
##
## allGroups mean group_1 group_2 group_3 r^2
## Overall 566.1 566.1 492.7 574.7 630.7 0.3591
## Pretrain 551.5 551.5 517.5 595.0 542.1 0.3756
## Individual 575.7 575.7 528.5 596.7 602.0 0.3482
##
## Support size:
##
## Overall 23
## Pretrain 60 (13 common + 47 individual)
## Individual 76
To look at the coefficients of our models, we can use the
coef function:
coefs = coef(fit)
The variable coefs is a list with the coefficients for
the overall, individual and pretrained models:
names(coefs)
## [1] "individual" "pretrain" "overall"
And coefs$pretrain is a list of length 3, containing the
coefficients for groups 1, 2 and 3:
length(coefs$pretrain)
## [1] 3
The coefficients themselves are returned as a single-column matrix,
as in glmnet:
head(coefs$pretrain[[3]])
## 6 x 1 sparse Matrix of class "dgCMatrix"
## s1
## (Intercept) -0.3221821
## V1 5.0591788
## V2 4.9309031
## V3 3.7828370
## V4 5.1038363
## V5 5.6944842
In our previous example, we chose the parameter \(\alpha\) randomly. In practice we recommend making a more thoughtful choice by using:
cv.ptLasso, which will recommend a choice of \(\alpha\) based on CV performance.Here, we’ll try cv.ptLasso.
cvfit <- cv.ptLasso(x, y, groups)
cvfit
##
## Call:
## cv.ptLasso(x = x, y = y, groups = groups, family = "gaussian",
## type.measure = "mse", use.case = "inputGroups", group.intercepts = TRUE)
##
##
##
## type.measure: mse
##
##
## alpha overall mean wtdMean group_1 group_2 group_3
## Overall 572.6 572.6 572.6 517.0 509.3 691.5
## Pretrain 0.0 524.0 524.0 524.0 486.8 549.4 535.6
## Pretrain 0.1 494.4 494.4 494.4 470.3 492.4 520.6
## Pretrain 0.2 494.4 494.4 494.4 467.1 493.9 522.1
## Pretrain 0.3 481.7 481.7 481.7 458.2 489.2 497.7
## Pretrain 0.4 483.4 483.4 483.4 454.3 472.3 523.5
## Pretrain 0.5 494.1 494.1 494.1 458.1 506.5 517.8
## Pretrain 0.6 496.0 496.0 496.0 464.9 483.9 539.2
## Pretrain 0.7 490.8 490.8 490.8 461.1 496.3 514.9
## Pretrain 0.8 507.5 507.5 507.5 483.1 491.0 548.5
## Pretrain 0.9 507.9 507.9 507.9 503.8 488.1 531.8
## Pretrain 1.0 516.0 516.0 516.0 504.4 523.0 520.6
## Individual 516.0 516.0 516.0 504.4 523.0 520.6
##
## alphahat (fixed) = 0.3
## alphahat (varying):
## group_1 group_2 group_3
## 0.4 0.4 0.3
Plotting the cv.ptLasso object visualizes performance as
a function of \(\alpha\), and draws a
vertical line to show the value of \(\alpha\) that minimized the CV MSE.
plot(cvfit, plot.alphahat = TRUE)
And, as with ptLasso, we can predict. By
default, predict uses the \(\alpha\) that minimized the cross validated
MSE.
preds = predict(cvfit, xtest, groupstest=groupstest, ytest=ytest)
preds
##
## Call:
## predict.cv.ptLasso(object = cvfit, xtest = xtest, groupstest = groupstest,
## ytest = ytest)
##
##
## alpha = 0.3
##
## Performance (Mean squared error):
##
## allGroups mean group_1 group_2 group_3 r^2
## Overall 561.0 561.0 497.6 572.0 613.3 0.3649
## Pretrain 537.8 537.8 493.7 591.8 527.8 0.3911
## Individual 564.8 564.8 540.5 593.6 560.3 0.3606
##
## Support size:
##
## Overall 30
## Pretrain 48 (13 common + 35 individual)
## Individual 70
We can also extract coefficients as we did before, only now the return format is slightly different. As before, we have a list containing the overall, individual and pretrained coefficients.
coefs = coef(cvfit)
names(coefs)
## [1] "individual" "pretrain" "overall"
But now, coefs$pretrain is a list of length 11: one set
of coefficients for each value of \(\alpha\).
length(coefs$pretrain)
## [1] 11
For a fixed choice of \(\alpha\), the coefficients are as they were in our previous example – a list of length 3, with one set of coefficients for each group.
length(coefs$pretrain[[1]])
## [1] 3
In our example, we supposed that each row of \(X\) belonged to one cancer class.
But what if we don’t have a predefined grouping on our data?
An example of this may be data with clinical variables like age and sex, where we have some prior belief that risks are different for different subpopulations.
In this case, we can use clinical variables to fit e.g. a shallow
CART tree; each node of the tree defines a separate group. Given these
groups, we can then use pretraining as described here. There is an
example of this in the vignette for ptLasso.
Let’s do a quick simulation. We’ll simulate 3 groups as before, this time varying the intercepts across groups. We will also simulate clinical variables:
out = gaussian.example.data(k = 3, intercepts = c(-10, 0, 10))
x = out$x; y = out$y; groups = out$groups
outtest = gaussian.example.data(k = 3, intercepts = c(-10, 0, 10))
xtest = outtest$x; ytest = outtest$y; groupstest = outtest$groups
n = nrow(x)
clinvars = cbind(
age = ifelse(c(groups, groupstest) == 1,
runif(2*n, min = 20, max = 50),
runif(2*n, min = 50, max = 90)),
sex = ifelse(c(groups, groupstest) == 1,
sample(c(0, 1), 2*n, replace = T),
ifelse(c(groups, groupstest) == 2, 0, 1)
)
)
clinvars.train = clinvars[1:n, ]
clinvars.test = clinvars[-(1:n), ]
Now, we train our CART tree using only our “clinical variables”. We
will load and use the package rpart.
require(rpart)
treefit = rpart(y~.,
data = data.frame(clinvars.train, y),
control=rpart.control(maxdepth=2, minsplit=50))
rpart.plot::rpart.plot(treefit)
The tree did a great job finding our groups: group 1 is under 50, and groups 2 and 3 are over 50 and divided by sex.
Now, we want our tree to return the ID of the terminal node for each
observation (instead of a predicted value of \(y\)). The following is a trick that causes
predict to behave as desired.
leaf=treefit$frame[,1]=="<leaf>"
treefit$frame[leaf,"yval"]=1:sum(leaf)
predgroups.train = predict(treefit, data.frame(clinvars.train))
predgroups.test = predict(treefit, data.frame(clinvars.test))
Finally, we are ready to apply pretraining using the predicted groups as our grouping variable.
cvfit = cv.ptLasso(x, y, predgroups.train)
plot(cvfit)
predict(cvfit, xtest, predgroups.test, ytest = ytest)
##
## Call:
## predict.cv.ptLasso(object = cvfit, xtest = xtest, groupstest = predgroups.test,
## ytest = ytest)
##
##
## alpha = 0.8
##
## Performance (Mean squared error):
##
## allGroups mean group_1 group_2 group_3 r^2
## Overall 565.4 565.8 480.2 503.2 714.0 0.4433
## Pretrain 487.4 487.4 498.8 475.6 487.7 0.5201
## Individual 491.1 491.0 511.1 477.3 484.6 0.5165
##
## Support size:
##
## Overall 47
## Pretrain 76 (26 common + 50 individual)
## Individual 76
Now we have all the ingredients we need to use ptLasso
with a real dataset.
See Part 2 for more use cases of ptLasso including
datasets with a multinomial target, multi-response data and time series
data!