base_models
Base models and helper classes using PyTorch as the backend.
Classes
PyTorchClassifierMixIn
class PyTorchClassifierMixIn( multilabel: bool = False, param_clipping: Optional[dict[str, int]] = None,):MixIn for PyTorch classification problems.
PyTorch classification models must have this class in their inheritance hierarchy.
Arguments
multilabel: Whether the problem is a multi-label problem. i.e. each datapoint belongs to multiple classesparam_clipping: Arguments for clipping for BatchNorm parameters. Used for federated models with secure aggregation. It should contain the SecureShare variables and the number of workers in a dictionary, e.g.{"prime_q":13, "precision": 10**3,"num_workers":2}
Attributes
fields_dict: A dictionary mapping all attributes that will be serialized in the class to their marshmallow field type. (e.g. fields_dict ={"class_name": fields.Str()}).multilabel: Whether the problem is a multi-label problemn_classes: Number of classes in the problemnested_fields: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields ={"datastructure": datastructure.registry})
Ancestors
- ClassifierMixIn
- bitfount.models.base_models._BaseModelRegistryMixIn
- bitfount.types._BaseSerializableObjectMixIn
Variables
- static
datastructure : DataStructure- set in _BaseModel
- static
schema : BitfountSchema- set in _BaseModel
Methods
set_number_of_classes
def set_number_of_classes(self, schema: BitfountSchema) ‑> None:Inherited from:
ClassifierMixIn.set_number_of_classes :
Sets the target number of classes for the classifier.
If the data is a multi-label problem, the number of classes is set to the number
of target columns as specified in the DataStructure. Otherwise, the number of
classes is set to the number of unique values in the target column as specified
in the BitfountSchema. The value is stored in the n_classes attribute.