Skip to content

Commit

Permalink
added param to train med3pa models on all sets + updated comparison c…
Browse files Browse the repository at this point in the history
…riteria
  • Loading branch information
lyna1404 committed Aug 16, 2024
1 parent 777840d commit 486b161
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 23 deletions.
47 changes: 47 additions & 0 deletions MED3pa/datasets/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,50 @@ def __get_testing_data(self, return_instance: bool = False):
else:
raise ValueError("Testing set not initialized.")

def combine(self, dataset_types: list = None) -> MaskedDataset:
"""
Combines the specified datasets and returns a new MaskedDataset instance.
Args:
dataset_types (list, optional): List of dataset types to combine. Valid options are
'training', 'validation', 'reference', 'testing'.
If None, combines all datasets that are set.
Returns:
MaskedDataset: A new MaskedDataset instance containing the combined data.
Raises:
ValueError: If any specified dataset is not set or if no datasets are provided.
"""
if dataset_types is None:
dataset_types = ['training', 'validation', 'reference', 'testing']

combined_observations = []
combined_true_labels = []

for dataset_type in dataset_types:
dataset = self.get_dataset_by_type(dataset_type, True)
if dataset is None:
raise ValueError(f"Dataset '{dataset_type}' is not set.")

combined_observations.append(dataset.get_observations())
combined_true_labels.append(dataset.get_true_labels())

# Combine all observations and true labels into single arrays
combined_observations = np.vstack(combined_observations)
combined_true_labels = np.concatenate(combined_true_labels)

# Create a new MaskedDataset instance with the combined data
combined_dataset = MaskedDataset(combined_observations, combined_true_labels, column_labels=self.column_labels)

return combined_dataset

def reset(self) -> None:
"""
Resets all datasets in the manager and clears the column labels.
"""
self.base_model_training_set = None
self.base_model_validation_set = None
self.reference_set = None
self.testing_set = None
self.column_labels = None
17 changes: 14 additions & 3 deletions MED3pa/detectron/comparaison.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,17 @@ def is_comparable(self) -> bool:
self.compare_config()

datasets_different = self.config_file['datasets']['different']
datasets_different_sets = self.config_file['datasets']['different_datasets']
base_model_different = self.config_file['base_model']['different']
detectron_params_different = self.config_file['detectron_params']['different']

# Check the conditions for comparability
can_compare = False
if datasets_different and not base_model_different and not detectron_params_different:
# First condition: params are the same, base model is the same, only the testing_set is different
if not detectron_params_different and not base_model_different and datasets_different_sets == ['testing_set']:
can_compare = True
elif base_model_different and not datasets_different and not detectron_params_different:
# Second condition: base model is different, params are the same, datasets are the same or only differ in training and validation sets
elif base_model_different and not detectron_params_different and (not datasets_different or set(datasets_different_sets) <= {'training_set', 'validation_set'}):
can_compare = True

return can_compare
Expand Down Expand Up @@ -171,12 +174,20 @@ def compare_config(self):
config2 = json.load(f2)

combined['datasets'] = {}

dataset_keys = ['training_set', 'validation_set', 'reference_set', 'testing_set']
different_datasets = []

if config1["datasets"] == config2["datasets"]:
combined['datasets']['different'] = False
else:
combined['datasets']['different'] = True

for key in dataset_keys:
if config1["datasets"].get(key) != config2["datasets"].get(key):
different_datasets.append(key)

combined['datasets']['different_datasets'] = different_datasets

combined['datasets']['datasets1'] = config1["datasets"]
combined['datasets']['datasets2'] = config2["datasets"]

Expand Down
19 changes: 15 additions & 4 deletions MED3pa/med3pa/comparaison.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def is_comparable(self) -> bool:
self.compare_config()

datasets_different = self.config_file['datasets']['different']
datasets_different_sets = self.config_file['datasets']['different_datasets']
base_model_different = self.config_file['base_model']['different']

if self.compare_detectron:
Expand All @@ -120,11 +121,13 @@ def is_comparable(self) -> bool:
params_different = (params1 != params2)
# Check the conditions for comparability
can_compare = False
if datasets_different and not base_model_different and not params_different:
# First condition: params are the same, base model is the same, only the testing_set is different
if not params_different and not base_model_different and datasets_different_sets == ['testing_set']:
can_compare = True
elif base_model_different and not datasets_different and not params_different:
# Second condition: base model is different, params are the same, datasets are the same or only differ in training and validation sets
elif base_model_different and not params_different and (not datasets_different or set(datasets_different_sets) <= {'training_set', 'validation_set'}):
can_compare = True

if can_compare:
if self.compare_detectron:
self.mode = self.config_file['med3pa_detectron_params']['med3pa_detectron_params1']['med3pa_params']['mode']
Expand Down Expand Up @@ -275,12 +278,20 @@ def compare_config(self):
config2 = json.load(f2)

combined['datasets'] = {}

dataset_keys = ['training_set', 'validation_set', 'reference_set', 'testing_set']
different_datasets = []

