bitfount_model
Contains the base classes for handling custom models.
Classes
BitfountModel
class BitfountModel( datastructure: DataStructure, schema: BitfountSchema, seed: Optional[int] = None, param_clipping: Optional[Dict[str, int]] = None,):
Base class for custom models which must implement DistributedModelProtocol
.
A base tagging class to highlight custom models which are designed to be uploaded to Bitfount Hub.
Arguments
datastructure
:DataStructure
to be passed to the model when initialisedparam_clipping
: Arguments for clipping for BatchNorm parameters. Used for federated models with secure aggregation. It should contain the SecureShare variables and the number of workers in a dictionary, e.g. {"prime_q":13, "precision": 10**3,"num_workers":2}. Defaults to None.schema
: TheBitfountSchema
object associated with the datasource on which the model will be trained on.seed
: Random number seed. Used for setting random seed for all libraries. Defaults to None.
Attributes
fields_dict
: A dictionary mapping all attributes that will be serialized in the class to their marshamllow field type. (e.g. fields_dict = {"class_name": fields.Str()}).nested_fields
: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields = {"datastructure": datastructure.registry})
Ancestors
- bitfount.models.base_models._BaseModel
- bitfount.models.base_models._BaseModelRegistryMixIn
- bitfount.types._BaseSerializableObjectMixIn
- abc.ABC
- typing.Generic
Subclasses
Variables
- static
fields_dict : ClassVar[Dict[str, marshmallow.fields.Field]]
- static
nested_fields : ClassVar[Dict[str, Mapping[str, Any]]]
Static methods
serialize_model_source_code
def serialize_model_source_code( filename: Union[str, os.PathLike], extra_imports: Optional[List[str]] = None,) ‑> None:
Serializes the source code of the model to file.
This is required so that the model source code can be uploaded to Bitfount Hub.
Arguments
filename
: The filename to save the source code to.extra_imports
: A list of extra import statements to include in the source code.
Methods
initialise_model
def initialise_model( self, data: Optional[BaseSource] = None, context: Optional[TaskContext] = None,) ‑> None:
Can be implemented to initialise model if necessary.
This is automatically called by the fit()
method if necessary.
Arguments
data
: The data used for model training.context
: Indicates if the model is running as a modeller or worker. If None, there is no difference between modeller and worker. Defaults to None.