From 6f5f0333a435e79bc7b1ee814c1a28a7bdf59857 Mon Sep 17 00:00:00 2001 From: rettigl Date: Tue, 15 Oct 2024 22:54:43 +0200 Subject: [PATCH] use matplotlib also for momentum correction --- .cspell/custom-dictionary.txt | 2 ++ src/sed/calibrator/momentum.py | 22 ++++++++++++++-------- src/sed/core/processor.py | 20 ++++++++++++-------- src/sed/diagnostics.py | 23 +++++++++++++---------- 4 files changed, 41 insertions(+), 26 deletions(-) diff --git a/.cspell/custom-dictionary.txt b/.cspell/custom-dictionary.txt index 38f8c89c..88ff0b12 100644 --- a/.cspell/custom-dictionary.txt +++ b/.cspell/custom-dictionary.txt @@ -404,6 +404,7 @@ xpos xratio xrng xscale +xticks xtrans Xuser xval @@ -414,6 +415,7 @@ ylabel ypos yratio yscale +yticks ytrans zain Zenodo diff --git a/src/sed/calibrator/momentum.py b/src/sed/calibrator/momentum.py index 36a41b86..534ad81f 100644 --- a/src/sed/calibrator/momentum.py +++ b/src/sed/calibrator/momentum.py @@ -1348,7 +1348,9 @@ def view( if annotated: tsr, tsc = kwds.pop("textshift", (3, 3)) - txtsize = kwds.pop("textsize", 12) + txtsize = kwds.pop("textsize", 10) + + title = kwds.pop("title", "") # Handle unexpected kwds: handled_kwds = {"figsize"} @@ -1358,7 +1360,7 @@ def view( ) if backend == "matplotlib": - fig_plt, ax = plt.subplots(figsize=figsize) + _, ax = plt.subplots(figsize=figsize) ax.imshow(image.T, origin=origin, cmap=cmap, **imkwds) if cross: @@ -1368,15 +1370,12 @@ def view( # Add annotation to the figure if annotated: - for ( - p_keys, # pylint: disable=unused-variable - p_vals, - ) in points.items(): + for p_keys, p_vals in points.items(): try: - ax.scatter(p_vals[:, 0], p_vals[:, 1], **scatterkwds) + ax.scatter(p_vals[:, 0], p_vals[:, 1], s=15, **scatterkwds) except IndexError: try: - ax.scatter(p_vals[0], p_vals[1], **scatterkwds) + ax.scatter(p_vals[0], p_vals[1], s=15, **scatterkwds) except IndexError: pass @@ -1389,6 +1388,13 @@ def view( fontsize=txtsize, ) + if crosshair and self.pcent is not None: + for radius in crosshair_radii: + circle = plt.Circle(self.pcent, radius, color="k", fill=False) + ax.add_patch(circle) + + ax.set_title(title) + elif backend == "bokeh": output_notebook(hide_banner=True) colors = it.cycle(ColorCycle[10]) diff --git a/src/sed/core/processor.py b/src/sed/core/processor.py index 6c51acfb..5888855d 100644 --- a/src/sed/core/processor.py +++ b/src/sed/core/processor.py @@ -649,24 +649,28 @@ def generate_splinewarp( self.mc.spline_warp_estimate(use_center=use_center, **kwds) if self.mc.slice is not None and self._verbose: - print("Original slice with reference features") - self.mc.view(annotated=True, backend="bokeh", crosshair=True) + self.mc.view( + annotated=True, + backend="matplotlib", + crosshair=True, + title="Original slice with reference features", + ) - print("Corrected slice with target features") self.mc.view( image=self.mc.slice_corrected, annotated=True, points={"feats": self.mc.ptargs}, - backend="bokeh", + backend="matplotlib", crosshair=True, + title="Corrected slice with target features", ) - print("Original slice with target features") self.mc.view( image=self.mc.slice, points={"feats": self.mc.ptargs}, annotated=True, - backend="bokeh", + backend="matplotlib", + title="Original slice with target features", ) # 3a. Save spline-warp parameters to config file. @@ -2384,7 +2388,7 @@ def view_event_histogram( bins: Sequence[int] = None, axes: Sequence[str] = None, ranges: Sequence[tuple[float, float]] = None, - backend: str = "bokeh", + backend: str = "matplotlib", legend: bool = True, histkwds: dict = None, legkwds: dict = None, @@ -2403,7 +2407,7 @@ def view_event_histogram( ranges (Sequence[tuple[float, float]], optional): Value ranges of all specified axes. Defaults to config["histogram"]["ranges"]. backend (str, optional): Backend of the plotting library - ('matplotlib' or 'bokeh'). Defaults to "bokeh". + ("matplotlib" or "bokeh"). Defaults to "matplotlib". legend (bool, optional): Option to include a legend in the histogram plots. Defaults to True. histkwds (dict, optional): Keyword arguments for histograms diff --git a/src/sed/diagnostics.py b/src/sed/diagnostics.py index 25fddd9a..6ec8e507 100644 --- a/src/sed/diagnostics.py +++ b/src/sed/diagnostics.py @@ -59,7 +59,7 @@ def grid_histogram( rvs: Sequence, rvbins: Sequence, rvranges: Sequence[tuple[float, float]], - backend: str = "bokeh", + backend: str = "matplotlib", legend: bool = True, histkwds: dict = None, legkwds: dict = None, @@ -73,22 +73,22 @@ def grid_histogram( rvs (Sequence): List of names for the random variables (rvs). rvbins (Sequence): Bin values for all random variables. rvranges (Sequence[tuple[float, float]]): Value ranges of all random variables. - backend (str, optional): Backend for making the plot ('matplotlib' or 'bokeh'). - Defaults to "bokeh". + backend (str, optional): Backend for making the plot ("matplotlib" or "bokeh"). + Defaults to "matplotlib". legend (bool, optional): Option to include a legend in each histogram plot. Defaults to True. histkwds (dict, optional): Keyword arguments for histogram plots. Defaults to None. legkwds (dict, optional): Keyword arguments for legends. Defaults to None. **kwds: - - *figsize*: Figure size. Defaults to (14, 8) + - *figsize*: Figure size. Defaults to (6, 4) """ if histkwds is None: histkwds = {} if legkwds is None: legkwds = {} - figsz = kwds.pop("figsize", (10, 7)) + figsz = kwds.pop("figsize", (6, 4)) if len(kwds) > 0: raise TypeError(f"grid_histogram() got unexpected keyword arguments {kwds.keys()}.") @@ -96,7 +96,7 @@ def grid_histogram( if backend == "matplotlib": nrv = len(rvs) nrow = int(np.ceil(nrv / ncol)) - histtype = kwds.pop("histtype", "step") + histtype = kwds.pop("histtype", "bar") fig, ax = plt.subplots(nrow, ncol, figsize=figsz) otherax = ax.copy() @@ -114,7 +114,7 @@ def grid_histogram( **histkwds, ) if legend: - ax[axind].legend(fontsize=15, **legkwds) + ax[axind].legend(fontsize=10, **legkwds) otherax[axind] = None @@ -128,13 +128,16 @@ def grid_histogram( **histkwds, ) if legend: - ax[i].legend(fontsize=15, **legkwds) + ax[i].legend(fontsize=10, **legkwds) otherax[i] = None for oax in otherax.flatten(): if oax is not None: fig.delaxes(oax) + plt.xticks(fontsize=8) + plt.yticks(fontsize=8) + plt.tight_layout() elif backend == "bokeh": output_notebook(hide_banner=True) @@ -163,7 +166,7 @@ def grid_histogram( gridplot( plots, # type: ignore ncols=ncol, - width=figsz[0] * 30, - height=figsz[1] * 28, + width=figsz[0] * 100 // ncol, + height=figsz[1] * 100 // (len(plots) // ncol), ), )