Skip to main content

inference_model

PyTorch inference models for Bitfount.

Classes

PytorchInferenceModel

class PytorchInferenceModel(**kwargs: Any):

Simple PyTorch inference model for Bitfount.

This class provides a minimal implementation for inference-only models, without requiring PyTorch Lightning or complex inheritance.

Users only need to implement:

  1. create_model() - Return the PyTorch nn.Module to use

All other methods have sensible defaults for inference.

Inference model for PyTorch.

Ancestors

Variables

  • initialised : bool - Should return True if initialise_model has been called.

Methods


deserialize

def deserialize(self, content: Union[str, os.PathLike, bytes], **kwargs: Any)> None:

Inherited from:

InferrableModelProtocol.deserialize :

Deserialises the model.

forward

def forward(self, x: Any)> Any:

Forward pass through the model.

initialise_model

def initialise_model(    self,    data: Optional[BaseSource] = None,    data_splitter: Optional[DatasetSplitter] = None,    context: Optional[TaskContext] = None,)> None:

Initialize model and prepare dataloaders for inference.

Arguments

  • data: Optional datasource for inference. If provided, a test dataloader is created using an inference-only splitter.
  • data_splitter: Optional splitter to use instead of _InferenceSplitter.
  • context: Optional execution context (unused).

predict

def predict(    self,    data: Optional[BaseSource] = None,    **_: Any,)> PredictReturnType:

Run inference and return predictions.

Arguments

  • data: Optional datasource to run inference on. If provided, the model may be (re-)initialised to use this datasource.

Returns PredictReturnType containing predictions and optional data keys.

Raises

  • ValueError: If no test dataloader is available.

PytorchLightningInferenceModel

class PytorchLightningInferenceModel(    *,    datastructure: DataStructure,    schema: BitfountSchema,    batch_size: int = 32,    **kwargs: Any,):

PyTorch Lightning inference model for Bitfount.

Inference model for PyTorch.

Variables

  • initialised : bool - Should return True if initialise_model has been called.

Methods


deserialize

def deserialize(self, content: Union[str, os.PathLike, bytes], **kwargs: Any)> None:

Inherited from:

InferrableModelProtocol.deserialize :

Deserialises the model.

forward

def forward(self, x: Any)> Any:

Forward pass through the model.

initialise_model

def initialise_model(    self,    data: Optional[BaseSource] = None,    data_splitter: Optional[DatasetSplitter] = None,    context: Optional[TaskContext] = None,)> None:

Initialise ORT session and prepare dataloaders for inference.

Arguments

  • data: Optional datasource for inference. If provided, a test dataloader is created using an inference-only splitter.
  • data_splitter: Optional splitter to use instead of _InferenceSplitter.
  • context: Optional execution context (unused).

on_test_epoch_end

def on_test_epoch_end(self)> None:

Called at the end of the test epoch.

Aggregates the predictions and targets from the test set.

caution

If you are overwriting this method, ensure you set self._test_preds to maintain compatibility with self._predict_local unless you are overwriting both of them.

predict

def predict(    self,    data: Optional[BaseSource] = None,    **_: Any,)> PredictReturnType:

Run inference and return predictions.

Arguments

  • data: Optional datasource to run inference on. If provided, the model may be (re-)initialised to use this datasource.

Returns PredictReturnType containing predictions and optional data keys. Data keys must be present if the datasource is file-based.

Raises

  • ValueError: If no test dataloader is available.

test_step

def test_step(    self, batch: Any, batch_idx: int,)> bitfount.backends.pytorch.models.base_models._TEST_STEP_OUTPUT:

Process a single batch during testing/inference.

Override this step as required.

Arguments

  • batch: The batch data
  • batch_idx: Index of the batch

Returns Dictionary with predictions and targets

trainer_init

def trainer_init(self)> pytorch_lightning.trainer.trainer.Trainer:

Initialize PyTorch Lightning trainer.