Source code for flex.data.fed_dataset_config

"""
Copyright (C) 2024  Instituto Andaluz Interuniversitario en Ciencia de Datos e Inteligencia Computacional (DaSCI).

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published
    by the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.

    You should have received a copy of the GNU Affero General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
from dataclasses import asdict, dataclass
from typing import Hashable, List, Optional, Tuple, Union

import numpy as np
import numpy.typing as npt


[docs] class InvalidConfig(ValueError): """Raised when the input config is wrong""" pass
[docs] @dataclass class FedDatasetConfig: """Class used to represent a configuration to federate a centralized dataset. The following table shows the compatiblity of each option: +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | Options compatibility | **n_nodes** | **node_ids** | **weights** | **weights_per_label** | **replacement** | **labels_per_node** | **features_per_node** | **indexes_per_node** | **group_by_label_index** | **keep_labels** | **shuffle** | +=========================+=============+==============+=============+=======================+=================+=====================+=======================+======================+==========================+=================+=============+ | **n_nodes** | - | Y | Y | Y | Y | Y | Y | N | N | Y | Y | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | **node_ids** | - | - | Y | Y | Y | Y | Y | Y | N | Y | Y | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | **weights** | - | - | - | N | Y | Y | Y | N | N | Y | Y | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | **weights_per_label** | - | - | - | - | Y | N | N | N | N | Y | Y | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | **replacement** | - | - | - | - | - | Y | N | N | N | Y | Y | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | **labels_per_node** | - | - | - | - | - | - | N | N | N | Y | Y | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | **features_per_node** | - | - | - | - | - | - | - | N | N | Y | Y | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | **indexes_per_node** | - | - | - | - | - | - | - | - | N | Y | Y | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | **group_by_label_index**| - | - | - | - | - | - | - | - | - | N | Y | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | **keep_labels** | - | - | - | - | - | - | - | - | - | - | Y | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ | **shuffle** | - | - | - | - | - | - | - | - | - | - | - | +-------------------------+-------------+--------------+-------------+-----------------------+-----------------+---------------------+-----------------------+----------------------+--------------------------+-----------------+-------------+ Attributes ---------- seed: Optional[int] Seed used to make the federated dataset generated reproducible with this configuration. Default None. n_nodes: int Number of nodes among which to split a centralized dataset. Default 2. shuffle: bool If True data is shuffled before being sampled. Default False. node_ids: Optional[List[Hashable]] Ids to identifty each node, if not provided, nodes will be indexed using integers. If n_nodes is also \ given, we consider up to n_nodes elements. Default None. weights: Optional[npt.NDArray] A numpy.array which provides the proportion of data to give to each node. Default None. weights_per_label: Optional[npt.NDArray] A numpy.array which provides the proportion of data to give to each node and class of the dataset to federate. \ We expect a bidimensional array of shape (n, m) where "n" is the number of nodes and "m" is the number of labels of \ the dataset to federate. Default None. replacement: bool Whether the samping procedure used to split a centralized dataset is with replacement or not. Default False labels_per_node: Optional[Union[int, npt.NDArray, Tuple[int]]] labels to assign to each node, if provided as an int, it is the number labels per node, if provided as a \ tuple of ints, it establishes a mininum and a maximum of number of labels per node, a random number sampled \ in such interval decides the number of labels of each node. If provided as a list of lists, it establishes the labels \ assigned to each node. Default None. features_per_node: Optional[Union[int, npt.NDArray, Tuple[int]]] Features to assign to each node, it share the same interface as labels_per_node. Default None. indexes_per_node: Optional[npt.NDArray] Data indexes to assign to each node. Default None. group_by_label_index: Optional[int] Index which indicates which feature unique values will be used to generate federated nodes. Default None. keep_labels: Optional[list[bool]] Whether each node keeps or not the labels or y_data """ seed: Optional[int] = None n_nodes: int = 2 shuffle: bool = False node_ids: Optional[List[Hashable]] = None weights: Optional[npt.NDArray] = None weights_per_label: Optional[npt.NDArray] = None replacement: bool = False labels_per_node: Optional[Union[int, npt.NDArray, Tuple[int]]] = None features_per_node: Optional[Union[int, npt.NDArray, Tuple[int]]] = None indexes_per_node: Optional[npt.NDArray] = None group_by_label_index: Optional[int] = None keep_labels: Optional[List[bool]] = None def _check_incomp(self, dict, option1, option2): """This function checks if two options are compatible, if not it raises and exception""" cond1 = dict[option1] is not None if option1 != "replacement" else dict[option1] cond2 = dict[option2] is not None if option2 != "replacement" else dict[option2] if cond1 and cond2: raise InvalidConfig( f"Options {option1} and {option2} are incompatible, please provide only one." )
[docs] def validate(self): """This function checks whether the configuration to federate a dataset is correct.""" self_dict = asdict(self) # By default every option is compatible, therefore we only specify incompatibilities self._check_incomp(self_dict, "weights", "group_by_label_index") self._check_incomp(self_dict, "weights", "weights_per_label") self._check_incomp(self_dict, "weights", "indexes_per_node") self._check_incomp(self_dict, "weights", "group_by_label_index") self._check_incomp(self_dict, "weights_per_label", "indexes_per_node") self._check_incomp(self_dict, "weights_per_label", "labels_per_node") self._check_incomp(self_dict, "weights_per_label", "features_per_node") self._check_incomp(self_dict, "weights_per_label", "indexes_per_node") self._check_incomp(self_dict, "weights_per_label", "group_by_label_index") self._check_incomp(self_dict, "replacement", "indexes_per_node") self._check_incomp(self_dict, "replacement", "group_by_label_index") self._check_incomp(self_dict, "labels_per_node", "features_per_node") self._check_incomp(self_dict, "labels_per_node", "indexes_per_node") self._check_incomp(self_dict, "labels_per_node", "group_by_label_index") self._check_incomp(self_dict, "features_per_node", "indexes_per_node") self._check_incomp(self_dict, "features_per_node", "group_by_label_index") self.__validate_nodes_and_weights() if self.indexes_per_node is not None: self.__validate_indexes_per_node() elif self.labels_per_node is not None: self.__validate_labels_per_node() elif self.features_per_node is not None: self.__validate_features_per_node() if self.keep_labels is not None: self.__validate_keep_labels()
def __validate_indexes_per_node(self): if len(self.indexes_per_node) != self.n_nodes: raise InvalidConfig( "The number of provided nodes should equal the length of indexes per node." ) def __validate_keep_labels(self): if len(self.keep_labels) != self.n_nodes: raise InvalidConfig( "keep_labels list should have the same length as n_nodes." ) def __validate_nodes_and_weights(self): if self.n_nodes < 2: raise InvalidConfig( "The number of nodes must be greater or equal to 2. Default is 2" ) if self.node_ids is not None and self.n_nodes > len(self.node_ids): raise InvalidConfig( "The number of named nodes, node_ids, can not be greater than the number of nodes, n_nodes" ) if self.weights is not None and self.n_nodes != len(self.weights): raise InvalidConfig("The number of weights must equal the number of nodes.") if ( self.weights_per_label is not None and len(np.asarray(self.weights_per_label).shape) != 2 ): raise InvalidConfig( ( "weights_per_label must be a two dimensional array where the first dimension is the number of nodes and the second is the number of labels of the dataset to federate." ) ) if self.weights_per_label is not None and self.n_nodes != len( self.weights_per_label ): raise InvalidConfig( "The length of weights_per_label must equal the number of nodes." ) if self.weights is not None and max(self.weights) > 1: raise InvalidConfig( "Provided weights contains an element greater than 1, we do not allow sampling more than one time the entire dataset per node." ) if self.weights is not None and min(self.weights) < 0: raise InvalidConfig( "Provided weights contains negative numbers, we do not allow that." ) def __validate_labels_per_node(self): if isinstance(self.labels_per_node, tuple): if len(self.labels_per_node) != 2: raise InvalidConfig( f"labels_per_node if provided as a tuple, it must have two elements, mininum number of labels per node and maximum number of labels per node, but labels_per_node={self.labels_per_node}." ) elif not isinstance(self.labels_per_node, int) and self.n_nodes != len( self.labels_per_node ): raise InvalidConfig( "labels_per_node if provided as a list o np.ndarray, its length and n_nodes must equal." ) def __validate_features_per_node(self): if not self.replacement: raise InvalidConfig( "By setting replacement to False and specifying features_per_node, nodes will not share any data instances." ) if isinstance(self.features_per_node, tuple): if len(self.features_per_node) != 2: raise InvalidConfig( f"features_per_node if provided as a tuple, it must have two elements, mininum number of features per node and maximum number of features per node, but features_per_node={self.features_per_node}." ) elif not isinstance(self.features_per_node, int) and self.n_nodes != len( self.features_per_node ): raise InvalidConfig( "features_per_node if provided as a list o np.ndarray, its length and n_nodes must equal." )