Skip to main content

types

Type hints, enums and protocols for the Bitfount libraries.

Classes

BaseDistributedModelProtocol

class BaseDistributedModelProtocol(*args, **kwargs):

Federated Model structural type that only specifies the methods.

The reason for this protocol is that issubclass checks with Protocols can only be performed if the Protocol only specifies methods and not attributes. We still want to specify the attributes in another protocol though for greater type safety, (both statically and dynamically) so we have this protocol that only specifies methods and another protocol that specifies the attributes.

Methods


apply_weight_updates

def apply_weight_updates(    self, weight_updates: Sequence[_Weights],)> collections.abc.Mapping:

Defined in DistributedModelMixIn.

deserialize

def deserialize(self, filename: Union[str, os.PathLike])> None:

Defined in _BaseModel.

deserialize_params

def deserialize_params(    self, serialized_weights: _SerializedWeights,)> collections.abc.Mapping:

Defined in DistributedModelMixIn.

diff_params

def diff_params(    self, old_params: _Weights, new_params: _Weights,)> collections.abc.Mapping:

Defined in DistributedModelMixIn.

evaluate

def evaluate(    self,    test_dl: Optional[BitfountDataLoader] = None,    pod_identifiers: Optional[list[str]] = None,    **kwargs: Any,)> Union[EvaluateReturnType, dict[str, float]]:

Defined in _BaseModel.

fit

def fit(    self,    data: Optional[BaseSource] = None,    metrics: Optional[dict[str, Metric]] = None,    pod_identifiers: Optional[list[str]] = None,    **kwargs: Any,)> Optional[dict[str, str]]:

Defined in DistributedModelMixIn.

get_param_states

def get_param_states(self)> collections.abc.Mapping:

Defined in DistributedModelMixIn.

initialise_model

def initialise_model(    self, data: Optional[BaseSource] = None, context: Optional[TaskContext] = None,)> None:

Defined in _BaseModel.

log_

def log_(self, name: str, value: Any, **kwargs: Any)> Any:

Defined in DistributedModelMixIn.

predict

def predict(    self,    data: Optional[BaseSource] = None,    pod_identifiers: Optional[list[str]] = None,    **kwargs: Any,)> Union[PredictReturnType, dict[str, list[np.ndarray]]]:

Defined in _BaseModel.

reset_trainer

def reset_trainer(self)> None:

Defined in DistributedModelMixIn.

serialize

def serialize(self, filename: Union[str, os.PathLike])> None:

Defined in _BaseModel.

serialize_params

def serialize_params(self, weights: _Weights)> collections.abc.Mapping:

Defined in DistributedModelMixIn.

set_datastructure_identifier

def set_datastructure_identifier(self, datastructure_identifier: str)> None:

Defined in DistributedModelMixIn.

set_model_training_iterations

def set_model_training_iterations(self, iterations: int)> None:

Defined in DistributedModelMixIn.

tensor_precision

def tensor_precision(self)> +T_DTYPE:

Defined in DistributedModelMixIn.

update_params

def update_params(self, new_model_params: _Weights)> None:

Defined in DistributedModelMixIn.

DistributedModelProtocol

class DistributedModelProtocol(*args, **kwargs):

Federated Model structural type.

This protocol should be implemented by classes that inherit from either BitfountModel, or both of _BaseModel and DistributedModelMixIn.

Variables

  • static class_name : str
  • static datastructure : DataStructure
  • static epochs : Optional[int]
  • static fields_dict : ClassVar[T_FIELDS_DICT]
  • static metrics : Optional[MutableMapping[str, Metric]]
  • static nested_fields : ClassVar[T_NESTED_FIELDS]
  • static param_clipping : Optional[dict[str, int]]
  • static schema : BitfountSchema
  • static steps : Optional[int]
  • initialised : bool - Defined in _BaseModel.

Methods


apply_weight_updates

def apply_weight_updates(    self, weight_updates: Sequence[_Weights],)> collections.abc.Mapping:

Inherited from:

BaseDistributedModelProtocol.apply_weight_updates :

Defined in DistributedModelMixIn.

deserialize

def deserialize(self, filename: Union[str, os.PathLike])> None:

Inherited from:

BaseDistributedModelProtocol.deserialize :

Defined in _BaseModel.

deserialize_params

def deserialize_params(    self, serialized_weights: _SerializedWeights,)> collections.abc.Mapping:

Inherited from:

BaseDistributedModelProtocol.deserialize_params :

Defined in DistributedModelMixIn.

diff_params

def diff_params(    self, old_params: _Weights, new_params: _Weights,)> collections.abc.Mapping:

Inherited from:

BaseDistributedModelProtocol.diff_params :

Defined in DistributedModelMixIn.

evaluate

def evaluate(    self,    test_dl: Optional[BitfountDataLoader] = None,    pod_identifiers: Optional[list[str]] = None,    **kwargs: Any,)> Union[EvaluateReturnType, dict[str, float]]:

Inherited from:

BaseDistributedModelProtocol.evaluate :

Defined in _BaseModel.

fit

def fit(    self,    data: Optional[BaseSource] = None,    metrics: Optional[dict[str, Metric]] = None,    pod_identifiers: Optional[list[str]] = None,    **kwargs: Any,)> Optional[dict[str, str]]:

Inherited from:

BaseDistributedModelProtocol.fit :

Defined in DistributedModelMixIn.

get_param_states

def get_param_states(self)> collections.abc.Mapping:

Inherited from:

BaseDistributedModelProtocol.get_param_states :

Defined in DistributedModelMixIn.

initialise_model

def initialise_model(    self, data: Optional[BaseSource] = None, context: Optional[TaskContext] = None,)> None:

Inherited from:

BaseDistributedModelProtocol.initialise_model :

Defined in _BaseModel.

log_

def log_(self, name: str, value: Any, **kwargs: Any)> Any:

Inherited from:

BaseDistributedModelProtocol.log_ :

Defined in DistributedModelMixIn.

predict

def predict(    self,    data: Optional[BaseSource] = None,    pod_identifiers: Optional[list[str]] = None,    **kwargs: Any,)> Union[PredictReturnType, dict[str, list[np.ndarray]]]:

Inherited from:

BaseDistributedModelProtocol.predict :

Defined in _BaseModel.

reset_trainer

def reset_trainer(self)> None:

Inherited from:

BaseDistributedModelProtocol.reset_trainer :

Defined in DistributedModelMixIn.

serialize

def serialize(self, filename: Union[str, os.PathLike])> None:

Inherited from:

BaseDistributedModelProtocol.serialize :

Defined in _BaseModel.

serialize_params

def serialize_params(self, weights: _Weights)> collections.abc.Mapping:

Inherited from:

BaseDistributedModelProtocol.serialize_params :

Defined in DistributedModelMixIn.

set_datastructure_identifier

def set_datastructure_identifier(self, datastructure_identifier: str)> None:

Inherited from:

BaseDistributedModelProtocol.set_datastructure_identifier :

Defined in DistributedModelMixIn.

set_model_training_iterations

def set_model_training_iterations(self, iterations: int)> None:

Inherited from:

BaseDistributedModelProtocol.set_model_training_iterations :

Defined in DistributedModelMixIn.

tensor_precision

def tensor_precision(self)> +T_DTYPE:

Inherited from:

BaseDistributedModelProtocol.tensor_precision :

Defined in DistributedModelMixIn.

update_params

def update_params(self, new_model_params: _Weights)> None:

Inherited from:

BaseDistributedModelProtocol.update_params :

Defined in DistributedModelMixIn.