Skip to main content

federated_averaging

Federated Averaging protocol.

Classes

FederatedAveraging

class FederatedAveraging(    *,    algorithm: _FederatedAveragingCompatibleAlgoFactory,    aggregator: Optional[_BaseAggregatorFactory] = None,    steps_between_parameter_updates: Optional[int] = None,    epochs_between_parameter_updates: Optional[int] = None,    auto_eval: bool = True,    secure_aggregation: bool = False,):

Original Federated Averaging algorithm by McMahan et al. (2017).

This protocol performs a predetermined number of epochs or steps of training on each remote Pod before sending the updated model parameters to the modeller. These parameters are then averaged and sent back to the Pods for as many federated iterations as the Modeller specifies.

tip

For more information, take a look at the seminal paper: https://arxiv.org/abs/1602.05629

Arguments

  • aggregator: The aggregator to use for updating the model parameters across all Pods participating in the task. This argument takes priority over the secure_aggregation argument.
  • algorithm: The algorithm to use for training. This must be compatible with the FederatedAveraging protocol.
  • auto_eval: Whether to automatically evaluate the model on the validation dataset. Defaults to True.
  • epochs_between_parameter_updates: The number of epochs between parameter updates, i.e. the number of rounds of local training before parameters are updated. If steps_between_parameter_updates is provided, epochs_between_parameter_updates cannot be provided. Defaults to None.
  • secure_aggregation: Whether to use secure aggregation. This argument is overridden by the aggregator argument. Defaults to False.
  • steps_between_parameter_updates: The number of steps between parameter updates, i.e. the number of rounds of local training before parameters are updated. If epochs_between_parameter_updates is provided, steps_between_parameter_updates cannot be provided. Defaults to None.

Attributes

  • aggregator: The aggregator to use for updating the model parameters.
  • algorithm: The algorithm to use for training
  • auto_eval: Whether to automatically evaluate the model on the validation dataset.
  • epochs_between_parameter_updates: The number of epochs between parameter updates.
  • fields_dict: A dictionary mapping all attributes that will be serialized in the class to their marshamllow field type. (e.g. fields_dict = {"class_name": fields.Str()}).
  • name: The name of the protocol.
  • nested_fields: A dictionary mapping all nested attributes to a registry that contains class names mapped to the respective classes. (e.g. nested_fields = {"datastructure": datastructure.registry})
  • steps_between_parameter_updates: The number of steps between parameter updates.

Raises

  • TypeError: If the algorithm is not compatible with the protocol.

Ancestors

Variables

  • static algorithm : bitfount.federated.protocols.model_protocols.federated_averaging._FederatedAveragingCompatibleAlgoFactory

Methods


create

def create(self, role: Union[str, Role], **kwargs: Any)> Any:

Create an instance representing the role specified.

dump

def dump(self)> SerializedProtocol:

Inherited from:

BaseProtocolFactory.dump :

Returns the JSON-serializable representation of the protocol.

modeller

def modeller(    self,    mailbox: _ModellerMailbox,    early_stopping: Optional[FederatedEarlyStopping] = None,    **kwargs: Any,)> bitfount.federated.protocols.model_protocols.federated_averaging._ModellerSide:

Returns the modeller side of the FederatedAveraging protocol.

run

def run(    self,    pod_identifiers: Collection[str],    session: Optional[BitfountSession] = None,    username: Optional[str] = None,    hub: Optional[BitfountHub] = None,    ms_config: Optional[MessageServiceConfig] = None,    message_service: Optional[_MessageService] = None,    pod_public_key_paths: Optional[Mapping[str, Path]] = None,    identity_verification_method: IdentityVerificationMethod = IdentityVerificationMethod.OIDC_DEVICE_CODE,    private_key_or_file: Optional[Union[RSAPrivateKey, Path]] = None,    idp_url: Optional[str] = None,    require_all_pods: bool = False,    run_on_new_data_only: bool = False,    model_out: Optional[Union[Path, str]] = None,    project_id: Optional[str] = None,    batched_execution: Optional[bool] = None,)> Optional[Any]:

Inherited from:

BaseProtocolFactory.run :

Sets up a local Modeller instance and runs the protocol.

Arguments

  • pod_identifiers: The BitfountHub pod identifiers to run against.
  • session: Optional. Session to use for authenticated requests. Created if needed.
  • username: Username to run as. Defaults to logged in user.
  • hub: BitfountHub instance. Default: hub.bitfount.com.
  • ms_config: Message service config. Default: messaging.bitfount.com.
  • message_service: Message service instance, created from ms_config if not provided. Defaults to "messaging.bitfount.com".
  • pod_public_key_paths: Public keys of pods to be checked against.
  • identity_verification_method: The identity verification method to use.
  • private_key_or_file: Private key (to be removed).
  • idp_url: The IDP URL.
  • require_all_pods: If true raise PodResponseError if at least one pod identifier specified rejects or fails to respond to a task request.
  • run_on_new_data_only: Whether to run the task on new datapoints only. Defaults to False.
  • model_out: The path to save the model to.
  • project_id: The project ID to run the task under.
  • batched_execution: Whether to run the task in batched mode. Defaults to False.

Returns Results of the protocol.

Raises

  • PodResponseError: If require_all_pods is true and at least one pod identifier specified rejects or fails to respond to a task request.
  • ValueError: If attempting to train on multiple pods, and the DataStructure table name is given as a string.

worker

def worker(    self, mailbox: _WorkerMailbox, hub: BitfountHub, **kwargs: Any,)> bitfount.federated.protocols.model_protocols.federated_averaging._WorkerSide:

Returns the worker side of the FederatedAveraging protocol.

Raises

  • TypeError: If the mailbox is not compatible with the aggregator.