if config1["datasets"] == config2["datasets"]:
combined['datasets']['different'] = False
else:
combined['datasets']['different'] = True

for key in dataset_keys:
if config1["datasets"].get(key) != config2["datasets"].get(key):
different_datasets.append(key)

combined['datasets']['different_datasets'] = different_datasets

combined['datasets']['datasets1'] = config1["datasets"]
combined['datasets']['datasets2'] = config2["datasets"]

Expand Down
45 changes: 30 additions & 15 deletions MED3pa/med3pa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def run(datasets_manager: DatasetsManager,
med3pa_metrics: List[str] = ['Accuracy', 'BalancedAccuracy', 'Precision', 'Recall', 'F1Score', 'Specificity', 'Sensitivity', 'Auc', 'LogLoss', 'Auprc', 'NPV', 'PPV', 'MCC'],
evaluate_models: bool = False,
use_ref_models: bool = False,
train_all:bool = True,
mode: str = 'mpc',
models_metrics: List[str] = ['MSE', 'RMSE', 'MAE']) -> Med3paResults:

Expand Down Expand Up @@ -312,22 +313,22 @@ def run(datasets_manager: DatasetsManager,
ipc_type=ipc_type, ipc_params=ipc_params, ipc_grid_params=ipc_grid_params, ipc_cv=ipc_cv, pretrained_ipc=pretrained_ipc,
apc_params=apc_params,apc_grid_params=apc_grid_params, apc_cv=apc_cv, pretrained_apc=pretrained_apc,
samples_ratio_min=samples_ratio_min, samples_ratio_max=samples_ratio_max, samples_ratio_step=samples_ratio_step,
med3pa_metrics=med3pa_metrics, evaluate_models=evaluate_models, models_metrics=models_metrics, mode=mode)
med3pa_metrics=med3pa_metrics, evaluate_models=evaluate_models, models_metrics=models_metrics, mode=mode, train_all=train_all)
print("Running MED3pa Experiment on the test set:")
if use_ref_models:
results_testing, ipc_config, apc_config = Med3paExperiment._run_by_set(datasets_manager=datasets_manager,set= 'testing',base_model_manager= base_model_manager,
uncertainty_metric=uncertainty_metric,
ipc_type=ipc_type, ipc_params=ipc_params, ipc_grid_params=ipc_grid_params, ipc_cv=ipc_cv, pretrained_ipc=pretrained_ipc, ipc_instance=ipc_config,
apc_params=apc_params,apc_grid_params=apc_grid_params, apc_cv=apc_cv, pretrained_apc=pretrained_apc, apc_instance=apc_config,
samples_ratio_min=samples_ratio_min, samples_ratio_max=samples_ratio_max, samples_ratio_step=samples_ratio_step,
med3pa_metrics=med3pa_metrics, evaluate_models=evaluate_models, models_metrics=models_metrics, mode=mode)
med3pa_metrics=med3pa_metrics, evaluate_models=evaluate_models, models_metrics=models_metrics, mode=mode, train_all=train_all)
else:
results_testing, ipc_config, apc_config = Med3paExperiment._run_by_set(datasets_manager=datasets_manager,set= 'testing',base_model_manager= base_model_manager,
uncertainty_metric=uncertainty_metric,
ipc_type=ipc_type, ipc_params=ipc_params, ipc_grid_params=ipc_grid_params, ipc_cv=ipc_cv, pretrained_ipc=pretrained_ipc, ipc_instance=None,
apc_params=apc_params,apc_grid_params=apc_grid_params, apc_cv=apc_cv, pretrained_apc=pretrained_apc, apc_instance=None,
samples_ratio_min=samples_ratio_min, samples_ratio_max=samples_ratio_max, samples_ratio_step=samples_ratio_step,
med3pa_metrics=med3pa_metrics, evaluate_models=evaluate_models, models_metrics=models_metrics, mode=mode)
med3pa_metrics=med3pa_metrics, evaluate_models=evaluate_models, models_metrics=models_metrics, mode=mode, train_all=train_all)

