Skip to content

Commit

Permalink
Merge pull request #34 from OpenCOMPES/crop_tool_kwds
Browse files Browse the repository at this point in the history
pass kwd arguments to convert_image in the crop tool
  • Loading branch information
rettigl authored Apr 19, 2024
2 parents e13a19e + 73ab04f commit 0e18ac1
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 78 deletions.
177 changes: 108 additions & 69 deletions specsanalyzer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ def crop_tool(
- ek_range_max
- ang_range_min
- ang_range_max
Other parameters are passed to ``convert_image()``.
"""
data_array = self.convert_image(
raw_img=raw_img,
Expand All @@ -390,6 +392,7 @@ def crop_tool(
pass_energy=pass_energy,
work_function=work_function,
crop=False,
**kwds,
)

matplotlib.use("module://ipympl.backend_nbagg")
Expand Down Expand Up @@ -581,31 +584,29 @@ def fft_tool(
apply (bool, optional): Option to directly apply the settings. Defaults to False.
**kwds: Keyword arguments:
- fft_tool_params (dict): Dictionary of parameters for fft_tool, containing keys
`amplitude`: Normalized amplitude of subtraction
`pos_x`: horzontal spatial frequency of th mesh
`pos_y`: vertical spatial frequency of the mesh
`sigma_x`: horizontal frequency width
`sigma_y`: vertical frequency width
- `amplitude`: Normalized amplitude of subtraction
- `pos_x`: horzontal spatial frequency of th mesh
- `pos_y`: vertical spatial frequency of the mesh
- `sigma_x`: horizontal frequency width
- `sigma_y`: vertical frequency width
"""
matplotlib.use("module://ipympl.backend_nbagg")
try:
fft_tool_params = (
kwds["fft_tool_params"]
if "fft_tool_params" in kwds
else self._correction_matrix_dict["fft_tool_params"]
)
(amp, pos_x, pos_y, sig_x, sig_y) = (
fft_tool_params["amplitude"],
fft_tool_params["pos_x"],
fft_tool_params["pos_y"],
fft_tool_params["sigma_x"],
fft_tool_params["sigma_y"],
)
except KeyError:
(amp, pos_x, pos_y, sig_x, sig_y) = (0.95, 86, 116, 13, 22)
stored_parameters = self._correction_matrix_dict.get("fft_tool_params", {})
if not stored_parameters:
stored_parameters = {
"amplitude": 0.95,
"pos_x": 86,
"pos_y": 116,
"sigma_x": 13,
"sigma_y": 22,
}
amplitude = kwds.get("amplitude", stored_parameters["amplitude"])
pos_x = kwds.get("pos_x", stored_parameters["pos_x"])
pos_y = kwds.get("pos_y", stored_parameters["pos_y"])
sigma_x = kwds.get("sigma_x", stored_parameters["sigma_x"])
sigma_y = kwds.get("sigma_y", stored_parameters["sigma_y"])

fft_filter_peaks = create_fft_params(amp, pos_x, pos_y, sig_x, sig_y)
fft_filter_peaks = create_fft_params(amplitude, pos_x, pos_y, sigma_x, sigma_y)
try:
img = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="fft")
fft_filtered = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered_fft")
Expand Down Expand Up @@ -646,37 +647,37 @@ def fft_tool(
ax5.legend()
# plt.tight_layout()

posx_slider = ipw.FloatSlider(
pos_x_slider = ipw.FloatSlider(
description="pos_x",
value=pos_x,
min=0,
max=128,
step=1,
)
posy_slider = ipw.FloatSlider(
pos_y_slider = ipw.FloatSlider(
description="pos_y",
value=pos_y,
min=0,
max=150,
step=1,
)
sigx_slider = ipw.FloatSlider(
sigma_x_slider = ipw.FloatSlider(
description="sig_x",
value=sig_x,
value=sigma_x,
min=0,
max=50,
step=1,
)
sigy_slider = ipw.FloatSlider(
sigma_y_slider = ipw.FloatSlider(
description="sig_y",
value=sig_y,
value=sigma_y,
min=0,
max=50,
step=1,
)
amp_slider = ipw.FloatSlider(
amplitude_slider = ipw.FloatSlider(
description="Amplitude",
value=amp,
value=amplitude,
min=0,
max=1,
step=0.01,
Expand All @@ -689,8 +690,8 @@ def fft_tool(
max=int(np.log10(np.abs(img).max())) + 1,
)

def update(v_vals, pos_x, pos_y, sig_x, sig_y, amp):
fft_filter_peaks = create_fft_params(amp, pos_x, pos_y, sig_x, sig_y)
def update(v_vals, pos_x, pos_y, sigma_x, sigma_y, amplitude):
fft_filter_peaks = create_fft_params(amplitude, pos_x, pos_y, sigma_x, sigma_y)
msk = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="mask")
filtered_new = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered")

Expand Down Expand Up @@ -718,39 +719,39 @@ def update(v_vals, pos_x, pos_y, sig_x, sig_y, amp):

ipw.interact(
update,
amp=amp_slider,
pos_x=posx_slider,
pos_y=posy_slider,
sig_x=sigx_slider,
sig_y=sigy_slider,
amplitude=amplitude_slider,
pos_x=pos_x_slider,
pos_y=pos_y_slider,
sigma_x=sigma_x_slider,
sigma_y=sigma_y_slider,
v_vals=clim_slider,
)

def apply_fft(apply: bool): # pylint: disable=unused-argument
amp = amp_slider.value
pos_x = posx_slider.value
pos_y = posy_slider.value
sig_x = sigx_slider.value
sig_y = sigy_slider.value
amplitude = amplitude_slider.value
pos_x = pos_x_slider.value
pos_y = pos_y_slider.value
sigma_x = sigma_x_slider.value
sigma_y = sigma_y_slider.value
self._correction_matrix_dict["fft_tool_params"] = {
"amplitude": amp,
"amplitude": amplitude,
"pos_x": pos_x,
"pos_y": pos_y,
"sigma_x": sig_x,
"sigma_y": sig_y,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
}
self.config["fft_filter_peaks"] = create_fft_params(
amp,
amplitude,
pos_x,
pos_y,
sig_x,
sig_y,
sigma_x,
sigma_y,
)
amp_slider.close()
posx_slider.close()
posy_slider.close()
sigx_slider.close()
sigy_slider.close()
amplitude_slider.close()
pos_x_slider.close()
pos_y_slider.close()
sigma_x_slider.close()
sigma_y_slider.close()
clim_slider.close()
apply_button.close()

Expand All @@ -762,25 +763,63 @@ def apply_fft(apply: bool): # pylint: disable=unused-argument
apply_fft(True)


def create_fft_params(amp, pos_x, pos_y, sig_x, sig_y) -> list[dict]:
"""Function to create fft filter peaks list using the
provided Gaussian peak parameters. The peaks are defined
relative to each other such that they are periodically
aranged in a 256 x 150 Fourier space.
def create_fft_params(
amplitude: float,
pos_x: float,
pos_y: float,
sigma_x: float,
sigma_y: float,
) -> list[dict]:
"""Function to create fft filter peaks list using the provided Gaussian peak parameters.
The peaks are placed at +-x, y=0, and +-x, y=y, with width corresponding to the sigma
values.
Args:
amp: Gaussian peak amplitude
pos_x: x-position
pos_y: y-position
sig_x: FWHM in x-axis
sig_y: FWHM in y-axis
amplitude (float): Gaussian peak amplitude
pos_x (float): horizontal spatial frequency
pos_y (float): vertical spatial frequency
sigma_x (float): horizontal width
sigma_y (float): vertical width
Returns:
list[dict]: A list of the defined filter parameters
"""

fft_filter_peaks = [
{"amplitude": amp, "pos_x": -pos_x, "pos_y": 0, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": pos_x, "pos_y": 0, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": 0, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": -pos_x, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": pos_x, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y},
{
"amplitude": amplitude,
"pos_x": -pos_x,
"pos_y": 0,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
},
{
"amplitude": amplitude,
"pos_x": pos_x,
"pos_y": 0,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
},
{
"amplitude": amplitude,
"pos_x": 0,
"pos_y": pos_y,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
},
{
"amplitude": amplitude,
"pos_x": -pos_x,
"pos_y": pos_y,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
},
{
"amplitude": amplitude,
"pos_x": pos_x,
"pos_y": pos_y,
"sigma_x": sigma_x,
"sigma_y": sigma_y,
},
]

return fft_filter_peaks
16 changes: 16 additions & 0 deletions specsscan/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,22 @@ def crop_tool(self, scan: int = None, path: Path | str = "", **kwds):
)

def fft_tool(self, scan: int = None, path: Path | str = "", **kwds):
"""FFT tool to play around with the peak parameters in the Fourier plane. Built to filter
out the meshgrid appearing in the raw data images. The optimized parameters are stored in
the class config dict under fft_filter_peaks.
Args:
scan (int, optional): Scan number to load. Defaults to the previously loaded scan.
path (Path | str): Path from where to load the data. Defaults to config value.
**kwds: Keyword arguments passed to ``SpecsAnalyzer.fft_tool()``:
- `apply`: Option to directly apply the settings.
- `amplitude`: Normalized amplitude of subtraction
- `pos_x`: horzontal spatial frequency of th mesh
- `pos_y`: vertical spatial frequency of the mesh
- `sigma_x`: horizontal frequency width
- `sigma_y`: vertical frequency width
"""
matplotlib.use("module://ipympl.backend_nbagg")
if scan is not None:
scan_path = get_scan_path(path, scan, self._config["data_path"])
Expand Down
8 changes: 6 additions & 2 deletions tests/test_specsscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

package_dir = os.path.dirname(specsscan.__file__)
test_dir = package_dir + "/../tests/data/"
fft_filter_peaks = create_fft_params(amp=1, pos_x=82, pos_y=116, sig_x=15, sig_y=23)
fft_filter_peaks = create_fft_params(amplitude=1, pos_x=82, pos_y=116, sigma_x=15, sigma_y=23)


def test_version():
Expand Down Expand Up @@ -289,7 +289,11 @@ def test_fft_tool():
np.testing.assert_almost_equal(res_xarray.data.sum(), 62197237155.50347, decimal=3)

sps.fft_tool(
fft_tool_params={"amplitude": 1, "pos_x": 82, "pos_y": 116, "sigma_x": 15, "sigma_y": 23},
amplitude=1,
pos_x=82,
pos_y=116,
sigma_x=15,
sigma_y=23,
apply=True,
)
assert sps.config["spa_params"]["fft_filter_peaks"] == fft_filter_peaks
Expand Down
12 changes: 5 additions & 7 deletions tutorial/2_specsscan_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,11 @@
"outputs": [],
"source": [
"sps.fft_tool(\n",
" fft_tool_params={\n",
" \"amplitude\": 1,\n",
" \"pos_x\": 82,\n",
" \"pos_y\": 116,\n",
" \"sigma_x\": 15,\n",
" \"sigma_y\": 23\n",
" },\n",
" amplitude=1,\n",
" pos_x=82,\n",
" pos_y=116,\n",
" sigma_x=15,\n",
" sigma_y=23,\n",
" apply=True # Use apply=False for interactive mode\n",
")"
]
Expand Down

0 comments on commit 0e18ac1

Please sign in to comment.