Skip to main content

Inference-Only Models: Custom Models for Federated Inference

Welcome to building custom models for federated inference! With Bitfount's inference model classes, you can create powerful federated inference models with minimal code and complexity.

Overview

Bitfount provides two base classes specifically designed for inference-only tasks:

  • PytorchLightningInferenceModel - PyTorch Lightning-based inference model
  • PytorchInferenceModel - Simple PyTorch inference model without Lightning dependencies

These classes eliminate the need to implement training-related methods, making it easy to deploy pre-trained models for federated inference.

Prerequisites

!pip install bitfount

Getting Started with Inference Models

With Bitfount's Pytorch inference base classes, you only need to implement one method:

from bitfount.backends.pytorch.models.inference_model import PytorchInferenceModelclass MySimpleInferenceModel(PytorchInferenceModel):    def __init__(self, **kwargs):        super().__init__(**kwargs)    def create_model(self):        """The only method you need to implement!"""        ...

Image Classification with PyTorch

Let's set up the environment and create a complete inference model using a pre-trained ResNet:

import loggingimport nest_asyncioimport torchimport torch.nn as nnfrom pathlib import Pathfrom PIL import Imagefrom bitfount import (    BitfountModelReference,    BitfountSchema,    ImageSource,    DataStructure,    ModelInference,    ResultsOnly,    get_pod_schema,    setup_loggers,)from bitfount.backends.pytorch.models.inference_model import (    PytorchLightningInferenceModel,    PytorchInferenceModel,)nest_asyncio.apply()  # Needed because Jupyter also has an asyncio loop
loggers = setup_loggers([logging.getLogger("bitfount")])

Now let's create our inference model by first saving it to a file:

import torch.nn as nnfrom bitfount.backends.pytorch.models.inference_model import PytorchInferenceModelclass ResNetInferenceModel(PytorchInferenceModel):    """Simple inference model for image classification."""    def __init__(self, n_classes: int = 10, **kwargs):        super().__init__(**kwargs)        self.n_classes = n_classes    def create_model(self) -> nn.Module:        """Create and return a simple CNN model."""        model = nn.Sequential(            nn.AdaptiveAvgPool2d((1, 1)),            nn.Flatten(),            nn.Linear(3, 64),            nn.ReLU(),            nn.Linear(64, self.n_classes),            nn.Softmax(dim=1),        )        model.eval()        return model

Testing Locally with Image Data

Let's test our model locally first:

# Create some image data for testingdatasource = ImageSource(    path="sample_images/",  # Path to your image directory)schema = BitfountSchema(    name="image-inference-demo",)# For image data, specify image columnsforce_stypes = {    "image": ["Pixel Data"],  # Standard image column name}schema.generate_full_schema(datasource, force_stypes=force_stypes)# Create datastructure for inference (no target needed)datastructure = DataStructure(    target=None,  # No target for inference, can be skipped    image_cols=["Pixel Data"],  # Specify the image column    selected_cols=["Pixel Data"],  # Add selected columns    # schema_requirements="full"  # Optional)# Initialize our modelmodel = ResNetInferenceModel(    datastructure=datastructure,    schema=schema,    n_classes=2,    batch_size=2,  # Adjust batch size as needed)# Test local inferencemodel.initialise_model(datasource)local_results = model.predict(data=datasource)print(f"Local inference completed! Got {len(local_results.preds)} predictions")

Run Inference on a Pod

Now use your simple model with the existing Bitfount infrastructure:

# Use the image dataset podpod_identifier = "image-datasource"schema = get_pod_schema(pod_identifier)# Create model referencemodel_ref = BitfountModelReference(    model_ref=Path("ResNetInferenceModel.py"),  # Your simple model file    datastructure=datastructure,    schema=schema,)# Run federated inferenceprotocol = ResultsOnly(algorithm=ModelInference(model=model_ref))results = protocol.run(pod_identifiers=[pod_identifier])print("Inference completed!")print(f"Results: {results}")

Similar approach can be used for creating a model inheriting from PytorchLightningInferenceModel.

import torch.nn as nnfrom bitfount.backends.pytorch.models.inference_model import (    PytorchLightningInferenceModel,)class ResNetLightningInferenceModel(PytorchLightningInferenceModel):    """Simple inference model for image classification."""    def __init__(self, n_classes: int = 10, **kwargs):        super().__init__(**kwargs)        self.n_classes = n_classes    def create_model(self) -> nn.Module:        """Create and return a simple CNN model."""        model = nn.Sequential(            nn.AdaptiveAvgPool2d((1, 1)),            nn.Flatten(),            nn.Linear(3, 64),            nn.ReLU(),            nn.Linear(64, self.n_classes),            nn.Softmax(dim=1),        )        model.eval()        return model

Let's test the Lightning model locally:

# Create some image data for testingdatasource = ImageSource(    path="sample_images/",  # Path to your image directory)schema = BitfountSchema(    name="image-inference-demo",)# For image data, specify image columnsforce_stypes = {    "image": ["Pixel Data"],  # Standard image column name}schema.generate_full_schema(datasource, force_stypes=force_stypes)# Create datastructure for inference (no target needed)datastructure = DataStructure(    target=None,  # No target for inference, can be skipped    image_cols=["Pixel Data"],  # Specify the image column    selected_cols=["Pixel Data"],  # Add selected columns)# Initialize our modelmodel = ResNetLightningInferenceModel(    datastructure=datastructure,    schema=schema,    n_classes=2,    batch_size=2,  # Adjust batch size as needed)# Test local inferencemodel.initialise_model(datasource)local_lightning_results = model.predict(data=datasource)print(    f"Local inference completed! Got {len(local_lightning_results.preds)} predictions")

Understanding the Two Base Classes

PytorchLightningInferenceModel vs PytorchInferenceModel

