timm_federated_training
HuggingFace TIMM Federated Training Algorithm.
Borrowed with permission from https://github.com/huggingface/pytorch-image-models. Copyright 2020 Ross Wightman (https://github.com/rwightman)
Classes
TIMMFederatedTraining
class TIMMFederatedTraining( datastructure: DataStructure, model_id: str, labels: Optional[list[str]] = None, args: Optional[TIMMTrainingConfig] = None, pretrained_file: Optional[Union[str, os.PathLike[str]]] = None, lora_rank: "Optional[Union[int, Literal['auto']]]" = None,):TIMM-based federated training algorithm compatible with FederatedAveraging.
Enables collaborative DINOv2 fine-tuning across multiple fabs (pods) without any fab sharing its raw wafer map images. Each fab trains locally on its private data and sends only model weight updates to the OEM (modeller), which aggregates them using Federated Averaging.
This algorithm is designed for use with the FederatedAveraging protocol.
Use TIMMFineTuning with ResultsOnly instead for per-fab local fine-tuning
without weight exchange.
Arguments
**kwargs: Additional keyword arguments.args: Training configuration. Defaults toTIMMTrainingConfig().datastructure: Data structure describing the image and label columns.labels: Ordered list of class label strings corresponding to the classification head outputs (e.g. wafer defect types).lora_rank: LoRA rank for parameter-efficient fine-tuning.None(default) uses full fine-tuning;"auto"selects rank from model size (recommended for DINOv2 and other large ViTs); an integer sets the rank directly. See module-level comments for guidance on rank selection.model_id: The HuggingFace / TIMM model identifier, e.g."timm/vit_base_patch14_dinov2.lvd142m".pretrained_file: Optional path to a local pretrained checkpoint to start from instead of downloading from HuggingFace Hub. Defaults toNone.
Attributes
class_name: The name of the algorithm class.fields_dict: A dictionary mapping all attributes that will be serialized in the class to their marshmallow field type. (e.g. fields_dict ={"class_name": fields.Str()}).nested_fields: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields ={"datastructure": datastructure.registry})
Variables
- static
fields_dict : ClassVar[dict[str, marshmallow.fields.Field]]
-
model : str- Returns the model identifier string.This property satisfies the structural requirement of the
_FederatedAveragingCompatibleAlgoFactoryProtocol. The Protocol check is@runtime_checkableand only verifies that the attribute exists, not its type.
Methods
create
def create(self, role: Union[str, Role], **kwargs: Any) ‑> Any:Create an instance representing the role specified.
modeller
def modeller( self, *, context: ProtocolContext, **kwargs: Any,) ‑> bitfount.federated.algorithms.hugging_face_algorithms.timm_federated_training._FedModellerSide:Returns the modeller side of the TIMMFederatedTraining algorithm.
Arguments
context: Protocol context for the current task.**kwargs: Additional keyword arguments.
Returns The modeller-side algorithm instance.
worker
def worker( self, *, context: ProtocolContext, **kwargs: Any,) ‑> bitfount.federated.algorithms.hugging_face_algorithms.timm_federated_training._FedWorkerSide:Returns the worker side of the TIMMFederatedTraining algorithm.
The hub keyword argument passed by FederatedAveraging is accepted
via **kwargs and forwarded to the worker side (currently unused since
weights are downloaded from HuggingFace Hub directly via TIMM).
Arguments
context: Protocol context for the current task.**kwargs: Additional keyword arguments (includeshubfrom the FederatedAveraging protocol).
Returns The worker-side algorithm instance.