Source code for flex.data.fed_dataset

"""
Copyright (C) 2024  Instituto Andaluz Interuniversitario en Ciencia de Datos e Inteligencia Computacional (DaSCI).

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published
    by the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
from collections import UserDict
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Hashable, List, Optional

from multiprocess import Pool

from flex.common.utils import check_min_arguments
from flex.data.dataset import Dataset
from flex.data.preprocessing_utils import normalize, one_hot_encoding


[docs] class FedDataset(UserDict): """Class that represents a federated dataset for the Flex library. The dataset contains the ids of the nodes and the dataset associated with each node. Attributes ---------- data (collections.UserDict): The structure is a dictionary with the node ids as keys and the dataset as value. """ def __setitem__(self, key: Hashable, item: Dataset) -> None: self.data[key] = item
[docs] def get(self, key: Hashable, default: Optional[Any] = None) -> Any: try: return self[key] except KeyError: return default
[docs] def apply( self, func: Callable, node_ids: List[Hashable] = None, num_proc: int = 1, **kwargs, ): r"""This function lets apply a custom function to the FlexDataset in parallel. The \**kwargs provided to this function are all the kwargs of the custom function provided by the node. Args: ----- func (Callable, optional): Function to apply to preprocess the data. node_ids (List[Hashtable], optional): List containig the the node ids where func will be applied. Each element of the list must be hashable and part of the FlexDataset. Defaults to None. num_proc (int, optional): Number of processes to parallelize, negative values are ignored. Default to 1 Returns: -------- FedDataset: The modified FlexDataset. Raises: ------- ValueError: All node ids given must be in the FlexDataset. """ if node_ids is None: node_ids = list(self.keys()) elif isinstance(node_ids, str): if node_ids not in self.keys(): raise ValueError("All node ids given must be in the FedDataset.") elif any(node not in self.keys() for node in node_ids): raise ValueError("All nodes ids given must be in the FedDataset.") error_msg = f"The provided function: {func.__name__} is expected to have at least 1 argument/s." assert check_min_arguments(func, min_args=1), error_msg # if any(self[i].X_data._is_generator for i in node_ids): # raise NotImplementedError("LazyIndexable with generators will be supported soon") if num_proc < 2: updates = self._map_single(func, node_ids, **kwargs) else: f = partial(self._map_single, func) updates = self._map_parallel(f, node_ids, num_proc=num_proc, **kwargs) new_fld = deepcopy( self ) # seguramente solo haga falta copiar los que no están en node_ids new_fld.update(updates) return new_fld
def _map_parallel( self, func: Callable, node_ids: List[Hashable], num_proc: int = 2, **kwargs, ): """This function lets apply a custom function to the FlexDataset in parallel. The **kwargs provided to this function are the kwargs of the custom function provided by the node. Args: ----- fld (FedDataset): FlexDataset containing all the data from the nodes. func (Callable): Function to apply to preprocess the data. node_ids (List[Hashtable]): List containig the the nodes id where func will be applied. Each element of the list must be hashable and part of the FlexDataset num_proc (int): Number of processes to parallelize, negative values are ignored. Default to 2 Returns: -------- FedDataset: The modified FlexDataset. """ updates = {} f = partial(func, **kwargs) # bind **kwargs arguments to each call with Pool(processes=num_proc) as p: for i in p.imap(f, node_ids): updates.update(i) return updates def _map_single( self, func: Callable, node_ids: List[Hashable], **kwargs, ): """This function lets apply a custom function to the FlexDataset secuentially. This functions will be used by default in the map function, because of the error generated by a bug with the multiprocessing library. If you want to check the error to try to use the _map_parallel The *args and the **kwargs provided to this function are all the args and kwargs of the custom function provided by the node. Args: ----- func (Callable): Function to apply to preprocess the data. node_ids (List[Hashtable]): List containig the the node ids where func will be applied. Each element of the list must be hashable and part of the FlexDataset. Returns: -------- FedDataset: The modified FlexDataset. """ if not isinstance(node_ids, list): node_ids = [node_ids] return {node_id: func(self[node_id], **kwargs) for node_id in node_ids}
[docs] def normalize( self, node_ids: List[Hashable] = None, num_proc: int = 0, *args, **kwargs, ): """Function that normalize the data over the nodes. Args: ----- fld (FedDataset): FlexDataset containing all the data from the nodes. node_ids (List[Hashtable], optional): List containig the nodes id whether to normalize the data or not. Each element of the list must be hashable. Defaults to None. num_proc (int, optional): Number of processes to paralelize. Default to None (Use all). Returns: -------- FedDataset: The FlexDataset normalized. """ return self.apply(normalize, node_ids, num_proc, *args, **kwargs)
[docs] def one_hot_encoding( self, node_ids: List[Hashable] = None, num_proc: int = 0, *args, **kwargs, ): """Function that apply one hot encoding to the node labels. Args: ----- fld (FedDataset): FlexDataset containing all the data from the nodes. node_ids (List[Hashtable], optional): List containing the nodes id whether to normalize the data or not. Each element of the list must be hashable. Defaults to None. num_proc (int, optional): Number of processes to paralelize. Default to None (Use all). Returns: -------- FedDataset: The FlexDataset normalized. """ return self.apply(one_hot_encoding, node_ids, num_proc, *args, **kwargs)