Training and Fine-tuning
If you don't already have a model that you can use for inference or evaluation tasks, you can train a new model on Bitfount. Training typically refers to the process of updating a model's weights from scratch on a dataset i.e. starting with a randomly initialised model. Whereas fine-tuning refers to the process of taking a pre-trained model and updating its weights only slightly to suit your specific task or dataset. The process itself is the same regardless of whether you are training a new model or fine-tuning an existing one and you will ultimately end up with a model that you can use for inference or evaluation tasks.
The interface required for training or fine-tuning models is naturally more complex than the interface required for inference or evaluation tasks and is currently only supported for PyTorch Lightning models.
Required interface
In order to train a model on Bitfount, you need to extend the PyTorchBitfountModelv2 class. Details on this can be found in the documentation in the API Reference.
The PyTorchBitfountModelv2 uses the PyTorch Lightning library to provide high-level implementation options for a model in the PyTorch framework. This enables you to only have to implement the methods you need to dictate how the model training should be performed.
In addition to subclassing the PyTorchBitfountModelv2 class, you will need to implement the following methods:
__init__(): how to setup the modelconfigure_optimizers(): how optimizers should be configured in the modelcreate_model(): how to create the modelforward(): how to perform a forward pass in the model, how the loss is calculated_training_step(): what one training step in the model looks like_validation_step(): what one validation step in the model looks like_test_step(): what one test step in the model looks like
Classification models
Classification models are a very common type of model and are used to classify data into one of a number of classes. For this reason, we have provided some utilities to help you implement a classification model. These are:
PyTorchClassifierMixIn: a mixin that provides helper methods and attributes for a classification modelget_torchvision_classification_model: a function that creates a pre-trained classification model from the torchvision library
PyTorchClassifierMixIn
The PyTorchClassifierMixIn class requires the multilabel argument to be provided signifying whether a given record can belong to multiple classes. In exchange, it sets the n_classes attribute automatically based on the number of classes in the specified target column of the dataset and also provides a do_output_activation method that can be used to apply the appropriate activation function to the model's output (sigmoid or softmax) based on the number of classes and whether the problem is a multi-label problem. You may find many examples using this mixin class (such as the example below) but it is not required for your model to use it. If you do use this mixin class, make sure to specify the mixin class first in the model's inheritance hierarchy:
class MyClassificationModel(PyTorchClassifierMixIn, PyTorchBitfountModelv2):
...
get_torchvision_classification_model
The get_torchvision_classification_model function is a helper function that creates a pre-trained classification model from the torchvision library. It takes the following arguments:
model_name: the name of the model to create. This can be any model supported by the torchvision library.pretrained: whether to return a pre-trained model (typically trained on ImageNet) or a randomly initialised modelnum_classes: the number of classes in the model which determines the output size of the model
It can be used directly in your model's create_model method to return a pre-trained classification model to be used for fine-tuning.
from bitfount.backends.pytorch.models.nn import get_torchvision_classification_modelclass MyClassificationModel(PyTorchClassifierMixIn, PyTorchBitfountModelv2): ... def create_model(self) -> nn.Module: """Creates the model to use.""" model = get_torchvision_classification_model( model_name="resnet18", pretrained=True, num_classes=self.n_classes ) return modelFull example
This example shows a simple logistic regression model that can be used for binary or multi-class classification tasks.
from __future__ import annotationsfrom typing import Anyimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom torchmetrics.functional import accuracyfrom bitfount.backends.pytorch import PyTorchBitfountModelv2from bitfount.backends.pytorch.models.base_models import ( _TEST_STEP_OUTPUT, _TRAIN_STEP_OUTPUT, PyTorchClassifierMixIn, _OptimizerType,)from bitfount.types import _StrAnyDictclass LogisticRegressionModel(PyTorchClassifierMixIn, PyTorchBitfountModelv2): """A Logistic/Softmax Regression model built using PyTorch Lightning. This implements a single linear layer which acts as a Logistic Regression (for binary) or Softmax Regression (for multi-class) classifier. """ def __init__( self, learning_rate: float = 0.0001, weight_decay: float = 0.0, **kwargs: Any ) -> None: """Initializes the LogisticRegressionModel. Args: learning_rate: The step size for the optimizer. Controls how much to change the model in response to the estimated error each time the model weights are updated. weight_decay: L2 regularization penalty. Adds a term to the loss function proportional to the sum of the squared weights, preventing the model from becoming too complex (overfitting). **kwargs: Additional arguments passed to the base PyTorchBitfountModelv2. This includes 'steps' (training iterations per round) or 'epochs'. """ super().__init__(**kwargs) self.learning_rate = learning_rate self.weight_decay = weight_decay def create_model(self) -> nn.Module: """Creates the model architecture. Logistic Regression is essentially a single Linear layer mapping input features to class logits. The activation (Sigmoid/Softmax) is handled by the loss function (CrossEntropyLoss) during training. """ if self.n_classes < 2: raise ValueError( "n_classes must be at least 2 for classification. " "For binary classification, use n_classes=2." ) # A single linear layer: Input Features -> Output Classes return nn.Linear(self.datastructure.input_size, self.n_classes) def forward(self, x: Any) -> Any: """Defines the operations we want to use for prediction.""" x, sup = x assert self._model is not None # Pass through the linear layer x = self._model(x.float()) return x def _training_step(self, batch: Any, batch_idx: int) -> _TRAIN_STEP_OUTPUT: """Computes and returns the training loss for a batch of data.""" if self.skip_training_batch(batch_idx): return None # type: ignore[return-value] # reason: Allow None to skip a batch. # noqa: E501 x, y = batch y_hat = self(x) # CrossEntropyLoss in PyTorch combines LogSoftmax and NLLLoss. # We squeeze y to ensure it is 1D (N,) as expected by CrossEntropyLoss for # class indices. loss = F.cross_entropy(y_hat, y.squeeze()) return loss def _validation_step(self, batch: Any, batch_idx: int) -> _StrAnyDict: """Operates on a single batch of data from the validation set.""" x, y = batch preds = self(x) # Ensure y is squeezed for loss calculation loss = F.cross_entropy(preds, y.squeeze()) # Apply softmax to get probabilities for accuracy calculation preds_prob = F.softmax(preds, dim=1) acc = accuracy( preds_prob, y.squeeze(), task="multiclass", num_classes=self.n_classes ) self.log("val_loss", loss, prog_bar=True) self.log("val_acc", acc, prog_bar=True) return { "val_loss": loss, "val_acc": acc, } def _test_step(self, batch: Any, batch_idx: int) -> _TEST_STEP_OUTPUT: """Operates on a single batch of data from the test set.""" x, y = batch preds = self(x) preds = F.softmax(preds, dim=1) return {"predictions": preds, "targets": y} def configure_optimizers(self) -> _OptimizerType: """Configure the optimizer.""" # Using AdamW optimizer with L2 regularization via weight_decay. optimizer = torch.optim.AdamW( self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) return optimizerTutorials
For more complex models, we have two tutorials that walk you through the process of training a model on Bitfount:
- Training a Custom Model: This tutorial walks you through the process of training a tabular classification model on CSV data
- Training a Custom Segmentation Model: This tutorial walks you through the process of training a segmentation model on an image dataset