Skip to content

Commit

Permalink
Merge pull request #27 from OpenCOMPES/fft-tool
Browse files Browse the repository at this point in the history
FFT tool interface
  • Loading branch information
rettigl authored Apr 19, 2024
2 parents 88c8891 + 7382d01 commit 046bd3f
Show file tree
Hide file tree
Showing 9 changed files with 425 additions and 25 deletions.
219 changes: 219 additions & 0 deletions specsanalyzer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,222 @@ def cropit(val): # pylint: disable=unused-argument
plt.show()
if apply:
cropit("")

def fft_tool(
self,
raw_image: np.ndarray,
apply: bool = False,
**kwds,
):
"""FFT tool to play around with the peak parameters in the Fourier plane. Built to filter
out the meshgrid appearing in the raw data images. The optimized parameters are stored in
the class config dict under fft_filter_peaks.
Args:
raw_image (np.ndarray): The source image
apply (bool, optional): Option to directly apply the settings. Defaults to False.
**kwds: Keyword arguments:
- fft_tool_params (dict): Dictionary of parameters for fft_tool, containing keys
`amplitude`: Normalized amplitude of subtraction
`pos_x`: horzontal spatial frequency of th mesh
`pos_y`: vertical spatial frequency of the mesh
`sigma_x`: horizontal frequency width
`sigma_y`: vertical frequency width
"""
matplotlib.use("module://ipympl.backend_nbagg")
try:
fft_tool_params = (
kwds["fft_tool_params"]
if "fft_tool_params" in kwds
else self._correction_matrix_dict["fft_tool_params"]
)
(amp, pos_x, pos_y, sig_x, sig_y) = (
fft_tool_params["amplitude"],
fft_tool_params["pos_x"],
fft_tool_params["pos_y"],
fft_tool_params["sigma_x"],
fft_tool_params["sigma_y"],
)
except KeyError:
(amp, pos_x, pos_y, sig_x, sig_y) = (0.95, 86, 116, 13, 22)

fft_filter_peaks = create_fft_params(amp, pos_x, pos_y, sig_x, sig_y)
try:
img = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="fft")
fft_filtered = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered_fft")

mask = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="mask")

filtered = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered")
except IndexError:
print("Load the scan first!")
raise

fig = plt.figure()
ax = fig.add_subplot(3, 2, 1)
im_fft = ax.imshow(np.abs(img).T, origin="lower", aspect="auto")
fig.colorbar(im_fft)

ax.set_title("FFT")
cont = ax.contour(mask.T)

# Plot raw image
ax2 = fig.add_subplot(3, 2, 2)
fft_filt = ax2.imshow(np.abs(fft_filtered).T, origin="lower", aspect="auto")
ax2.set_title("Filtered FFT")
fig.colorbar(fft_filt)

# Plot fft filtered image
ax3 = fig.add_subplot(2, 2, 3)
filt = ax3.imshow(filtered.T, origin="lower", aspect="auto")
ax3.set_title("Filtered Image")
fig.colorbar(filt)

ax4 = fig.add_subplot(3, 2, 4)
(edc,) = ax4.plot(np.sum(filtered, 0), label="EDC")
ax4.legend()

ax5 = fig.add_subplot(3, 2, 6)
(mdc,) = ax5.plot(np.sum(filtered, 1), label="MDC")
ax5.legend()
# plt.tight_layout()

posx_slider = ipw.FloatSlider(
description="pos_x",
value=pos_x,
min=0,
max=128,
step=1,
)
posy_slider = ipw.FloatSlider(
description="pos_y",
value=pos_y,
min=0,
max=150,
step=1,
)
sigx_slider = ipw.FloatSlider(
description="sig_x",
value=sig_x,
min=0,
max=50,
step=1,
)
sigy_slider = ipw.FloatSlider(
description="sig_y",
value=sig_y,
min=0,
max=50,
step=1,
)
amp_slider = ipw.FloatSlider(
description="Amplitude",
value=amp,
min=0,
max=1,
step=0.01,
)
clim_slider = ipw.FloatLogSlider(
description="colorbar limits",
value=int(np.abs(img).max() / 500),
base=10,
min=-1,
max=int(np.log10(np.abs(img).max())) + 1,
)

def update(v_vals, pos_x, pos_y, sig_x, sig_y, amp):
fft_filter_peaks = create_fft_params(amp, pos_x, pos_y, sig_x, sig_y)
msk = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="mask")
filtered_new = fourier_filter_2d(raw_image, peaks=fft_filter_peaks, ret="filtered")

