Skip to main content

timm_federated_training

HuggingFace TIMM Federated Training Algorithm.

Borrowed with permission from https://github.com/huggingface/pytorch-image-models. Copyright 2020 Ross Wightman (https://github.com/rwightman)

Classes

TIMMFederatedTraining

class TIMMFederatedTraining(    datastructure: DataStructure,    model_id: str,    labels: Optional[list[str]] = None,    args: Optional[TIMMTrainingConfig] = None,    pretrained_file: Optional[Union[str, os.PathLike[str]]] = None,    lora_rank: "Optional[Union[int, Literal['auto']]]" = None,):

TIMM-based federated training algorithm compatible with FederatedAveraging.

Enables collaborative DINOv2 fine-tuning across multiple fabs (pods) without any fab sharing its raw wafer map images. Each fab trains locally on its private data and sends only model weight updates to the OEM (modeller), which aggregates them using Federated Averaging.

This algorithm is designed for use with the FederatedAveraging protocol. Use TIMMFineTuning with ResultsOnly instead for per-fab local fine-tuning without weight exchange.

Arguments

  • **kwargs: Additional keyword arguments.
  • args: Training configuration. Defaults to TIMMTrainingConfig().
  • datastructure: Data structure describing the image and label columns.
  • labels: Ordered list of class label strings corresponding to the classification head outputs (e.g. wafer defect types).
  • lora_rank: LoRA rank for parameter-efficient fine-tuning. None (default) uses full fine-tuning; "auto" selects rank from model size (recommended for DINOv2 and other large ViTs); an integer sets the rank directly. See module-level comments for guidance on rank selection.
  • model_id: The HuggingFace / TIMM model identifier, e.g. "timm/vit_base_patch14_dinov2.lvd142m".
  • pretrained_file: Optional path to a local pretrained checkpoint to start from instead of downloading from HuggingFace Hub. Defaults to None.

Attributes

  • class_name: The name of the algorithm class.
  • fields_dict: A dictionary mapping all attributes that will be serialized in the class to their marshmallow field type. (e.g. fields_dict = {"class_name": fields.Str()}).
  • nested_fields: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields = {"datastructure": datastructure.registry})

Variables

  • model : str - Returns the model identifier string.

    This property satisfies the structural requirement of the _FederatedAveragingCompatibleAlgoFactory Protocol. The Protocol check is @runtime_checkable and only verifies that the attribute exists, not its type.

Methods


create

def create(self, role: Union[str, Role], **kwargs: Any)> Any:

Create an instance representing the role specified.

modeller

def modeller(    self,    *,    context: ProtocolContext,    **kwargs: Any,)> bitfount.federated.algorithms.hugging_face_algorithms.timm_federated_training._FedModellerSide:

Returns the modeller side of the TIMMFederatedTraining algorithm.

Arguments

  • context: Protocol context for the current task.
  • **kwargs: Additional keyword arguments.

Returns The modeller-side algorithm instance.

worker

def worker(    self,    *,    context: ProtocolContext,    **kwargs: Any,)> bitfount.federated.algorithms.hugging_face_algorithms.timm_federated_training._FedWorkerSide:

Returns the worker side of the TIMMFederatedTraining algorithm.

The hub keyword argument passed by FederatedAveraging is accepted via **kwargs and forwarded to the worker side (currently unused since weights are downloaded from HuggingFace Hub directly via TIMM).

Arguments

  • context: Protocol context for the current task.
  • **kwargs: Additional keyword arguments (includes hub from the FederatedAveraging protocol).

Returns The worker-side algorithm instance.