timm_inference
Hugging Face TIMM inference Algorithm.
Adapted from: https://github.com/huggingface/api-inference-community/
Classes
TIMMInference
class TIMMInference( datastructure: DataStructure, model_id: str, num_classes: Optional[int] = None, batch_size: int = 1, checkpoint_path: Optional[Union[os.PathLike, str]] = None, class_outputs: Optional[list[str]] = None,):HuggingFace TIMM Inference Algorithm.
Arguments
- **
**kwargs**: Additional keyword arguments. batch_transformations: A list of dictionaries containing the batch transformations. Defaults to None.checkpoint_path: The path to a checkpoint file local to the Pod. Defaults to None.class_outputs: A list of explict class outputs to use as labels. Defaults to None.datastructure: The data structure to use for the algorithm.model_id: The model id to use from the Hugging Face Hub.num_classes: The number of classes in the model. Defaults to None.
Attributes
checkpoint_path: The path to a checkpoint file local to the Pod. Defaults to None.class_name: The name of the algorithm class.class_outputs: A list of explict class outputs to use as labels. Defaults to None.fields_dict: A dictionary mapping all attributes that will be serialized in the class to their marshmallow field type. (e.g. fields_dict ={"class_name": fields.Str()}).model_id: The model id to use from the Hugging Face Hub.nested_fields: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields ={"datastructure": datastructure.registry})num_classes: The number of classes in the model. Defaults to None.
Ancestors
- BaseNonModelAlgorithmFactory
- BaseAlgorithmFactory
- abc.ABC
- bitfount.federated.roles._RolesMixIn
- bitfount.types._BaseSerializableObjectMixIn
- typing.Generic
Variables
- static
fields_dict : ClassVar[dict[str, marshmallow.fields.Field]]
Methods
create
def create(self, role: Union[str, Role], **kwargs: Any) ‑> Any:Create an instance representing the role specified.
modeller
def modeller( self, *, context: ProtocolContext, **kwargs: Any,) ‑> bitfount.federated.algorithms.hugging_face_algorithms.base._HFModellerSide:Returns the modeller side of the TIMMInference algorithm.
worker
def worker( self, *, context: ProtocolContext, **kwargs: Any,) ‑> bitfount.federated.algorithms.hugging_face_algorithms.timm_inference._WorkerSide:Returns the worker side of the TIMMInference algorithm.