inference_model
PyTorch inference models for Bitfount.
Classes
PytorchInferenceModel
class PytorchInferenceModel(**kwargs: Any):
Simple PyTorch inference model for Bitfount.
This class provides a minimal implementation for inference-only models, without requiring PyTorch Lightning or complex inheritance.
Users only need to implement:
- create_model() - Return the PyTorch nn.Module to use
All other methods have sensible defaults for inference.
Inference model for PyTorch.
Ancestors
- bitfount.backends.pytorch.models.inference_model._BaseInferenceModel
- InferrableModelProtocol
- ModelProtocol
- BaseModelProtocol
- typing.Protocol
- typing.Generic
Variables
initialised : bool
- Should return True ifinitialise_model
has been called.
Methods
deserialize
def deserialize(self, content: Union[str, os.PathLike, bytes], **kwargs: Any) ‑> None:
Inherited from:
InferrableModelProtocol.deserialize :
Deserialises the model.
forward
def forward(self, x: Any) ‑> Any:
Forward pass through the model.
initialise_model
def initialise_model( self, data: Optional[BaseSource] = None, data_splitter: Optional[DatasetSplitter] = None, context: Optional[TaskContext] = None,) ‑> None:
Initialize model and prepare dataloaders for inference.
Arguments
data
: Optional datasource for inference. If provided, a test dataloader is created using an inference-only splitter.data_splitter
: Optional splitter to use instead of_InferenceSplitter
.context
: Optional execution context (unused).
predict
def predict( self, data: Optional[BaseSource] = None, **_: Any,) ‑> PredictReturnType:
Run inference and return predictions.
Arguments
data
: Optional datasource to run inference on. If provided, the model may be (re-)initialised to use this datasource.
Returns PredictReturnType containing predictions and optional data keys.
Raises
ValueError
: If no test dataloader is available.
PytorchLightningInferenceModel
class PytorchLightningInferenceModel( *, datastructure: DataStructure, schema: BitfountSchema, batch_size: int = 32, **kwargs: Any,):
PyTorch Lightning inference model for Bitfount.
Inference model for PyTorch.
Ancestors
- pytorch_lightning.core.module.LightningModule
- lightning_fabric.utilities.device_dtype_mixin._DeviceDtypeModuleMixin
- pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
- pytorch_lightning.core.hooks.ModelHooks
- pytorch_lightning.core.hooks.DataHooks
- pytorch_lightning.core.hooks.CheckpointHooks
- torch.nn.modules.module.Module
- bitfount.backends.pytorch.models.inference_model._BaseInferenceModel
- InferrableModelProtocol
- ModelProtocol
- BaseModelProtocol
- typing.Protocol
- typing.Generic
Variables
initialised : bool
- Should return True ifinitialise_model
has been called.
Methods
deserialize
def deserialize(self, content: Union[str, os.PathLike, bytes], **kwargs: Any) ‑> None:
Inherited from:
InferrableModelProtocol.deserialize :
Deserialises the model.
forward
def forward(self, x: Any) ‑> Any:
Forward pass through the model.
initialise_model
def initialise_model( self, data: Optional[BaseSource] = None, data_splitter: Optional[DatasetSplitter] = None, context: Optional[TaskContext] = None,) ‑> None:
Initialise ORT session and prepare dataloaders for inference.
Arguments
data
: Optional datasource for inference. If provided, a test dataloader is created using an inference-only splitter.data_splitter
: Optional splitter to use instead of_InferenceSplitter
.context
: Optional execution context (unused).
on_test_epoch_end
def on_test_epoch_end(self) ‑> None:
Called at the end of the test epoch.
Aggregates the predictions and targets from the test set.
If you are overwriting this method, ensure you set self._test_preds
to
maintain compatibility with self._predict_local
unless you are overwriting
both of them.
predict
def predict( self, data: Optional[BaseSource] = None, **_: Any,) ‑> PredictReturnType:
Run inference and return predictions.
Arguments
data
: Optional datasource to run inference on. If provided, the model may be (re-)initialised to use this datasource.
Returns PredictReturnType containing predictions and optional data keys. Data keys must be present if the datasource is file-based.
Raises
ValueError
: If no test dataloader is available.
test_step
def test_step( self, batch: Any, batch_idx: int,) ‑> bitfount.backends.pytorch.models.base_models._TEST_STEP_OUTPUT:
Process a single batch during testing/inference.
Override this step as required.
Arguments
batch
: The batch databatch_idx
: Index of the batch
Returns Dictionary with predictions and targets
trainer_init
def trainer_init(self) ‑> pytorch_lightning.trainer.trainer.Trainer:
Initialize PyTorch Lightning trainer.