Skip to main content

Binder

Training a Custom Segmentation Model

In this tutorial you will learn how to train a model using a custom segmentation model by extending a base model in the Bitfount framework. We will use the Pod you will need to set up in the "Running a Segmentation Data Pod" tutorial, so make sure it is online. If it is offline, you can re-start it by running the Running a Segmentation Data Pod tutorial again.

Setting everything up

Let's import the relevant pieces from the API Reference:

import logging  # isort: splitfrom pathlib import Pathimport nest_asyncio# Update the class name for your Custom modelimport torchfrom torch import nn as nnimport torch.nn as nnfrom torch.nn import functional as Fimport torch.nn.functional as Ffrom bitfount import (    SEGMENTATION_METRICS,    BitfountModelReference,    BitfountSchema,    DataStructure,    PyTorchBitfountModel,    SoftDiceLoss,    get_pod_schema,    setup_loggers,)nest_asyncio.apply()  # Needed because Jupyter also has an asyncio loop

Let's import the loggers, which allow you to monitor progress of your executed commands and raise errors in the event something goes wrong.

loggers = setup_loggers([logging.getLogger("bitfount")])

Creating a custom model

As in the Training a Custom Model tutorial, for this tutorial we will be creating a custom model and extending and overriding the built-in BitfountModel class (in particular we will be using the PyTorchBitfountModel class). Details on this can be found in the documentation in the bitfount.backends.pytorch.models.bitfount_model module.

The PyTorchBitfountModel uses the PyTorch Lightning library to provide high-level implementation options for a model in the PyTorch framework. This enables you to only have to implement the methods you need to dictate how the model training should be performed.

For our custom model we need to implement the following methods:

  • __init__(): how to setup the model
  • configure_optimizers(): how optimizers should be configured in the model
  • forward(): how to perform a forward pass in the model, how the loss is calculated
  • training_step(): what one training step in the model looks like
  • validation_step(): what one validation step in the model looks like
  • test_step(): what one test step in the model looks like

Now we'll show you how to implement the custom segmentation model, but feel free to try out your own model here:

