Source code for flex.pool.aggregators

"""
Copyright (C) 2024  Instituto Andaluz Interuniversitario en Ciencia de Datos e Inteligencia Computacional (DaSCI).

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published
    by the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
"""File that contains the adapted aggregators in FLEXible for fast
development of a federated model in FLEXible.

This aggregators also can work as examples for creating a custom aggregator.
"""

import tensorly as tl  # noqa: E402

from flex.pool.decorators import aggregate_weights  # noqa: E402


[docs] def flatten(xs): for x in xs: if isinstance(x, (list, tuple)): yield from flatten(x) else: yield x
[docs] def set_tensorly_backend( aggregated_weights_as_list: list, supported_modules: list = None ): # jax support is planned if supported_modules is None: supported_modules = ["tensorflow", "torch"] backend_set = False for modulename in supported_modules: try: tmp_import = __import__(modulename) if all( tmp_import.is_tensor(t) for t in flatten(aggregated_weights_as_list) ): if modulename == "torch": modulename = f"py{modulename}" tl.set_backend(modulename) backend_set = True break else: del tmp_import except ImportError: ... # Default backend if not backend_set: tl.set_backend("numpy")
[docs] def fed_avg_f(aggregated_weights_as_list: list): n_nodes = len(aggregated_weights_as_list) ponderation = [1 / n_nodes] * n_nodes return weighted_fed_avg_f(aggregated_weights_as_list, ponderation)
[docs] def weighted_fed_avg_f(aggregated_weights_as_list: list, ponderation: list): n_layers = len(aggregated_weights_as_list[0]) agg_weights = [] for layer_index in range(n_layers): weights_per_layer = [] for client_weights, p in zip(aggregated_weights_as_list, ponderation): context = tl.context(client_weights[layer_index]) w = client_weights[layer_index] * tl.tensor(p, **context) weights_per_layer.append(w) weights_per_layer = tl.stack(weights_per_layer) agg_layer = tl.sum(weights_per_layer, axis=0) agg_weights.append(agg_layer) return agg_weights
[docs] @aggregate_weights def fed_avg(aggregated_weights_as_list: list): """Function that implements the FedAvg aggregation method Args: ----- aggregated_weights_as_list (list): List which contains all the weights to aggregate Returns: -------- tensor array: An array with the aggregated weights Example of use assuming you are using a client-server architecture: from flex.pool.primitive_functions import fed_avg aggregator = flex_pool.aggregators server = flex_pool.servers aggregator.map(server, fed_avg) Example of use using the FlexPool without separating server and aggregator, and following a client-server architecture. from flex.pool.primitive_functions import fed_avg flex_pool.aggregators.map(flex_pool.servers, fed_avg) """ set_tensorly_backend(aggregated_weights_as_list) return fed_avg_f(aggregated_weights_as_list)
[docs] @aggregate_weights def weighted_fed_avg(aggregated_weights_as_list: list, ponderation: list): """Function that implements the weighted FedAvg aggregation method. Args: ----- aggregated_weights_as_list (list): List which contains all the weights to aggregate ponderation (list): weights assigned to each client Returns: -------- tensor array: An array with the aggregated weights Example of use assuming you are using a client-server architecture: from flex.pool.primitive_functions import weighted_fed_avg aggregator = flex_pool.aggregators server = flex_pool.servers dummy_poderation = [1.]*len(flex_pool.clients) aggregator.map(server, weighted_fed_avg, ponderation=dummy_poderation) Example of use using the FlexPool without separating server and aggregator, and following a client-server architecture. from flex.pool.primitive_functions import weighted_fed_avg dummy_poderation = [1.]*len(flex_pool.clients) flex_pool.aggregators.map(flex_pool.servers, weighted_fed_avg, ponderation=dummy_poderation) """ set_tensorly_backend(aggregated_weights_as_list) return weighted_fed_avg_f(aggregated_weights_as_list, ponderation)