fft_filtered_new = fourier_filter_2d(
raw_image,
peaks=fft_filter_peaks,
ret="filtered_fft",
)

im_fft.set_clim(vmax=v_vals)
fft_filt.set_clim(vmax=v_vals)

filt.set_data(filtered_new.T)
fft_filt.set_data(np.abs(fft_filtered_new.T))

nonlocal cont
for i in range(len(cont.collections)):
cont.collections[i].remove()
cont = ax.contour(msk.T)

edc.set_ydata(np.sum(filtered_new, 0))
mdc.set_ydata(np.sum(filtered_new, 1))

fig.canvas.draw_idle()

ipw.interact(
update,
amp=amp_slider,
pos_x=posx_slider,
pos_y=posy_slider,
sig_x=sigx_slider,
sig_y=sigy_slider,
v_vals=clim_slider,
)

def apply_fft(apply: bool): # pylint: disable=unused-argument
amp = amp_slider.value
pos_x = posx_slider.value
pos_y = posy_slider.value
sig_x = sigx_slider.value
sig_y = sigy_slider.value
self._correction_matrix_dict["fft_tool_params"] = {
"amplitude": amp,
"pos_x": pos_x,
"pos_y": pos_y,
"sigma_x": sig_x,
"sigma_y": sig_y,
}
self.config["fft_filter_peaks"] = create_fft_params(
amp,
pos_x,
pos_y,
sig_x,
sig_y,
)
amp_slider.close()
posx_slider.close()
posy_slider.close()
sigx_slider.close()
sigy_slider.close()
clim_slider.close()
apply_button.close()

apply_button = ipw.Button(description="Apply")
display(apply_button)
apply_button.on_click(apply_fft)
plt.show()
if apply:
apply_fft(True)


def create_fft_params(amp, pos_x, pos_y, sig_x, sig_y) -> list[dict]:
"""Function to create fft filter peaks list using the
provided Gaussian peak parameters. The peaks are defined
relative to each other such that they are periodically
aranged in a 256 x 150 Fourier space.
Args:
amp: Gaussian peak amplitude
pos_x: x-position
pos_y: y-position
sig_x: FWHM in x-axis
sig_y: FWHM in y-axis
"""

