Skip to main content

Inference Models

Bitfount provides specialized base classes for inference-only tasks that simplify model deployment. These classes are designed for running pre-trained models in federated environments without the complexity of training infrastructure.

Overview

Inference models eliminate the need to implement training-related methods when you only need to run predictions. You only need to implement a single create_model() method to get full federated inference capabilities.

tip

For a complete hands-on example of using inference models, see the Inference-Only Models tutorial.

Available Base Classes

Bitfount provides two inference model base classes:

PytorchLightningInferenceModel

A PyTorch Lightning-based inference model that provides full Lightning ecosystem integration.

from bitfount.backends.pytorch.models.inference_model import PytorchLightningInferenceModel
import torch.nn as nn

class MyInferenceModel(PytorchLightningInferenceModel):
def create_model(self):
"""The only method you need to implement."""
return nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.Softmax(dim=1)
)

Use when:

  • You want full PyTorch Lightning integration
  • You want to use Bitfount's built-in dataloaders
  • You need Lightning's advanced features (callbacks, logging, etc.)
  • Your model is part of a larger Lightning ecosystem

PytorchInferenceModel

A lightweight PyTorch inference model without Lightning dependencies.

from bitfount.backends.pytorch.models.inference_model import PytorchInferenceModel
import torch.nn as nn

class MySimpleModel(PytorchInferenceModel):
def create_model(self):
"""The only method you need to implement."""
return nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.Softmax(dim=1)
)

Use when:

  • You want minimal dependencies and faster startup
  • You only need inference capabilities
  • You prefer simple, direct PyTorch code

Key Differences

FeaturePytorchLightningInferenceModelPytorchInferenceModel
DependenciesRequires PyTorch LightningPure PyTorch only
ExecutionUses Lightning TrainerDirect PyTorch execution
Memory UsageSlightly higher (Lightning overhead)Lower (minimal overhead)
ExtensibilityFull Lightning ecosystemSimple, direct control

Basic Usage

1. Define Your Model

Create a class inheriting from one of the inference base classes:

import torch.nn as nn
from bitfount.backends.pytorch.models.inference_model import PytorchLightningInferenceModel

class ResNetInference(PytorchLightningInferenceModel):
def __init__(self, n_classes: int = 10, **kwargs):
super().__init__(**kwargs)
self.n_classes = n_classes

def create_model(self) -> nn.Module:
model = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, self.n_classes),
nn.Softmax(dim=1)
)
model.eval() # Important: set to evaluation mode
return model

2. Use with BitfountModelReference

Deploy your inference model using the standard Bitfount infrastructure:

from pathlib import Path
from bitfount import BitfountModelReference, ModelInference, ResultsOnly

# Create model reference
model_ref = BitfountModelReference(
model_ref=Path("ResNetInference.py"), # Your model file
datastructure=datastructure, # Your datastructure
schema=schema, # The schema of the data you want to run inference on
hyperparameters={"n_classes": 10, "batch_size": 32}
)

# Run federated inference
protocol = ResultsOnly(
algorithm=ModelInference(model=model_ref)
)
results = protocol.run(pod_identifiers=[pod_identifier])

Advanced Customization: Overriding Methods

While you only need to implement create_model(), you can override other methods for custom behavior:

Customizing Your Inference Models

Both base classes share common functionality but have different advanced hooks available.

Common Base Functionality (Both Classes)

Required Method: create_model()

Every inference model must implement this abstract method:

  • Return your PyTorch model architecture (nn.Module)
  • Called automatically during model initialization
  • The model will be moved to appropriate device and set to evaluation mode

Shared Public Methods Available for Override:

initialise_model(data, data_splitter, context) - Model Setup

Default behavior:

  • Prepares the model for inference
  • Creates data loaders from provided datasource
  • Calls create_model() to instantiate your model
  • Sets up the inference pipeline

When to override:

  • Custom model initialization logic

forward(x) - Model Forward Pass

Default behavior:

  • Handles single and multi-image column scenarios
  • Runs input through your created model
  • Returns model predictions

When to override:

  • Custom input preprocessing
  • Multi-model ensemble logic
  • Special output formatting needs

Shared Utility Methods:

split_dataloader_output(data) - Data Parsing

Purpose: Properly extracts input data from dataloader output When to use: Processing batch data in custom methods instead of manual parsing

serialize(filename) and deserialize(content) - Model Persistence

Purpose: Save and load trained model weights Usage: Standard model checkpointing and deployment

PytorchLightningInferenceModel Customization

Lightning-Specific Override Methods:

test_step(batch, batch_idx) - Per-Batch Processing

Default behavior:

  • Processes each batch during inference
  • Extracts data and optional keys from batch
  • Runs forward pass and collects results
  • Handles prediction aggregation automatically

When to override:

  • Custom preprocessing per batch
  • Ensemble predictions across multiple models
  • Custom metrics or logging during inference
  • Special batch result formatting

on_test_epoch_end() - End-of-Inference Processing

Default behavior:

  • Aggregates all batch results
  • Prepares final prediction outputs
  • Handles key-prediction alignment

When to override:

  • Custom result aggregation logic
  • Post-inference processing steps
  • Custom validation or filtering

predict(data, **kwargs) - Complete Pipeline Control

Default behavior:

  • Uses Lightning trainer for inference execution
  • Manages the complete inference workflow
  • Returns formatted prediction results

Lightning Benefits:

  • Automatic device management through trainer
  • Built-in logging and metrics capabilities
  • Structured approach with hooks and callbacks
  • Easy integration with Lightning ecosystem

PytorchInferenceModel Customization

Inference Model Override Methods:

predict(data, **kwargs) - Direct Inference Control

Default behavior:

  • Manual batch processing loop with torch.no_grad()
  • Direct device management and model evaluation
  • Explicit prediction collection and formatting
  • No Lightning trainer dependency

When to override:

  • Fine-grained control over inference loop
  • Custom batch processing logic
  • Memory-efficient streaming inference
  • Integration with non-Lightning workflows

Simple Model Benefits:

  • No PyTorch Lightning dependency
  • Direct PyTorch control and transparency
  • Explicit device and memory management
  • Faster startup and execution

Method Override Guidelines

Start Simple:

  1. Implement only create_model()
  2. Test basic inference functionality
  3. Add method overrides only when needed

Lightning Model Progression:

  1. Override test_step() for batch-level customization
  2. Override on_test_epoch_end() for result aggregation
  3. Override predict() for complete pipeline control

Simple Model Progression:

  1. Override forward() for input/output processing
  2. Override initialise_model() for setup customization
  3. Override predict() for complete pipeline control

Permissions

Inference models follow the same permission model as other custom models:

caution

Like all custom models, inference models require appropriate permissions to run on Pods. You need either:

  • Super Modeller role for arbitrary code execution, or
  • General Modeller role with specific permission for your inference model

For details on obtaining permissions, see the Custom Models guide.

Best Practices

  1. Choose the Right Base: Lightning for research, Simple for production
  2. Always call model.eval() in your create_model() method
  3. Start Minimal: Begin with just create_model(), add complexity incrementally
  4. Use Utilities: Leverage split_dataloader_output() for robust data handling
  5. Test Locally: Validate all customizations before federated deployment
  6. Handle Edge Cases: Consider different input formats and error conditions
  7. Document Changes: Comment custom logic for team collaboration

Next Steps