Skip to main content

base_models

Base models and helper classes using PyTorch as the backend.

Classes

PyTorchClassifierMixIn

class PyTorchClassifierMixIn(    multilabel: bool = False, param_clipping: Optional[dict[str, int]] = None,):

MixIn for PyTorch classification problems.

PyTorch classification models must have this class in their inheritance hierarchy.

Arguments

  • multilabel: Whether the problem is a multi-label problem. i.e. each datapoint belongs to multiple classes
  • param_clipping: Arguments for clipping for BatchNorm parameters. Used for federated models with secure aggregation. It should contain the SecureShare variables and the number of workers in a dictionary, e.g. {"prime_q":13, "precision": 10**3,"num_workers":2}

Attributes

  • fields_dict: A dictionary mapping all attributes that will be serialized in the class to their marshamllow field type. (e.g. fields_dict = {"class_name": fields.Str()}).
  • multilabel: Whether the problem is a multi-label problem
  • n_classes: Number of classes in the problem
  • nested_fields: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields = {"datastructure": datastructure.registry})

Ancestors

  • ClassifierMixIn
  • bitfount.models.base_models._BaseModelRegistryMixIn
  • bitfount.types._BaseSerializableObjectMixIn

Variables

Methods


set_number_of_classes

def set_number_of_classes(self, schema: TableSchema)> None:

Inherited from:

ClassifierMixIn.set_number_of_classes :

Sets the target number of classes for the classifier.

If the data is a multi-label problem, the number of classes is set to the number of target columns as specified in the DataStructure. Otherwise, the number of classes is set to the number of unique values in the target column as specified in the BitfountSchema. The value is stored in the n_classes attribute.