Skip to main content

Model Fine-tuning

Model fine-tuning tasks are supported for both Bitfount and TIMM models.

Bitfount models

The protocol used is bitfount.FederatedAveraging and the algorithm used is bitfount.FederatedModelTraining. This combination also supports federated learning tasks where the model is trained on multiple datasets in a federated manner. In this case, we are only using a single dataset.

tip

For more information on running federated learning tasks, please refer to the documentation here.

An example task file for using a Bitfount-hosted model for fine-tuning is shown below. In this case, the model is a simple binary classification model. The features are not specified meaning that all columns in the dataset will be used for training.

modeller:
identity_verification_method: key-based

pods:
identifiers:
- <replace-with-dataset-identifier>

batched_execution: false
test_run: false
run_on_new_data_only: false

task:
protocol:
name: bitfount.FederatedAveraging
arguments:
epochs_between_parameter_updates: 10 # No need to share the model weights until the end of the training
algorithm:
- name: bitfount.FederatedModelTraining
arguments:
modeller_checkpointing: true # Whether to save the last checkpoint on the modeller side
checkpoint_filename: best_checkpoint.pt
model:
bitfount_model:
model_ref: MyBinaryClassificationModel
model_version: 1
username: bitfount
hyperparameters:
epochs: 10
batch_size: "{{ batch_size }}"
learning_rate: "{{ learning_rate }}"
weight_decay: "{{ weight_decay }}"
aggregator:
secure: False
data_structure:
schema_requirements: partial
assign:
target:
- "{{ target_column_name }}"

template:
batch_size:
label: "Batch size"
tooltip: "Number of samples per batch during training."
type: "number"
default: 8
learning_rate:
label: "Learning rate"
tooltip: "Learning rate for the model optimizer."
type: "number"
default: 0.0001
weight_decay:
label: "Weight decay"
tooltip: "Weight decay (L2 regularization) for the model optimizer."
type: "number"
default: 0.01
target_column_name:
label: "Target column"
tooltip: "The column containing dataset labels."
type:
schema_column_name:
semantic_type: "categorical"

TIMM models

A good example of a TIMM model is the RETFound (Retina foundation) model which is a multiclass image classification model. The below example task file shows how to use this model in a multiclass image classification task. This algorithm is only compatible with the bitfount.ResultsOnly protocol which simply runs the task and returns the results (if any) to the modeller. The new model parameters are not part of the results returned by the algorithm meaning that the model parameters are only saved to the Pod-side i.e. where the dataset is located and the modeller may only receive metrics about the training process.

tip

Take a look at the RETFound demo project to easily run this task in the Bitfount app.

modeller:
identity_verification_method: key-based

pods:
identifiers:
- <replace-with-dataset-identifier>

task:
protocol:
name: bitfount.ResultsOnly
algorithm:
- arguments:
model_id: bitfount/RETFound_MAE
labels:
- "0"
- "1"
- "2"
- "3"
- "4"
args:
epochs: 1
batch_size: 32
num_classes: 5
name: bitfount.TIMMFineTuning
data_structure:
table_config:
table: <replace-with-dataset-identifier>
select:
include:
- Image name
- Retinopathy grade
assign:
target:
- Retinopathy grade