"""
Copyright (C) 2024 Instituto Andaluz Interuniversitario en Ciencia de Datos e Inteligencia Computacional (DaSCI).
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
"""File that contains the primitive functions to build an easy training loop of the federated learning model.
In this file we specify some functions for each framework, i.e., TensorFlow (tf), PyTorch (pt), among others, but
we only give functions for a general purpose. For a more personalized use of FLEXible, the user must create
her own functions. The user can use this functions as template on how to create a custom function for each step
of the training steps in a federated learning environment.
Note that each function is using the decorators we've created to facilitate the use of the library. For a better
understanding on how the platform works, please go to the flex_decorators file.
"""
from copy import deepcopy # noqa: E402
from flex.pool.decorators import ( # noqa: E402
collect_clients_weights,
deploy_server_model,
set_aggregated_weights,
)
[docs]
@deploy_server_model
def deploy_server_model_pt(server_flex_model, *args, **kwargs):
"""Creates a copy of the server_flex_model and it is set to client nodes using the decorator @deploy_server_model.
Args:
-----
server_flex_model (FlexModel): object storing information needed to run a Pytorch model
"""
return deepcopy(server_flex_model)
[docs]
def check_ignored_weights_pt(name, ignore_weights=None):
"""Checks wether name contains any of the words in ignore_weights.
Args:
-----
name (str): name to check
ignore_weights (list, optional): A list of str. Defaults to None.
Returns:
--------
bool: True if any og the elements of list ignore_weights is present in name, otherwise False.
"""
if ignore_weights is None:
ignore_weights = ["num_batches_tracked"]
return any(ignored in name for ignored in ignore_weights)
[docs]
@collect_clients_weights
def collect_client_diff_weights_pt(client_flex_model, *args, **kwargs):
# sourcery skip: raise-specific-error
"""Function that collect the weights for a PyTorch model. Particularly,
it collects the difference between the model before and after training, \
that is, what the model has learnt in its local training step. Also note \
that the weights of the model before training are assume to be stored \
using `previous_model` as key.
This function returns the weights of the model.
Args:
-----
client_flex_model (FlexModel): A client's FlexModel
ignore_weights (list): the name of the weights not to collect, by default,
those containind the words `num_batches_tracked` are not collected, as they
only make sense in the local model
Returns:
--------
List: List with the weights of the client's model
Example of use assuming you are using a client-server architecture:
from flex.pool import collect_client_diff_weights_pt
clients = flex_pool.clients
aggregator = flex_pool.aggregators
clients.map(collect_client_diff_weights_pt, aggregator)
Example of using the FlexPool without separating clients
and aggregator, and following a client-server architecture.
from flex.pool import collect_client_diff_weights_pt
flex_pool.clients.map(collect_client_diff_weights_pt, flex_pool.aggregators)
"""
import torch
ignore_weights = kwargs.get("ignore_weights", None)
with torch.no_grad():
weight_dict = client_flex_model["model"].state_dict()
try:
previous_weight_dict = client_flex_model["previous_model"].state_dict()
except KeyError as e:
raise Exception(
'A copy of the model before training must be stored in client FlexModel using key: "previous_model"'
) from e
parameters = []
for name in weight_dict:
if check_ignored_weights_pt(name, ignore_weights=ignore_weights):
parameters.append(torch.tensor([]))
continue
weight_diff = weight_dict[name] - previous_weight_dict[name]
parameters.append(weight_diff)
return parameters
[docs]
@collect_clients_weights
def collect_clients_weights_pt(client_flex_model, *args, **kwargs):
"""Function that collect the weights for a PyTorch model.
This function returns all the weights of the model.
Args:
-----
client_flex_model (FlexModel): A client's FlexModel
ignore_weights (list): the name of the weights not to collect, by default,
those containind the words `num_batches_tracked` are not collected, as they
only make sense in the local model
Returns:
--------
List: List with all the weights of the client's model
Example of use assuming you are using a client-server architecture:
from flex.pool import collect_weights_pt
clients = flex_pool.clients
aggregator = flex_pool.aggregators
clients.map(collect_weights_pt, aggregator)
Example of using the FlexPool without separating clients
and aggregator, and following a client-server architecture.
from flex.pool import collect_weights_pt
flex_pool.clients.map(collect_weights_pt, flex_pool.aggregators)
"""
import torch
ignore_weights = kwargs.get("ignore_weights", None)
with torch.no_grad():
parameters = []
weight_dict = client_flex_model["model"].state_dict()
for name in weight_dict:
w = weight_dict[name]
if check_ignored_weights_pt(name, ignore_weights=ignore_weights):
w = torch.tensor([])
continue
parameters.append(w)
return parameters
[docs]
@set_aggregated_weights
def set_aggregated_weights_pt(server_flex_model, aggregated_weights, *args, **kwargs):
"""Function that replaces the weights of the server with the aggregated weights of the aggregator.
Args:
-----
server_flex_model (FlexModel): The server's FlexModel
aggregated_weights (np.array): Aggregated weights
Example of use assuming you are using a client-server architecture:
from flex.pool import set_aggregated_weights_pt
aggregator = flex_pool.aggregators
aggregator.map(set_aggregated_weights_pt)
Example of using the FlexPool without separating clients
and aggregator, and following a client-server architecture.
from flex.pool import set_aggregated_weights_pt
flex_pool.aggregators.map(set_aggregated_weights_pt)
"""
import torch
with torch.no_grad():
weight_dict = server_flex_model["model"].state_dict()
for layer_key, new in zip(weight_dict, aggregated_weights):
try:
if len(new) != 0: # Do not copy empty layers
weight_dict[layer_key].copy_(new)
except TypeError: # new has no len property
weight_dict[layer_key].copy_(new)
[docs]
@set_aggregated_weights
def set_aggregated_diff_weights_pt(
server_flex_model, aggregated_diff_weights, *args, **kwargs
):
"""Function to add the aggregated weights to the server.
Args:
-----
server_flex_model (FlexModel): The server's FlexModel
aggregated_diff_weights (np.array): Aggregated weights
Example of use assuming you are using a client-server architecture:
from flex.pool import set_aggregated_diff_weights_pt
aggregator = flex_pool.aggregators
aggregator.map(set_aggregated_diff_weights_pt)
Example of using the FlexPool without separating clients
and aggregator, and following a client-server architecture.
from flex.pool import set_aggregated_diff_weights_pt
flex_pool.aggregators.map(set_aggregated_diff_weights_pt)
"""
import torch
with torch.no_grad():
weight_dict = server_flex_model["model"].state_dict()
for layer_key, new in zip(weight_dict, aggregated_diff_weights):
try:
if len(new) != 0: # Do not copy empty layers
weight_dict[layer_key].add_(new)
except TypeError: # new has no len property
weight_dict[layer_key].add_(new)