Skip to content

Commit

Permalink
6873 data analyzer histogram_only=True fix (#6874)
Browse files Browse the repository at this point in the history
Fixes #6873 

### Description
- fixes data analyzer
- replace `"image_stats"` with `DataStatsKeys.IMAGE_STATS`

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Aug 16, 2023
1 parent e24b969 commit 617c1be
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
7 changes: 5 additions & 2 deletions monai/apps/auto3dseg/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def _check_data_uniformity(keys: list[str], result: dict) -> bool:
"""

if DataStatsKeys.SUMMARY not in result or DataStatsKeys.IMAGE_STATS not in result[DataStatsKeys.SUMMARY]:
return True
constant_props = [result[DataStatsKeys.SUMMARY][DataStatsKeys.IMAGE_STATS][key] for key in keys]
for prop in constant_props:
if "stdev" in prop and np.any(prop["stdev"]):
Expand Down Expand Up @@ -358,10 +360,11 @@ def _get_all_case_stats(
stats_by_cases = {
DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH],
DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH],
DataStatsKeys.IMAGE_STATS: d[DataStatsKeys.IMAGE_STATS],
}
if not self.histogram_only:
stats_by_cases[DataStatsKeys.IMAGE_STATS] = d[DataStatsKeys.IMAGE_STATS]
if self.hist_bins != 0:
stats_by_cases.update({DataStatsKeys.IMAGE_HISTOGRAM: d[DataStatsKeys.IMAGE_HISTOGRAM]})
stats_by_cases[DataStatsKeys.IMAGE_HISTOGRAM] = d[DataStatsKeys.IMAGE_HISTOGRAM]

if self.label_key is not None:
stats_by_cases.update(
Expand Down
16 changes: 10 additions & 6 deletions monai/auto3dseg/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class ImageStats(Analyzer):
"""

def __init__(self, image_key: str, stats_name: str = "image_stats") -> None:
def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS) -> None:
if not isinstance(image_key, str):
raise ValueError("image_key input must be str")

Expand Down Expand Up @@ -296,7 +296,7 @@ class FgImageStats(Analyzer):
"""

def __init__(self, image_key: str, label_key: str, stats_name: str = "image_foreground_stats"):
def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKeys.FG_IMAGE_STATS):
self.image_key = image_key
self.label_key = label_key

Expand Down Expand Up @@ -378,7 +378,9 @@ class LabelStats(Analyzer):
"""

def __init__(self, image_key: str, label_key: str, stats_name: str = "label_stats", do_ccp: bool | None = True):
def __init__(
self, image_key: str, label_key: str, stats_name: str = DataStatsKeys.LABEL_STATS, do_ccp: bool | None = True
):
self.image_key = image_key
self.label_key = label_key
self.do_ccp = do_ccp
Expand Down Expand Up @@ -533,7 +535,7 @@ class ImageStatsSumm(Analyzer):
"""

def __init__(self, stats_name: str = "image_stats", average: bool | None = True):
def __init__(self, stats_name: str = DataStatsKeys.IMAGE_STATS, average: bool | None = True):
self.summary_average = average
report_format = {
ImageStatsKeys.SHAPE: None,
Expand Down Expand Up @@ -623,7 +625,7 @@ class FgImageStatsSumm(Analyzer):
"""

def __init__(self, stats_name: str = "image_foreground_stats", average: bool | None = True):
def __init__(self, stats_name: str = DataStatsKeys.FG_IMAGE_STATS, average: bool | None = True):
self.summary_average = average

report_format = {ImageStatsKeys.INTENSITY: None}
Expand Down Expand Up @@ -687,7 +689,9 @@ class LabelStatsSumm(Analyzer):
"""

def __init__(self, stats_name: str = "label_stats", average: bool | None = True, do_ccp: bool | None = True):
def __init__(
self, stats_name: str = DataStatsKeys.LABEL_STATS, average: bool | None = True, do_ccp: bool | None = True
):
self.summary_average = average
self.do_ccp = do_ccp

Expand Down
4 changes: 2 additions & 2 deletions monai/auto3dseg/seg_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def __init__(
self.summary_analyzers: list[Any] = []
super().__init__()

self.add_analyzer(FilenameStats(image_key, DataStatsKeys.BY_CASE_IMAGE_PATH), None)
self.add_analyzer(FilenameStats(label_key, DataStatsKeys.BY_CASE_LABEL_PATH), None)
if not self.histogram_only:
self.add_analyzer(FilenameStats(image_key, DataStatsKeys.BY_CASE_IMAGE_PATH), None)
self.add_analyzer(FilenameStats(label_key, DataStatsKeys.BY_CASE_LABEL_PATH), None)
self.add_analyzer(ImageStats(image_key), ImageStatsSumm(average=average))

if label_key is None:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_auto3dseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,21 @@ def test_data_analyzer_cpu(self, input_params):

assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"])

def test_data_analyzer_histogram(self):
create_sim_data(
self.dataroot_dir, sim_datalist, [32] * 3, image_only=True, rad_max=8, rad_min=1, num_seg_classes=1
)
analyser = DataAnalyzer(
self.datalist_file,
self.dataroot_dir,
output_path=self.datastat_file,
label_key=None,
device=device,
histogram_only=True,
)
datastat = analyser.get_all_case_stats()
assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"])

@parameterized.expand(SIM_GPU_TEST_CASES)
@skip_if_no_cuda
def test_data_analyzer_gpu(self, input_params):
Expand Down

0 comments on commit 617c1be

Please sign in to comment.