fft_filter_peaks = [
{"amplitude": amp, "pos_x": -pos_x, "pos_y": 0, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": pos_x, "pos_y": 0, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": 0, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": -pos_x, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y},
{"amplitude": amp, "pos_x": pos_x, "pos_y": pos_y, "sigma_x": sig_x, "sigma_y": sig_y},
]

return fft_filter_peaks
6 changes: 4 additions & 2 deletions specsanalyzer/img_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def fourier_filter_2d(

# Do Fourier Transform of the (real-valued) image
image_fft = np.fft.rfft2(image)
# shift fft axis to have 0 in the center
image_fft = np.fft.fftshift(image_fft, axes=0)
mask = np.ones(image_fft.shape)
xgrid, ygrid = np.meshgrid(
range(image_fft.shape[0]),
Expand All @@ -73,7 +75,7 @@ def fourier_filter_2d(
mask -= peak["amplitude"] * gauss2d(
xgrid,
ygrid,
peak["pos_x"],
image_fft.shape[0] / 2 + peak["pos_x"],
peak["pos_y"],
peak["sigma_x"],
peak["sigma_y"],
Expand All @@ -85,7 +87,7 @@ def fourier_filter_2d(
) from exc

# apply mask to the FFT, and transform back
filtered = np.fft.irfft2(image_fft * mask)
filtered = np.fft.irfft2(np.fft.ifftshift(image_fft * mask, axes=0))
# strip negative values
filtered = filtered.clip(min=0)
if ret == "filtered":
Expand Down
9 changes: 2 additions & 7 deletions specsscan/config/example_config_FHI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ spa_params:
sigma_x: 8
sigma_y: 8
- amplitude: 1
pos_x: 176
pos_x: -80
pos_y: 0
sigma_x: 8
sigma_y: 8
Expand All @@ -106,12 +106,7 @@ spa_params:
sigma_x: 5
sigma_y: 5
- amplitude: 1
pos_x: 175
pos_x: -81
pos_y: 108
sigma_x: 5
sigma_y: 5
- amplitude: 1
pos_x: 254
pos_y: 109
sigma_x: 5
sigma_y: 8
45 changes: 41 additions & 4 deletions specsscan/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def load_scan(
else:
res_xarray = res_xarray.transpose("Angle", "Ekin", dim)

slow_axes = {dim} if dim else set()
fast_axes = set(res_xarray.dims) - slow_axes
projection = "reciprocal" if "Angle" in fast_axes else "real"
conversion_metadata = res_xarray.attrs.pop("conversion_parameters")

# rename coords and store mapping information, if available
Expand All @@ -238,6 +241,13 @@ def load_scan(
if k in res_xarray.dims
}
res_xarray = res_xarray.rename(rename_dict)
for k, v in coordinate_mapping.items():
if k in fast_axes:
fast_axes.remove(k)
fast_axes.add(v)
if k in slow_axes:
slow_axes.remove(k)
slow_axes.add(v)
self._scan_info["coordinate_depends"] = depends_dict

axis_dict = {
Expand All @@ -262,8 +272,9 @@ def load_scan(
df_lut,
self._scan_info,
self.config,
fast_axis="Angle" if "Angle" in res_xarray.dims else "Position",
slow_axis=dim,
fast_axes=list(fast_axes), # type: ignore
slow_axes=list(slow_axes),
projection=projection,
metadata=copy.deepcopy(metadata),
collect_metadata=collect_metadata,
),
Expand Down Expand Up @@ -312,6 +323,27 @@ def crop_tool(self, scan: int = None, path: Path | str = "", **kwds):
**kwds,
)

def fft_tool(self, scan: int = None, path: Path | str = "", **kwds):
matplotlib.use("module://ipympl.backend_nbagg")
if scan is not None:
scan_path = get_scan_path(path, scan, self._config["data_path"])

data = load_images(
scan_path=scan_path,
tqdm_enable_nested=self._config["enable_nested_progress_bar"],
)
image = data[0]
else:
try:
image = self.metadata["loader"]["raw_data"][0]
except KeyError as exc:
raise ValueError("No image loaded, load image first!") from exc

self.spa.fft_tool(
image,
**kwds,
)

def check_scan(
self,
scan: int,
Expand Down Expand Up @@ -411,13 +443,18 @@ def check_scan(
except KeyError:
pass

slow_axes = {"Iteration"}
fast_axes = set(res_xarray.dims) - slow_axes
projection = "reciprocal" if "Angle" in fast_axes else "real"

self.metadata.update(
**handle_meta(
df_lut,
self._scan_info,
self.config,
fast_axis="Angle" if "Angle" in res_xarray.dims else "Position",
slow_axis=dims[1],
fast_axes=list(fast_axes), # type: ignore
slow_axes=list(slow_axes),
projection=projection,
metadata=metadata,
collect_metadata=collect_metadata,
),
Expand Down
14 changes: 7 additions & 7 deletions specsscan/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,9 @@ def handle_meta(
df_lut: pd.DataFrame,
scan_info: dict,
config: dict,
fast_axis: str,
slow_axis: str,
fast_axes: list[str],
slow_axes: list[str],
projection: str,
metadata: dict = None,
collect_metadata: bool = False,
) -> dict:
Expand All @@ -360,8 +361,8 @@ def handle_meta(
from ``parse_lut_to_df()``
scan_info (dict): scan_info class dict containing containing the contents of info.txt file
config (dict): config dictionary containing the contents of config.yaml file
fast_axis (str): The fast-axis dimension of the scan
slow_axis (str): The slow-axis dimension of the scan
fast_axes (list[str]): The fast-axis dimensions of the scan
slow_axes (list[str]): The slow-axis dimensions of the scan
metadata (dict, optional): Metadata dictionary with additional metadata for the scan.
Defaults to empty dictionary.
collect_metadata (bool, optional): Option to collect further metadata e.g. from EPICS
Expand Down Expand Up @@ -470,14 +471,13 @@ def handle_meta(

metadata["scan_info"]["energy_scan_mode"] = energy_scan_mode

projection = "reciprocal" if fast_axis in {"Anlge", "angular0", "angular1"} else "real"
metadata["scan_info"]["projection"] = projection
metadata["scan_info"]["scheme"] = (
"angular dispersive" if projection == "reciprocal" else "spatial dispersive"
)

metadata["scan_info"]["slow_axes"] = slow_axis
metadata["scan_info"]["fast_axes"] = ["Ekin", fast_axis]
metadata["scan_info"]["slow_axes"] = slow_axes
metadata["scan_info"]["fast_axes"] = fast_axes

print("Done!")

Expand Down
Loading

0 comments on commit 046bd3f

Please sign in to comment.