Skip to main content

Training and Fine-tuning

If you don't already have a model that you can use for inference or evaluation tasks, you can train a new model on Bitfount. Training typically refers to the process of updating a model's weights from scratch on a dataset i.e. starting with a randomly initialised model. Whereas fine-tuning refers to the process of taking a pre-trained model and updating its weights only slightly to suit your specific task or dataset. The process itself is the same regardless of whether you are training a new model or fine-tuning an existing one and you will ultimately end up with a model that you can use for inference or evaluation tasks.

The interface required for training or fine-tuning models is naturally more complex than the interface required for inference or evaluation tasks and is currently only supported for PyTorch Lightning models.

Required interface

In order to train a model on Bitfount, you need to extend the PyTorchBitfountModelv2 class. Details on this can be found in the documentation in the API Reference.

The PyTorchBitfountModelv2 uses the PyTorch Lightning library to provide high-level implementation options for a model in the PyTorch framework. This enables you to only have to implement the methods you need to dictate how the model training should be performed.

In addition to subclassing the PyTorchBitfountModelv2 class, you will need to implement the following methods:

  • __init__(): how to setup the model
  • configure_optimizers(): how optimizers should be configured in the model
  • create_model(): how to create the model
  • forward(): how to perform a forward pass in the model, how the loss is calculated
  • _training_step(): what one training step in the model looks like
  • _validation_step(): what one validation step in the model looks like
  • _test_step(): what one test step in the model looks like

Classification models

Classification models are a very common type of model and are used to classify data into one of a number of classes. For this reason, we have provided some utilities to help you implement a classification model. These are:

PyTorchClassifierMixIn

The PyTorchClassifierMixIn class requires the multilabel argument to be provided signifying whether a given record can belong to multiple classes. In exchange, it sets the n_classes attribute automatically based on the number of classes in the specified target column of the dataset and also provides a do_output_activation method that can be used to apply the appropriate activation function to the model's output (sigmoid or softmax) based on the number of classes and whether the problem is a multi-label problem. You may find many examples using this mixin class (such as the example below) but it is not required for your model to use it. If you do use this mixin class, make sure to specify the mixin class first in the model's inheritance hierarchy:

class MyClassificationModel(PyTorchClassifierMixIn, PyTorchBitfountModelv2):
...

get_torchvision_classification_model

The get_torchvision_classification_model function is a helper function that creates a pre-trained classification model from the torchvision library. It takes the following arguments:

  • model_name: the name of the model to create. This can be any model supported by the torchvision library.
  • pretrained: whether to return a pre-trained model (typically trained on ImageNet) or a randomly initialised model
  • num_classes: the number of classes in the model which determines the output size of the model

It can be used directly in your model's create_model method to return a pre-trained classification model to be used for fine-tuning.

from bitfount.backends.pytorch.models.nn import get_torchvision_classification_modelclass MyClassificationModel(PyTorchClassifierMixIn, PyTorchBitfountModelv2):    ...    def create_model(self) -> nn.Module:        """Creates the model to use."""        model = get_torchvision_classification_model(            model_name="resnet18", pretrained=True, num_classes=self.n_classes        )        return model

Full example

This example shows a simple logistic regression model that can be used for binary or multi-class classification tasks.

logistic_regression_model.py
from __future__ import annotationsfrom typing import Anyimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torchmetrics.functional import accuracyfrom bitfount.backends.pytorch import PyTorchBitfountModelv2from bitfount.backends.pytorch.models.base_models import (    _TEST_STEP_OUTPUT,    _TRAIN_STEP_OUTPUT,    PyTorchClassifierMixIn,    _OptimizerType,)from bitfount.types import _StrAnyDictclass LogisticRegressionModel(PyTorchClassifierMixIn, PyTorchBitfountModelv2):    """A Logistic/Softmax Regression model built using PyTorch Lightning.    This implements a single linear layer which acts as a Logistic Regression    (for binary) or Softmax Regression (for multi-class) classifier.    """    def __init__(        self, learning_rate: float = 0.0001, weight_decay: float = 0.0, **kwargs: Any    ) -> None:        """Initializes the LogisticRegressionModel.        Args:            learning_rate: The step size for the optimizer. Controls how much to                change the model in response to the estimated error each time the                model weights are updated.            weight_decay: L2 regularization penalty. Adds a term to the loss function                proportional to the sum of the squared weights, preventing the model                from becoming too complex (overfitting).            **kwargs: Additional arguments passed to the base PyTorchBitfountModelv2.                This includes 'steps' (training iterations per round) or 'epochs'.        """        super().__init__(**kwargs)        self.learning_rate = learning_rate        self.weight_decay = weight_decay    def create_model(self) -> nn.Module:        """Creates the model architecture.        Logistic Regression is essentially a single Linear layer mapping        input features to class logits. The activation (Sigmoid/Softmax) is        handled by the loss function (CrossEntropyLoss) during training.        """        if self.n_classes < 2:            raise ValueError(                "n_classes must be at least 2 for classification. "                "For binary classification, use n_classes=2."            )        # A single linear layer: Input Features -> Output Classes        return nn.Linear(self.datastructure.input_size, self.n_classes)    def forward(self, x: Any) -> Any:        """Defines the operations we want to use for prediction."""        x, sup = x        assert self._model is not None        # Pass through the linear layer        x = self._model(x.float())        return x    def _training_step(self, batch: Any, batch_idx: int) -> _TRAIN_STEP_OUTPUT:        """Computes and returns the training loss for a batch of data."""        if self.skip_training_batch(batch_idx):            return None  # type: ignore[return-value] # reason: Allow None to skip a batch. # noqa: E501        x, y = batch        y_hat = self(x)        # CrossEntropyLoss in PyTorch combines LogSoftmax and NLLLoss.        # We squeeze y to ensure it is 1D (N,) as expected by CrossEntropyLoss for        # class indices.        loss = F.cross_entropy(y_hat, y.squeeze())        return loss    def _validation_step(self, batch: Any, batch_idx: int) -> _StrAnyDict:        """Operates on a single batch of data from the validation set."""        x, y = batch        preds = self(x)        # Ensure y is squeezed for loss calculation        loss = F.cross_entropy(preds, y.squeeze())        # Apply softmax to get probabilities for accuracy calculation        preds_prob = F.softmax(preds, dim=1)        acc = accuracy(            preds_prob, y.squeeze(), task="multiclass", num_classes=self.n_classes        )        self.log("val_loss", loss, prog_bar=True)        self.log("val_acc", acc, prog_bar=True)        return {            "val_loss": loss,            "val_acc": acc,        }    def _test_step(self, batch: Any, batch_idx: int) -> _TEST_STEP_OUTPUT:        """Operates on a single batch of data from the test set."""        x, y = batch        preds = self(x)        preds = F.softmax(preds, dim=1)        return {"predictions": preds, "targets": y}    def configure_optimizers(self) -> _OptimizerType:        """Configure the optimizer."""        # Using AdamW optimizer with L2 regularization via weight_decay.        optimizer = torch.optim.AdamW(            self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay        )        return optimizer

Tutorials

For more complex models, we have two tutorials that walk you through the process of training a model on Bitfount: