A more realistic convnet example¶
Medical MNIST data¶
Medical images offer unique challenges for imaging. A common format for medical images is dicom. Most medical images are 3D or 4D grayscale images. To get a sense of working with medical images, let’s consider a set of 2D color images from the medical mnist library medmnistv2. A tutorial for working with these images can be found here. These are images of moles to determine whether or not they are cancerous. We’ll follow along with the medmnist code then investigate the output.
Here are the imports
!pip install medmnist
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms
import medmnist
from medmnist import INFO, Evaluator
import matplotlib.pyplot as plt
import scipyCollecting medmnist
Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from medmnist) (2.0.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from medmnist) (2.2.2)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.12/dist-packages (from medmnist) (1.6.1)
Requirement already satisfied: scikit-image in /usr/local/lib/python3.12/dist-packages (from medmnist) (0.25.2)
Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from medmnist) (4.67.3)
Requirement already satisfied: Pillow in /usr/local/lib/python3.12/dist-packages (from medmnist) (11.3.0)
Collecting fire (from medmnist)
Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (from medmnist) (2.10.0+cpu)
Requirement already satisfied: torchvision in /usr/local/lib/python3.12/dist-packages (from medmnist) (0.25.0+cpu)
Requirement already satisfied: termcolor in /usr/local/lib/python3.12/dist-packages (from fire->medmnist) (3.3.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->medmnist) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->medmnist) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->medmnist) (2025.3)
Requirement already satisfied: scipy>=1.11.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (1.16.3)
Requirement already satisfied: networkx>=3.0 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (3.6.1)
Requirement already satisfied: imageio!=2.35.0,>=2.33 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (2.37.3)
Requirement already satisfied: tifffile>=2022.8.12 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (2026.3.3)
Requirement already satisfied: packaging>=21 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (26.0)
Requirement already satisfied: lazy-loader>=0.4 in /usr/local/lib/python3.12/dist-packages (from scikit-image->medmnist) (0.5)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn->medmnist) (1.5.3)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from scikit-learn->medmnist) (3.6.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (3.25.2)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (4.15.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (1.14.0)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (3.1.6)
Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch->medmnist) (2025.3.0)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->medmnist) (1.17.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch->medmnist) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch->medmnist) (3.0.3)
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Downloading fire-0.7.1-py3-none-any.whl (115 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 115.9/115.9 kB 6.3 MB/s eta 0:00:00
Installing collected packages: fire, medmnist
Successfully installed fire-0.7.1 medmnist-3.0.2
The dermamnist code downloads the data and loads it into a pytorch dataloader. It also keeps track of outcomes.
## Using the code from https://github.com/MedMNIST/MedMNIST/blob/main/examples/getting_started.ipynb
data_flag = 'dermamnist'
## This defines our NN parameters
NUM_EPOCHS = 10
BATCH_SIZE = 128
lr = 0.001
info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5])
])
DataClass = getattr(medmnist, info['python_class'])
# load the data
train_dataset = DataClass(split = 'train', transform = data_transform, download = True)
test_dataset = DataClass(split = 'test' , transform = data_transform, download = True)
pil_dataset = DataClass(split = 'train', download = True)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size= 2 * BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)100%|██████████| 19.7M/19.7M [00:01<00:00, 14.6MB/s]
Let’s look at a montage of images. Here is the Mayo clinic’s page for investigating moles for potential cancer. To quote them, use the ABCDE method, which they define exactly as below
A is for asymmetrical shape. Look for moles with irregular shapes, such as two very different-looking halves.
B is for irregular border. Look for moles with irregular, notched or scalloped borders — characteristics of melanomas.
C is for changes in color. Look for growths that have many colors or an uneven distribution of color.
D is for diameter. Look for new growth in a mole larger than 1/4 inch (about 6 millimeters).
E is for evolving. Look for changes over time, such as a mole that grows in size or that changes color or shape. Moles may also evolve to develop new signs and symptoms, such as new itchiness or bleeding.
Our data is cross sectional. So parts of C and D and all of E are challenging. We do not have patient retrospective reports.
train_dataset.montage(length=20)
Here are the labels
info['label']{'0': 'actinic keratoses and intraepithelial carcinoma',
'1': 'basal cell carcinoma',
'2': 'benign keratosis-like lesions',
'3': 'dermatofibroma',
'4': 'melanoma',
'5': 'melanocytic nevi',
'6': 'vascular lesions'}Here’s a frequency table of the outcomes in the training data. Notice the prevalence of the sixth category (numbered 5), melanocytic nevi. Looking over our data, it’s by far the most prevalent. See below where we see that the this category is 67% of our training data.
train_targets = []
for i, t in train_loader:
train_targets = np.append(train_targets, t.numpy())
import pandas as pd
pd.DataFrame({'target' : train_targets }).value_counts(normalize = True)This is important to note, because 67% trianing data accuracy is obtainable by simply calling every tumor melanocytic nevi.
Note that the categories are exclusionary. That is, if a mole is of type , it can’t be of type . This is different than a task where we’re trying to model multiple properties of the mole. For example, imagine if our dermatologist recorded whether a mole satisfied each of the A, B, C or D criteria. Then, the picture could have multiple 1s across different outcomes. Instead, we only have one possible outcome of the 7 for each.
So, we use a softmax outcome. If is a final layer of our network corresponding to category , consider defining
where is the number of categories (7 in our case) and is an outcome for image . Notice, this sums to 1 over .
Training the network¶
Here is the medmnist NN. Note convd has argument in_channels, out_channels, kernel_size. So, in this case the number of channels in is 3 and in the first layer it puts out a 16 channel image obtained by convolving a 3x3x3 kernels with each channel, adding bias terms, then repeating that 15 more times (documentaiton).
class Net(nn.Module):
def __init__(self, in_channels, num_classes):
super(Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels, 16, kernel_size=3),
nn.BatchNorm2d(16),
nn.ReLU())
self.layer2 = nn.Sequential(
nn.Conv2d(16, 16, kernel_size=3),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer3 = nn.Sequential(
nn.Conv2d(16, 64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU())
self.layer4 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU())
self.layer5 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Sequential(
nn.Linear(64 * 4 * 4, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, num_classes))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = Net(in_channels=n_channels, num_classes=n_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)Now we can train the network. Just for the sake of time, we’ve set the number of epochs to be fairly low.
for epoch in range(NUM_EPOCHS):
train_correct = 0
train_total = 0
test_correct = 0
test_total = 0
model.train()
for inputs, targets in train_loader:
# forward + backward + optimize
optimizer.zero_grad()
outputs = model(inputs)
if task == 'multi-label, binary-class':
targets = targets.to(torch.float32)
loss = criterion(outputs, targets)
else:
targets = targets.squeeze().long()
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
from pathlib import Path
checkpoint_path = Path("book/.cache/convnet_example_dermamnist.pt")
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), checkpoint_path)
print(f"Saved checkpoint to {checkpoint_path}")Saved checkpoint to book/.cache/convnet_example_dermamnist.pt
Evaluating the network¶
There’s some evaluation methods included at the MedMnist repo. Below this calculates the AUCs and the accuracy.
def test(split):
model.eval()
y_true = torch.tensor([])
y_score = torch.tensor([])
data_loader = train_loader_at_eval if split == 'train' else test_loader
with torch.no_grad():
for inputs, targets in data_loader:
outputs = model(inputs)
outputs = outputs.softmax(dim=-1)
if task == 'multi-label, binary-class':
targets = targets.to(torch.float32)
else:
targets = targets.squeeze().long()
targets = targets.float().resize_(len(targets), 1)
y_true = torch.cat((y_true, targets), 0)
y_score = torch.cat((y_score, outputs), 0)
y_true = y_true.numpy()
y_score = y_score.detach().numpy()
evaluator = Evaluator(data_flag, split)
metrics = evaluator.evaluate(y_score)
print('%s auc: %.3f acc:%.3f' % (split, *metrics))
print('==> Evaluating ...')
test('train')
test('test')==> Evaluating ...
train auc: 0.877 acc:0.708
test auc: 0.874 acc:0.700
Let’s look into the levels. First let’s grab a batch and run it through the model. Then we’ll look at the predictions for a batch. Recall there are 7 tumor types, so we plot our predictions as an image (in little bars since it’s too long otherwise).
from pathlib import Path
import pandas as pd
import medmnist
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
from medmnist import INFO
data_flag = 'dermamnist'
info = INFO[data_flag]
n_channels = info['n_channels']
n_classes = len(info['label'])
batch_size = globals().get('BATCH_SIZE', 128)
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5])
])
DataClass = getattr(medmnist, info['python_class'])
test_dataset = DataClass(split='test', transform=data_transform, download=True)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2 * batch_size, shuffle=False)
if 'Net' not in globals():
class Net(nn.Module):
def __init__(self, in_channels, num_classes):
super(Net, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels, 16, kernel_size=3),
nn.BatchNorm2d(16),
nn.ReLU())
self.layer2 = nn.Sequential(
nn.Conv2d(16, 16, kernel_size=3),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer3 = nn.Sequential(
nn.Conv2d(16, 64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU())
self.layer4 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU())
self.layer5 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Sequential(
nn.Linear(64 * 4 * 4, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, num_classes))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
checkpoint_path = Path('book/.cache/convnet_example_dermamnist.pt')
if 'model' not in globals():
model = Net(in_channels=n_channels, num_classes=n_classes)
if checkpoint_path.exists():
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
else:
print(f'Checkpoint not found at {checkpoint_path}; using a fresh model.')
model.eval()
inputs, targets = next(iter(test_loader))
with torch.no_grad():
outputs = model(inputs)
outputs = outputs.softmax(dim=-1)
I, J = outputs.shape
plt.figure(figsize=(12, 3))
for i in range(8):
plt.subplot(1,8,i+1)
plt.xticks([])
plt.yticks([])
plt.imshow(outputs.detach().numpy()[int(i * I / 8 ) : int((i + 1) * I / 8), :]);

As expected, it’s predicting the sixth category, labeled 5 since we’re counting from 0, very often, since it is the most prevalent category. The roughly 90% accuracy is pretty good. But, recall, we get 67% accuracy for free. Here’s the highest predictory category for 16 randomly selected images from our batch, their highest probability category and what their actual category is.
batch = outputs.detach().numpy()
#grab 16 images to plot
indices = np.random.permutation(np.array(range(I)))[0 : 16]
## The associated
actual = targets.squeeze(-1).numpy()[indices]
target_pred = batch.argmax(axis = 1)[indices]
target_prob = np.round(batch.max(axis = 1) * 100)[indices]
images = inputs.numpy()[indices,:,:,:]
plt.figure(figsize=(10,10))
for i in range(16):
plt.subplot(4,4,i+1)
plt.xticks([])
plt.yticks([])
img = np.transpose(images[i,:,:,:], (1, 2, 0));
img = ((img + 1)* 255/2).astype(np.uint8)
plt.title("P="+str(target_prob[i])+",Yhat="+str(target_pred[i])+",Y="+ str(actual[i]))
plt.imshow(img);
Let’s look at a cross tabulation of the model prediction versus the actual value.
targets_pred, targets_actual = [], []
for i,t in test_loader:
targets_pred.append(model(i).softmax(dim=-1))
targets_actual.append(t)
targets_pred = torch.cat(targets_pred, dim = 0).detach().numpy().argmax(axis = 1)
targets_actual = torch.cat(targets_actual, dim=0).numpy().squeeze()
## This is the confusion matrix data
confusion = pd.DataFrame({'y' : targets_actual, 'yhat' : targets_pred}).value_counts()
confusion.head(10)Visualizing the network¶
Here are the convolutions. Note they are color images since there are 3 channels.
### here's a list of the states that the model has kept track of
state_dictionary = list(model.state_dict())
state_dictionary[1 : 5]
### let's get the first round of weights
layer1_kernel = model.get_parameter('layer1.0.weight').detach().numpy()
layer1_kernel = layer1_kernel - layer1_kernel.min()
layer1_kernel = layer1_kernel / layer1_kernel.max()
layer1_kernel = layer1_kernel * 255
plt.figure(figsize=(5,5))
for h in range(16):
plt.subplot(4,4,h+1)
plt.xticks([])
plt.yticks([])
img = np.transpose(layer1_kernel[h,:,:,:],(1,2,0)).astype(np.uint8);
plt.imshow(img);
Look at the convolutional layers by following the first image (upper left) through the network.
l1 = model.layer1(inputs)
## plot one
images = l1.detach().numpy()
i = 0
plt.figure(figsize=(10,10))
for h in range(16):
plt.subplot(4,4,h+1)
plt.xticks([])
plt.yticks([])
img = images[i,h,:,:];
plt.imshow(img);
Let’s use SHAP to look at our convolutional neural network. SHAP uses Shapley values from game theory to explain NN predictors. Here’s a manuscript young2019deep that utilizes SHAP values on this problem.
The Shapley value is defined in game theory as the contribution of a player to a teams’ output factoring in their cooperation. In neural networks, a pixel could be a considered a player and the gain function as the output. The goal is to see the contribution of the player/pixel, factoring in the contributions and how they interact with the others. The Shapley values are difficult to compute, and so approximations need to be used. Python has a package shap that can be used to calculate and visualize approximated Shapley values for an input.
The following code is taken from here
# Re-obtain batch and SHAP values
batch_new = next(iter(test_loader))
images_new, _ = batch_new
background_new = images_new[:100]
test_images_new = images_new[100:105]
e_new = shap.DeepExplainer(model, background_new)
shap_values_new = e_new.shap_values(test_images_new)
def to_image_plot_format_new(values):
if isinstance(values, list):
return [np.moveaxis(v, 1, -1) for v in values]
values = np.asarray(values)
n_samples = test_images_new.shape[0]
if values.shape[0] == n_samples:
return [np.moveaxis(values[..., i], 1, -1) for i in range(values.shape[-1])]
if values.shape[1] == n_samples:
return [np.moveaxis(values[i], 1, -1) for i in range(values.shape[0])]
raise ValueError(f"Could not interpret SHAP output shape: {values.shape}")
def min_max_scale_new(arr):
min_val = arr.min()
max_val = arr.max()
if max_val == min_val:
return np.zeros_like(arr)
return (arr - min_val) / (max_val - min_val)
shap_numpy_raw = to_image_plot_format_new(shap_values_new)
test_numpy_original = np.moveaxis(test_images_new.cpu().numpy(), 1, -1)
# Scale the original images for background, but do NOT scale shap_numpy_raw
test_numpy_scaled_for_bg = min_max_scale_new(test_numpy_original)
# Plot the SHAP values using the raw shap_numpy_raw and the scaled original images as background
# The raw shap_numpy_raw allows shap.image_plot to correctly apply diverging red/blue colormap
shap.image_plot(shap_numpy_raw, test_numpy_scaled_for_bg)
