---
title: "Introduction to ptLasso"
author:
  name: Erin Craig and Rob Tibshirani
  affiliation: Stanford University
output:
  html_document: 
    toc: true
  pdf_document: default
date: "2024-06-25"
self_contained: true
header-includes: \usepackage{amsmath} \usepackage{amssymb}
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```

## A challenge

- Suppose we have a dataset spanning ten cancers and we want to fit a lasso logistic regression model to predict 10-year survival. 

- Some cancer classes in our dataset are large  (e.g. breast), some are small (e.g. head and neck). 

<!--<div align="center">
<img src="figs/HardChoice.png" width=700>
</div>-->

- There are two obvious approaches: 
1. fit a "pancancer model" to the entire training set and use it to predict for all cancer classes or 
2. fit a separate (class specific) model for each cancer and use it to predict for that class only. 

## Our conceptual model

- We suppose that some features are predictive across all classes. Combining all of our data gives us the best chance to discover these features.

- We also suppose that some features are predictive in just one (or a few) classes. Discovering them is easiest when we fit a separate model for each class.

<!--<div align="center">
<img src="figs/ConceptualModel.png" width=650>
</div>-->

- __Pretraining gives us the best of both worlds.__

- [Pretraining](https://arxiv.org/abs/2401.12911) fits models in two steps: the first step discovers features that are predictive across all classes, and the second discovers features for individual classes.

## Pretraining 


- Pretraining is a _general method_ to pass information from one model to another, with many more use cases including:

   - data with a multinomial outcome,
   - time series data, and
   - multi-response data.

- These cases are described in detail and with examples in Part 2 of this series.

- Before we describe pretraining in more detail, we will do a quick review of the lasso.



### Review of the lasso
For the Gaussian family with data $(x_i,y_i), i=1,2,\ldots n$, the lasso has the form 
\[
{\rm argmin}_{\beta_0, \beta} \frac{1}{2} \sum_{i=1}^n\left(y_i- \beta_0 - x_i^T \beta\right)^2 + \lambda \sum_{j=1}^p |\beta_j |.
\]

This is composed of:

1. a least squares component, encouraging the predictions $x_i^T \beta$ to be close to the observations $y_i$, and
2. a penalty encouraging coefficients to be sparse -- this does _feature selection_. Sparsity can yield better prediction accuracy and interpretability.

Though I show here a lasso linear regression problem, the lasso can be applied to the entire GLM family (including logistic and Poisson regression for example), as well as Cox's proportional hazards model.

There are many software packages that implement the lasso, a popular one is `glmnet` in R.

Varying the regularization parameter $\lambda \ge 0$ yields a path of solutions: an optimal value $\hat\lambda$ is usually chosen by cross-validation using for example the `cv.glmnet` function from the package `glmnet`.

### Modeling options that we use
- We specify an offset $O_i$ for each observation $x_i$.  
This adds an extra column to the data that always has the coefficient 1. For linear regression, this is the same as fitting a residual.

\[
{\rm argmin}_{\beta_0, \beta} \frac{1}{2} \sum_{i=1}^n\left(y_i- O_i - \beta_0 -  x_i^T \beta\right)^2 + \lambda \sum_{j=1}^p |\beta_j |.
\]


- We specify a penalty factor ${\rm pf}_j$ that modifies the lasso penalty for the $j^\text{th}$ feature.

\[
{\rm argmin}_{\beta_0, \beta} \frac{1}{2} \sum_{i=1}^n\left(y_i- O_i - \beta_0 -  x_i^T \beta\right)^2 + \lambda \sum_{j=1}^p {\rm pf}_j  |\beta_j |.
\]

&emsp; &emsp;  At the extremes:

&emsp; &emsp; &emsp;  - penalty factor = 0 $\rightarrow$ the feature will always be included in the model; 

&emsp; &emsp; &emsp;  - a penalty factor = $+\infty$ $\rightarrow$ the feature is discarded.

Both options are standard in generalized linear models (and in the software package `glmnet`).

### Pretraining details

For the input grouped setting, pretraining model fitting happens in two steps.

1. __Fit an overall model__: 

&emsp; &emsp; Fit a lasso penalized model using the full dataset $X, y$ to get coefficients $\hat{\beta}$.

2. __Fit group-specific models__:

&emsp; &emsp; Choose the hyperparameter $\alpha$ either by trying a few values, or using cross-validation (details later). 

Fit a model for each group $X_k, y_k$ using 

&emsp; &emsp; - Offset: $(1 - \alpha) X_k \hat{\beta}$

&emsp; &emsp; - Penalty factor: $1$ (features selected by the overall model) or $\frac{1}{\alpha}$ (other features)

Note that when $\alpha = 0$, this fits a fine-tuned version of the overall model for each group: the pretrained model can _only_ use features that were selected by the overall model, and it must use the full prediction from the overall model as an offset.

At the other extreme when $\alpha = 1$, this approach completely ignores the overall model and fits a separate model for each group.

<!--<div align="center">
<img src="figs/Alpha.png" width=1000>
</div>-->

## The ptLasso package

- `ptLasso` is a package that fits pretrained models using the `glmnet` package, including lasso, elasticnet and ridge models.

- All model fitting in `ptLasso` is done with `cv.glmnet`. 

- The first step of pretraining is a straightforward call to `cv.glmnet`; the second step is done by calling `cv.glmnet` with an offset and penalty factor.

- Additionally, one call to `ptLasso` fits an overall model, pretrained class specific models, and class specific models for each group (without pretraining). 

- The `ptLasso` package also includes methods for prediction and plotting, and a function that performs K-fold cross-validation.

## A quick start example

A typical modeling pipeline looks like this:
```{r, eval = FALSE}

