Source code for flex.pool.decorators

"""
Copyright (C) 2024  Instituto Andaluz Interuniversitario en Ciencia de Datos e Inteligencia Computacional (DaSCI).

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published
    by the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import functools
from typing import List

from flex.common.utils import check_min_arguments
from flex.model import FlexModel


[docs] def ERROR_MSG_MIN_ARG_GENERATOR(f, min_args): return f"The decorated function: {f.__name__} is expected to have at least {min_args} argument/s."
[docs] def init_server_model(func): @functools.wraps(func) def _init_server_model_(server_flex_model: FlexModel, _, *args, **kwargs): server_flex_model.update(func(*args, **kwargs)) return _init_server_model_
[docs] def deploy_server_model(func): min_args = 1 assert check_min_arguments(func, min_args), ERROR_MSG_MIN_ARG_GENERATOR( func, min_args ) @functools.wraps(func) def _deploy_model_( server_flex_model: FlexModel, clients_flex_models: List[FlexModel], *args, **kwargs, ): for k in clients_flex_models: # Reminder, it is not possible to make assignements here clients_flex_models[k].update(func(server_flex_model, *args, **kwargs)) return _deploy_model_
[docs] def collect_clients_weights(func): min_args = 1 assert check_min_arguments(func, min_args), ERROR_MSG_MIN_ARG_GENERATOR( func, min_args ) @functools.wraps(func) def _collect_weights_( aggregator_flex_model: FlexModel, clients_flex_models: List[FlexModel], *args, **kwargs, ): if "weights" not in aggregator_flex_model: aggregator_flex_model["weights"] = [] for k in clients_flex_models: client_weights = func(clients_flex_models[k], *args, **kwargs) aggregator_flex_model["weights"].append(client_weights) return _collect_weights_
[docs] def aggregate_weights(func): min_args = 1 assert check_min_arguments(func, min_args), ERROR_MSG_MIN_ARG_GENERATOR( func, min_args ) @functools.wraps(func) def _aggregate_weights_(aggregator_flex_model: FlexModel, _, *args, **kwargs): aggregator_flex_model["aggregated_weights"] = func( aggregator_flex_model["weights"], *args, **kwargs ) aggregator_flex_model["weights"] = [] return _aggregate_weights_
[docs] def set_aggregated_weights(func): min_args = 2 assert check_min_arguments(func, min_args), ERROR_MSG_MIN_ARG_GENERATOR( func, min_args ) @functools.wraps(func) def _deploy_aggregated_weights_( aggregator_flex_model: FlexModel, servers_flex_models: FlexModel, *args, **kwargs, ): for k in servers_flex_models: func( servers_flex_models[k], aggregator_flex_model["aggregated_weights"], *args, **kwargs, ) return _deploy_aggregated_weights_
[docs] def evaluate_server_model(func): min_args = 1 assert check_min_arguments(func, min_args), ERROR_MSG_MIN_ARG_GENERATOR( func, min_args ) @functools.wraps(func) def _evaluate_server_model_(server_flex_model: FlexModel, _, *args, **kwargs): return func(server_flex_model, *args, **kwargs) return _evaluate_server_model_