Estimation of conditional mixture Weibull distribution with right-censored data using neural network for time-to-event analysis
This paper extends classical parametric time-to-event models (a part of Survival Analysis (SA)) using a neural network (NN) architecture. While providing the implementation, we will be reviewing the following concepts:
Theoretical concepts:
- Survival analysis
- Parametric models
- Maximum Likelihood estimate
- Drawing random numbers from a given distribution
Implementation details (colab, p=1 and colab, any p):
- Custom loss and architecture with Tensorflow
- How to use Tensorflow probability as an alternative to build the model architecture and its loss
- Time-to-event prediction with simulated data (the ones discussed in the paper cf. Figure 4)
Our first goal was to try and reproduce Figure 4, but unfortunately this wasn't entirely possible even after spending hours of debugging. This is quite deceiving but unfortunately many interesting papers do not come with official code and we are therefore helpless when facing such issues. I believe that here the reproducibility issue comes from some restitution error of the how the SYNTHETIC dataset was constructed (i.e the parameter matrices that link the features $X$ with the Weibull coefficients could be wrong). What makes me think this is that I don't find similar values for the real likelihood (the green vertical bars of Figure 4), so it seems my model implementation is not the issue here (even though you're never sure a code is 100% error proof !). We will discuss this in detail inside the implementation section (section 3).
Before starting the implementation, I'd like to say I really enjoy papers providing examples with synthetic data: it doesn't make you have to look for the real-world data which usually comes with some cleaning steps that can be hard to reproduce. Also it helps tremendously with the debugging, if your model doesn't work on simulated data then you're pretty sure something is wrong with the theory or with the implementation.
1. Time-to-event prediction
Time-to-event prediction is a field of Survival Analysis that reasons about when a future event will occur. The general setting is that you are given a study with a start and end date, in which you observe events happening (or not) within a population of $n$ individuals. The event itself can really be anything you can think of, but most of the literature is concerned with death of patients (where the term "survival" is sadly appropriate) within the context of clinical trials. In addition to clinical trials this field of study was also successful in dealing with system failures and customer churn.
The key feature of Time-to-event prediction is that the models take into account the fact that the event might not appear during the study for some individuals. This is referred to as "Right Censoring" (RC) and is how most of the SA use cases are flavoured. The paper under consideration is one of those.
Let's just formalise this with a little bit of math (there are thousands of locations of where you can find those equations but I find it convenient to have it all on the same page).
- $T_i^{\star}$: random variable representing the time at which the event is observed for individual $i$
- $C_i^{\star}$: the RC time (in practice it usually does not depend on $i$, but let's keep the indexing here for the sake of generalization)
- Because of this RC thing, the general observations we have is not $T_i^{\star}$ but rather $T_i=\min\left\{T_i^{\star}, C_i^{\star}\right\}$ (i.e we observe the censoring time rather than the event time when the event didn't happen within the study)
- $\Delta_i$ is a binary random variable equal to 1 when we observe $T_i^{\star}$ and 0 when we observe $C_i^{\star}$. In compact form, this comes as $\Delta_i = \mathbb{1}_{\left\{T_i \leq C_i^{\star} \right\}}$
Okay enough notations ! So now we basically want to be able to say something about $T_i$ using the data at hand. We can then use this knowledge to compare groups of individuals together (such as the ones taking some new drug versus others), and make some forecasts about how likely it is that an even will happen in the future.
Typically, knowing the density distribution or its empirical cumulative distribution would be sweet. Think about it: if you know the distribution then by definition you also know how to compute the probability $\mathbb{P}\left[T_i > t \right] \, \forall \left(t, i\right)$, which can resolve the different use cases we just mentioned. This quantity is called the "Survival Function" and is central to SA. Within clinical trials, it tells us what is the probability that a patient survives at least up to $t$.
Now we basically have enough math and vocab to come back to the paper. This paper deals with a family of Time-to-event models called parametric models. These models are very simple and assume that all you need to know to fully characterise the distribution of $T_i$ is a set of parameters $\mathbf{\theta} \in \mathbb{R}^d$ where $d$ is usually quite small (2, 3...). Typically, the models will assume that the distribution belongs to a parametric distribution such as the Normal distribution, Exponential distribution, or, as the title suggests, the Weibull distribution. So when dealing with parametric SA models, the pipeline is usually straightforward:
- Pick a parametric distribution (it seems people venerate the Weibull family in SA...). This gives you $d$. According to the Weibull parameterisation found in Wikipedia, we would have $\theta = \left(\lambda, k\right) \in \mathbb{R}^2$, more precisely $\left]0; +\infty\right] \times \left]1; +\infty\right] $ (for the sake of SA, we have put 1 for $k$ as lower bound as otherwise we would not have the survival probability increase with time, see the Wikipedia page for more detail. But in general $k$ has the same support as $\lambda$)
- Estimate $\mathbf{\theta}$ using your data, usually using Maximum Likelihood Estimator (MLE)
- Done !
2. MLE for parametric Survival model
2.1 Without covariates
On page 4, the paper kind of pops the Likelihood's expression out of its hat using the specifics of the Weibull distribution, but without really explaining where this formula came from. I know, you're totally lost, you'd give anything for a random blogger to tell you a little more. So let's go through the math together. In Wikipedia the parameters of the Weibull are denoted $(\lambda, k)$. The paper uses other Greek letters: $(\beta, \eta)$ so let's just stick with those to avoid confusion. Also, the $\beta$ in the paper corresponds to the $k$ in Wikipedia so the order and the notations are both different.
We want to build the likelihood for the pair of random variables $(T_i, \Delta_i)$. The trick is to write down the probability of $T_i$ depending of the value of $\Delta_i$ ($\delta_i$ represents realizations i.e actual observations of $\Delta_i$):
- If $\delta_i=1$ it means we have observed the event before the censoring time therefore it the realization of $T_i$ is $t_i$, then the probability of such event would be "$\mathbb{P}(T_i=t_i)$". But we're in the Weibull parametric world here i.e we assumed $T_i$ follows a $\mathcal{W}(\beta, \eta)$ so we know this happens with probability $f_{T_i}(t_i)$ where $f$ is the density function of a $\mathcal{W}(\beta, \eta)$
- If $\delta_i =0$ then all we know is that the event happens after the censoring time. This happens with probability $\mathbb{P}(T_i>c_i) = 1-F_{T_i}(c_i)$ where $F$ is the cumulative distribution function of $T_i$
So now the likelihood of observing one data point $t_i$ can be written down in only one equation that deals with both possibilities for $\delta_i$:
$$\mathcal{L}_i = f_{T_i}(t_i)^{\delta_i}(1-F_{T_i}(c_i))^{1-\delta_i}$$
Note that the above still remains valid for any parametric model, not just the Weibull one. But let's stick to the Weibull and keep writing what the above would look like:
- $1-F_{T_i}(c_i) = e^{-\left(\frac{c_i}{\eta} \right)^\beta} = S_{\beta, \eta}(c_i)$ (using the article notations).
- $f_{T_i}(t_i)=\frac{\beta}{\eta}\left(\frac{t_i}{\eta}\right)^{\beta-1}e^{-\left(\frac{t_i}{\eta} \right)^\beta}= S_{\beta, \eta}(t_i) \lambda_{\beta, \eta}(t_i)$ when $t_i >0 $ and $0$ otherwise, again using the article notations
$$\mathcal{L}_i = \left[S_{\beta, \eta}(t_i) \lambda_{\beta, \eta}(t_i)\right]^{\delta_i}S_{\beta, \eta}(c_i)^{1-\delta_i}$$
Usually the log likelhood ($\mathcal{LL}_i$) is used rather than the likelihood. Here, using basic math, this would be:
$$\mathcal{LL}_i = {\delta_i}\log\left[S_{\beta, \eta}(t_i) \lambda_{\beta, \eta}(t_i)\right] + (1-\delta_i)\log\left[S_{\beta, \eta}(c_i)\right]$$
In practice, once we have the log likelihood written down, the goal is usually to find the parameters that maximizes it. The quantity:
$$(\hat{\beta}, \hat{\eta}) = \text{arg}\max_{(\beta, \eta)}\mathcal{LL}$$
is referred to as the maximum likelihood estimator (MLE) of $(\beta, \eta)$ and has some nice theoretical properties (consistency, efficiency, asymptotically unbiased etc.). Once we have those estimates, then we have an approximation of the time-to-event parametric distribution and we can use it right away for forecasting or any other use cases.
2.2 With covariates
In the above, it is assumed that all individuals have the same survival time distribution, that is the distribution of $(T_i, \Delta_i)$ does not depend on $i$. I think we all appreciate this may be too simplistic to account for real life cases. We can assume there always are attributes attached to individuals that can affect their survival probability (for clinical trials this can be past surgeries, smoking and other life habits). Therefore a natural extension of what we wrote above is to not consider $T_i \sim \mathcal{W}(\beta, \eta)$ but rather $T_i \sim \mathcal{W}(\beta_i, \eta_i)$ where the parameters $(\beta_i, \eta_i)$ are functions (probably complex ones) of the individuals' attributes $X_i \in \mathbb{R}^p$ where $p$ is the number of attributes. The paper actually describes a methodology of how to estimate $(\beta_i, \eta_i)$ using two things:
- The core event observations i.e the realizations of $(T_i, \Delta_i)$
- The individual covariates $X_i$
3. Model implementation
Alright we can now start implementing the article ! A first thing to note is that the article suggests to model the time-to-event using either a simple Weibull distribution or a mixture of such distributions. We will first go through the simplest case which will consist in retrieving the first three bars in Figure plot, that is when the mixture of Weibull distribution only consists of one distribution (this means we do not need to estimate any $\alpha$).
3.0 Side note: sampling from a Weibull
If a random variable $Y \sim \mathcal{D}$ where $\mathcal{D}$ is a distribution with CDF $F_{\mathcal{D}}$, then $F^{-1}_{\mathcal{D}}(U) \sim Y$ if $U \sim \mathcal{U}_{[0;1]}$
This means that we can draw sample from $Y$ by first sampling from $U \sim \mathcal{U}_{[0;1]}$ and then compute $F^{-1}_{\mathcal{D}}(U)$. Of course, this doesn't work if you can't compute $F^{-1}_{\mathcal{D}}$, but with the Weibull distribution this is pretty simple (checkout this cell of the notebook for the formula)
3.1 Simple case: $p=1$ (without Tensorflow Probability)
When $p=1$, it's possible to simply ignore half of the model's architecture (i.e the block that the authors label as the classifier or clf). Note that so far we have use $p$ to denote the number of features (number of columns in the matrix $X$), ad the paper uses it to denote the number of Weibull used in the mixture, so let's use $p$ for the number of mixtures too.
In the paper's section 4.0.2 there is a paragraph called Network Configuration. Using the configuration described the model, here's how our tensorflow model looks like (we've been using the tf.keras.utils.plot_model to generate this drawing):
As one can see, the NN learns a mapping between the features and the Weibull coefficients $(\beta_i, \eta_i)_{i \leq n}$. The survival examples $(t_i, \delta_i)$ are NOT fed in the model: they are just used to compute the loss. In this Colab notebook, one can find the implementation of the model with $p=1$ and using the SYNTHETIC dataset described in the paper.
Here are some comments about the implementation:
- We use the tf.data.Dataset API to batch through our data samples. Every batch is a tuple $(X, y)$. Usually, $y$ is a tensor with only one column (think about classification: $y$ would only be the class label for every sample). But here, there are actually two columns: $(t_i, \delta_i)$, so we must be careful about splitting it back into two columns before using it in our loss
- We use the Functional API to build our model. This API consists in instantiating the layers of the model and directly using them as functions, and then wrapping all of them into an object tf.keras.models.Model. This API is super powerful and allows you to build a huge amount of network architectures
- We do not use a validation set for the experiment: the goal of the SYNTHETIC dataset is to make sure we are able to retrieve the true parameters of the simulated data, so for once we're totally allowed to overfit as much as we want. This is usually what you would do when building fake data to make sure everything works as expected
- Theoretically, minimizing the loss for $\mathcal{L}$ would yield the same results as for $\mathcal{LL}$. Hoewever, if you implement $\mathcal{L}$ as your loss, you will get numerical instability because you have to use the exponential function (this function is just part of the Weibull density and Weibull CDF). So this is why we are writing down $\mathcal{LL}$ which removes the exponential
- We offset the parameters $(\beta_i, \eta_i)$ so that they fall in the range that we expect them to have, just like suggested by the paper (checkout the tf.add(beta, 2.) and tf.add(eta, 1.+1e-6) in the notebook).
- Batch size = 64
- Learning rate = $10^{-3}$ with Adam optimizer
- Number of epochs 200 with early stopping of 40 epochs (using the training loss, there's no validation here)
3.2 Simple case: $p=1$ (with Tensorflow Probability)
- They can evaluate their own densities, CDFs and... survival functions at any point(s) !
- You can build mixture of distributions, meaning we can easily switch from the case $p=1$ to any $p$ !
- We can build the REAL distribution using Tensorflow probability and then we can use it to generate samples and compute the real NLL. This is very handy when dealing with $p > 1$, as sampling from a mixture density can become cumbersome, especially in our specific setup where we have decided to use the inverse CDF theorem to sample from a Weibull distribution (I kind of shot myself in the foot here...)
- A function called mydist: this function generates a distribution using input parameters (that are typically output by some upstream layer in your neural network). This is the function we should modify if we want to switch from Weibull to another distribution, or if we want to include mixtures of Weibull rather than a simple Weibull
- The DistributionLambda layer: this is actually a Neural Network layer that will generate the distribution specified by mydist using input parameters
- The loss: here we called it survival_loss and it has the signature (y, distr) and it uses some methods of the Tensorflow distribution class: log_prob (which is the log density) and log_survival_function (self explanatory)
3.3 General case $p\geq 1$ with Tensorflow probability
- It returns the mixture weights (we need to implement the "classifier" part of the network, using the terminology used in the paper)
- For every $i$, the network's encoder should return $p$ tuples $(\beta_i, \eta_i)$ as we have mixtures i.e there will now be a double indexing of the parameters: one for the individuals and one for the mixture components: $((\beta_i^k, \eta_i^k)_{k\leq p})_{i\leq n}$
- Adam with LR $10^{-4}$
- Batch size of 64
- Patience of 20 when monitoring the training loss, with a maximum of 800 epochs





Comments
Post a Comment