Estimation of conditional mixture Weibull distribution with right-censored data using neural network for time-to-event analysis

Paper link

This paper extends classical parametric time-to-event models (a part of Survival Analysis (SA)) using a neural network (NN) architecture. While providing the implementation, we will be reviewing the following concepts:

Theoretical concepts:

  • Survival analysis
  • Parametric models
  • Maximum Likelihood estimate
  • Drawing random numbers from a given distribution

Implementation details (colab, p=1 and colab, any p):

    • Custom loss and architecture with Tensorflow
    • How to use Tensorflow probability as an alternative to build the model architecture and its loss
    • Time-to-event prediction with simulated data (the ones discussed in the paper cf. Figure 4)
    The key take away of this article is how to leverage Tensorflow Probability for survival analysis. Using the framework removes a lot of pain, typically yielding a code base that is easy to modify and that can accomodate every setting described in this article (i.e using a Mixture or not, using the Weibull or some other distribution). So I really hope you will make it up to this section !

    Our first goal was to try and reproduce Figure 4, but unfortunately this wasn't entirely possible even after spending hours of debugging. This is quite deceiving but unfortunately many interesting papers do not come with official code and we are therefore helpless when facing such issues. I believe that here the reproducibility issue comes from some restitution error of the how the SYNTHETIC dataset was constructed (i.e the parameter matrices that link the features $X$ with the Weibull coefficients could be wrong). What makes me think this is that I don't find similar values for the real likelihood (the green vertical bars of Figure 4), so it seems my model implementation is not the issue here (even though you're never sure a code is 100% error proof !). We will discuss this in detail inside the implementation section (section 3).

    Before starting the implementation, I'd like to say I really enjoy papers providing examples with synthetic data: it doesn't make you have to look for the real-world data which usually comes with some cleaning steps that can be hard to reproduce. Also it helps tremendously with the debugging, if your model doesn't work on simulated data then you're pretty sure something is wrong with the theory or with the implementation.

    1. Time-to-event prediction

    Time-to-event prediction is a field of Survival Analysis that reasons about when a future event will occur. The general setting is that you are given a study with a start and end date, in which you observe events happening (or not) within a population of $n$ individuals. The event itself can really be anything you can think of, but most of the literature is concerned with death of patients (where the term "survival" is sadly appropriate) within the context of clinical trials. In addition to clinical trials this field of study was also successful in dealing with system failures and customer churn.

    The key feature of Time-to-event prediction is that the models take into account the fact that the event might not appear during the study for some individuals. This is referred to as "Right Censoring" (RC) and is how most of the SA use cases are flavoured. The paper under consideration is one of those.

    Let's just formalise this with a little bit of math (there are thousands of locations of where you can find those equations but I find it convenient to have it all on the same page).

    1. $T_i^{\star}$: random variable representing the time at which the event is observed for individual $i$
    2. $C_i^{\star}$: the RC time (in practice it usually does not depend on $i$, but let's keep the indexing here for the sake of generalization)
    3. Because of this RC thing, the general observations we have is not $T_i^{\star}$ but rather $T_i=\min\left\{T_i^{\star}, C_i^{\star}\right\}$ (i.e we observe the censoring time rather than the event time when the event didn't happen within the study)
    4. $\Delta_i$ is a binary random variable equal to 1 when we observe $T_i^{\star}$ and 0 when we observe $C_i^{\star}$. In compact form, this comes as $\Delta_i = \mathbb{1}_{\left\{T_i \leq C_i^{\star} \right\}}$

    Okay enough notations ! So now we basically want to be able to say something about $T_i$ using the data at hand. We can then use this knowledge to compare groups of individuals together (such as  the ones taking some new drug versus others), and make some forecasts about how likely it is that an even will happen in the future. 

    Typically, knowing the density distribution or its empirical cumulative distribution would be sweet. Think about it: if you know the distribution then by definition you also know how to compute the probability $\mathbb{P}\left[T_i > t \right] \, \forall \left(t, i\right)$, which can resolve the different use cases we just mentioned. This quantity is called the "Survival Function" and is central to SA. Within clinical trials, it tells us what is the probability that a  patient survives at least up to $t$. 

    Now we basically have enough math and vocab to come back to the paper. This paper deals with a family of Time-to-event models called parametric models. These models are very simple and assume that all you need to know to fully characterise the distribution of $T_i$ is a set of parameters $\mathbf{\theta} \in \mathbb{R}^d$ where $d$ is usually quite small (2, 3...). Typically, the models will assume that the distribution belongs to a parametric distribution such as the Normal distribution, Exponential distribution, or, as the title suggests, the Weibull distribution. So when dealing with parametric SA models, the pipeline is usually straightforward:

    • Pick a parametric distribution (it seems people venerate the Weibull family in SA...). This gives you $d$. According to the Weibull parameterisation found in Wikipedia, we would have $\theta = \left(\lambda, k\right) \in \mathbb{R}^2$, more precisely $\left]0; +\infty\right] \times \left]1; +\infty\right] $ (for the sake of SA, we have put 1 for $k$ as lower bound as otherwise we would not have the survival probability increase with time, see the Wikipedia page for more detail. But in general $k$ has the same support as $\lambda$)
    • Estimate $\mathbf{\theta}$ using your data, usually using Maximum Likelihood Estimator (MLE)
    • Done !

    2. MLE for parametric Survival model

    2.1 Without covariates

    On page 4, the paper kind of pops the Likelihood's expression out of its hat using the specifics of the Weibull distribution, but without really explaining where this formula came from. I know, you're totally lost, you'd give anything for a random blogger to tell you a little more. So let's go through the math together. In Wikipedia the parameters of the Weibull are denoted $(\lambda, k)$. The paper uses other Greek letters: $(\beta, \eta)$ so let's just stick with those to avoid confusion. Also, the $\beta$ in the paper corresponds to the $k$ in Wikipedia so the order and the notations are both different.

    We want to build the likelihood for the pair of random variables $(T_i, \Delta_i)$. The trick is to write down the probability of $T_i$ depending of the value of $\Delta_i$ ($\delta_i$ represents realizations i.e actual observations of $\Delta_i$): 

    1. If $\delta_i=1$ it means we have observed the event before the censoring time therefore it the realization of $T_i$ is $t_i$, then the probability of such event would be "$\mathbb{P}(T_i=t_i)$". But we're in the Weibull parametric world here i.e we assumed $T_i$ follows a $\mathcal{W}(\beta, \eta)$ so we know this happens with probability $f_{T_i}(t_i)$ where $f$ is the density function of a $\mathcal{W}(\beta, \eta)$
    2. If $\delta_i =0$ then all we know is that the event happens after the censoring time. This happens with probability $\mathbb{P}(T_i>c_i) = 1-F_{T_i}(c_i)$ where $F$ is the cumulative distribution function of $T_i$

    So now the likelihood of observing one data point $t_i$ can be written down in only one equation that deals with both possibilities for $\delta_i$:
    $$\mathcal{L}_i = f_{T_i}(t_i)^{\delta_i}(1-F_{T_i}(c_i))^{1-\delta_i}$$

    Note that the above still remains valid for any parametric model, not just the Weibull one. But let's stick to the Weibull and keep writing what the above would look like:

    Using wikipedia, we know that if $T_i \sim \mathcal{W}(\beta, \eta)$ then:
    • $1-F_{T_i}(c_i) = e^{-\left(\frac{c_i}{\eta} \right)^\beta} = S_{\beta, \eta}(c_i)$ (using the article notations).
    • $f_{T_i}(t_i)=\frac{\beta}{\eta}\left(\frac{t_i}{\eta}\right)^{\beta-1}e^{-\left(\frac{t_i}{\eta} \right)^\beta}= S_{\beta, \eta}(t_i) \lambda_{\beta, \eta}(t_i)$ when $t_i >0 $ and $0$ otherwise, again using the article notations
    So we're there ! In the $\mathcal{W}(\beta, \eta)$ situation we have:

    $$\mathcal{L}_i =  \left[S_{\beta, \eta}(t_i) \lambda_{\beta, \eta}(t_i)\right]^{\delta_i}S_{\beta, \eta}(c_i)^{1-\delta_i}$$

    Usually the log likelhood ($\mathcal{LL}_i$) is used rather than the likelihood. Here, using basic math, this would be:

    $$\mathcal{LL}_i =  {\delta_i}\log\left[S_{\beta, \eta}(t_i) \lambda_{\beta, \eta}(t_i)\right] + (1-\delta_i)\log\left[S_{\beta, \eta}(c_i)\right]$$

    In practice, once we have the log likelihood written down, the goal is usually to find the parameters that maximizes it. The quantity:

    $$(\hat{\beta}, \hat{\eta}) = \text{arg}\max_{(\beta, \eta)}\mathcal{LL}$$ 

    is referred to as the maximum likelihood estimator (MLE) of $(\beta, \eta)$ and has some nice theoretical properties (consistency, efficiency, asymptotically unbiased etc.). Once we have those estimates, then we have an approximation of the time-to-event parametric distribution and we can use it right away for forecasting or any other use cases.

    2.2 With covariates

    In the above, it is assumed that all individuals have the same survival time distribution, that is the distribution of $(T_i, \Delta_i)$ does not depend on $i$. I think we all appreciate this may be too simplistic to account for real life cases. We can assume there always are attributes attached to individuals that can affect their survival probability (for clinical trials this can be past surgeries, smoking and other life habits). Therefore a natural extension of what we wrote above is to not consider $T_i \sim \mathcal{W}(\beta, \eta)$ but rather $T_i \sim \mathcal{W}(\beta_i, \eta_i)$ where the parameters $(\beta_i, \eta_i)$ are functions (probably complex ones) of the individuals' attributes $X_i \in \mathbb{R}^p$ where $p$ is the number of attributes. The paper actually describes a methodology of how to estimate $(\beta_i, \eta_i)$ using two things:

    • The core event observations i.e the realizations of $(T_i, \Delta_i)$
    • The individual covariates $X_i$

    I guess we cannot just replace $(\beta, \eta)$ with $(\beta_i, \eta_i)$ in the log likelihood and then find the MLE for $((\beta_1, \eta_1), ..., (\beta_n, \eta_n))$, this problem seems unfeasible with classical optimization algorithms such as Newton-Raphson. Instead, the article suggests to learn a mapping in between the covariates $X_i$ and the parameters $(\beta_i, \eta_i)$ i.e it assumes there exists a function $g$ such that $g\left(X_i\right) = (\beta_i, \eta_i)$. And it suggests to use a NN as a placeholder for $g$, and uses the log likelihood as a loss after replacing $(\beta, \eta)$ with $g(X_i)$. This article is a beautiful example of how neural networks can be customised to "hack" around a problem, which is why I chose to talk about it. 

    3. Model implementation 

    Alright we can now start implementing the article ! A first thing to note is that the article suggests to model the time-to-event using either a simple Weibull distribution or a mixture of such distributions. We will first go through the simplest case which will consist in retrieving the first three bars in Figure plot, that is when the mixture of Weibull distribution only consists of one distribution (this means we do not need to estimate any $\alpha$).

    3.0 Side note: sampling from a Weibull

    Usually, I would rely on scipy or numpy to draw random numbers from a given distribution.. However, when it comes to the Weibull distribution, scipy has a different parameterization than the one used in Wikipedia or this article. The scipy page (at least at the time of writing) gives you some advice about how you can tansform their parameterization to retrieve the one used in Wikipedia, but that sounded like extra work and validations that I didn't feel like performing. Instead, I'd rather use the following theorm:

    If a random variable $Y \sim \mathcal{D}$  where $\mathcal{D}$ is a distribution with CDF $F_{\mathcal{D}}$, then $F^{-1}_{\mathcal{D}}(U) \sim Y$ if $U \sim \mathcal{U}_{[0;1]}$

    This means that we can draw sample from $Y$ by first sampling from $U \sim \mathcal{U}_{[0;1]}$ and then compute $F^{-1}_{\mathcal{D}}(U)$. Of course, this doesn't work if you can't compute $F^{-1}_{\mathcal{D}}$, but with the Weibull distribution this is pretty simple (checkout this cell of the notebook for the formula)

    3.1 Simple case: $p=1$ (without Tensorflow Probability)

    When $p=1$, it's possible to simply ignore half of the model's architecture (i.e the block that the authors label as the classifier or clf). Note that so far we have use $p$ to denote the number of features (number of columns in the matrix $X$), ad the paper uses it to denote the number of Weibull used in the mixture, so let's use $p$ for the number of mixtures too.

    In the paper's section 4.0.2 there is a paragraph called Network Configuration. Using the configuration described the model, here's how our tensorflow model looks like (we've been using the tf.keras.utils.plot_model to generate this drawing):



    The implementation of the model is pretty straightforward as the model is a sequence of Dense / BatchNorm layers which are readily implemented with the tensorflow ecosystem. You can find it in this cell here.

    As one can see, the NN learns a mapping between the features and the Weibull coefficients $(\beta_i, \eta_i)_{i \leq n}$. The survival examples $(t_i, \delta_i)$ are NOT fed in the model: they are just used to compute the loss. In this Colab notebook, one can find the implementation of the model with $p=1$ and using the SYNTHETIC dataset described in the paper. 

    Here are some comments about the implementation:

    1. We use the tf.data.Dataset API to batch through our data samples. Every batch is a tuple $(X, y)$. Usually, $y$ is a tensor with only one column (think about classification: $y$ would only be the class label for every sample). But here, there are actually two columns: $(t_i, \delta_i)$, so we must be careful about splitting it back into two columns before using it in our loss
    2. We use the Functional API to build our model. This API consists in instantiating the layers of the model and directly using them as functions, and then wrapping all of them into an object tf.keras.models.Model. This API is super powerful and allows you to build a huge amount of network architectures
    3. We do not use a validation set for the experiment: the goal of the SYNTHETIC dataset is to make sure we are able to retrieve the true parameters of the simulated data, so for once we're totally allowed to overfit as much as we want. This is usually what you would do when building fake data to make sure everything works as expected
    4. Theoretically, minimizing the loss for $\mathcal{L}$ would yield the same results as for $\mathcal{LL}$. Hoewever, if you implement $\mathcal{L}$ as your loss, you will get numerical instability because you have to use the exponential function (this function is just part of the Weibull density and Weibull CDF). So this is why we are writing down  $\mathcal{LL}$ which removes the exponential
    5. We offset the parameters $(\beta_i, \eta_i)$ so that they fall in the range that we expect them to have, just like suggested by the paper (checkout the tf.add(beta, 2.) and tf.add(eta, 1.+1e-6) in the notebook).
    Now, about the red and green bars in Figure 4 of the article: when I compute the real NLL our results are quite different. Below is what we obtained using the same parameters and method described in the paper (values are 0.6318, 0.827, 0.652). Note that I've tried simulating the SYNTHETIC dataset multiple times but I always get numbers within the same range. What I can conclude from this is that something is wrong either in the coefficients given in the paper or with my understanding of how I should have used them.

    By the way, in case this was not clear, the real log likelihood is nothing but the log-likelihood using the real Weibull coefficients $(\beta_i, \eta_i)_{i\leq n}$ that were simulated when building the SYNTHETIC data. The Colab notebook has a section that is dedicated to computing it, and the code should make it crystal clear for you.


    Even though the numbers being different is quite disappointing, it doesn't really matter: it's still possible to train the model and check whether the loss of the trained model gets close as the real loss. The good news is that it does ! For instance, with one experiment I found that with $f_i=f_2$ (i.e using $X^2$ as features, check the paper to see what $f_i$ means) I could approximate the real loss with an error of magnitude of $10^{-4}$, just like the one in the paper (see the last cell of the notebook). 

    Here are some comments about the training:
    • Batch size = 64
    • Learning rate = $10^{-3}$ with Adam optimizer
    • Number of epochs 200 with early stopping of 40 epochs (using the training loss, there's no validation here)
    I ran the fit twice with 200 epochs as the early stopping criteria was never met, and I stopped there once I noticed I had reached a very good approximation of the real NLL. 

    3.2 Simple case: $p=1$ (with Tensorflow Probability)

    Really glad you've read the article up to here ! At its core, the model we've been studying so far thanks to the paper is a probabilistic model as it directly models a probabilistic distribution. Up to now we have abstracted the distribution using the mathematical formulae of densities and CDFs because this is the kind of objects that Tensorflow handles. But Tensorflow Probability actually handles Distributions as objects, and it actually has layers that can feed from tensors (typically some matrices of parameters) and that can output directly such distributions (rather than mathematical expressions). 

    Here are a few cool things that those distribution objects can handle and that are of direct interest in our task of implementing the paper:
    • They can evaluate their own densities, CDFs and... survival functions at any point(s) !
    • You can build mixture of distributions, meaning we can easily switch from the case $p=1$ to any $p$ !
    • We can build the REAL distribution using Tensorflow probability and then we can use it to generate samples and compute the real NLL. This is very handy when dealing with $p > 1$, as sampling from a mixture density can become cumbersome, especially in our specific setup where we have decided to use the inverse CDF theorem to sample from a Weibull distribution (I kind of shot myself in the foot here...)
    This is how our model will look like when using Tensorflow probability:




    The final layer that feeds from the predicted values of $(\beta_i, \eta_i)_{i\leq n}$ is a DistributionLambda layer. It basically feeds from parameter matrices and returns a distribution. Usually, when we create a loss we have something of the form loss(true, preds). Well here we will have exactly the same signature, except that preds will be a distribution rather than a tensor of numbers. 

    The model is built in this cell, and the magic happens in this code:


    We need three elements:
    1. A function called mydist: this function generates a distribution using input parameters (that are typically output by some upstream layer in your neural network). This is the function we should modify if we want to switch from Weibull to another distribution, or if we want to include mixtures of Weibull rather than a simple Weibull
    2. The DistributionLambda layer: this is actually a Neural Network layer that will generate the distribution specified by mydist using input parameters
    3. The loss: here we called it survival_loss and it has the signature (y, distr) and it uses some methods of the Tensorflow distribution class: log_prob (which is the log density) and log_survival_function (self explanatory)
    That's it ! Even if we change the distribution of our model, we do not need the re-write the loss which could have been very error prone. The beautiful thing is that this code converges faster than the one using "pure" Tensorflow (365 epochs rather than 400) using the same synthetic data and has a final loss that is even closer to the real NLL !

    3.3 General case $p\geq 1$ with Tensorflow probability

    So as we said in the section above, we now need to change the function mydist so that it returns a mixture of Weibull rather than a single Weibull. But we also need to change the network's architecture so that:
    • It returns the mixture weights (we need to implement the "classifier" part of the network, using the terminology used in the paper)
    • For every $i$, the network's encoder should return $p$ tuples $(\beta_i, \eta_i)$ as we have mixtures i.e there will now be a double indexing of the parameters: one for the individuals and one for the mixture components: $((\beta_i^k, \eta_i^k)_{k\leq p})_{i\leq n}$
    So all of this is implemented into this cell of the notebook (and the next one); as you can see there is not a lot of difference with the implementation with $p=1$ ! Especially, we didn't have to write down the NLL using a mixture of Weibull which could have involved a lot of work to debug properly. Instead, we put our trust in the hands the Google developers which sounds reasonable. 

    After refactoring the model to account for the $p\geq 1$ case, the architecture now looks like this (input shape is 3 as we have displayed the architecture with $f_i=f_2$ i.e 3 covariates):



     
    Last step is... training ! We use the following training configuration (I didn't experiment a lot with the training configuration, just stopped as soon as I had something yielding good results, but you should tune this when working with real-world data):
    • Adam with LR $10^{-4}$
    • Batch size of 64
    • Patience of 20 when monitoring the training loss, with a maximum of 800 epochs
    After training, we get a theoretical vs predicted loss that are the same with an error of $10^{-5}$ ! 

    So here's the end of this discussion about the paper, in this post we're not going to check how the model performs on the real data provided by the authors, and we actually don't need to: the model is "perfect" on simulated data so we know we got the code right. Now it's up to you to tweak it to solve your own time-to-event use case. Happy coding !





    Comments

    Popular posts from this blog

    An Embedding Learning Framework for Numerical Features in CTR Prediction

    TiDE - Forecasting in the lodging world