# Data: features X, response y and group identifier "groups"

# Fit a model:
fit = ptLasso(X, y, groups)
# ...optionally using cross-validation for parameter selection:
cvfit = cv.ptLasso(X, y, groups)

# Visualize the model:
plot(fit)
plot(cvfit)

# And make predictions:
predict(fit, Xtest, groupstest)
predict(cvfit, Xtest, groupstest)

# Inspect the coefficients
coef(fit)
coef(cvfit)
```

We'll walk through these steps using `ptLasso` with a simple simulated dataset.

## Example using simulated data

First, we load the `ptLasso` package:
```{r, echo=FALSE}
suppressPackageStartupMessages(require(ptLasso))
suppressPackageStartupMessages(require(rpart))
```

```{r eval = FALSE}
require(ptLasso)
```

### Simulating data
Now, we simulate data with $3$ groups and a continuous response using the helper function `gaussian.example.data`:

- $n = 200$ observations in each group and $p = 80$ features. 
- All groups share $10$ informative features with different coefficient values. 
- Each group has $10$ additional group-specific features.
- Other features are noise.

```{r}
set.seed(1234)

out = gaussian.example.data(k = 3)
x = out$x; y = out$y; groups = out$groups

outtest = gaussian.example.data(k = 3)
xtest = outtest$x; ytest = outtest$y; groupstest = outtest$groups
```

Sanity check:
```{r}
dim(x)
head(y)
```

### Model fitting

We are ready to fit a model using `ptLasso`. We'll choose the pretraining parameter $\alpha = 0.5$ (randomly chosen). 

Here, we'll go through three steps: fitting a model, predicting with a validation set, and extracting the model coefficients.

#### Fit

```{r}
fit <- ptLasso(x, y, groups, alpha = 0.5)
```

The function `ptLasso` used `cv.glmnet` to fit $7$ models: 

- the *overall* model (one model trained using all $3$ groups), 
- the 3 *pretrained* models (one for each group) and
- the 3 *individual* models (one for each group).

A call to `plot` displays the cross validation curves for each model. The top row shows the overall model, the middle row the pretrained models, and the bottom row the individual models.

```{r, fig.width=7, fig.height=6, dpi=100}
plot(fit)
```

#### Predict

`predict` makes predictions from all $7$ models. It returns a list containing:

1. `yhatoverall` (predictions from the overall model), 
2. `yhatpre` (predictions from the pretrained models) and 
3. `yhatind` (predictions from the individual models).

By default, `predict` uses `lambda.min` -- the value of the regularization parameter $\lambda$ that minimized the cross-validated mean squared error -- for all $7$ `cv.glmnet` models. 

```{r}
preds = predict(fit, xtest, groupstest=groupstest)
```

If you also provide `ytest` (for model validation), `predict` will additionally compute a measure of performance. 

```{r}
preds = predict(fit, xtest, groupstest=groupstest, ytest=ytest)
preds
```

#### Interpret

To look at the coefficients of our models, we can use the `coef` function:
```{r}
coefs = coef(fit)
```

The variable `coefs` is a list with the coefficients for the overall, individual and pretrained models:
```{r}
names(coefs)
```

And `coefs$pretrain` is a list of length 3, containing the coefficients for groups 1, 2 and 3:
```{r}
length(coefs$pretrain)
```

The coefficients themselves are returned as a single-column matrix, as in `glmnet`:
```{r}
head(coefs$pretrain[[3]])
```



### Cross validation to choose $\alpha$

In our previous example, we chose the parameter $\alpha$ randomly. In practice we recommend making a more thoughtful choice by using:

1. a validation set to measure performance for a few different choices of $\alpha$ (e.g. $0, 0.25, 0.5, 0.75, 1.0$), or 
2. `cv.ptLasso`, which will recommend a choice of $\alpha$ based on CV performance.

Here, we'll try `cv.ptLasso`.

```{r}
cvfit <- cv.ptLasso(x, y, groups)
cvfit
```

Plotting the `cv.ptLasso` object visualizes performance as a function of $\alpha$, and draws a vertical line to show the value of $\alpha$ that minimized the CV MSE.
```{r, fig.width=5, fig.height=4, dpi=100}
plot(cvfit, plot.alphahat = TRUE)
```

And, as with `ptLasso`, we can `predict`. By default, `predict` uses the $\alpha$ that minimized the cross validated MSE.
```{r}
preds = predict(cvfit, xtest, groupstest=groupstest, ytest=ytest)
preds
```

We can also extract coefficients as we did before, only now the return format is slightly different. As before, we have a list containing the overall, individual and pretrained coefficients. 
```{r}
coefs = coef(cvfit)
names(coefs)
```

But now, `coefs$pretrain` is a list of length 11: one set of coefficients for each value of $\alpha$.
```{r}
length(coefs$pretrain)
```

For a fixed choice of $\alpha$, the coefficients are as they were in our previous example -- a list of length 3, with one set of coefficients for each group.
```{r}
length(coefs$pretrain[[1]])
```

## What if we don't know the input groups?

In our example, we supposed that each row of $X$ belonged to one cancer class. 

But what if we don't have a predefined grouping on our data? 

An example of this may be data with clinical variables like age and sex, where we have some prior belief that risks are different for different subpopulations. 

In this case, we can use clinical variables to fit e.g. a shallow CART tree; each node of the tree defines a separate group. Given these groups, we can then use pretraining as described here. There is an example of this in the vignette for `ptLasso`.

Let's do a quick simulation. We'll simulate 3 groups as before, this time varying the intercepts across groups. We will also simulate clinical variables: 

- in group 1, age is 50 or under and sex is equally likely male or female,
- in group 2, age is over 50 and sex is male,
- in group 3, age is over 50 and sex is female.

```{r}

