From a22aaab38ce613c83bdb0d4704fb6e4081dda3b6 Mon Sep 17 00:00:00 2001 From: Xichen Wu <102925032+wxicu@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:54:56 +0200 Subject: [PATCH] Add two distance metrics, three-way comparison and bootstrapping (#608) * add two distance metrics * add obsm_key param to distance test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add agg fct * speed up tests * add type * add description * Update pertpy/tools/_distances/_distances.py Co-authored-by: Lukas Heumos * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update pertpy/tools/_distances/_distances.py Co-authored-by: Lukas Heumos * Update pertpy/tools/_distances/_distances.py Co-authored-by: Lukas Heumos * Update pertpy/tools/_distances/_distances.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * Update pertpy/tools/_distances/_distances.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * Update pertpy/tools/_distances/_distances.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * Update pertpy/tools/_distances/_distances.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * update code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix drug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add bootstrapping and metrics_3g * speed up tests, * remove test classes * drop test classes * update compare_de * correct the comments * speed tests * speed up tests * split metrics_3g * fix pre-commit * pin numpy <2 * unpin numpy * speed up mahalanobis distance * use scipy to calculate mahalanobis distance * rename DGE to DGEEVAL --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lukas Heumos Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> --- pertpy/tools/__init__.py | 17 +- .../_differential_gene_expression/__init__.py | 1 + .../_dge_comparison.py | 86 ++++ pertpy/tools/_distances/_distance_tests.py | 13 +- pertpy/tools/_distances/_distances.py | 457 +++++++++++++++--- .../tools/_perturbation_space/_comparison.py | 112 +++++ .../_differential_gene_expression/test_dge.py | 80 +++ tests/tools/_distances/test_distance_tests.py | 2 + tests/tools/_distances/test_distances.py | 227 ++++++--- .../_perturbation_space/test_comparison.py | 29 ++ 10 files changed, 880 insertions(+), 144 deletions(-) create mode 100644 pertpy/tools/_differential_gene_expression/_dge_comparison.py create mode 100644 pertpy/tools/_perturbation_space/_comparison.py create mode 100644 tests/tools/_differential_gene_expression/test_dge.py create mode 100644 tests/tools/_perturbation_space/test_comparison.py diff --git a/pertpy/tools/__init__.py b/pertpy/tools/__init__.py index 4bc976a2..4e2c709e 100644 --- a/pertpy/tools/__init__.py +++ b/pertpy/tools/__init__.py @@ -3,18 +3,31 @@ from pertpy.tools._coda._sccoda import Sccoda from pertpy.tools._coda._tasccoda import Tasccoda from pertpy.tools._dialogue import Dialogue -from pertpy.tools._differential_gene_expression import EdgeR, PyDESeq2, Statsmodels, TTest, WilcoxonTest +from pertpy.tools._differential_gene_expression import ( + DGEEVAL, + EdgeR, + PyDESeq2, + Statsmodels, + TTest, + WilcoxonTest, +) from pertpy.tools._distances._distance_tests import DistanceTest from pertpy.tools._distances._distances import Distance from pertpy.tools._enrichment import Enrichment from pertpy.tools._milo import Milo from pertpy.tools._mixscape import Mixscape from pertpy.tools._perturbation_space._clustering import ClusteringSpace +from pertpy.tools._perturbation_space._comparison import PerturbationComparison from pertpy.tools._perturbation_space._discriminator_classifiers import ( LRClassifierSpace, MLPClassifierSpace, ) -from pertpy.tools._perturbation_space._simple import CentroidSpace, DBSCANSpace, KMeansSpace, PseudobulkSpace +from pertpy.tools._perturbation_space._simple import ( + CentroidSpace, + DBSCANSpace, + KMeansSpace, + PseudobulkSpace, +) from pertpy.tools._scgen import Scgen __all__ = [ diff --git a/pertpy/tools/_differential_gene_expression/__init__.py b/pertpy/tools/_differential_gene_expression/__init__.py index 178ecaa3..35ccef51 100644 --- a/pertpy/tools/_differential_gene_expression/__init__.py +++ b/pertpy/tools/_differential_gene_expression/__init__.py @@ -1,4 +1,5 @@ from ._base import ContrastType, LinearModelBase, MethodBase +from ._dge_comparison import DGEEVAL from ._edger import EdgeR from ._pydeseq2 import PyDESeq2 from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest diff --git a/pertpy/tools/_differential_gene_expression/_dge_comparison.py b/pertpy/tools/_differential_gene_expression/_dge_comparison.py new file mode 100644 index 00000000..0492dcc5 --- /dev/null +++ b/pertpy/tools/_differential_gene_expression/_dge_comparison.py @@ -0,0 +1,86 @@ +import numpy as np +import pandas as pd +from anndata import AnnData + + +class DGEEVAL: + def compare( + self, + adata: AnnData | None = None, + de_key1: str = None, + de_key2: str = None, + de_df1: pd.DataFrame | None = None, + de_df2: pd.DataFrame | None = None, + shared_top: int = 100, + ) -> dict[str, float]: + """Compare two differential expression analyses. + + Compare two sets of DE results and evaluate the similarity by the overlap of top DEG and + the correlation of their scores and adjusted p-values. + + Args: + adata: AnnData object containing DE results in `uns`. Required if `de_key1` and `de_key2` are used. + de_key1: Key for DE results in `adata.uns`, e.g., output of `tl.rank_genes_groups`. + de_key2: Another key for DE results in `adata.uns`, e.g., output of `tl.rank_genes_groups`. + de_df1: DataFrame containing DE results, e.g. output from pertpy differential gene expression interface. + de_df2: DataFrame containing DE results, e.g. output from pertpy differential gene expression interface. + shared_top: The number of top DEG to compute the proportion of their intersection. + + """ + if (de_key1 or de_key2) and (de_df1 is not None or de_df2 is not None): + raise ValueError( + "Please provide either both `de_key1` and `de_key2` with `adata`, or `de_df1` and `de_df2`, but not both." + ) + + if de_df1 is None and de_df2 is None: # use keys + if not de_key1 or not de_key2: + raise ValueError("Both `de_key1` and `de_key2` must be provided together if using `adata`.") + + else: # use dfs + if de_df1 is None or de_df2 is None: + raise ValueError("Both `de_df1` and `de_df2` must be provided together if using DataFrames.") + + if de_key1: + if not adata: + raise ValueError("`adata` should be provided with `de_key1` and `de_key2`. ") + assert all( + k in adata.uns for k in [de_key1, de_key2] + ), "Provided `de_key1` and `de_key2` must exist in `adata.uns`." + vars = adata.var_names + + if de_df1 is not None: + for df in (de_df1, de_df2): + if not {"variable", "log_fc", "adj_p_value"}.issubset(df.columns): + raise ValueError("Each DataFrame must contain columns: 'variable', 'log_fc', and 'adj_p_value'.") + + assert set(de_df1["variable"]) == set(de_df2["variable"]), "Variables in both dataframes must match." + vars = de_df1["variable"].sort_values() + + shared_top = min(shared_top, len(vars)) + vars_ranks = np.arange(1, len(vars) + 1) + results = pd.DataFrame(index=vars) + top_names = [] + + if de_key1 and de_key2: + for i, k in enumerate([de_key1, de_key2]): + label = adata.uns[k]["names"].dtype.names[0] + srt_idx = np.argsort(adata.uns[k]["names"][label]) + results[f"scores_{i}"] = adata.uns[k]["scores"][label][srt_idx] + results[f"pvals_adj_{i}"] = adata.uns[k]["pvals_adj"][label][srt_idx] + results[f"ranks_{i}"] = vars_ranks[srt_idx] + top_names.append(adata.uns[k]["names"][label][:shared_top]) + else: + for i, df in enumerate([de_df1, de_df2]): + srt_idx = np.argsort(df["variable"]) + results[f"scores_{i}"] = df["log_fc"].values[srt_idx] + results[f"pvals_adj_{i}"] = df["adj_p_value"].values[srt_idx] + results[f"ranks_{i}"] = vars_ranks[srt_idx] + top_names.append(df["variable"][:shared_top]) + + metrics = {} + metrics["shared_top_genes"] = len(set(top_names[0]).intersection(top_names[1])) / shared_top + metrics["scores_corr"] = results["scores_0"].corr(results["scores_1"], method="pearson") + metrics["pvals_adj_corr"] = results["pvals_adj_0"].corr(results["pvals_adj_1"], method="pearson") + metrics["scores_ranks_corr"] = results["ranks_0"].corr(results["ranks_1"], method="spearman") + + return metrics diff --git a/pertpy/tools/_distances/_distance_tests.py b/pertpy/tools/_distances/_distance_tests.py index adfe47af..0c6f67a0 100644 --- a/pertpy/tools/_distances/_distance_tests.py +++ b/pertpy/tools/_distances/_distance_tests.py @@ -66,11 +66,14 @@ def __init__( self.alpha = alpha self.correction = correction self.cell_wise_metric = ( - cell_wise_metric if cell_wise_metric else Distance(self.metric, self.obsm_key).cell_wise_metric + cell_wise_metric if cell_wise_metric else Distance(self.metric, obsm_key=self.obsm_key).cell_wise_metric ) self.distance = Distance( - self.metric, layer_key=self.layer_key, obsm_key=self.obsm_key, cell_wise_metric=self.cell_wise_metric + self.metric, + layer_key=self.layer_key, + obsm_key=self.obsm_key, + cell_wise_metric=self.cell_wise_metric, ) def __call__( @@ -176,7 +179,8 @@ def test_xy(self, adata: AnnData, groupby: str, contrast: str, show_progressbar: # Evaluate the test # count times shuffling resulted in larger distance comparison_results = np.array( - pd.concat([r["distance"] - df["distance"] for r in results], axis=1) > 0, dtype=int + pd.concat([r["distance"] - df["distance"] for r in results], axis=1) > 0, + dtype=int, ) n_failures = pd.Series(np.clip(np.sum(comparison_results, axis=1), 1, np.inf), index=df.index) pvalues = n_failures / self.n_perms @@ -284,7 +288,8 @@ def test_precomputed(self, adata: AnnData, groupby: str, contrast: str, verbose: # Evaluate the test # count times shuffling resulted in larger distance comparison_results = np.array( - pd.concat([r["distance"] - df["distance"] for r in results], axis=1) > 0, dtype=int + pd.concat([r["distance"] - df["distance"] for r in results], axis=1) > 0, + dtype=int, ) n_failures = pd.Series(np.clip(np.sum(comparison_results, axis=1), 1, np.inf), index=df.index) pvalues = n_failures / self.n_perms diff --git a/pertpy/tools/_distances/_distances.py b/pertpy/tools/_distances/_distances.py index 16830dff..5cde82d5 100644 --- a/pertpy/tools/_distances/_distances.py +++ b/pertpy/tools/_distances/_distances.py @@ -1,7 +1,8 @@ from __future__ import annotations +import multiprocessing from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, NamedTuple import numba import numpy as np @@ -13,18 +14,26 @@ from pandas import Series from rich.progress import track from scipy.sparse import issparse -from scipy.spatial.distance import cosine +from scipy.spatial.distance import cosine, mahalanobis from scipy.special import gammaln from scipy.stats import kendalltau, kstest, pearsonr, spearmanr from sklearn.linear_model import LogisticRegression from sklearn.metrics import pairwise_distances, r2_score from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel +from sklearn.neighbors import KernelDensity from statsmodels.discrete.discrete_model import NegativeBinomialP if TYPE_CHECKING: + from collections.abc import Callable + from anndata import AnnData +class MeanVar(NamedTuple): + mean: float + variance: float + + class Distance: """Distance class, used to compute distances between groups of cells. @@ -80,6 +89,11 @@ class Distance: Average of the classification probability of the perturbation for a binary classifier. - "classifier_cp": classifier class projection Average of the class + - "mean_var_distribution": Distance between mean-variance distributions between cells of 2 groups. + Mean square distance between the mean-variance distributions of cells from 2 groups using Kernel Density Estimation (KDE). + - "mahalanobis": Mahalanobis distance between the means of cells from two groups. + It is originally used to measure distance between a point and a distribution. + in this context, it quantifies the difference between the mean profiles of a target group and a reference group. Attributes: metric: Name of distance metric. @@ -99,6 +113,7 @@ class Distance: def __init__( self, metric: str = "edistance", + agg_fct: Callable = np.mean, layer_key: str = None, obsm_key: str = None, cell_wise_metric: str = "euclidean", @@ -107,35 +122,37 @@ def __init__( Args: metric: Distance metric to use. + agg_fct: Aggregation function to generate pseudobulk vectors. layer_key: Name of the counts layer containing raw counts to calculate distances for. Mutually exclusive with 'obsm_key'. Is not used if `None`. obsm_key: Name of embedding in adata.obsm to use. - Mutually exclusive with 'counts_layer_key'. + Mutually exclusive with 'layer_key'. Defaults to None, but is set to "X_pca" if not explicitly set internally. cell_wise_metric: Metric from scipy.spatial.distance to use for pairwise distances between single cells. """ metric_fct: AbstractDistance = None + self.aggregation_func = agg_fct if metric == "edistance": metric_fct = Edistance() elif metric == "euclidean": - metric_fct = EuclideanDistance() + metric_fct = EuclideanDistance(self.aggregation_func) elif metric == "root_mean_squared_error": - metric_fct = EuclideanDistance() + metric_fct = EuclideanDistance(self.aggregation_func) elif metric == "mse": - metric_fct = MeanSquaredDistance() + metric_fct = MeanSquaredDistance(self.aggregation_func) elif metric == "mean_absolute_error": - metric_fct = MeanAbsoluteDistance() + metric_fct = MeanAbsoluteDistance(self.aggregation_func) elif metric == "pearson_distance": - metric_fct = PearsonDistance() + metric_fct = PearsonDistance(self.aggregation_func) elif metric == "spearman_distance": - metric_fct = SpearmanDistance() + metric_fct = SpearmanDistance(self.aggregation_func) elif metric == "kendalltau_distance": - metric_fct = KendallTauDistance() + metric_fct = KendallTauDistance(self.aggregation_func) elif metric == "cosine_distance": - metric_fct = CosineDistance() + metric_fct = CosineDistance(self.aggregation_func) elif metric == "r2_distance": - metric_fct = R2ScoreDistance() + metric_fct = R2ScoreDistance(self.aggregation_func) elif metric == "mean_pairwise": metric_fct = MeanPairwiseDistance() elif metric == "mmd": @@ -154,14 +171,17 @@ def __init__( metric_fct = ClassifierProbaDistance() elif metric == "classifier_cp": metric_fct = ClassifierClassProjection() + elif metric == "mean_var_distribution": + metric_fct = MeanVarDistributionDistance() + elif metric == "mahalanobis": + metric_fct = MahalanobisDistance(self.aggregation_func) else: raise ValueError(f"Metric {metric} not recognized.") self.metric_fct = metric_fct if layer_key and obsm_key: raise ValueError( - "Cannot use 'counts_layer_key' and 'obsm_key' at the same time.\n" - "Please provide only one of the two keys." + "Cannot use 'layer_key' and 'obsm_key' at the same time.\n" "Please provide only one of the two keys." ) if not layer_key and not obsm_key: obsm_key = "X_pca" @@ -203,15 +223,54 @@ def __call__( return self.metric_fct(X, Y, **kwargs) + def bootstrap( + self, + X: np.ndarray, + Y: np.ndarray, + *, + n_bootstrap: int = 100, + random_state: int = 0, + **kwargs, + ) -> MeanVar: + """Bootstrap computation of mean and variance of the distance between vectors X and Y. + + Args: + X: First vector of shape (n_samples, n_features). + Y: Second vector of shape (n_samples, n_features). + n_bootstrap: Number of bootstrap samples. + random_state: Random state for bootstrapping. + + Returns: + MeanVar: Mean and variance of distance between X and Y. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.distance_example() + >>> Distance = pt.tools.Distance(metric="edistance") + >>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"] + >>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"] + >>> D = Distance.bootstrap(X, Y) + """ + return self._bootstrap_mode( + X, + Y, + n_bootstraps=n_bootstrap, + random_state=random_state, + **kwargs, + ) + def pairwise( self, adata: AnnData, groupby: str, groups: list[str] | None = None, + bootstrap: bool = False, + n_bootstrap: int = 100, + random_state: int = 0, show_progressbar: bool = True, n_jobs: int = -1, **kwargs, - ) -> pd.DataFrame: + ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]: """Get pairwise distances between groups of cells. Args: @@ -219,12 +278,16 @@ def pairwise( groupby: Column name in adata.obs. groups: List of groups to compute pairwise distances for. If None, uses all groups. + bootstrap: Whether to bootstrap the distance. + n_bootstrap: Number of bootstrap samples. + random_state: Random state for bootstrapping. show_progressbar: Whether to show progress bar. n_jobs: Number of cores to use. Defaults to -1 (all). kwargs: Additional keyword arguments passed to the metric function. Returns: pd.DataFrame: Dataframe with pairwise distances. + tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of pairwise distances. Examples: >>> import pertpy as pt @@ -235,6 +298,8 @@ def pairwise( groups = adata.obs[groupby].unique() if groups is None else groups grouping = adata.obs[groupby].copy() df = pd.DataFrame(index=groups, columns=groups, dtype=float) + if bootstrap: + df_var = pd.DataFrame(index=groups, columns=groups, dtype=float) fct = track if show_progressbar else lambda iterable: iterable # Some metrics are able to handle precomputed distances. This means that @@ -250,16 +315,29 @@ def pairwise( for index_x, group_x in enumerate(fct(groups)): idx_x = grouping == group_x for group_y in groups[index_x:]: # type: ignore - if group_x == group_y: - dist = 0.0 # by distance axiom + # subset the pairwise distance matrix to the two groups + idx_y = grouping == group_y + sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y] + sub_idx = grouping[idx_x | idx_y] == group_x + if not bootstrap: + if group_x == group_y: + dist = 0.0 + else: + dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs) + df.loc[group_x, group_y] = dist + df.loc[group_y, group_x] = dist + else: - idx_y = grouping == group_y - # subset the pairwise distance matrix to the two groups - sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y] - sub_idx = grouping[idx_x | idx_y] == group_x - dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs) - df.loc[group_x, group_y] = dist - df.loc[group_y, group_x] = dist + bootstrap_output = self._bootstrap_mode_precomputed( + sub_pwd, + sub_idx, + n_bootstraps=n_bootstrap, + random_state=random_state, + **kwargs, + ) + # In the bootstrap case, distance of group to itself is a mean and can be non-zero + df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean + df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance else: if self.layer_key: embedding = adata.layers[self.layer_key] @@ -268,18 +346,39 @@ def pairwise( for index_x, group_x in enumerate(fct(groups)): cells_x = embedding[grouping == group_x].copy() for group_y in groups[index_x:]: # type: ignore - if group_x == group_y: - dist = 0.0 + cells_y = embedding[grouping == group_y].copy() + if not bootstrap: + # By distance axiom, the distance between a group and itself is 0 + dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs) + + df.loc[group_x, group_y] = dist + df.loc[group_y, group_x] = dist else: - cells_y = embedding[grouping == group_y].copy() - dist = self(cells_x, cells_y, **kwargs) - df.loc[group_x, group_y] = dist - df.loc[group_y, group_x] = dist + bootstrap_output = self.bootstrap( + cells_x, + cells_y, + n_bootstrap=n_bootstrap, + random_state=random_state, + **kwargs, + ) + # In the bootstrap case, distance of group to itself is a mean and can be non-zero + df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean + df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance + df.index.name = groupby df.columns.name = groupby df.name = f"pairwise {self.metric}" - return df + if not bootstrap: + return df + else: + df = df.fillna(0) + df_var.index.name = groupby + df_var.columns.name = groupby + df_var = df_var.fillna(0) + df_var.name = f"pairwise {self.metric} variance" + + return df, df_var def onesided_distances( self, @@ -287,10 +386,13 @@ def onesided_distances( groupby: str, selected_group: str | None = None, groups: list[str] | None = None, + bootstrap: bool = False, + n_bootstrap: int = 100, + random_state: int = 0, show_progressbar: bool = True, n_jobs: int = -1, **kwargs, - ) -> pd.DataFrame: + ) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]: """Get distances between one selected cell group and the remaining other cell groups. Args: @@ -299,12 +401,17 @@ def onesided_distances( selected_group: Group to compute pairwise distances to all other. groups: List of groups to compute distances to selected_group for. If None, uses all groups. + bootstrap: Whether to bootstrap the distance. + n_bootstrap: Number of bootstrap samples. + random_state: Random state for bootstrapping. show_progressbar: Whether to show progress bar. n_jobs: Number of cores to use. Defaults to -1 (all). kwargs: Additional keyword arguments passed to the metric function. Returns: pd.DataFrame: Dataframe with distances of groups to selected_group. + tuple[pd.DataFrame, pd.DataFrame]: Two Dataframes, one for the mean and one for the variance of distances of groups to selected_group. + Examples: >>> import pertpy as pt @@ -313,20 +420,30 @@ def onesided_distances( >>> pairwise_df = Distance.onesided_distances(adata, groupby="perturbation", selected_group="control") """ if self.metric == "classifier_cp": + if bootstrap: + raise NotImplementedError("Currently, ClassifierClassProjection does not support bootstrapping.") return self.metric_fct.onesided_distances( # type: ignore - adata, groupby, selected_group, groups, show_progressbar, n_jobs, **kwargs + adata, + groupby, + selected_group, + groups, + show_progressbar, + n_jobs, + **kwargs, ) groups = adata.obs[groupby].unique() if groups is None else groups grouping = adata.obs[groupby].copy() df = pd.Series(index=groups, dtype=float) + if bootstrap: + df_var = pd.Series(index=groups, dtype=float) fct = track if show_progressbar else lambda iterable: iterable # Some metrics are able to handle precomputed distances. This means that # the pairwise distances between all cells are computed once and then # passed to the metric function. This is much faster than computing the # pairwise distances for each group separately. Other metrics are not - # able to handle precomputed distances such as the PsuedobulkDistance. + # able to handle precomputed distances such as the PseudobulkDistance. if self.metric_fct.accepts_precomputed: # Precompute the pairwise distances if needed if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp.keys(): @@ -336,14 +453,25 @@ def onesided_distances( idx_x = grouping == group_x group_y = selected_group if group_x == group_y: - dist = 0.0 # by distance axiom + df.loc[group_x] = 0.0 # by distance axiom else: idx_y = grouping == group_y # subset the pairwise distance matrix to the two groups sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y] sub_idx = grouping[idx_x | idx_y] == group_x - dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs) - df.loc[group_x] = dist + if not bootstrap: + dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs) + df.loc[group_x] = dist + else: + bootstrap_output = self._bootstrap_mode_precomputed( + sub_pwd, + sub_idx, + n_bootstraps=n_bootstrap, + random_state=random_state, + **kwargs, + ) + df.loc[group_x] = bootstrap_output.mean + df_var.loc[group_x] = bootstrap_output.variance else: if self.layer_key: embedding = adata.layers[self.layer_key] @@ -352,15 +480,32 @@ def onesided_distances( for group_x in fct(groups): cells_x = embedding[grouping == group_x].copy() group_y = selected_group - if group_x == group_y: - dist = 0.0 + cells_y = embedding[grouping == group_y].copy() + if not bootstrap: + # By distance axiom, the distance between a group and itself is 0 + dist = 0.0 if group_x == group_y else self(cells_x, cells_y, **kwargs) + df.loc[group_x] = dist else: - cells_y = embedding[grouping == group_y].copy() - dist = self(cells_x, cells_y, **kwargs) - df.loc[group_x] = dist + bootstrap_output = self.bootstrap( + cells_x, + cells_y, + n_bootstrap=n_bootstrap, + random_state=random_state, + **kwargs, + ) + # In the bootstrap case, distance of group to itself is a mean and can be non-zero + df.loc[group_x] = bootstrap_output.mean + df_var.loc[group_x] = bootstrap_output.variance df.index.name = groupby df.name = f"{self.metric} to {selected_group}" - return df + if not bootstrap: + return df + else: + df_var.index.name = groupby + df_var = df_var.fillna(0) + df_var.name = f"pairwise {self.metric} variance to {selected_group}" + + return df, df_var def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None: """Precompute pairwise distances between all cells, writes to adata.obsp. @@ -386,6 +531,77 @@ def precompute_distances(self, adata: AnnData, n_jobs: int = -1) -> None: pwd = pairwise_distances(cells, cells, metric=self.cell_wise_metric, n_jobs=n_jobs) adata.obsp[f"{self.obsm_key}_{self.cell_wise_metric}_predistances"] = pwd + def compare_distance( + self, + pert: np.ndarray, + pred: np.ndarray, + ctrl: np.ndarray, + mode: Literal["simple", "scaled"] = "simple", + fit_to_pert_and_ctrl: bool = False, + **kwargs, + ) -> float: + """Compute the score of simulating a perturbation. + + Args: + pert: Real perturbed data. + pred: Simulated perturbed data. + ctrl: Control data + mode: Mode to use. + fit_to_pert_and_ctrl: Scales data based on both `pert` and `ctrl` if True, otherwise only on `ctrl`. + kwargs: Additional keyword arguments passed to the metric function. + """ + if mode == "simple": + pass # nothing to be done + elif mode == "scaled": + from sklearn.preprocessing import MinMaxScaler + + scaler = MinMaxScaler().fit(np.vstack((pert, ctrl)) if fit_to_pert_and_ctrl else ctrl) + pred = scaler.transform(pred) + pert = scaler.transform(pert) + else: + raise ValueError(f"Unknown mode {mode}. Please choose simple or scaled.") + + d1 = self.metric_fct(pert, pred, **kwargs) + d2 = self.metric_fct(ctrl, pred, **kwargs) + return d1 / d2 + + def _bootstrap_mode(self, X, Y, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar: + rng = np.random.default_rng(random_state) + + distances = [] + for _ in range(n_bootstraps): + X_bootstrapped = X[rng.choice(a=X.shape[0], size=X.shape[0], replace=True)] + Y_bootstrapped = Y[rng.choice(a=Y.shape[0], size=X.shape[0], replace=True)] + + distance = self(X_bootstrapped, Y_bootstrapped, **kwargs) + distances.append(distance) + + mean = np.mean(distances) + variance = np.var(distances) + return MeanVar(mean=mean, variance=variance) + + def _bootstrap_mode_precomputed(self, sub_pwd, sub_idx, n_bootstraps=100, random_state=0, **kwargs) -> MeanVar: + rng = np.random.default_rng(random_state) + + distances = [] + for _ in range(n_bootstraps): + # To maintain the number of cells for both groups (whatever balancing they may have), + # we sample the positive and negative indices separately + bootstrap_pos_idx = rng.choice(a=sub_idx[sub_idx].index, size=sub_idx[sub_idx].size, replace=True) + bootstrap_neg_idx = rng.choice(a=sub_idx[~sub_idx].index, size=sub_idx[~sub_idx].size, replace=True) + bootstrap_idx = np.concatenate([bootstrap_pos_idx, bootstrap_neg_idx]) + bootstrap_idx_nrs = sub_idx.index.get_indexer(bootstrap_idx) + + bootstrap_sub_idx = sub_idx[bootstrap_idx] + bootstrap_sub_pwd = sub_pwd[bootstrap_idx_nrs, :][:, bootstrap_idx_nrs] + + distance = self.metric_fct.from_precomputed(bootstrap_sub_pwd, bootstrap_sub_idx, **kwargs) + distances.append(distance) + + mean = np.mean(distances) + variance = np.var(distances) + return MeanVar(mean=mean, variance=variance) + class AbstractDistance(ABC): """Abstract class of distance metrics between two sets of vectors.""" @@ -499,12 +715,17 @@ def solve_ot_problem(self, geom: Geometry, **kwargs): class EuclideanDistance(AbstractDistance): """Euclidean distance between pseudobulk vectors.""" - def __init__(self) -> None: + def __init__(self, aggregation_func: Callable = np.mean) -> None: super().__init__() self.accepts_precomputed = False + self.aggregation_func = aggregation_func def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=2, **kwargs) + return np.linalg.norm( + self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0), + ord=2, + **kwargs, + ) def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: raise NotImplementedError("EuclideanDistance cannot be called on a pairwise distance matrix.") @@ -513,12 +734,21 @@ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: class MeanSquaredDistance(AbstractDistance): """Mean squared distance between pseudobulk vectors.""" - def __init__(self) -> None: + def __init__(self, aggregation_func: Callable = np.mean) -> None: super().__init__() self.accepts_precomputed = False + self.aggregation_func = aggregation_func def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=2, **kwargs) ** 2 / X.shape[1] + return ( + np.linalg.norm( + self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0), + ord=2, + **kwargs, + ) + ** 2 + / X.shape[1] + ) def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: raise NotImplementedError("MeanSquaredDistance cannot be called on a pairwise distance matrix.") @@ -527,12 +757,20 @@ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: class MeanAbsoluteDistance(AbstractDistance): """Absolute (Norm-1) distance between pseudobulk vectors.""" - def __init__(self) -> None: + def __init__(self, aggregation_func: Callable = np.mean) -> None: super().__init__() self.accepts_precomputed = False + self.aggregation_func = aggregation_func def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - return np.linalg.norm(X.mean(axis=0) - Y.mean(axis=0), ord=1, **kwargs) / X.shape[1] + return ( + np.linalg.norm( + self.aggregation_func(X, axis=0) - self.aggregation_func(Y, axis=0), + ord=1, + **kwargs, + ) + / X.shape[1] + ) def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: raise NotImplementedError("MeanAbsoluteDistance cannot be called on a pairwise distance matrix.") @@ -557,12 +795,13 @@ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: class PearsonDistance(AbstractDistance): """Pearson distance between pseudobulk vectors.""" - def __init__(self) -> None: + def __init__(self, aggregation_func: Callable = np.mean) -> None: super().__init__() self.accepts_precomputed = False + self.aggregation_func = aggregation_func def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - return 1 - pearsonr(X.mean(axis=0), Y.mean(axis=0))[0] + return 1 - pearsonr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0] def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: raise NotImplementedError("PearsonDistance cannot be called on a pairwise distance matrix.") @@ -571,12 +810,13 @@ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: class SpearmanDistance(AbstractDistance): """Spearman distance between pseudobulk vectors.""" - def __init__(self) -> None: + def __init__(self, aggregation_func: Callable = np.mean) -> None: super().__init__() self.accepts_precomputed = False + self.aggregation_func = aggregation_func def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - return 1 - spearmanr(X.mean(axis=0), Y.mean(axis=0))[0] + return 1 - spearmanr(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0))[0] def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: raise NotImplementedError("SpearmanDistance cannot be called on a pairwise distance matrix.") @@ -585,12 +825,13 @@ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: class KendallTauDistance(AbstractDistance): """Kendall-tau distance between pseudobulk vectors.""" - def __init__(self) -> None: + def __init__(self, aggregation_func: Callable = np.mean) -> None: super().__init__() self.accepts_precomputed = False + self.aggregation_func = aggregation_func def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - x, y = X.mean(axis=0), Y.mean(axis=0) + x, y = self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0) n = len(x) tau_corr = kendalltau(x, y).statistic tau_dist = (1 - tau_corr) * n * (n - 1) / 4 @@ -603,12 +844,13 @@ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: class CosineDistance(AbstractDistance): """Cosine distance between pseudobulk vectors.""" - def __init__(self) -> None: + def __init__(self, aggregation_func: Callable = np.mean) -> None: super().__init__() self.accepts_precomputed = False + self.aggregation_func = aggregation_func def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - return cosine(X.mean(axis=0), Y.mean(axis=0)) + return cosine(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)) def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: raise NotImplementedError("CosineDistance cannot be called on a pairwise distance matrix.") @@ -619,12 +861,13 @@ class R2ScoreDistance(AbstractDistance): # NOTE: This is not a distance metric but a similarity metric. - def __init__(self) -> None: + def __init__(self, aggregation_func: Callable = np.mean) -> None: super().__init__() self.accepts_precomputed = False + self.aggregation_func = aggregation_func def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: - return 1 - r2_score(X.mean(axis=0), Y.mean(axis=0)) + return 1 - r2_score(self.aggregation_func(X, axis=0), self.aggregation_func(Y, axis=0)) def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: raise NotImplementedError("R2ScoreDistance cannot be called on a pairwise distance matrix.") @@ -833,6 +1076,7 @@ def onesided_distances( Similar to the parent function, the returned dataframe contains only the specified groups. """ groups = adata.obs[groupby].unique() if groups is None else groups + fct = track if show_progressbar else lambda iterable: iterable X = adata[adata.obs[groupby] != selected_group].X labels = adata[adata.obs[groupby] != selected_group].obs[groupby].values @@ -843,7 +1087,8 @@ def onesided_distances( test_probas = reg.predict_proba(Y) df = pd.Series(index=groups, dtype=float) - for group in groups: + + for group in fct(groups): if group == selected_group: df.loc[group] = 0 else: @@ -856,3 +1101,95 @@ def onesided_distances( def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: raise NotImplementedError("ClassifierClassProjection cannot be called on a pairwise distance matrix.") + + +class MeanVarDistributionDistance(AbstractDistance): + """Distance between mean-var distributions of gene expression.""" + + def __init__(self) -> None: + super().__init__() + self.accepts_precomputed = False + + def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: + """Difference of mean-var distributions in 2 matrices. + + Args: + X: Normalized and log transformed cells x genes count matrix. + Y: Normalized and log transformed cells x genes count matrix. + """ + + def _mean_var(x, log: bool = False): + mean = np.mean(x, axis=0) + var = np.var(x, axis=0) + positive = mean > 0 + mean = mean[positive] + var = var[positive] + if log: + mean = np.log(mean) + var = np.log(var) + return mean, var + + def _prep_kde_data(x, y): + return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1) + + def _grid_points(d, n_points=100): + # Make grid, add 1 bin on lower/upper end to get final n_points + d_min = d.min() + d_max = d.max() + # Compute bin size + d_bin = (d_max - d_min) / (n_points - 2) + d_min = d_min - d_bin + d_max = d_max + d_bin + return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin) + + def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())): + # the thread_count is determined using the factor 0.875 as recommended here: + # https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook + with multiprocessing.Pool(thread_count) as p: + return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count))) + + def _kde_eval(d, grid): + # Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out + # can not be compared well on regions further away from the data as they are -inf + kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d) + return _parallel_score_samples(kde, grid) + + mean_x, var_x = _mean_var(X, log=True) + mean_y, var_y = _mean_var(Y, log=True) + + x = _prep_kde_data(mean_x, var_x) + y = _prep_kde_data(mean_y, var_y) + + # Gridpoints to eval KDE on + mean_grid = _grid_points(np.concatenate([mean_x, mean_y])) + var_grid = _grid_points(np.concatenate([var_x, var_y])) + grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2) + + kde_x = _kde_eval(x, grid) + kde_y = _kde_eval(y, grid) + + kde_diff = ((kde_x - kde_y) ** 2).mean() + + return kde_diff + + def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: + raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.") + + +class MahalanobisDistance(AbstractDistance): + """Mahalanobis distance between pseudobulk vectors.""" + + def __init__(self, aggregation_func: Callable = np.mean) -> None: + super().__init__() + self.accepts_precomputed = False + self.aggregation_func = aggregation_func + + def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: + return mahalanobis( + self.aggregation_func(X, axis=0), + self.aggregation_func(Y, axis=0), + np.linalg.inv(np.cov(X.T)), + ) + + def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: + raise NotImplementedError("Mahalanobis cannot be called on a pairwise distance matrix.") diff --git a/pertpy/tools/_perturbation_space/_comparison.py b/pertpy/tools/_perturbation_space/_comparison.py new file mode 100644 index 00000000..2ea76cc2 --- /dev/null +++ b/pertpy/tools/_perturbation_space/_comparison.py @@ -0,0 +1,112 @@ +from typing import TYPE_CHECKING + +import numpy as np +import pynndescent +from scipy.sparse import issparse +from scipy.sparse import vstack as sp_vstack +from sklearn.base import ClassifierMixin +from sklearn.linear_model import LogisticRegression + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +class PerturbationComparison: + """Comparison between real and simulated perturbations.""" + + def compare_classification( + self, + real: np.ndarray, + simulated: np.ndarray, + control: np.ndarray, + clf: ClassifierMixin | None = None, + ) -> float: + """Compare classification accuracy between real and simulated perturbations. + + Trains a classifier on the real perturbation data + the control data and reports a normalized + classification accuracy on the simulated perturbation. + + Args: + real: Real perturbed data. + simulated: Simulated perturbed data. + control: Control data + clf: sklearn classifier to use, `sklearn.linear_model.LogisticRegression` if not provided. + """ + assert real.shape[1] == simulated.shape[1] == control.shape[1] + if clf is None: + clf = LogisticRegression() + n_x = real.shape[0] + data = sp_vstack((real, control)) if issparse(real) else np.vstack((real, control)) + labels = np.concatenate([np.full(real.shape[0], "comp"), np.full(control.shape[0], "ctrl")]) + + clf.fit(data, labels) + norm_score = clf.score(simulated, np.full(simulated.shape[0], "comp")) / clf.score(real, labels[:n_x]) + norm_score = min(1.0, norm_score) + + return norm_score + + def compare_knn( + self, + real: np.ndarray, + simulated: np.ndarray, + control: np.ndarray | None = None, + use_simulated_for_knn: bool = False, + n_neighbors: int = 20, + random_state: int = 0, + n_jobs: int = 1, + ) -> dict[str, float]: + """Calculate proportions of real perturbed and control data points for simulated data. + + Computes proportions of real perturbed, control and simulated (if `use_simulated_for_knn=True`) + data points for simulated data. If control (`C`) is not provided, builds the knn graph from + real perturbed + simulated perturbed. + + Args: + real: Real perturbed data. + simulated: Simulated perturbed data. + control: Control data + use_simulated_for_knn: Include simulted perturbed data (`simulated`) into the knn graph. Only valid when + control (`control`) is provided. + n_neighbors: Number of neighbors to use in k-neighbor graph. + random_state: Random state used for k-neighbor graph construction. + n_jobs: Number of cores to use. Defaults to -1 (all). + + """ + assert real.shape[1] == simulated.shape[1] + if control is not None: + assert real.shape[1] == control.shape[1] + + n_y = simulated.shape[0] + + if control is None: + index_data = sp_vstack((simulated, real)) if issparse(real) else np.vstack((simulated, real)) + else: + datas = (simulated, real, control) if use_simulated_for_knn else (real, control) + index_data = sp_vstack(datas) if issparse(real) else np.vstack(datas) + + y_in_index = use_simulated_for_knn or control is None + c_in_index = control is not None + label_groups = ["comp"] + labels: NDArray[np.str_] = np.full(index_data.shape[0], "comp") + if y_in_index: + labels[:n_y] = "siml" + label_groups.append("siml") + if c_in_index: + labels[-control.shape[0] :] = "ctrl" + label_groups.append("ctrl") + + index = pynndescent.NNDescent( + index_data, + n_neighbors=max(50, n_neighbors), + random_state=random_state, + n_jobs=n_jobs, + ) + indices = index.query(simulated, k=n_neighbors)[0] + + uq, uq_counts = np.unique(labels[indices], return_counts=True) + uq_counts_norm = uq_counts / uq_counts.sum() + counts = dict(zip(label_groups, [0.0] * len(label_groups), strict=False)) + for group, count_norm in zip(uq, uq_counts_norm, strict=False): + counts[group] = count_norm + + return counts diff --git a/tests/tools/_differential_gene_expression/test_dge.py b/tests/tools/_differential_gene_expression/test_dge.py new file mode 100644 index 00000000..1f9d3eb8 --- /dev/null +++ b/tests/tools/_differential_gene_expression/test_dge.py @@ -0,0 +1,80 @@ +import numpy as np +import pandas as pd +import pertpy as pt +import pytest +from anndata import AnnData + + +@pytest.fixture +def adata(rng): + adata = AnnData(rng.normal(size=(100, 10))) + genes = np.rec.fromarrays( + [np.array([f"gene{i}" for i in range(10)])], + names=["group1", "O"], + ) + adata.uns["de_key1"] = { + "names": genes, + "scores": {"group1": rng.random(10)}, + "pvals_adj": {"group1": rng.random(10)}, + } + adata.uns["de_key2"] = { + "names": genes, + "scores": {"group1": rng.random(10)}, + "pvals_adj": {"group1": rng.random(10)}, + } + return adata + + +@pytest.fixture +def dataframe(rng): + df1 = pd.DataFrame( + { + "variable": ["gene" + str(i) for i in range(10)], + "log_fc": rng.random(10), + "adj_p_value": rng.random(10), + } + ) + df2 = pd.DataFrame( + { + "variable": ["gene" + str(i) for i in range(10)], + "log_fc": rng.random(10), + "adj_p_value": rng.random(10), + } + ) + return df1, df2 + + +def test_error_both_keys_and_dfs(adata, dataframe): + with pytest.raises(ValueError): + pt_DGE = pt.tl.DGEEVAL() + pt_DGE.compare(adata=adata, de_key1="de_key1", de_df1=dataframe[0]) + + +def test_error_missing_adata(): + with pytest.raises(ValueError): + pt_DGE = pt.tl.DGEEVAL() + pt_DGE.compare(de_key1="de_key1", de_key2="de_key2") + + +def test_error_missing_df(dataframe): + with pytest.raises(ValueError): + pt_DGE = pt.tl.DGEEVAL() + pt_DGE.compare(de_df1=dataframe[0]) + + +def test_key(adata): + pt_DGE = pt.tl.DGEEVAL() + results = pt_DGE.compare(adata=adata, de_key1="de_key1", de_key2="de_key2", shared_top=5) + assert "shared_top_genes" in results + assert "scores_corr" in results + assert "pvals_adj_corr" in results + assert "scores_ranks_corr" in results + + +def test_df(dataframe): + pt_DGE = pt.tl.DGEEVAL() + results = pt_DGE.compare(de_df1=dataframe[0], de_df2=dataframe[1], shared_top=5) + assert "shared_top_genes" in results + assert "scores_corr" in results + assert "pvals_adj_corr" in results + assert "scores_ranks_corr" in results diff --git a/tests/tools/_distances/test_distance_tests.py b/tests/tools/_distances/test_distance_tests.py index 22c3e93f..8e62154e 100644 --- a/tests/tools/_distances/test_distance_tests.py +++ b/tests/tools/_distances/test_distance_tests.py @@ -22,6 +22,8 @@ "classifier_proba", # "classifier_cp", # "nbll", + "mahalanobis", + "mean_var_distribution", ] count_distances = ["nb_ll"] diff --git a/tests/tools/_distances/test_distances.py b/tests/tools/_distances/test_distances.py index 95c81225..0f0e6fc4 100644 --- a/tests/tools/_distances/test_distances.py +++ b/tests/tools/_distances/test_distances.py @@ -4,6 +4,7 @@ import pytest import scanpy as sc from pandas import DataFrame, Series +from pytest import fixture, mark actual_distances = [ # Euclidean distances and related @@ -20,118 +21,188 @@ "spearman_distance", "t_test", "wasserstein", + "mahalanobis", ] semi_distances = ["r2_distance", "sym_kldiv", "ks_test"] non_distances = ["classifier_proba"] onesided_only = ["classifier_cp"] pseudo_counts_distances = ["nb_ll"] -all_distances = actual_distances + semi_distances + non_distances + pseudo_counts_distances # + onesided_only +lognorm_counts_distances = ["mean_var_distribution"] +all_distances = ( + actual_distances + semi_distances + non_distances + lognorm_counts_distances + pseudo_counts_distances +) # + onesided_only + + +@fixture +def adata(request): + low_subsample_distances = [ + "sym_kldiv", + "t_test", + "ks_test", + "classifier_proba", + "classifier_cp", + "mahalanobis", + "mean_var_distribution", + ] + no_subsample_distances = ["mahalanobis"] # mahalanobis only works on the full data without subsampling + + distance = request.node.callspec.params["distance"] - -@pytest.fixture -def all_pairwise_distances(): - all_calulated_distances = {} - no_subsample_distances = ["sym_kldiv", "t_test", "ks_test", "classifier_proba", "classifier_cp"] - - for distance in all_distances: - adata = pt.dt.distance_example() - if distance not in no_subsample_distances: - adata = sc.pp.subsample(adata, 0.001, copy=True) - else: + adata = pt.dt.distance_example() + if distance not in no_subsample_distances: + if distance in low_subsample_distances: adata = sc.pp.subsample(adata, 0.1, copy=True) - - adata.layers["lognorm"] = adata.X.copy() - adata.layers["counts"] = np.round(adata.X.toarray()).astype(int) - if "X_pca" not in adata.obsm.keys(): - sc.pp.pca(adata, n_comps=5) - - if distance in pseudo_counts_distances: - Distance = pt.tl.Distance(distance, layer_key="counts") else: - Distance = pt.tl.Distance(distance, obsm_key="X_pca") - df = Distance.pairwise(adata, groupby="perturbation", show_progressbar=True) - all_calulated_distances[distance] = df - - return all_calulated_distances + adata = sc.pp.subsample(adata, 0.001, copy=True) + adata.layers["lognorm"] = adata.X.copy() + adata.layers["counts"] = np.round(adata.X.toarray()).astype(int) + if "X_pca" not in adata.obsm.keys(): + sc.pp.pca(adata, n_comps=5) + if distance in lognorm_counts_distances: + groups = np.unique(adata.obs["perturbation"]) + # KDE is slow, subset to 3 groups for speed up + adata = adata[adata.obs["perturbation"].isin(groups[0:3])].copy() + + return adata + + +@fixture +def distance_obj(request): + distance = request.node.callspec.params["distance"] + if distance in lognorm_counts_distances: + Distance = pt.tl.Distance(distance, layer_key="lognorm") + elif distance in pseudo_counts_distances: + Distance = pt.tl.Distance(distance, layer_key="counts") + else: + Distance = pt.tl.Distance(distance, obsm_key="X_pca") + return Distance -def test_distance_axioms(all_pairwise_distances): - for distance in actual_distances + semi_distances: - # This is equivalent to testing for a semimetric, defined as fulfilling all axioms except triangle inequality. - df = all_pairwise_distances[distance] - # (M1) Definiteness - assert all(np.diag(df.values) == 0) # distance to self is 0 +@fixture +@mark.parametrize("distance", all_distances) +def pairwise_distance(adata, distance_obj, distance): + return distance_obj.pairwise(adata, groupby="perturbation", show_progressbar=True) - # (M2) Positivity - assert len(df) == np.sum(df.values == 0) # distance to other is not 0 - assert all(df.values.flatten() >= 0) # distance is non-negative - # (M3) Symmetry - assert np.sum(df.values - df.values.T) == 0 +@mark.parametrize("distance", actual_distances + semi_distances) +def test_distance_axioms(pairwise_distance, distance): + # This is equivalent to testing for a semimetric, defined as fulfilling all axioms except triangle inequality. + # (M1) Definiteness + assert all(np.diag(pairwise_distance.values) == 0) # distance to self is 0 + # (M2) Positivity + assert len(pairwise_distance) == np.sum(pairwise_distance.values == 0) # distance to other is not 0 + assert all(pairwise_distance.values.flatten() >= 0) # distance is non-negative -def test_triangle_inequality(all_pairwise_distances): - for distance in actual_distances: - # Test if distances are well-defined in accordance with metric axioms - df = all_pairwise_distances[distance] + # (M3) Symmetry + assert np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0 - # (M4) Triangle inequality (we just probe this for a few random triplets) - for _i in range(10): - rng = np.random.default_rng() - triplet = rng.choice(df.index, size=3, replace=False) - assert df.loc[triplet[0], triplet[1]] + df.loc[triplet[1], triplet[2]] >= df.loc[triplet[0], triplet[2]] +@mark.parametrize("distance", actual_distances) +def test_triangle_inequality(pairwise_distance, distance, rng): + # Test if distances are well-defined in accordance with metric axioms + # (M4) Triangle inequality (we just probe this for a few random triplets) + for _i in range(10): + triplet = rng.choice(pairwise_distance.index, size=3, replace=False) + assert ( + pairwise_distance.loc[triplet[0], triplet[1]] + pairwise_distance.loc[triplet[1], triplet[2]] + >= pairwise_distance.loc[triplet[0], triplet[2]] + ) -def test_distance_layers(all_pairwise_distances): - for distance in all_distances: - df = all_pairwise_distances[distance] - assert isinstance(df, DataFrame) - assert df.columns.equals(df.index) - assert np.sum(df.values - df.values.T) == 0 # symmetry +@mark.parametrize("distance", all_distances) +def test_distance_layers(pairwise_distance, distance): + assert isinstance(pairwise_distance, DataFrame) + assert pairwise_distance.columns.equals(pairwise_distance.index) + assert np.sum(pairwise_distance.values - pairwise_distance.values.T) == 0 # symmetry -def test_distance_counts(all_pairwise_distances): - for distance in actual_distances + pseudo_counts_distances: - df = all_pairwise_distances[distance] +@mark.parametrize("distance", actual_distances + pseudo_counts_distances) +def test_distance_counts(adata, distance): + if distance != "mahalanobis": # skip, doesn't work because covariance matrix is a singular matrix, not invertible + Distance = pt.tl.Distance(distance, layer_key="counts") + df = Distance.pairwise(adata, groupby="perturbation") assert isinstance(df, DataFrame) assert df.columns.equals(df.index) assert np.sum(df.values - df.values.T) == 0 -def test_mutually_exclusive_keys(): - for distance in all_distances: - with pytest.raises(ValueError): - _ = pt.tl.Distance(distance, layer_key="counts", obsm_key="X_pca") +@mark.parametrize("distance", all_distances) +def test_mutually_exclusive_keys(distance): + with pytest.raises(ValueError): + _ = pt.tl.Distance(distance, layer_key="counts", obsm_key="X_pca") + + +@mark.parametrize("distance", actual_distances + semi_distances + non_distances) +def test_distance_output_type(distance, rng): + # Test if distances are outputting floats + Distance = pt.tl.Distance(distance) + X = rng.normal(size=(50, 10)) + Y = rng.normal(size=(50, 10)) + d = Distance(X, Y) + assert isinstance(d, float) + +@mark.parametrize("distance", all_distances + onesided_only) +def test_distance_onesided(adata, distance_obj, distance): + # Test consistency of one-sided distance results + selected_group = adata.obs.perturbation.unique()[0] + df = distance_obj.onesided_distances(adata, groupby="perturbation", selected_group=selected_group) + assert isinstance(df, Series) + assert df.loc[selected_group] == 0 # distance to self is 0 -def test_distance_output_type(all_pairwise_distances): + +def test_bootstrap_distance_output_type(rng): # Test if distances are outputting floats - for distance in all_distances: - df = all_pairwise_distances[distance] - assert df.apply(lambda col: pd.api.types.is_float_dtype(col)).all(), "Not all values are floats." + Distance = pt.tl.Distance(metric="edistance") + X = rng.normal(size=(50, 10)) + Y = rng.normal(size=(50, 10)) + d = Distance.bootstrap(X, Y, n_bootstrap=3) + assert hasattr(d, "mean") + assert hasattr(d, "variance") -def test_distance_pairwise(all_pairwise_distances): +@mark.parametrize("distance", ["edistance"]) +def test_bootstrap_distance_pairwise(adata, distance): # Test consistency of pairwise distance results - for distance in all_distances: - df = all_pairwise_distances[distance] + Distance = pt.tl.Distance(distance, obsm_key="X_pca") + bootstrap_output = Distance.pairwise(adata, groupby="perturbation", bootstrap=True, n_bootstrap=3) - assert isinstance(df, DataFrame) - assert df.columns.equals(df.index) - assert np.sum(df.values - df.values.T) == 0 # symmetry + assert isinstance(bootstrap_output, tuple) + mean = bootstrap_output[0] + var = bootstrap_output[1] -def test_distance_onesided(): - # Test consistency of one-sided distance results - adata = pt.dt.distance_example() - adata = sc.pp.subsample(adata, 0.1, copy=True) - selected_group = adata.obs.perturbation.unique()[0] + assert mean.columns.equals(mean.index) + assert np.sum(mean.values - mean.values.T) == 0 # symmetry + assert np.sum(var.values - var.values.T) == 0 # symmetry - for distance in onesided_only: - Distance = pt.tl.Distance(distance, obsm_key="X_pca") - df = Distance.onesided_distances(adata, groupby="perturbation", selected_group=selected_group) - assert isinstance(df, Series) - assert df.loc[selected_group] == 0 # distance to self is 0 +@mark.parametrize("distance", ["edistance"]) +def test_bootstrap_distance_onesided(adata, distance): + # Test consistency of one-sided distance results + selected_group = adata.obs.perturbation.unique()[0] + Distance = pt.tl.Distance(distance, obsm_key="X_pca") + bootstrap_output = Distance.onesided_distances( + adata, + groupby="perturbation", + selected_group=selected_group, + bootstrap=True, + n_bootstrap=3, + ) + + assert isinstance(bootstrap_output, tuple) + + +def test_compare_distance(rng): + X = rng.normal(size=(50, 10)) + Y = rng.normal(size=(50, 10)) + C = rng.normal(size=(50, 10)) + Distance = pt.tl.Distance() + res_simple = Distance.compare_distance(X, Y, C, mode="simple") + assert isinstance(res_simple, float) + res_scaled = Distance.compare_distance(X, Y, C, mode="scaled") + assert isinstance(res_scaled, float) + with pytest.raises(ValueError): + Distance.compare_distance(X, Y, C, mode="new_mode") diff --git a/tests/tools/_perturbation_space/test_comparison.py b/tests/tools/_perturbation_space/test_comparison.py new file mode 100644 index 00000000..1751a5f1 --- /dev/null +++ b/tests/tools/_perturbation_space/test_comparison.py @@ -0,0 +1,29 @@ +import pertpy as pt +import pytest + + +@pytest.fixture +def test_data(rng): + X = rng.normal(size=(100, 10)) + Y = rng.normal(size=(100, 10)) + C = rng.normal(size=(100, 10)) + return X, Y, C + + +def test_compare_class(test_data): + X, Y, C = test_data + pt_comparison = pt.tl.PerturbationComparison() + result = pt_comparison.compare_classification(X, Y, C) + assert result <= 1 + + +def test_compare_knn(test_data): + X, Y, C = test_data + pt_comparison = pt.tl.PerturbationComparison() + result = pt_comparison.compare_knn(X, Y, C) + assert isinstance(result, dict) + assert "comp" in result + assert isinstance(result["comp"], float) + + result_no_ctrl = pt_comparison.compare_knn(X, Y) + assert isinstance(result_no_ctrl, dict)