results = Med3paResults(results_reference, results_testing)
med3pa_params = {
Expand Down Expand Up @@ -373,6 +374,7 @@ def _run_by_set(datasets_manager: DatasetsManager,
samples_ratio_step: int = 5,
med3pa_metrics: List[str] = ['Accuracy', 'BalancedAccuracy', 'Precision', 'Recall', 'F1Score', 'Specificity', 'Sensitivity', 'Auc', 'LogLoss', 'Auprc', 'NPV', 'PPV', 'MCC'],
evaluate_models: bool = False,
train_all:bool = True,
mode: str = 'mpc',
models_metrics: List[str] = ['MSE', 'RMSE', 'MAE']) -> Tuple[Med3paRecord, dict, dict]:

Expand Down Expand Up @@ -412,18 +414,22 @@ def _run_by_set(datasets_manager: DatasetsManager,
# retrieve different dataset components needed for the experiment
x = dataset.get_observations()
y_true = dataset.get_true_labels()
predicted_probabilities = dataset.get_pseudo_probabilities()
features = datasets_manager.get_column_labels()

combined_dataset = datasets_manager.combine()
x_combined = combined_dataset.get_observations()
y_true_combined = combined_dataset.get_true_labels()

# Initialize base model and predict probabilities if not provided
if base_model_manager is None and predicted_probabilities is None:
raise ValueError("Either the base model or the predicted probabilities should be provided!")
if base_model_manager is None:
raise ValueError("The base model must be provided!")

if predicted_probabilities is None:
base_model = base_model_manager.get_instance()
predicted_probabilities = base_model.predict(x, True)
base_model = base_model_manager.get_instance()
predicted_probabilities = base_model.predict(x, True)
predicted_probabilities_combined = base_model.predict(x_combined, True)

dataset.set_pseudo_probs_labels(predicted_probabilities, 0.5)
combined_dataset.set_pseudo_probs_labels(predicted_probabilities_combined, 0.5)

# Step 2 : Mode and metrics setup
valid_modes = ['mpc', 'apc', 'ipc']
Expand All @@ -436,13 +442,19 @@ def _run_by_set(datasets_manager: DatasetsManager,
# Step 3 : Calculate uncertainty values
uncertainty_calc = UncertaintyCalculator(uncertainty_metric)
uncertainty_values = uncertainty_calc.calculate_uncertainty(x, predicted_probabilities, y_true)
print(predicted_probabilities_combined.shape, y_true_combined.shape)
uncertainty_values_combined = uncertainty_calc.calculate_uncertainty(x_combined, predicted_probabilities_combined, y_true_combined)

# Step 4: Set up splits to evaluate the models
if evaluate_models:
_, x_test, _, uncertainty_test = train_test_split(x, uncertainty_values, test_size=0.1, random_state=42)

x_train = x
uncertainty_train = uncertainty_values
if not train_all:
x_train = x
uncertainty_train = uncertainty_values
else:
x_train = x_combined
uncertainty_train = uncertainty_values_combined

results = Med3paRecord()

Expand All @@ -465,6 +477,8 @@ def _run_by_set(datasets_manager: DatasetsManager,

# Predict IPC values
IPC_values = IPC_model.predict(x)
IPC_values_combined = IPC_model.predict(x_combined)
IPC_values_train = IPC_values if not train_all else IPC_values_combined
print("Individualized confidence scores calculated.")
# Save the calculated confidence scores by the APCmodel
ipc_dataset = dataset.clone()
Expand All @@ -478,15 +492,15 @@ def _run_by_set(datasets_manager: DatasetsManager,
# Step 6: Create and train APCModel
if pretrained_apc is None and apc_instance is None:
APC_model = APCModel(features=features, params=apc_params)
APC_model.train(x, IPC_values)
APC_model.train(x_train, IPC_values_train)
print("APC Model training complete.")
# optimize APC model if grid params were provided
if apc_grid_params is not None:
APC_model.optimize(apc_grid_params, apc_cv, x_train, uncertainty_train)
APC_model.optimize(apc_grid_params, apc_cv, x_train, IPC_values_train)
print("APC Model optimization complete.")
elif pretrained_apc is not None:
APC_model = APCModel(features=features, params=apc_params, pretrained_model=pretrained_apc)
APC_model.train(x, IPC_values)
APC_model.train(x_train, IPC_values_train)
print("Loaded a pretrained APC model.")
else:
APC_model = apc_instance
Expand Down Expand Up @@ -578,6 +592,7 @@ def run(datasets: DatasetsManager,
med3pa_metrics: List[str] = ['Accuracy', 'BalancedAccuracy', 'Precision', 'Recall', 'F1Score', 'Specificity', 'Sensitivity', 'Auc', 'LogLoss', 'Auprc', 'NPV', 'PPV', 'MCC'],
evaluate_models: bool = False,
use_ref_models: bool = False,
train_all:bool = True,
models_metrics: List[str] = ['MSE', 'RMSE', 'MAE'],
mode: str = 'mpc',
all_dr: bool = False) -> Med3paResults:
Expand Down Expand Up @@ -626,7 +641,7 @@ def run(datasets: DatasetsManager,
apc_params=apc_params, apc_grid_params=apc_grid_params, apc_cv=apc_cv, pretrained_apc=pretrained_apc,
evaluate_models=evaluate_models, models_metrics=models_metrics,
samples_ratio_min=samples_ratio_min, samples_ratio_max=samples_ratio_max, samples_ratio_step=samples_ratio_step,
med3pa_metrics=med3pa_metrics, mode=mode, use_ref_models=use_ref_models)
med3pa_metrics=med3pa_metrics, mode=mode, use_ref_models=use_ref_models, train_all=train_all)

print("Running Global Detectron Experiment:")
detectron_results = DetectronExperiment.run(datasets=datasets, training_params=training_params, base_model_manager=base_model_manager,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name="MED3pa",
version="0.1.28",
version="0.1.29",
author="MEDomics consortium",
author_email="[email protected]",
description="Python Open-source package for ensuring robust and reliable ML models deployments",
Expand Down

0 comments on commit 486b161

Please sign in to comment.