Both classes provide the same core functionality but with different underlying architectures:

FeaturePytorchLightningInferenceModelPytorchInferenceModel
DependenciesRequires PyTorch LightningPure PyTorch only
ExecutionUses Lightning TrainerDirect PyTorch execution
GPU/Device HandlingLightning's automatic device managementCustom device detection
ExtensibilityFull Lightning ecosystem (callbacks, loggers)Simple, direct control

Key Architectural Differences

The main difference lies in how the predict() method is implemented.

PytorchLightningInferenceModel:

# Uses PyTorch Lightning under the hooddef predict(self, data=None, **kwargs):    # Uses pl.Trainer.test() internally    self._pl_trainer.test(model=self, dataloaders=self.test_dl)    return PredictReturnType(preds=self._test_preds, keys=self._test_keys)

PytorchInferenceModel:

# Direct PyTorch executiondef predict(self, data=None, **kwargs):    # Direct batch processing with torch.no_grad()    with torch.no_grad():        for batch in self.test_dl:            predictions = self.forward(batch_data)            # Process predictions...    return PredictReturnType(preds=all_predictions, keys=all_keys)

When to Use Which Base Class

Use PytorchLightningInferenceModel when:

  • You want full PyTorch Lightning integration
  • You need Lightning's advanced features (callbacks, logging, etc.)
  • You want to leverage our dataloaders and datasets
  • You prefer Lightning's structured approach to model organization

Use PytorchInferenceModel when:

  • You want minimal dependencies and faster startup
  • You prefer simple, direct PyTorch code
  • You're building lightweight inference services
  • You need fine-grained control over the inference loop
  • You're deploying in resource-constrained environments

Advanced Customization: Overriding Methods

While you only need to implement create_model(), you can override other methods for custom behavior:

Customizing Your Inference Models

Both base classes share common functionality but have different advanced hooks available.

Common Base Functionality (Both Classes)

create_model()- Required Method

Every inference model must implement this abstract method:

  • Return your PyTorch model architecture (nn.Module)
  • Called automatically during model initialization
  • The model will be moved to appropriate device and set to evaluation mode

Shared Public Methods Available for Override:

initialise_model(data, data_splitter, context) - Model Setup

Default behavior:

  • Prepares the model for inference
  • Creates data loaders from provided datasource
  • Calls create_model() to instantiate your model
  • Sets up the inference pipeline

When to override:

  • Custom model initialization logic

forward(x) - Model Forward Pass

Default behavior:

  • Handles single and multi-image column scenarios
  • Runs input through your created model
  • Returns model predictions

When to override:

  • Custom input preprocessing
  • Multi-model ensemble logic
  • Special output formatting needs

Shared Utility Methods:

split_dataloader_output(data) - Data Parsing

Purpose: Properly extracts input data from dataloader output When to use: Processing batch data in custom methods instead of manual parsing

serialize(filename) and deserialize(content) - Model Persistence

Purpose: Save and load trained model weights Usage: Standard model checkpointing and deployment

PytorchLightningInferenceModel Customization

Lightning-Specific Override Methods:

test_step(batch, batch_idx) - Per-Batch Processing

Default behavior:

  • Processes each batch during inference
  • Extracts data and optional keys from batch
  • Runs forward pass and collects results
  • Handles prediction aggregation automatically

When to override:

  • Custom preprocessing per batch
  • Ensemble predictions across multiple models
  • Custom metrics or logging during inference
  • Special batch result formatting

on_test_epoch_end() - End-of-Inference Processing

Default behavior:

  • Aggregates all batch results
  • Prepares final prediction outputs
  • Handles key-prediction alignment

When to override:

  • Custom result aggregation logic
  • Post-inference processing steps
  • Custom validation or filtering

predict(data, **kwargs) - Complete Pipeline Control

Default behavior:

  • Uses Lightning trainer for inference execution
  • Manages the complete inference workflow
  • Returns formatted prediction results

Lightning Benefits:

  • Automatic device management through trainer
  • Built-in logging and metrics capabilities
  • Structured approach with hooks and callbacks
  • Easy integration with Lightning ecosystem

PytorchInferenceModel Customization

Inference Model Override Methods:

predict(data, **kwargs) - Direct Inference Control

Default behavior:

  • Manual batch processing loop with torch.no_grad()
  • Direct device management and model evaluation
  • Explicit prediction collection and formatting
  • No Lightning trainer dependency

When to override:

  • Fine-grained control over inference loop
  • Custom batch processing logic
  • Memory-efficient streaming inference
  • Integration with non-Lightning workflows

Simple Model Benefits:

  • No PyTorch Lightning dependency
  • Direct PyTorch control and transparency
  • Explicit device and memory management
  • Faster startup and execution

Method Override Guidelines

Start Simple:

  1. Implement only create_model()
  2. Test basic inference functionality
  3. Add method overrides only when needed

Lightning Model Progression:

  1. Override test_step() for batch-level customization
  2. Override on_test_epoch_end() for result aggregation
  3. Override predict() for complete pipeline control

Simple Model Progression:

  1. Override forward() for input/output processing
  2. Override initialise_model() for setup customization
  3. Override predict() for complete pipeline control

Best Practices

  1. Choose the Right Base: Lightning for research, Simple for production
  2. Always call model.eval() in your create_model() method
  3. Start Minimal: Begin with just create_model(), add complexity incrementally
  4. Use Utilities: Leverage split_dataloader_output() for robust data handling
  5. Test Locally: Validate all customizations before federated deployment
  6. Handle Edge Cases: Consider different input formats and error conditions
  7. Document Changes: Comment custom logic for team collaboration

You've now learned how to create simple, powerful inference models for federated learning with Bitfount!

Contact our support team at support@bitfount.com if you have any questions.