Inference-Only Models: Custom Models for Federated Inference
Welcome to building custom models for federated inference! With Bitfount's inference model classes, you can create powerful federated inference models with minimal code and complexity.
Overview
Bitfount provides two base classes specifically designed for inference-only tasks:
PytorchLightningInferenceModel
- PyTorch Lightning-based inference modelPytorchInferenceModel
- Simple PyTorch inference model without Lightning dependencies
These classes eliminate the need to implement training-related methods, making it easy to deploy pre-trained models for federated inference.
Prerequisites
!pip install bitfount
Getting Started with Inference Models
With Bitfount's Pytorch inference base classes, you only need to implement one method:
from bitfount.backends.pytorch.models.inference_model import PytorchInferenceModelclass MySimpleInferenceModel(PytorchInferenceModel): def __init__(self, **kwargs): super().__init__(**kwargs) def create_model(self): """The only method you need to implement!""" ...
Image Classification with PyTorch
Let's set up the environment and create a complete inference model using a pre-trained ResNet:
import loggingimport nest_asyncioimport torchimport torch.nn as nnfrom pathlib import Pathfrom PIL import Imagefrom bitfount import ( BitfountModelReference, BitfountSchema, ImageSource, DataStructure, ModelInference, ResultsOnly, get_pod_schema, setup_loggers,)from bitfount.backends.pytorch.models.inference_model import ( PytorchLightningInferenceModel, PytorchInferenceModel,)nest_asyncio.apply() # Needed because Jupyter also has an asyncio loop
loggers = setup_loggers([logging.getLogger("bitfount")])
Now let's create our inference model by first saving it to a file:
import torch.nn as nnfrom bitfount.backends.pytorch.models.inference_model import PytorchInferenceModelclass ResNetInferenceModel(PytorchInferenceModel): """Simple inference model for image classification.""" def __init__(self, n_classes: int = 10, **kwargs): super().__init__(**kwargs) self.n_classes = n_classes def create_model(self) -> nn.Module: """Create and return a simple CNN model.""" model = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, self.n_classes), nn.Softmax(dim=1), ) model.eval() return model
Testing Locally with Image Data
Let's test our model locally first:
# Create some image data for testingdatasource = ImageSource( path="sample_images/", # Path to your image directory)schema = BitfountSchema( name="image-inference-demo",)# For image data, specify image columnsforce_stypes = { "image": ["Pixel Data"], # Standard image column name}schema.generate_full_schema(datasource, force_stypes=force_stypes)# Create datastructure for inference (no target needed)datastructure = DataStructure( target=None, # No target for inference, can be skipped image_cols=["Pixel Data"], # Specify the image column selected_cols=["Pixel Data"], # Add selected columns # schema_requirements="full" # Optional)# Initialize our modelmodel = ResNetInferenceModel( datastructure=datastructure, schema=schema, n_classes=2, batch_size=2, # Adjust batch size as needed)# Test local inferencemodel.initialise_model(datasource)local_results = model.predict(data=datasource)print(f"Local inference completed! Got {len(local_results.preds)} predictions")
Run Inference on a Pod
Now use your simple model with the existing Bitfount infrastructure:
# Use the image dataset podpod_identifier = "image-datasource"schema = get_pod_schema(pod_identifier)# Create model referencemodel_ref = BitfountModelReference( model_ref=Path("ResNetInferenceModel.py"), # Your simple model file datastructure=datastructure, schema=schema,)# Run federated inferenceprotocol = ResultsOnly(algorithm=ModelInference(model=model_ref))results = protocol.run(pod_identifiers=[pod_identifier])print("Inference completed!")print(f"Results: {results}")
Similar approach can be used for creating a model inheriting from PytorchLightningInferenceModel
.
import torch.nn as nnfrom bitfount.backends.pytorch.models.inference_model import ( PytorchLightningInferenceModel,)class ResNetLightningInferenceModel(PytorchLightningInferenceModel): """Simple inference model for image classification.""" def __init__(self, n_classes: int = 10, **kwargs): super().__init__(**kwargs) self.n_classes = n_classes def create_model(self) -> nn.Module: """Create and return a simple CNN model.""" model = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, self.n_classes), nn.Softmax(dim=1), ) model.eval() return model
Let's test the Lightning model locally:
# Create some image data for testingdatasource = ImageSource( path="sample_images/", # Path to your image directory)schema = BitfountSchema( name="image-inference-demo",)# For image data, specify image columnsforce_stypes = { "image": ["Pixel Data"], # Standard image column name}schema.generate_full_schema(datasource, force_stypes=force_stypes)# Create datastructure for inference (no target needed)datastructure = DataStructure( target=None, # No target for inference, can be skipped image_cols=["Pixel Data"], # Specify the image column selected_cols=["Pixel Data"], # Add selected columns)# Initialize our modelmodel = ResNetLightningInferenceModel( datastructure=datastructure, schema=schema, n_classes=2, batch_size=2, # Adjust batch size as needed)# Test local inferencemodel.initialise_model(datasource)local_lightning_results = model.predict(data=datasource)print( f"Local inference completed! Got {len(local_lightning_results.preds)} predictions")
Understanding the Two Base Classes
PytorchLightningInferenceModel vs PytorchInferenceModel
Both classes provide the same core functionality but with different underlying architectures:
Feature | PytorchLightningInferenceModel | PytorchInferenceModel |
---|---|---|
Dependencies | Requires PyTorch Lightning | Pure PyTorch only |
Execution | Uses Lightning Trainer | Direct PyTorch execution |
GPU/Device Handling | Lightning's automatic device management | Custom device detection |
Extensibility | Full Lightning ecosystem (callbacks, loggers) | Simple, direct control |
Key Architectural Differences
The main difference lies in how the predict()
method is implemented.
PytorchLightningInferenceModel:
# Uses PyTorch Lightning under the hooddef predict(self, data=None, **kwargs): # Uses pl.Trainer.test() internally self._pl_trainer.test(model=self, dataloaders=self.test_dl) return PredictReturnType(preds=self._test_preds, keys=self._test_keys)
PytorchInferenceModel:
# Direct PyTorch executiondef predict(self, data=None, **kwargs): # Direct batch processing with torch.no_grad() with torch.no_grad(): for batch in self.test_dl: predictions = self.forward(batch_data) # Process predictions... return PredictReturnType(preds=all_predictions, keys=all_keys)
When to Use Which Base Class
Use PytorchLightningInferenceModel
when:
- You want full PyTorch Lightning integration
- You need Lightning's advanced features (callbacks, logging, etc.)
- You want to leverage our dataloaders and datasets
- You prefer Lightning's structured approach to model organization
Use PytorchInferenceModel
when:
- You want minimal dependencies and faster startup
- You prefer simple, direct PyTorch code
- You're building lightweight inference services
- You need fine-grained control over the inference loop
- You're deploying in resource-constrained environments
Advanced Customization: Overriding Methods
While you only need to implement create_model()
, you can override other methods for custom behavior:
Customizing Your Inference Models
Both base classes share common functionality but have different advanced hooks available.
Common Base Functionality (Both Classes)
create_model()
- Required Method
Every inference model must implement this abstract method:
- Return your PyTorch model architecture (
nn.Module
) - Called automatically during model initialization
- The model will be moved to appropriate device and set to evaluation mode
Shared Public Methods Available for Override:
initialise_model(data, data_splitter, context)
- Model Setup
Default behavior:
- Prepares the model for inference
- Creates data loaders from provided datasource
- Calls
create_model()
to instantiate your model - Sets up the inference pipeline
When to override:
- Custom model initialization logic
forward(x)
- Model Forward Pass
Default behavior:
- Handles single and multi-image column scenarios
- Runs input through your created model
- Returns model predictions
When to override:
- Custom input preprocessing
- Multi-model ensemble logic
- Special output formatting needs
Shared Utility Methods:
split_dataloader_output(data)
- Data Parsing
Purpose: Properly extracts input data from dataloader output When to use: Processing batch data in custom methods instead of manual parsing
serialize(filename)
and deserialize(content)
- Model Persistence
Purpose: Save and load trained model weights Usage: Standard model checkpointing and deployment
PytorchLightningInferenceModel Customization
Lightning-Specific Override Methods:
test_step(batch, batch_idx)
- Per-Batch Processing
Default behavior:
- Processes each batch during inference
- Extracts data and optional keys from batch
- Runs forward pass and collects results
- Handles prediction aggregation automatically
When to override:
- Custom preprocessing per batch
- Ensemble predictions across multiple models
- Custom metrics or logging during inference
- Special batch result formatting
on_test_epoch_end()
- End-of-Inference Processing
Default behavior:
- Aggregates all batch results
- Prepares final prediction outputs
- Handles key-prediction alignment
When to override:
- Custom result aggregation logic
- Post-inference processing steps
- Custom validation or filtering
predict(data, **kwargs)
- Complete Pipeline Control
Default behavior:
- Uses Lightning trainer for inference execution
- Manages the complete inference workflow
- Returns formatted prediction results
Lightning Benefits:
- Automatic device management through trainer
- Built-in logging and metrics capabilities
- Structured approach with hooks and callbacks
- Easy integration with Lightning ecosystem
PytorchInferenceModel Customization
Inference Model Override Methods:
predict(data, **kwargs)
- Direct Inference Control
Default behavior:
- Manual batch processing loop with
torch.no_grad()
- Direct device management and model evaluation
- Explicit prediction collection and formatting
- No Lightning trainer dependency
When to override:
- Fine-grained control over inference loop
- Custom batch processing logic
- Memory-efficient streaming inference
- Integration with non-Lightning workflows
Simple Model Benefits:
- No PyTorch Lightning dependency
- Direct PyTorch control and transparency
- Explicit device and memory management
- Faster startup and execution
Method Override Guidelines
Start Simple:
- Implement only
create_model()
- Test basic inference functionality
- Add method overrides only when needed
Lightning Model Progression:
- Override
test_step()
for batch-level customization - Override
on_test_epoch_end()
for result aggregation - Override
predict()
for complete pipeline control
Simple Model Progression:
- Override
forward()
for input/output processing - Override
initialise_model()
for setup customization - Override
predict()
for complete pipeline control
Best Practices
- Choose the Right Base: Lightning for research, Simple for production
- Always call
model.eval()
in yourcreate_model()
method - Start Minimal: Begin with just
create_model()
, add complexity incrementally - Use Utilities: Leverage
split_dataloader_output()
for robust data handling - Test Locally: Validate all customizations before federated deployment
- Handle Edge Cases: Consider different input formats and error conditions
- Document Changes: Comment custom logic for team collaboration
You've now learned how to create simple, powerful inference models for federated learning with Bitfount!
Contact our support team at support@bitfount.com if you have any questions.