out = gaussian.example.data(k = 3, intercepts = c(-10, 0, 10))
x = out$x; y = out$y; groups = out$groups

outtest = gaussian.example.data(k = 3, intercepts = c(-10, 0, 10))
xtest = outtest$x; ytest = outtest$y; groupstest = outtest$groups

n = nrow(x)
clinvars = cbind(
  age = ifelse(c(groups, groupstest) == 1, 
               runif(2*n, min = 20, max = 50),
               runif(2*n, min = 50, max = 90)),
  sex = ifelse(c(groups, groupstest) == 1, 
               sample(c(0, 1), 2*n, replace = T), 
               ifelse(c(groups, groupstest) == 2, 0, 1)
               )
)
clinvars.train = clinvars[1:n, ]
clinvars.test  = clinvars[-(1:n), ]
```

Now, we train our CART tree using only our "clinical variables". We will load and use the package `rpart`.
```{r eval = FALSE}
require(rpart)
```

```{r, fig.width=4, fig.height=5, fig.align='center'}
treefit = rpart(y~., 
                data = data.frame(clinvars.train, y), 
                control=rpart.control(maxdepth=2, minsplit=50))

rpart.plot::rpart.plot(treefit)
```

The tree did a great job finding our groups: group 1 is under 50, and groups 2 and 3 are over 50 and divided by sex.

Now, we want our tree to return the ID of the terminal node for each observation (instead of a predicted value of $y$). The following is a trick that causes `predict` to behave as desired. 
```{r}
leaf=treefit$frame[,1]=="<leaf>"   
treefit$frame[leaf,"yval"]=1:sum(leaf)

predgroups.train = predict(treefit, data.frame(clinvars.train))
predgroups.test  = predict(treefit, data.frame(clinvars.test))
```

Finally, we are ready to apply pretraining using the predicted groups as our grouping variable. 
```{r}
cvfit = cv.ptLasso(x, y, predgroups.train)
plot(cvfit)

predict(cvfit, xtest, predgroups.test, ytest = ytest)
```


## Wrap-up

Now we have all the ingredients we need to use `ptLasso` with a real dataset.

See Part 2 for more use cases of `ptLasso` including datasets with a multinomial target, multi-response data and time series data!