Skip to content

Commit

Permalink
Add PerformanceComparisonPlot (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret authored Dec 20, 2023
1 parent b14fca7 commit 960f83f
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
55 changes: 54 additions & 1 deletion src/gnn_tracking/analysis/efficiencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
df (pd.DataFrame): Dataframe with values. Errors should be in columns named with suffix ``_err``.
df_ul (_type_, optional): Dataframe with values for upper limit. Defaults to None.
x_label (regexp, optional): x label
y_label (str, optional): y abel
y_label (str, optional): y label
**kwargs: Passed to `Plot`
"""
super().__init__(**kwargs)
Expand Down Expand Up @@ -130,3 +130,56 @@ def add_legend(self, **kwargs) -> None:
all_handles = [item[0] for item in self._legend_items]
all_labels = [item[1] for item in self._legend_items]
self.ax.legend(all_handles, all_labels, **kwargs)


class PerformanceComparisonPlot(Plot):
def __init__(
self,
xs: np.ndarray,
var: str,
x_label: str,
ylabel: str = "Efficiency",
**kwargs,
):
"""Similar to `PerforamncePlot`, except that we use the same x axis for
plots of different models (and supply the dataframes directly to `plot_var`).
Args:
xs (np.ndarray): x values (e.g., pt or eta). Length must be one longer than the dataframe
to account for bin edges.
var (str): Name of variable
x_label (regexp, optional): x label
y_label (str, optional): y label
**kwargs: Passed to `Plot`
"""
super().__init__(**kwargs)
self.xs = xs
self.var = var
self.ax.set_xlabel(x_label)
self.ax.set_ylabel(ylabel)
self._legend_items = []

def plot_var(self, df: pd.DataFrame, label: str, color: str) -> None:
stairs = self.ax.stairs(df[self.var], edges=self.xs, color=color, lw=1.5)
mids = (self.xs[:-1] + self.xs[1:]) / 2
bar = self.ax.errorbar(
mids,
self.var,
yerr=f"{self.var}_err",
ls="none",
color=color,
data=df,
)
self._legend_items.append(((stairs, bar), label))

def add_legend(self, **kwargs) -> None:
all_handles = [item[0] for item in self._legend_items]
all_labels = [item[1] for item in self._legend_items]
self.ax.legend(all_handles, all_labels, **kwargs)

def add_blocked(self, a, b, label="Not trained for") -> None:
"""Used to mark low pt as "not trained for"."""
span = self.ax.axvspan(
a, b, alpha=0.3, color="gray", label=label, linestyle="none"
)
self._legend_items.append(((span,), label))
20 changes: 19 additions & 1 deletion tests/test_efficiency_plots.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import pandas as pd

from gnn_tracking.analysis.efficiencies import PerformancePlot, TracksVsDBSCANPlot
from gnn_tracking.analysis.efficiencies import (
PerformanceComparisonPlot,
PerformancePlot,
TracksVsDBSCANPlot,
)


def test_track_vs_dbscan_parameters():
Expand Down Expand Up @@ -33,3 +37,17 @@ def test_performance_plot():
p.add_blocked(1, 2)
p.plot_var("test", color="red")
p.add_legend()


def test_performance_comparison_plot():
df = pd.DataFrame(
{
"test": [1, 0.5, 0.25],
"test_err": [0.5, 0.25, 0.125],
}
)
pt = [1, 2, 3, 4]
p = PerformanceComparisonPlot(var="test", x_label="test", xs=np.array(pt))
p.add_blocked(1, 2)
p.plot_var(df, color="red", label="test")
p.add_legend()

0 comments on commit 960f83f

Please sign in to comment.