FLAIR | PD | T1 | T2 | FLAIR_10 | PD_10 | T1_10 | T2_10 | FLAIR_20 | PD_20 | T1_20 | T2_20 | GOLD_Lesions | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.143692 | 1.586219 | -0.799859 | 1.634467 | 0.437568 | 0.823800 | -0.002059 | 0.573663 | 0.279832 | 0.548341 | 0.219136 | 0.298662 | 0 |
1 | 1.652552 | 1.766672 | -1.250992 | 0.921230 | 0.663037 | 0.880250 | -0.422060 | 0.542597 | 0.422182 | 0.549711 | 0.061573 | 0.280972 | 0 |
2 | 1.036099 | 0.262042 | -0.858565 | -0.058211 | -0.044280 | -0.308569 | 0.014766 | -0.256075 | -0.136532 | -0.350905 | 0.020673 | -0.259914 | 0 |
3 | 1.037692 | 0.011104 | -1.228796 | -0.470222 | -0.013971 | -0.000498 | -0.395575 | -0.221900 | 0.000807 | -0.003085 | -0.193249 | -0.139284 | 0 |
19 Machine learning validation
20 Basics
When comparing our predictions versus our actual values, there are several considerations. One, is the distinction between agreement and association (Agresti 2003). Agreement implies association but the other direction does not necessarily apply. A Pearson correlation, for example, measures association, since its invariant to any linear transformation of the predictions. In this sense, Pearson correlations do not check to what extent the predictions are calibrated to the response.
In addition, we have several settings worth considering. The outcome could be continuous, multivariate continuous, binary, multi-class (categorical), multi-label, ordinal (ordered categorical) and mixtures of these. The predictions are typically continuous or multivariate continuous. We’ll start by discussing the case where the outcome is binary and the predictions are also binary, for example by binarizing a continuous prediction. However, before we begin, we should discuss testing/training strategies.
20.1 Testing versus training versions
It’s important to emphasize that for each component of validation discussed, there’s a version applied to training data and held out data. Often data is broken into three components
- Training data: data used to train the model.
- validation data: data used to choose hyperparameters, such as layers in a neural network.
- testing data: data used only for final evaluation of the model.
So, in a two class classification problem, there is a training ROC, validation ROC and testing ROC. I would also add testing the model on novel held out datasets, which is a stronger form of validation focusing on generalizability. In addition, different criteria would need to be applied to time series data, where one might differentiate forecasting error from other sorts. Simply holding out random times in a time series dataset is often not enough, since then you would be using the future to predict the past.
Here will focus on the most normal settings and consider different ways to probe for model fit. Some of the strategies are widely used in ML/AI, others are just widely used in more traditional statistical settings.
20.2 Binary outcomes (two-class classification)
Throughout this section, let
Consider the following definitions for the results of a diagnostic test
- Sensitivity
, probability that the prediction is positive given the disease is present, also called the true positive rate. (one minus the sensitivity) is the false negative rate. - Specificity
, probability that the prediction is negative given the disease is absent, also called the true negative rate. (one minus the specificity) is the false negative rate. - PPV, positive predictive value
. - NPV, negative predictive value
. - DLR+, diagnostic likelihood ratio of a positive prediction
is also the sensitivity over one minus the specificity, i.e. the true positive rate divided by the false positive rate. - DLR-, diagnostic likelihood ratio of a negative prediction
is one minus the sensitivity divided by the specificity or the false negative rate divided by the true negative rate. - The disease prevalence is
(and, generally less discussed, the prediction disease prevalance .) - The accuracy is
, which is the sensitivity times the prevalence plus the specificity times one minus the prevalance.
In a frequency setting with a positive prediction, one might argue that the
If you have a cross-sectional sample, then all of these quatities are directly estimable. If the data were sampled by case / control status (
20.2.1 Basic example
A study comparing the efficacy of HIV tests, reports on an experiment which concluded that HIV antibody tests have a sensitivity of 99.7% and a specificity of 98.5% Suppose that a subject, from a population with a .1% prevalence of HIV, receives a positive test result. What is the positive predictive value?
Mathematically, we want
In this population a positive test result suggests a 6% probability that the subject has the disease, (the positive predictive value is 6% for this test). If you were wondering how it could be so low for this test, the low positive predictive value is due to low prevalence of disease and the somewhat modest specificity
Suppose it was known that the subject was an intravenous drug user and routinely had intercourse with an HIV infected partner? Our prevalence would change dramatically, thus increasing the PPV. You might wonder if there’s a way to summarize the evidence without appealing to an often unknowable prevalence? Diagnostic likelihood ratios provide this for us.
We have:
and
Therefore, dividing these two equations we have:
In other words, the post test odds of disease is the pretest odds of disease times the
HIV example revisited Let’s reconsider our HIV antibody test again. Suppose a subject has a positive HIV test,
Suppose instead that a subject has a negative test result. Then
20.2.2 ROC curves
By thresholding
= dat.FLAIR
x = dat.GOLD_Lesions
y = dat.FLAIR[y == 0]
x0 = dat.FLAIR[y == 1]
x1
= True, label = 'Gold Std = 0')
sns.kdeplot(x0, shade = True, label = 'Gold Std = 1')
sns.kdeplot(x1, shade plt.show()
Consider a given FLAIR threshold, say
The ROC curve satisfies:
- Starts at the point (0, 0). This can be seen as
implies which implies . - Ends at the point (1, 1). This can be seen as
imples which implies . - Is monotonic. This can be seen as
increasing implies is non-increasing which implies is non-decreasing. - A uniformly better ROC curve lies entirely above a worse ROC curve. This follows from 3 and from the interpretation that higher values in the curve mean higher true positive rates for a fixed false positive rate. (Note two ROC curves can cross so that one may not be uniformly better than the other.)
- Is always worse than the discontinuous function (0, 0), (0, 1), (1, 1). This follows from 1-3.
- If
the ROC curve is the identity line from (0, 0) to (1, 1). This follows from for , thus and hence . - The ROC curve is invariant to strictly increasing monotonic transformations. Let
for a strictly monotonic, strictly increasing function. Let and be the associated true and false positive rates. Note then , and then . Then, the ROC function, where is composition. - The ROC curve is an identity line whenever
provided is continuous. This follows from 6 and 7, since if i follows distibution then for by the probability integral transform. - The ROC curve for
as a test for flips the ROC curve of as a test for over the identity line. (This simply reverses and , hence the result.)
20.2.2.1 Estimation
A natural (and consistent) estimate of
## Add terms at the beginning and the end over the max and under the min
= np.concatenate( [[ np.min(x) - 1], np.sort(np.unique(x)) , [np.max(x) + 1]])
c
= [np.mean( (x1 >= citer) ) for citer in c]
tpr = [np.mean( (x0 >= citer) ) for citer in c]
fpr
plt.plot(fpr, tpr)0,1], [0,1]) plt.plot([
20.2.2.2 Binormal estimation
We could also assume distributional forms for
from scipy.stats import norm
= np.mean(x0), np.mean(x1)
mu0, mu1 = np.std(x0), np.std(x1)
s0, s1 = np.linspace(0, 3, 1000)
c_seq
= 1-norm.cdf(c_seq, mu0, s0)
fpr_binorm = 1-norm.cdf(c_seq, mu1, s1)
tpr_binorm
plt.plot(fpr, tpr)
plt.plot(fpr_binorm, tpr_binorm)0,1], [0,1]) plt.plot([
20.2.2.3 AUC
The area under the ROC curve is given by
Using this interpretation, we can easily calculate the AUC for the binormal model.
It can be shown that the Wilcoxon Rank Sum Test is a test of
In some cases the full AUC isn’t of interest, so a partial AUC can be used.
20.2.3 Calibration
We say that the classification probability,
Let
One way to look at the optimal classification probability is to consider
is calibrated. . That is, the covariates contain no additional information beyond what is contained in . for some other classification prediction probability, if and only if it is finer than .- A classification probability,
,is calibrated iff
Here we say
For 1.
For 3. The conditional independence,
Then it follows that
Consider a classifier that is coarser. So there exists a
For 4.
20.2.4 Agreement
Expected agreement is optimized when threhsolding whether the optimal classification probability is greater than 0.5. That is, the optimal classifier is
Let
Strong agreement suggests agreement suggests
20.2.4.1 Marginal homogeneity
One question that one could reasonably ask is how well does the fitted data marginals mirror the actual marginals? That is, does
Consider comparing the two. Let
20.2.4.2 Conditional logistic regression
An alternative to marginal agreement is from subject-specific models. The idea here would be to create a model
Here, this suggests that there is a person-specific effect,
Thus, conditional logistic regression simply estimates
20.3 Multi-label and multi-class
Multi-label prediction validation typically follows from two class classification. For example, when trying to ascertain whether there is a dog and a tractor in a pictures, it is reasonable to evaluate the performance of predicting the presence of a dog while separately evaluating predicting the presence of a tractor. The fact that records can have duplicate labels.
Multi-class problems are more direct generalizations of our two class classification problem. Of the two class definitions, the accuracy remains obvioiusly well defined and can be used without modification. One could evaluate the sensitivity and specificity, etc, associated with each class versus not each class. We’ll also show some more model based approaches for evaluating multi-class performance that more completely characterize the agreement, association and performance.
The standard summary of a multi-class output is a confusion matrix. Let’s look at the confusion matrix from the medmnist algorithm.
Using downloaded and verified file: /home/bcaffo/.medmnist/pathmnist.npz
Using downloaded and verified file: /home/bcaffo/.medmnist/pathmnist.npz
Using downloaded and verified file: /home/bcaffo/.medmnist/pathmnist.npz
First, here are example pathology images.
=20) train_dataset.montage(length
/home/bcaffo/miniconda3/envs/ds4bio/lib/python3.10/site-packages/medmnist/utils.py:25: FutureWarning:
`multichannel` is a deprecated argument name for `montage`. It will be removed in version 1.0. Please use `channel_axis` instead.
Here are the different class labels.
'label'] info[
{'0': 'adipose',
'1': 'background',
'2': 'debris',
'3': 'lymphocytes',
'4': 'mucus',
'5': 'smooth muscle',
'6': 'normal colon mucosa',
'7': 'cancer-associated stroma',
'8': 'colorectal adenocarcinoma epithelium'}
And here’s a general description of the data and problem.
'description'] info[
'The PathMNIST is based on a prior study for predicting survival from colorectal cancer histology slides, providing a dataset (NCT-CRC-HE-100K) of 100,000 non-overlapping image patches from hematoxylin & eosin stained histological images, and a test dataset (CRC-VAL-HE-7K) of 7,180 image patches from a different clinical center. The dataset is comprised of 9 types of tissues, resulting in a multi-class classification task. We resize the source images of 3×224×224 into 3×28×28, and split NCT-CRC-HE-100K into training and validation set with a ratio of 9:1. The CRC-VAL-HE-7K is treated as the test set.'
I ran the algorithm from the MEDMNIST site in the background (code in the quarto document). Here is our testing dataset confusion matrix.
#! cache: true
= [], []
targets_pred, targets_actual
for i,t in test_loader:
=-1))
targets_pred.append(model(i).softmax(dim
targets_actual.append(t)
= torch.cat(targets_pred, dim = 0).detach().numpy().argmax(axis = 1)
targets_pred = torch.cat(targets_actual, dim=0).numpy().squeeze()
targets_actual
from sklearn.metrics import confusion_matrix
= confusion_matrix(targets_actual, targets_pred)
cm print(cm)
; plt.imshow(cm)
[[1290 0 0 0 8 31 8 0 1]
[ 0 847 0 0 0 0 0 0 0]
[ 0 0 152 9 0 149 0 29 0]
[ 0 0 56 554 0 0 19 3 2]
[ 41 25 1 0 881 13 7 35 32]
[ 3 18 23 40 0 462 0 43 3]
[ 1 0 30 41 14 2 584 11 58]
[ 0 0 131 7 2 61 2 191 27]
[ 0 2 59 24 7 2 57 15 1067]]
We can calculate accuracy easily as the fraction of times that
print(np.round(np.mean( targets_pred == targets_actual ) * 100, 3))
print(len(targets_pred))
print(1 / np.sqrt(len(targets_pred)))
83.955
7180
0.011801515411874575
So, we’re getting roughly 80% test set accuracy on 7k cases. Since a 95% binomial confidence interval has margin of error (MOE) roughly
20.3.1 Loglinear models
Poisson loglinear model are often used for modeling contingency tables like confusion matrices. They have a correspondence with the associated multinomial models since if
In our case, let
Independence is the model that
And the false positive rate
A second less useful log-linear model is symmetry. Symmetry assumes
Note that symmetry implies marginal homogeneity, since
20.3.1.1 Quasi-independence
A useful deviation from independence is quasi-independence. This model assumes independence in the off-diagonal cells. Specifically,
20.3.1.1.1 Quasi-symmetry
Quasi-symmetry is a very general model that contains independence, symmetry, quasi-independence and marginal homogeneity as special cases. Quasi-symmetry specifies that
20.3.1.2 Deviance
The deviance statistic is used to measure model fit. The deviance is specified as
20.3.1.3 Example
Let’s look at these model fits for our confusion matrix. First, let’s get the data into a dataframe. There are 81 cells and value_counts omits zero counts (which we need).
## This is the confusion matrix data
= pd.DataFrame({'y' : targets_actual, 'yhat' : targets_pred}).value_counts().reset_index().rename(columns = {0 : 'n'})
valdat 12)
valdat.head(
## There are a ton of values with nothing
## so create a matrix with all of the values
## then we can merge that in and set the missing values to 0
= np.array([(x, y) for x in range(n_classes) for y in range(n_classes)])
grid = pd.DataFrame(grid , columns = ['y', 'yhat'])
grid
## Merge in the values with nothting
= valdat.merge(grid,how = "outer").fillna(0) valdat
Let’s fit each of these models and compare them.
import statsmodels.api as sm
import statsmodels.formula.api as smf
= 'n ~ C(y) + C(yhat)'
indep_formula = smf.glm(formula = indep_formula,
mod_indep= valdat,
data = sm.families.Poisson()
family
).fit()
## The symmetry term is just the concatenated y and yhat in order
'symm'] = np.fmin(valdat['y'], valdat['yhat']).astype(str) + np.fmax(valdat['y'], valdat['yhat']).astype(str)
valdat[
= 'n ~ symm'
symm_formula = smf.glm(formula = symm_formula,
mod_symm= valdat,
data = sm.families.Poisson()
family
).fit()
## This sets the non-diagonal elements as the lowest value
## so when the model fits, that's a reference category and
## it includes terms for every diagonal element.
'indep'] = '0'
valdat[= valdat['y'] == valdat['yhat']
diag 'indep'] = valdat.symm[diag]
valdat.loc[diag,
= 'n ~ C(y) + C(yhat) + indep'
quasi_indep_formula = smf.glm(formula = quasi_indep_formula,
mod_quasi_indep = valdat,
data = sm.families.Poisson()
family
).fit()
## define the model
= 'n ~ C(y) + C(yhat) + symm'
quasi_symm_formula ## this is what we would fit if we could
= smf.glm(formula = quasi_symm_formula,
qs_design = valdat, family = sm.families.Poisson())
data ## The above doesn't fit for me. So, I grab the design
## matrix from this formula, drop the redundant columns
## and fit directly without the formula
= np.linalg.svd(qs_design.exog, 0)
x, d, v ## the correct DF for a QS model is
= int(1 + 2 * (n_classes - 1) + n_classes * (n_classes - 1) / 2)
df_qs = x[:, 0 : df_qs]
x = valdat['n'].to_numpy()
y = sm.GLM(y, x, family = sm.families.Poisson()).fit()
mod_quasi_symm
## Note the model DF typically excludes the intercept
print(
pd.DataFrame('models' : ['I', 'S', 'QI', 'QS'],
{'deviance' : [mod_indep.deviance, mod_symm.deviance, mod_quasi_indep.deviance, mod_quasi_symm.deviance],
'model df' : [mod_indep.df_model, mod_symm.df_model, mod_quasi_indep.df_model, mod_quasi_symm.df_model],
'resid df': [mod_indep.df_resid, mod_symm.df_resid, mod_quasi_indep.df_resid, mod_quasi_symm.df_resid]
}
)
)print(
pd.DataFrame('models' : ['S', 'S - QS'],
{'deviance' : [mod_symm.deviance, mod_symm.deviance - mod_quasi_symm.deviance],
'model df' : [mod_symm.df_model, mod_quasi_symm.df_model - mod_symm.df_model]
}
) )
models deviance model df resid df
0 I 22463.164035 16 64
1 S 634.422898 44 36
2 QI 1359.774930 25 55
3 QS 298.520707 52 28
models deviance model df
0 S 634.422898 44
1 S - QS 335.902190 8
It would appear that none of independence, symmetry or QI fit this data well. The quasi-symmetry model does appear to fit substantially better than the others. Moreover, models that incorporate symmetry seem to be preferable to those that focus on independence.
20.3.1.4 Kappa
Recall the distinction between agreement and association. In a two class problem, a perfect algorithm, where one acidentally switched the labels would have stil perfect association and no agreement.
Kappa measures agreement. Accuracy in our multinomial model is
with estimator
Here is the kappa score for our pathology example.
sklearn.metrics.cohen_kappa_score(targets_actual, targets_pred)
0.8159986513828732
A standard error for kappa can be obtained using asymptotic normality from the multinomial (see Agresti 2003, sec. 10.5.4). In addition, one could simply use a bootstrap.
20.4 Prediction
The standard metrics of prediction are mean squared error, correlation and
It’s interesting to note that simply by the calculation of the variance, the theoretical MSE satisfies:
Thus, we have that our MSE breaks down into residual variability and bias.
The correlation is a measure of association between
A useful plot in this scenario is the mean/difference plot where the mean,
= smf.ols('FLAIR ~ T2', data = dat).fit()
fit = fit.predict()
yhat = fit.resid
e
= plt.subplots()
fig, ax = 0, color = "orange");
ax.axhline(y ; plt.scatter(yhat, e )
One could test marginal equality of distributions using any statistic, including Wasserstein, Kolmogorov Smirnov, … and swapping prediction / actual labels. Moreover, one could use a rank based statistic in the same way.