Inference Models
Bitfount provides specialized base classes for inference-only tasks that simplify model deployment. These classes are designed for running pre-trained models in federated environments without the complexity of training infrastructure.
Overview
Inference models eliminate the need to implement training-related methods when you only need to run predictions. You only need to implement a single create_model()
method to get full federated inference capabilities.
For a complete hands-on example of using inference models, see the Inference-Only Models tutorial.
Available Base Classes
Bitfount provides two inference model base classes:
PytorchLightningInferenceModel
A PyTorch Lightning-based inference model that provides full Lightning ecosystem integration.
from bitfount.backends.pytorch.models.inference_model import PytorchLightningInferenceModel
import torch.nn as nn
class MyInferenceModel(PytorchLightningInferenceModel):
def create_model(self):
"""The only method you need to implement."""
return nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.Softmax(dim=1)
)
Use when:
- You want full PyTorch Lightning integration
- You want to use Bitfount's built-in dataloaders
- You need Lightning's advanced features (callbacks, logging, etc.)
- Your model is part of a larger Lightning ecosystem
PytorchInferenceModel
A lightweight PyTorch inference model without Lightning dependencies.
from bitfount.backends.pytorch.models.inference_model import PytorchInferenceModel
import torch.nn as nn
class MySimpleModel(PytorchInferenceModel):
def create_model(self):
"""The only method you need to implement."""
return nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.Softmax(dim=1)
)
Use when:
- You want minimal dependencies and faster startup
- You only need inference capabilities
- You prefer simple, direct PyTorch code
Key Differences
Feature | PytorchLightningInferenceModel | PytorchInferenceModel |
---|---|---|
Dependencies | Requires PyTorch Lightning | Pure PyTorch only |
Execution | Uses Lightning Trainer | Direct PyTorch execution |
Memory Usage | Slightly higher (Lightning overhead) | Lower (minimal overhead) |
Extensibility | Full Lightning ecosystem | Simple, direct control |
Basic Usage
1. Define Your Model
Create a class inheriting from one of the inference base classes:
import torch.nn as nn
from bitfount.backends.pytorch.models.inference_model import PytorchLightningInferenceModel
class ResNetInference(PytorchLightningInferenceModel):
def __init__(self, n_classes: int = 10, **kwargs):
super().__init__(**kwargs)
self.n_classes = n_classes
def create_model(self) -> nn.Module:
model = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, self.n_classes),
nn.Softmax(dim=1)
)
model.eval() # Important: set to evaluation mode
return model
2. Use with BitfountModelReference
Deploy your inference model using the standard Bitfount infrastructure:
from pathlib import Path
from bitfount import BitfountModelReference, ModelInference, ResultsOnly
# Create model reference
model_ref = BitfountModelReference(
model_ref=Path("ResNetInference.py"), # Your model file
datastructure=datastructure, # Your datastructure
schema=schema, # The schema of the data you want to run inference on
hyperparameters={"n_classes": 10, "batch_size": 32}
)
# Run federated inference
protocol = ResultsOnly(
algorithm=ModelInference(model=model_ref)
)
results = protocol.run(pod_identifiers=[pod_identifier])
Advanced Customization: Overriding Methods
While you only need to implement create_model()
, you can override other methods for custom behavior:
Customizing Your Inference Models
Both base classes share common functionality but have different advanced hooks available.
Common Base Functionality (Both Classes)
Required Method: create_model()
Every inference model must implement this abstract method:
- Return your PyTorch model architecture (
nn.Module
) - Called automatically during model initialization
- The model will be moved to appropriate device and set to evaluation mode
Shared Public Methods Available for Override:
initialise_model(data, data_splitter, context)
- Model Setup
Default behavior:
- Prepares the model for inference
- Creates data loaders from provided datasource
- Calls
create_model()
to instantiate your model - Sets up the inference pipeline
When to override:
- Custom model initialization logic
forward(x)
- Model Forward Pass
Default behavior:
- Handles single and multi-image column scenarios
- Runs input through your created model
- Returns model predictions
When to override:
- Custom input preprocessing
- Multi-model ensemble logic
- Special output formatting needs
Shared Utility Methods:
split_dataloader_output(data)
- Data Parsing
Purpose: Properly extracts input data from dataloader output When to use: Processing batch data in custom methods instead of manual parsing
serialize(filename)
and deserialize(content)
- Model Persistence
Purpose: Save and load trained model weights Usage: Standard model checkpointing and deployment
PytorchLightningInferenceModel Customization
Lightning-Specific Override Methods:
test_step(batch, batch_idx)
- Per-Batch Processing
Default behavior:
- Processes each batch during inference
- Extracts data and optional keys from batch
- Runs forward pass and collects results
- Handles prediction aggregation automatically
When to override:
- Custom preprocessing per batch
- Ensemble predictions across multiple models
- Custom metrics or logging during inference
- Special batch result formatting
on_test_epoch_end()
- End-of-Inference Processing
Default behavior:
- Aggregates all batch results
- Prepares final prediction outputs
- Handles key-prediction alignment
When to override:
- Custom result aggregation logic
- Post-inference processing steps
- Custom validation or filtering
predict(data, **kwargs)
- Complete Pipeline Control
Default behavior:
- Uses Lightning trainer for inference execution
- Manages the complete inference workflow
- Returns formatted prediction results
Lightning Benefits:
- Automatic device management through trainer
- Built-in logging and metrics capabilities
- Structured approach with hooks and callbacks
- Easy integration with Lightning ecosystem
PytorchInferenceModel Customization
Inference Model Override Methods:
predict(data, **kwargs)
- Direct Inference Control
Default behavior:
- Manual batch processing loop with
torch.no_grad()
- Direct device management and model evaluation
- Explicit prediction collection and formatting
- No Lightning trainer dependency
When to override:
- Fine-grained control over inference loop
- Custom batch processing logic
- Memory-efficient streaming inference
- Integration with non-Lightning workflows
Simple Model Benefits:
- No PyTorch Lightning dependency
- Direct PyTorch control and transparency
- Explicit device and memory management
- Faster startup and execution
Method Override Guidelines
Start Simple:
- Implement only
create_model()
- Test basic inference functionality
- Add method overrides only when needed
Lightning Model Progression:
- Override
test_step()
for batch-level customization - Override
on_test_epoch_end()
for result aggregation - Override
predict()
for complete pipeline control
Simple Model Progression:
- Override
forward()
for input/output processing - Override
initialise_model()
for setup customization - Override
predict()
for complete pipeline control
Permissions
Inference models follow the same permission model as other custom models:
Like all custom models, inference models require appropriate permissions to run on Pods. You need either:
- Super Modeller role for arbitrary code execution, or
- General Modeller role with specific permission for your inference model
For details on obtaining permissions, see the Custom Models guide.
Best Practices
- Choose the Right Base: Lightning for research, Simple for production
- Always call
model.eval()
in yourcreate_model()
method - Start Minimal: Begin with just
create_model()
, add complexity incrementally - Use Utilities: Leverage
split_dataloader_output()
for robust data handling - Test Locally: Validate all customizations before federated deployment
- Handle Edge Cases: Consider different input formats and error conditions
- Document Changes: Comment custom logic for team collaboration
Next Steps
- Try the Inference-Only Models tutorial for hands-on examples
- Learn about Custom Models for more complex scenarios
- Explore Bitfount Task Elements for other model types