Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stitching bugfix #158

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 108 additions & 23 deletions mantis/analysis/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@


def estimate_shift(
im0: np.ndarray, im1: np.ndarray, percent_overlap: float, direction: Literal["row", "col"]
im0: np.ndarray,
im1: np.ndarray,
percent_overlap: float,
direction: Literal["row", "col"],
add_offset: bool = False,
):
"""
Estimate the shift between two images based on a given percentage overlap and direction.
Expand All @@ -29,6 +33,9 @@ def estimate_shift(
The percentage of overlap between the two images. Must be between 0 and 1.
direction : Literal["row", "col"]
The direction of the shift. Can be either "row" or "col". See estimate_zarr_fov_shifts
add_offset : bool
Add offsets to shift-x and shift-y when stitching data from ISS microscope.
Not clear why we need to do that. By default False

Returns
-------
Expand All @@ -52,31 +59,31 @@ def estimate_shift(
if direction == "row":
y_roi = int(sizeY * np.minimum(percent_overlap + 0.05, 1))
shift, _, _ = phase_cross_correlation(
im0[-y_roi:, :], im1[:y_roi, :], upsample_factor=10
im0[..., -y_roi:, :], im1[..., :y_roi, :], upsample_factor=1
)
shift[0] += sizeY - y_roi
shift[-2] += sizeY
if add_offset:
shift[-2] -= y_roi
elif direction == "col":
x_roi = int(sizeX * np.minimum(percent_overlap + 0.05, 1))
shift, _, _ = phase_cross_correlation(
im0[:, -x_roi:], im1[:, :x_roi], upsample_factor=10
im0[..., :, -x_roi:], im1[..., :, :x_roi], upsample_factor=1
)
shift[1] += sizeX - x_roi
shift[-1] += sizeX
if add_offset:
shift[-1] -= x_roi

# TODO: we shouldn't need to flip the order
# TODO: we shouldn't need to flip the order, will cause problems in 3D
return shift[::-1]


def get_grid_rows_cols(dataset_path: str):
def get_grid_rows_cols(fov_names: list[str]):
grid_rows = set()
grid_cols = set()

with open_ome_zarr(dataset_path) as dataset:

_, well = next(dataset.wells())
for position_name, _ in well.positions():
fov_name = Path(position_name).parts[-1]
grid_rows.add(fov_name[3:]) # 1-Pos<COL>_<ROW> syntax
grid_cols.add(fov_name[:3])
for fov_name in fov_names:
grid_rows.add(fov_name[3:]) # 1-Pos<COL>_<ROW> syntax
grid_cols.add(fov_name[:3])

return sorted(grid_rows), sorted(grid_cols)

Expand Down Expand Up @@ -397,7 +404,7 @@ def estimate_zarr_fov_shifts(
im0 = np.flipud(im0)
im1 = np.flipud(im1)

shift = estimate_shift(im0, im1, percent_overlap, direction)
shift = estimate_shift(im0, im1, percent_overlap, direction, add_offset=flipud)

df = pd.DataFrame(
{
Expand Down Expand Up @@ -512,8 +519,93 @@ def compute_total_translation(csv_filepath: str) -> pd.DataFrame:
# create 'row' and 'col' number columns and sort the dataframe by 'fov1'
df['row'] = df['fov1'].str[-3:].astype(int)
df['col'] = df['fov1'].str[:3].astype(int)
df_row = df[(df['direction'] == 'row')]
df_col = df[(df['direction'] == 'col')]
row_anchors = sorted(df_row['fov0'][~df_row['fov0'].isin(df_row['fov1'])].unique())
col_anchors = sorted(df_col['fov0'][~df_col['fov0'].isin(df_col['fov1'])].unique())
row_col_anchors = sorted(set(row_anchors).intersection(col_anchors))
df['fov0'] = df[['well', 'fov0']].agg('/'.join, axis=1)
df['fov1'] = df[['well', 'fov1']].agg('/'.join, axis=1)
df.set_index('fov1', inplace=True)
df.sort_index(inplace=True)

for well in df['well'].unique():
# add anchors
df = pd.concat(
(
pd.DataFrame(
{
'well': well,
'shift-x': 0,
'shift-y': 0,
'direction': 'row',
'row': [int(a[3:]) for a in row_anchors],
'col': [int(a[:3]) for a in row_anchors],
},
index=['/'.join((well, a)) for a in row_anchors],
),
pd.DataFrame(
{
'well': well,
'shift-x': 0,
'shift-y': 0,
'direction': 'col',
'row': [int(a[3:]) for a in col_anchors],
'col': [int(a[:3]) for a in col_anchors],
},
index=['/'.join((well, a)) for a in col_anchors],
),
df,
)
)

for anchor in row_col_anchors[::-1]:
df_well = df[df['well'] == well]
df_well_col = df_well[df_well['direction'] == 'col']
df_well_row = df_well[df_well['direction'] == 'row']

_row = int(anchor[3:])
idx1 = df_well_col[df_well_col['row'] == _row].index
idx_out = ['/'.join((well, a)) for a in row_anchors if a[3:] == anchor[3:]]
idx_in = sorted(idx1[~idx1.isin(idx_out)])

if len(idx_in) > 0: # will be zero for first row
shift_x = df_well_row[
(df_well_row['row'] <= _row)
& (df_well_row['col'] == int(idx_in[0][-6:-3]))
]['shift-x'].sum()
shift_y = df_well_row[
(df_well_row['row'] <= _row)
& (df_well_row['col'] == int(idx_in[0][-6:-3]))
]['shift-y'].sum()

df_well_row.loc[idx_out, ['shift-x', 'shift-y']] = (shift_x, shift_y)

df[(df['direction'] == 'row') & (df['well'] == well)] = df_well_row

for anchor in col_anchors:
_col = int(anchor[:3])

shift_x = (
df_well_col[(df_well_col['col'] <= _col) & (df_well_col['shift-x'] != 0)]
.groupby('col')['shift-x']
.median()
.sum()
)
shift_y = (
df_well_col[(df_well_col['col'] <= _col) & (df_well_col['shift-y'] != 0)]
.groupby('col')['shift-y']
.median()
.sum()
)

df_well_col.loc['/'.join((well, anchor)), ['shift-x', 'shift-y']] = (
shift_x,
shift_y,
)

df[(df['direction'] == 'col') & (df['well'] == well)] = df_well_col

df.sort_index(inplace=True) # TODO: remember to sort index after any additions

total_shift = []
for well in df['well'].unique():
Expand All @@ -522,18 +614,11 @@ def compute_total_translation(csv_filepath: str) -> pd.DataFrame:
col_shifts = _df.groupby('row')[['shift-x', 'shift-y']].cumsum()
_df = df[(df['direction'] == 'row') & (df['well'] == well)]
row_shifts = _df.groupby('col')[['shift-x', 'shift-y']].cumsum()
# total shift is the sum of row and column shifts
_total_shift = col_shifts.add(row_shifts, fill_value=0)

# add row 000000
_total_shift = pd.concat(
[pd.DataFrame({'shift-x': 0, 'shift-y': 0}, index=['000000']), _total_shift]
)

# add global offset to remove negative values
_total_shift['shift-x'] += -np.minimum(_total_shift['shift-x'].min(), 0)
_total_shift['shift-y'] += -np.minimum(_total_shift['shift-y'].min(), 0)
_total_shift.set_index(well + '/' + _total_shift.index, inplace=True)
total_shift.append(_total_shift)

return pd.concat(total_shift)
27 changes: 19 additions & 8 deletions mantis/cli/estimate_stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,24 @@ def estimate_stitch(
# here we assume that all wells have the same fov grid
click.echo('Indexing input zarr store')
wells = list(set([Path(*p.parts[-3:-1]) for p in input_position_dirpaths]))
grid_rows, grid_cols = get_grid_rows_cols(input_zarr_path)
row_fov0 = [col + row for row in grid_rows[:-1] for col in grid_cols]
row_fov1 = [col + row for row in grid_rows[1:] for col in grid_cols]
col_fov0 = [col + row for col in grid_cols[:-1] for row in grid_rows]
col_fov1 = [col + row for col in grid_cols[1:] for row in grid_rows]
fov_names = set([p.name for p in input_position_dirpaths])
grid_rows, grid_cols = get_grid_rows_cols(fov_names)

# account for non-square grids
row_fov_pairs, col_fov_pairs = [], []
for col in grid_cols:
for row0, row1 in zip(grid_rows[:-1], grid_rows[1:]):
fov0 = col + row0
fov1 = col + row1
if fov0 in fov_names and fov1 in fov_names:
row_fov_pairs.append((fov0, fov1))
for row in grid_rows:
for col0, col1 in zip(grid_cols[:-1], grid_cols[1:]):
fov0 = col0 + row
fov1 = col1 + row
if fov0 in fov_names and fov1 in fov_names:
col_fov_pairs.append((fov0, fov1))

estimate_shift_params = {
"tcz_index": tcz_idx,
"percent_overlap": percent_overlap,
Expand Down Expand Up @@ -134,9 +147,7 @@ def estimate_stitch(
click.echo('Estimating FOV shifts...')
shifts, jobs = [], []
for well_name in wells:
for direction, fovs in zip(
("row", "col"), (zip(row_fov0, row_fov1), zip(col_fov0, col_fov1))
):
for direction, fovs in zip(("row", "col"), (row_fov_pairs, col_fov_pairs)):
for fov0, fov1 in fovs:
fov0_zarr_path = Path(input_zarr_path, well_name, fov0)
fov1_zarr_path = Path(input_zarr_path, well_name, fov1)
Expand Down
Loading