Skip to content

Commit

Permalink
ICDTA4FL added.
Browse files Browse the repository at this point in the history
  • Loading branch information
AlArgente committed Oct 25, 2024
1 parent 91ef286 commit ed3fd06
Show file tree
Hide file tree
Showing 18 changed files with 4,422 additions and 2 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The methods implemented in the repository are:
| Federated ID3 | The ID3 model adapted to a federated learning scenario. | [A Hybrid Approach to Privacy-Preserving Federated Learning](https://arxiv.org/pdf/1812.03224.pdf) |
| Federated Random Forest | The Random Forest (RF) model adapted to a federated learning scenario. Each client builds a RF locally, then `N` trees are randomly sampled from each client to get a global RF composed from the `N` trees retrieved from the clients. | [Federated Random Forests can improve local performance of predictive models for various healthcare applications](https://pubmed.ncbi.nlm.nih.gov/35139148/) |
| Federated Gradient Boosting Decision Trees | The Gradient Boosting Decision Trees model adapted to a federated learning scenario. In this model a global hash table is first created to aling the data between the clients within sharing it. After that, `N` trees (CART) are built by the clients. The process of building the ensemble is iterative, and one client builds the tree, then it is added to the ensemble, and after that the weights of the instances is updated, so the next client can build the next tree with the weights updated.| [Practical Federated Gradient Boosting Decision Trees](https://arxiv.org/abs/1911.04206) |
| Interpretable Client Decision Tree Aggregation For Federated Learning process (ICDTA4FL process) | The ICDTA4FL process is a process that allows the clients to build a decision tree locally, and then the trees are aggregated in a global tree by merging the rules extracted from the local trees. The process is iterative, and the clients can build a tree, then the trees that surpass a threshold are selected to be merged. In order the merge the trees, these are transformed into rules, and then the merged rules are used to build a global tree. This process is tree independent, and the code is available for merging ID3, CART and C4.5 trees. | [Interpretable Client Decision Tree Aggregation For Federated Learning process](https://arxiv.org/pdf/2404.02510) |

The tabular datasets available in the repository are:
| `Dataset` | `Description` | `Citation` |
Expand Down Expand Up @@ -58,4 +59,9 @@ pip install -e .

If you use this package, please cite the following paper:

``` TODO: Add citation ```
``` @article{herrera2024flex,
title={FLEX: FLEXible Federated Learning Framework},
author={Herrera, Francisco and Jim{\'e}nez-L{\'o}pez, Daniel and Argente-Garrido, Alberto and Rodr{\'\i}guez-Barroso, Nuria and Zuheros, Cristina and Aguilera-Martos, Ignacio and Bello, Beatriz and Garc{\'\i}a-M{\'a}rquez, Mario and Luz{\'o}n, M},
journal={arXiv preprint arXiv:2404.06127},
year={2024}
} ```
1 change: 1 addition & 0 deletions flextrees/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from flextrees.datasets.tabular_datasets import credit2
from flextrees.datasets.tabular_datasets import nursery
from flextrees.datasets.tabular_datasets import adult
from flextrees.datasets.tabular_datasets import adult_raw
from flextrees.datasets.tabular_datasets import bank
from flextrees.datasets.preprocessing_utils import preprocess_adult
from flextrees.datasets.preprocessing_utils import preprocess_credit2
37 changes: 37 additions & 0 deletions flextrees/datasets/tabular_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,43 @@ def nursery(out_dir: str = '.', ret_feature_names: bool = False, categorical=Tru
return train_data_object, test_data_object, col_names
return train_data_object, test_data_object

def adult_raw(out_dir: str = '.', categorical=False):
import os
import pandas as pd
if not os.path.exists(f"{out_dir}/adult_train.csv"):
path_to_train = 'http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'
path_to_test = 'http://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test'
x_columns = ['x' + str(i) for i in range(14)]
y_column = 'label'
train_data = pd.read_csv(path_to_train, names=x_columns + [y_column])
test_data = pd.read_csv(path_to_test, names=x_columns + [y_column])
test_data = test_data.iloc[1:]
test_data['x0'] = [int(val) for val in test_data['x0']]

train_labels = train_data.apply(lambda row: 1 if '>50K' in row['label'] else 0, axis=1).to_numpy()
test_labels = test_data.apply(lambda row: 1 if '>50K' in row['label'] else 0, axis=1).to_numpy()
train_data['label'] = train_labels
test_data['label'] = test_labels
train_data.to_csv(f"{out_dir}/adult_train.csv", index=False)
test_data.to_csv(f"{out_dir}/adult_test.csv", index=False)
else:
train_data = pd.read_csv(f"{out_dir}/adult_train.csv")
test_data = pd.read_csv(f"{out_dir}/adult_test.csv")
train_labels = train_data['label']
test_labels = test_data['label']
train_data = train_data.drop(columns=['label'], axis=1)
test_data = test_data.drop(columns=['label'], axis=1)
y_data = train_labels.to_numpy()
X_data = train_data.to_numpy()
from sklearn.model_selection import train_test_split
X_data, X_test, y_data, y_test = train_test_split(X_data, y_data, test_size=0.3)
# y_test = test_labels.to_numpy()
# X_test = test_data.to_numpy()

train_data_object = Dataset.from_array(X_data, y_data)
test_data_object = Dataset.from_array(X_test, y_test)
return train_data_object, test_data_object

def adult(out_dir: str = '.', ret_feature_names: bool = False, categorical=True):
"""Function that load the adult dataset from the UCI database
Expand Down
23 changes: 23 additions & 0 deletions flextrees/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,29 @@
from flextrees.pool.primitives_fedgbdt import evaluate_global_model_clients_gbdt
from flextrees.pool.primitives_fedgbdt import train_n_estimators

# Primitives and aggregation functions for DTFL
from flextrees.pool.primitives_dtfl import init_server_model_dtfl
from flextrees.pool.primitives_dtfl import deploy_server_model_dtfl
from flextrees.pool.primitives_dtfl import train_dtfl
from flextrees.pool.primitives_dtfl import collect_clients_weights_dtfl
from flextrees.pool.primitives_dtfl import set_aggregated_weights_dtfl
from flextrees.pool.primitives_dtfl import set_local_trees_to_server
from flextrees.pool.primitives_dtfl import evaluate_server_model_dtfl
from flextrees.pool.primitives_dtfl import get_classes_branches
from flextrees.pool.primitives_dtfl import send_all_trees_to_client
from flextrees.pool.primitives_dtfl import deploy_global_model_dtfl
from flextrees.pool.primitives_dtfl import collect_clients_trees_dtfl
from flextrees.pool.primitives_dtfl import evaluate_global_model_dtfl_on_client
from flextrees.pool.primitives_dtfl import evaluate_global_trees
from flextrees.pool.primitives_dtfl import collect_local_evaluations_from_clients_dtfl
from flextrees.pool.primitives_dtfl import set_selected_trees_to_server_dtfl
from flextrees.pool.aggregators_dtfl import aggregate_dtfl
from flextrees.pool.aggregators_dtfl import aggregate_dtfl_prunning
from flextrees.pool.aggregators_dtfl import aggregate_client_dts
from flextrees.pool.aggregators_dtfl import aggregate_thresholds_and_select
from flextrees.pool.aggregators_dtfl import aggregate_transfer_learning


from flextrees.pool.primitives_fedgbdt import preprocessing_stage
# Functions from pool_functions
from flextrees.pool.pool_functions import select_client_by_id_from_pool
Expand Down
134 changes: 134 additions & 0 deletions flextrees/pool/aggregators_dtfl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from flex.pool.decorators import aggregate_weights


@aggregate_weights
def aggregate_dtfl(list_of_weights: list, *args, **kwargs):
"""Function that aggregate the rules from the clients
"""
from dtfl.utils.utils_function_aggregator import generate_cs_dt_branches_from_list
from dtfl.utils.branch_tree import TreeBranch
from dtfl.utils.branch_tree_categorical import TreeBranchCategorical

classes_ = set()
features_ = set()
classes_ |= {c for client in list_of_weights for c in client[1]}
features_ |= {fe for client in list_of_weights for fe in client[2].columns if 'upper' or 'lower' in fe}
classes_ = list(classes_)
features_.remove('probas')
features_ = list(set(features_)) # This param will be used in a future for VFL. Still need some coding to match the features correctly.
client_cs = [cs[0] for cs in list_of_weights]
tree_model_ = {client[3] for client in list_of_weights}
assert len(tree_model_) == 1
tree_model_ = TreeBranch if 'cart' in tree_model_ else TreeBranchCategorical
return generate_cs_dt_branches_from_list(client_cs, classes_, tree_model_)

@aggregate_weights
def aggregate_dtfl_prunning(list_of_weights: list, *args, **kwargs):
"""Function that aggregate the rules from the clients
"""
list_of_weights = [client for i, client in enumerate(list_of_weights) if i in kwargs['selected_indexes']]
from dtfl.utils.utils_function_aggregator import generate_cs_dt_branches_from_list
from dtfl.utils.branch_tree import TreeBranch
from dtfl.utils.branch_tree_categorical import TreeBranchCategorical
from dtfl.utils.branch_tree_mixed import TreeBranchMixed

classes_ = set()
features_ = set()
classes_ |= {c for client in list_of_weights for c in client[1]}
features_ |= {fe for client in list_of_weights for fe in client[2].columns if 'upper' or 'lower' in fe}
classes_ = list(classes_)
# breakpoint()
features_.remove('probas')
features_ = list(set(features_)) # This param will be used in a future for VFL. Still need some coding to match the features correctly.
client_cs = [cs[0] for cs in list_of_weights]
# breakpoint()
tree_model_ = {client[3] for client in list_of_weights}

try:
assert len(tree_model_) == 1
except AssertionError:
print(f"Tree model: {tree_model_}")
print(f"List of weights: {list_of_weights}")
print(f"Selected indexes: {kwargs['selected_indexes']}")
raise AssertionError
# tree_model_ = TreeBranch if 'cart' in tree_model_ else TreeBranchCategorical # OLD
if 'cart' in tree_model_:
tree_model_ = TreeBranch
elif 'id3' in tree_model_:
tree_model_ = TreeBranchCategorical
elif 'c45' in tree_model_:
print("Using TreeBranchMixed")
tree_model_ = TreeBranchMixed
else:
raise NotImplementedError(f"Tree model {tree_model_} not implemented.")
return generate_cs_dt_branches_from_list(client_cs, classes_, tree_model_)

@aggregate_weights
def aggregate_client_dts(list_of_weights: list, *args, **kwargs):
"""Function that aggregate all the client trees to send them to the clients
"""
return list_of_weights

@aggregate_weights
def aggregate_thresholds_and_select(list_of_weights: list, *args, **kwargs):
"""
Function that select those trees that pass the threshold in both accuracy and f1.
This function recieves a list with all the f1 and acc for each tree with the predictions
for each test dataset for each client, and returns the indices of those that surpass
the threshold given for both acc and macro f1.
"""
acc_threshold = kwargs['acc_threshold']
f1_threshold = kwargs['f1_threshold']
func_str = kwargs['func_str']
func_kwval = kwargs['func_kwargs']
# print(f"Metrics at client level: {list_of_weights}")
import numpy as np
sum_list_of_weights = np.sum(np.array(list_of_weights), axis=0)/len(list_of_weights)
acc_array = sum_list_of_weights[0]
f1_array = sum_list_of_weights[1]
def select_func_aggregation(func_str='percentile'):
func_opts = {
'percentile': (np.percentile, 'q'),
'quantile': (np.quantile, 'q'),
'mean': (np.mean, None),
'median': (np.median, None)
}
return func_opts[func_str]
func, func_kwargs = select_func_aggregation(func_str=func_str)
print(f"Using {func_str} as threshold.")
func_kwargs = {func_kwargs:func_kwval} if func_kwargs is not None else {}
acc_threshold, f1_threshold = func(np.mean(np.array(list_of_weights), axis=0),
axis=1, **func_kwargs)
# acc_threshold, f1_threshold = np.percentile(np.mean(np.array(list_of_weights), axis=0), q=75, axis=1) # noqa: E501
# END FOR TESTING PURPOSES #
selected_trees = np.where((acc_array >= acc_threshold) & (f1_array >= f1_threshold))[0]
if len(selected_trees) < 1:
"""
If no tree is selected, we select the best tree according to the accuracy threshold
and the best tree according to the f1 threshold. Instead of using the last
thresholds, we use a 98.9% of the original thresholds. This way, we can be sure
that at least one tree will be selected.
"""
# selected_trees = np.where((acc_array >= acc_threshold) | (f1_array >= f1_threshold))[0]
f1_threshold = f1_threshold * 0.989
acc_threshold = acc_threshold * 0.989
selected_trees = np.where((acc_array >= acc_threshold) & (f1_array >= f1_threshold))[0]
print(f"Number of selected trees: {len(selected_trees)}")
return list(selected_trees)

@aggregate_weights
def aggregate_transfer_learning(list_of_weights: list, *args, **kwargs):
"""Function that select the best models to aggregate them into one
Right now return all of them as the final model must be built first in
order to optimize the build of the global tree.
Args:
list_of_weights (list): _description_
Returns:
_type_: _description_
"""
import numpy as np
print("transfer_agg")
raise NotImplementedError("This function is not implemented yet.")
return list(np.arange(len(list_of_weights)))
Loading

0 comments on commit ed3fd06

Please sign in to comment.