class MyCustomSegmentationModel(PyTorchBitfountModel):    # Implementation of a UNet model, used for testing purposes.    def __init__(self, n_channels=3, n_classes=3, **kwargs):        super().__init__(**kwargs)        self.n_channels = n_channels        self.n_classes = n_classes        self.bilinear = True        self.dice_loss = SoftDiceLoss()        self.ce_loss = torch.nn.CrossEntropyLoss()        self.metrics = SEGMENTATION_METRICS    def create_model(self):        class UNet(nn.Module):            def __init__(self, n_channels, n_classes, **kwargs):                super().__init__(**kwargs)                self.n_channels = n_channels                self.n_classes = n_classes                def double_conv(in_channels, out_channels):                    return nn.Sequential(                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),                        nn.BatchNorm2d(out_channels),                        nn.ReLU(inplace=True),                        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),                        nn.BatchNorm2d(out_channels),                        nn.ReLU(inplace=True),                    )                def down(in_channels, out_channels):                    return nn.Sequential(                        nn.MaxPool2d(2), double_conv(in_channels, out_channels)                    )                class up(nn.Module):                    def __init__(self, in_channels, out_channels, bilinear=True):                        super().__init__()                        if bilinear:                            self.up = nn.Upsample(                                scale_factor=2, mode="bilinear", align_corners=True                            )                        else:                            self.up = nn.ConvTranpose2d(                                in_channels // 2,                                in_channels // 2,  # noqa: B950                                kernel_size=2,                                stride=2,                            )                        self.conv = double_conv(in_channels, out_channels)                    def forward(self, x1, x2):                        x1 = self.up(x1)                        # [Batch size, Channels, Height, Width]                        diffY = x2.size()[2] - x1.size()[2]                        diffX = x2.size()[3] - x1.size()[3]                        x1 = F.pad(                            x1,                            [                                diffX // 2,                                diffX - diffX // 2,                                diffY // 2,                                diffY - diffY // 2,                            ],                        )                        x = torch.cat([x2, x1], dim=1)                        return self.conv(x)                self.inc = double_conv(self.n_channels, 64)                self.down1 = down(64, 128)                self.down2 = down(128, 256)                self.down3 = down(256, 512)                self.down4 = down(512, 512)                self.up1 = up(1024, 256)                self.up2 = up(512, 128)                self.up3 = up(256, 64)                self.up4 = up(128, 64)                self.out = nn.Conv2d(64, self.n_classes, kernel_size=1)            def forward(self, x):                x1 = self.inc(x)                x2 = self.down1(x1)                x3 = self.down2(x2)                x4 = self.down3(x3)                x5 = self.down4(x4)                x = self.up1(x5, x4)                x = self.up2(x, x3)                x = self.up3(x, x2)                x = self.up4(x, x1)                return self.out(x)        return UNet(self.n_channels, self.n_classes)    def forward(self, x):        return self._model(x)    def split_dataloader_output(self, data):        # During the data loading process some extra columns are added.        # For the purpose of this tutorial we only need the images,        # so we separate those from the actual images.        images, sup = data        weights = sup[:, 0].float()        if sup.shape[1] > 2:            category = sup[:, -1].long()        else:            category = None        return images, weights, category    def training_step(self, batch, batch_nb):        x, y = batch        x, *sup = self.split_dataloader_output(x)        y = y[:, 0].long()        y_hat = self.forward(x)        # Cross entropy loss        ce_loss = (            F.cross_entropy(y_hat, y)            if self.n_classes > 1            else F.binary_cross_entropy_with_logits(y_hat, y)        )  # noqa: B950        return {"loss": ce_loss}    def validation_step(self, batch, batch_nb):        x, y = batch        x, *sup = self.split_dataloader_output(x)        # Get rid of the number of channels dimension and make targets of type `long`        y = y[:, 0].long()        y_hat = self.forward(x)        softmax_y_hat = F.softmax(y_hat, dim=1)        # Cross entropy loss        ce_loss = (            F.cross_entropy(y_hat, y)            if self.n_classes > 1            else F.binary_cross_entropy_with_logits(y_hat, y)        )  # noqa: B950        # dice loss        dice_loss = self.dice_loss(softmax_y_hat, y)        # total loss        total_loss = (ce_loss + dice_loss) / 2        # We can log out some useful stats so we can see progress        self.log("ce_loss", ce_loss, prog_bar=True)        self.log("dice_loss", dice_loss, prog_bar=True)        self.log("loss", total_loss, prog_bar=True)        return {            "ce_loss": ce_loss,            "dice_loss": dice_loss,            "loss": total_loss,        }    def validation_epoch_end(self, outputs):        mean_outputs = {}        for k in outputs[0].keys():            mean_outputs[k] = torch.stack([x[k] for x in outputs]).mean()        # Add the means to the validation stats.        self.val_stats.append(mean_outputs)        # Also log out these averaged metrics        for k, v in mean_outputs.items():            self.log(f"avg_{k}", v)    def test_step(self, batch, batch_nb):        x, y = batch        x, *sup = self.split_dataloader_output(x)        # Get rid of the number of channels dimension and make targets of type `long`        y = y[:, 0].long()        # Get validation output and predictions        y_hat = self.forward(x)        pred = F.softmax(y_hat, dim=1)        # Output targets and prediction for later        return {"predictions": pred, "targets": y}    def configure_optimizers(self):        return torch.optim.Adam(self.parameters(), lr=1e-4)

Training on a Pod with your own custom segmentation model

If you have defined your segmentation model locally, you can train on a remote Pod by providing the Pod identifiers as an argument to the .fit method.

NOTE: Your model will be uploaded to the Bitfount Hub during this process. Models uploaded to the Hub are public by default, so please be sure you are happy for your model structure to be searchable by others before uploading. You can view your uploaded models here: https://hub.bitfount.com/my-models

datastructure = DataStructure(    table="segmentation-data-demo", image_cols=["img", "masks"], target="masks")pod_identifier = "segmentation-data-demo"schema = get_pod_schema(pod_identifier)model = MyCustomSegmentationModel(    datastructure=datastructure, schema=schema, epochs=1, batch_size=5)model.fit(    pod_identifiers=[pod_identifier],    model_out=Path("training_a_custom_segmentation_model.pt"),)

Congrats! You've now successfully trained a custom segmentation model.