-
Notifications
You must be signed in to change notification settings - Fork 408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a CW-SSIM support for torchmetrics #2428
Comments
Hi @michael080808, thanks for opening this issue. We would be more than welcome to receive a pull request with this metric (either a partial implementation or full), but I do not think anyone at the core team has bandwidth or the experience to implement such a complex metric at the moment. If you can point me to a specific reference implementation, maybe I can give it a stab. |
There are some implementations as mentioned. |
@michael080808 could you point me to which one of these have the most promising reference implementation? What specific file are the metric implemented in? |
Others are the references for DTCWT implementation. According to my understanding, once Steerable Pyramid or DTCWT has been implemented, CW-SSIM will be very easy to code because CW-SSIM just uses the transform result from SP or DTCWT to calculate the new SSIM definition. Maybe it's very hard to implement DTCWT without some reference codes. The following two are specific for DTCWT implementation.
I hope that this information is helpful. If there are any other questions, I would be pleased to answer them. 🙂❤ |
@michael080808 thanks for providing this overview, it really helps me. |
I got a quick learn with
and tried a simple version of CW-SSIM. Here are two parts of the codes running with PyTorch 2.2. I did some coordinate changes for better calculation when input width or height is with odd number. I think it should pay more attention with complex convolution support. It's a very new feature and CW-SSIM heavily depends on it. I hope it could be helpful for understand. #pyramid.py
"""
Put [0, Length - 1] into [-1, 1]
I prefer use pixel center as coordinate position
+-----+-----+-----+
| | | |
| A | B | C |
| | | |
+-----+-----+ +-----+-----+-----+
| | | | | | |
| A | B | | D | O | E |
| | | | | | |
+-----O-----+ +-----+-----+-----+
| | | | | | |
| C | D | | F | G | H |
| | | | | | |
+-----+-----+ +-----+-----+-----+
Here, O is the coordinate origin.
In even amount of pixels situation, A, B, C, D's coordinates are with half values.
In odd amount of pixels situation, A, B, C, D, E, F, G, H's coordinates are without half values.
"""
import functools
import itertools
import math
import operator
from abc import ABCMeta, abstractmethod
from typing import List, Tuple, Union
import torch.fft
from torch import Tensor
from torch.types import Device
class SteerablePyramid:
class _Filter(metaclass=ABCMeta):
@staticmethod
def bound_convert_2_tuple(boundary: Union[float, Tuple[float], Tuple[float, float]]) -> Tuple[float, float]:
if isinstance(boundary, float):
boundary = (boundary,)
if isinstance(boundary, tuple) and len(boundary) == 1:
boundary = (boundary[0], boundary[0])
return boundary[0], boundary[1]
@staticmethod
def normalized_lin_spaces(length: int, device: Device = None) -> Tensor:
number = torch.arange(length, device=device)
return (number - length / 2 + 0.5) / (length // 2)
@staticmethod
def normalized_coordinate(shapes: Tuple[int, int], device: Device = None) -> Tuple[Tensor, ...]:
coords = [SteerablePyramid._Filter.normalized_lin_spaces(length, device) for length in reversed(shapes)]
return torch.meshgrid(*list(reversed(coords)), indexing='ij')
@staticmethod
def polars(shapes: Tuple[int, int], device: Device = None) -> Tuple[Tensor, Tensor]:
x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
return torch.sqrt(x ** 2 + y ** 2), torch.arctan2(y, x)
@staticmethod
def angles(shapes: Tuple[int, int], device: Device = None) -> Tensor:
x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
return torch.arctan2(y, x)
@staticmethod
def radius(shapes: Tuple[int, int], device: Device = None) -> Tensor:
x, y = SteerablePyramid._Filter.normalized_coordinate(shapes, device)
return torch.sqrt(x ** 2 + y ** 2)
@staticmethod
def bounds(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]], device: Device = None) -> Tensor:
boundary = SteerablePyramid._Filter.bound_convert_2_tuple(boundary)
return (boundary[0] * boundary[1]) / torch.sqrt((boundary[0] * torch.cos(SteerablePyramid._Filter.angles(shapes, device))) ** 2 + (boundary[1] * torch.sin(SteerablePyramid._Filter.angles(shapes, device))) ** 2)
@staticmethod
def high_band_pass_filter(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, device: Device = None):
diff = torch.log2(SteerablePyramid._Filter.radius(shapes, device)) - torch.log2(SteerablePyramid._Filter.bounds(shapes, boundary, device))
return torch.abs(torch.cos((torch.clamp(diff, min=-transition_width, max=0) / transition_width) * (math.pi / 2)))
@staticmethod
def bass_band_pass_filter(shapes: Tuple[int, int], boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, device: Device = None):
high = SteerablePyramid._Filter.high_band_pass_filter(shapes, boundary=boundary, transition_width=transition_width, device=device)
return torch.sqrt(1 - high ** 2)
@abstractmethod
def __init__(self):
super().__init__()
@abstractmethod
def __call__(self, shapes: Tuple[int, int], device: Device = None):
raise NotImplementedError
class BassPassFilter(_Filter):
def __init__(self, boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
super().__init__()
self.boundary = boundary
self.transition_width = transition_width
def __call__(self, shapes: Tuple[int, int], device: Device = None):
return self.bass_band_pass_filter(shapes=shapes, boundary=self.boundary, transition_width=self.transition_width, device=device)
class HighPassFilter(_Filter):
def __init__(self, boundary: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
super().__init__()
self.boundary = boundary
self.transition_width = transition_width
def __call__(self, shapes: Tuple[int, int], device: Device = None):
return self.high_band_pass_filter(shapes=shapes, boundary=self.boundary, transition_width=self.transition_width, device=device)
class BandPassFilter(_Filter):
def __init__(self, boundary_high: Union[float, Tuple[float], Tuple[float, float]] = 1.0, boundary_bass: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0):
super().__init__()
assert functools.reduce(operator.__and__, itertools.starmap(operator.ge, zip((boundary_high,) if isinstance(boundary_high, float) else boundary_high, (boundary_bass,) if isinstance(boundary_bass, float) else boundary_bass))), 'All elements from "boundary_high" must be greater than or equal to the corresponding elements in "boundary_bass".'
self.boundary_bass, self.boundary_high, self.transition_width = boundary_bass, boundary_high, transition_width
def __call__(self, shapes: Tuple[int, int], device: Device = None):
return self.bass_band_pass_filter(shapes=shapes, boundary=self.boundary_high, transition_width=self.transition_width, device=device) * self.high_band_pass_filter(shapes=shapes, boundary=self.boundary_bass, transition_width=self.transition_width, device=device)
class SteeringFilter(BandPassFilter):
def __init__(self, boundary_bass: Union[float, Tuple[float], Tuple[float, float]] = 1.0, boundary_high: Union[float, Tuple[float], Tuple[float, float]] = 1.0, transition_width: float = 1.0, index: int = 0, orientations: int = 2, support_cplx: bool = False):
super().__init__(boundary_bass, boundary_high, transition_width)
assert index < orientations, '"index" must be less than or equal to "orientations".'
self.index, self.orientations, self.support_cplx = index, orientations, support_cplx
def __call__(self, shapes: Tuple[int, int], device: Device = None):
return super().__call__(shapes, device) * self.orientation_filter(shapes, self.support_cplx, device)
@property
def constant(self):
order = self.orientations - 1
return math.pow(2, (2 * order)) * math.pow(math.factorial(order), 2) / (self.orientations * math.factorial(2 * order))
def orientation_filter(self, shapes: Tuple[int, int], u4cplx: bool = False, device: Device = None):
angles = torch.remainder(math.pi + self.angles(shapes, device) - math.pi * self.index / self.orientations, 2 * math.pi) - math.pi
return (torch.abs(math.sqrt(self.constant) * torch.pow(torch.cos(angles), self.orientations - 1))) * (torch.lt(torch.abs(angles), math.pi / 2) if u4cplx else 1)
@staticmethod
def to_freq_domain(x: Tensor) -> Tensor:
assert x.dim() >= 2, 'Not enough dimensions to run "to_freq_domain" procedure.'
return torch.fft.fftshift(torch.fft.fft2(x, dim=[-2, -1]), dim=[-2, -1])
@staticmethod
def to_time_domain(x: Tensor) -> Tensor:
assert x.dim() >= 2, 'Not enough dimensions to run "to_time_domain" procedure.'
return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-2, -1]), dim=[-2, -1])
@staticmethod
def to_crop_region(entire: Tuple[int, int], region: Tuple[int, int]) -> Tuple[List[int], ...]:
assert functools.reduce(operator.__and__, itertools.starmap(operator.ge, zip(entire, region))) and functools.reduce(operator.__and__, itertools.starmap(lambda x, y: (x - y) % 2 == 0, zip(entire, region))), 'All elements from "shapes" must be greater than or equal to the corresponding elements in "region".'
return tuple([(shape - focal) // 2, focal, (shape - focal) // 2] for shape, focal in zip(entire, region))
@staticmethod
def to_crop_tensor(inputs: Tensor, region: Tuple[int, int]) -> Tensor:
splits = SteerablePyramid.to_crop_region(entire=(inputs.shape[-2], inputs.shape[-1]), region=region)
return torch.split(torch.split(inputs, splits[-1], dim=-1)[1], splits[-2], dim=-2)[1]
@staticmethod
def to_join_tensor(fronts: Tensor, backed: Tensor) -> Tensor:
assert fronts.dim() == backed.dim() >= 2 and fronts.shape[-1] < backed.shape[-1] and fronts.shape[-2] < backed.shape[-2] and fronts.shape[:-2] == backed.shape[:-2], 'Unable to join two tensors into one due to the shape mismatch.'
return torch.nn.functional.pad(fronts, [(backed.shape[-1] - fronts.shape[-1]) // 2] * 2 + [(backed.shape[-2] - fronts.shape[-2]) // 2] * 2, mode='constant', value=0) + backed * torch.nn.functional.pad(torch.zeros_like(fronts), [(backed.shape[-1] - fronts.shape[-1]) // 2] * 2 + [(backed.shape[-2] - fronts.shape[-2]) // 2] * 2, mode='constant', value=1)
def __init__(self, group_levels: int = 6, orientations: int = 16, support_cplx: bool = True, transition_w: float = 1.0):
super().__init__()
self.group_levels = group_levels
self.orientations = orientations
self.support_cplx = support_cplx
self.transition_w = transition_w
def region_iteration(self, shapes: Tuple[int, int]):
last = shapes
yield last
for i in range(self.group_levels):
last = (last[0] - (last[0] // 4) * 2, last[1] - (last[1] // 4) * 2)
yield last
def factor_iteration(self, shapes: Tuple[int, int]):
last = shapes
yield tuple(itertools.starmap(operator.truediv, zip(last, shapes)))
for _ in range(self.group_levels):
last = (last[0] - (last[0] // 4) * 2, last[1] - (last[1] // 4) * 2)
yield tuple(itertools.starmap(operator.truediv, zip(last, shapes)))
def filter_iteration(self, shapes: Tuple[int, int], device: Device = None):
iteration = zip(itertools.pairwise(self.factor_iteration(shapes)), self.region_iteration(shapes))
for level, ((prev_f, curr_f), region) in enumerate(iteration):
if level == 0:
yield self.to_crop_tensor(self.HighPassFilter(boundary=prev_f, transition_width=self.transition_w)(shapes, device), region), f'H{level}'
yield self.to_crop_tensor(self.BassPassFilter(boundary=prev_f, transition_width=self.transition_w)(shapes, device), region), f'L{level}'
for orientation in range(self.orientations):
yield self.to_crop_tensor(self.SteeringFilter(boundary_bass=prev_f, boundary_high=curr_f, transition_width=self.transition_w, index=orientation, orientations=self.orientations, support_cplx=self.support_cplx)(shapes, device), region), f'B{level + 1}o{orientation}'
yield self.to_crop_tensor(self.BassPassFilter(boundary=curr_f, transition_width=self.transition_w)(shapes, device), region), f'L{level + 1}'
def encode_iteration(self, tensor: Tensor):
shapes = (tensor.shape[-2], tensor.shape[-1])
target, window = self.to_freq_domain(tensor), None
it_filter = self.filter_iteration(shapes, tensor.device)
# L0 HighPass Output
window = next(it_filter)
time_domain = self.to_time_domain(target * window[0])
yield time_domain if self.support_cplx else torch.real(time_domain), window[1]
# L0 BassPass Remove
window = next(it_filter)
target = target * window[0]
# yield time_domain if self.support_cplx else torch.real(time_domain), window[1] <- Removed due to definition.
# Each Level Steering BandPass
for level, (curr_r, next_r) in enumerate(itertools.pairwise(self.region_iteration(shapes))):
for orientation in range(self.orientations):
window = next(it_filter)
time_domain = self.to_time_domain(target * window[0])
yield time_domain if self.support_cplx else torch.real(time_domain), window[1]
window = next(it_filter)
target = self.to_crop_tensor(target * window[0], next_r)
# Final BassPass
time_domain = self.to_time_domain(target)
yield time_domain if self.support_cplx else torch.real(time_domain), window[1] # main.py
from typing import Tuple
import skimage
import torch
from skimage.data import astronaut
from torch import Tensor
from torch.nn import Module
from pyramid import SteerablePyramid
class CwSSIM(Module):
result_pyramid = SteerablePyramid()
ground_pyramid = SteerablePyramid()
def __init__(self, kernel: int = 7, k: float = 0, levels: int = 6, orientations: int = 16, transition_w: float = 1.0):
super().__init__()
self.k = k
self.kernel = torch.ones([kernel] * 2)[None, None, ...]
self.result_pyramid = SteerablePyramid(group_levels=levels, orientations=orientations, transition_w=transition_w, support_cplx=True)
self.ground_pyramid = SteerablePyramid(group_levels=levels, orientations=orientations, transition_w=transition_w, support_cplx=True)
def multidim_conv2d(self, inputs: Tensor, *args, **kwargs) -> Tensor:
if inputs.dim() <= 1:
raise ValueError('One Dimensional Input is not supported.')
channels = inputs.shape[-3] if inputs.dim() >= 3 else 1
paddings = [self.kernel.size(dim) // 2 for _ in range(2) for dim in [-1, -2]]
groups = args[4] if len(args) >= 5 else kwargs.get('groups', channels)
kwargs['groups'] = groups
shapes = inputs.shape
kernel = self.kernel.repeat(1, 1, 1, 1) if inputs.dim() == 2 else self.kernel.repeat(channels, channels // groups, 1, 1)
if inputs.dim() >= 5:
return torch.nn.functional.conv2d(torch.nn.functional.pad(inputs.flatten(0x0, -0x4), paddings, mode='reflect'), kernel.to(device=inputs.device, dtype=inputs.dtype), *args, **kwargs).unflatten(0, shapes[:-3])
if 2 <= inputs.dim() <= 4:
return torch.nn.functional.conv2d(torch.nn.functional.pad(inputs.repeat(1, 1, 1, 1), paddings, mode='reflect'), kernel.to(device=inputs.device, dtype=inputs.dtype), *args, **kwargs).squeeze(tuple(range(0, 4 - inputs.dim())))
def statistics(self, result: Tensor, ground: Tensor, *args, **kwargs) -> Tuple[Tensor, ...]:
conj_prods = self.multidim_conv2d(result * torch.conj(ground), *args, **kwargs)
sum_mod_sq = self.multidim_conv2d(torch.abs(result) ** 2 + torch.abs(ground) ** 2, *args, **kwargs)
return conj_prods, sum_mod_sq
def forward(self, result: Tensor, ground: Tensor, *args, **kwargs) -> Tensor:
assert result.shape == ground.shape
result_encode_iter = self.__class__.result_pyramid.encode_iteration(result)
ground_encode_iter = self.__class__.ground_pyramid.encode_iteration(ground)
count, summarized = 0, torch.zeros(1)
for (result_encode, _), (ground_encode, _) in zip(result_encode_iter, ground_encode_iter):
conj_prods, sum_mod_sq = self.statistics(result_encode, ground_encode, *args, **kwargs)
_ssim = (2 * torch.abs(conj_prods) + self.k) / (sum_mod_sq + self.k)
count, summarized = count + 1, summarized + torch.mean(_ssim, dim=[-2, -1], keepdim=True)
return summarized / count
ssim = CwSSIM()
if __name__ == '__main__':
image = skimage.util.img_as_float32(astronaut())
noise = skimage.util.random_noise(image, mode='speckle')
print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))
image = skimage.util.img_as_float32(astronaut())
noise = 0.8 * image
print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))
image = skimage.util.img_as_float32(astronaut())
noise = skimage.transform.rotate(image, 1)
print(ssim(torch.tensor(image).movedim(-1, 0), torch.tensor(noise).movedim(-1, 0)))
|
@michael080808 is what you posted a full implementation? It does not seem to rely on any third party package? |
It's a relatively full implementation. I did not write |
🚀 Feature
A Complex-Wavelets Structure Similarity (also know as CW-SSIM) support with Steerable Pyramid (SP) or Dual-Tree Complex Wavelet Transform Method (DTCWT). Maybe support all possible Q-shift and first level filters as well?
Motivation
I noticed that someone just mentioned in #799
For some research purpose, I found that there is a few project with CW-SSIM code but they are not updated to the latest pytorch version. Will torchmetrics add a CW-SSIM support in
torchmetrics.image
? Here are some collections with old codes.I just tried to use latest pytorch version (which has supported complex convolution and it's important to my usage) to achieve this function. It's very difficult for me to understand the math formula for all the complex wavelet things. If there are some suggestions on math I would appreciate for that. Please do not make confuse with Complex Wavelet Transform and Continuous Wavelet Transform because of the
CWT
abbreviation.Pitch
Alternatives
I tried the scipy and pywavelets but they do not support SP or DTCWT. Only Discrete Wavelet Transform (DWT) with multi-dimensional support and Continuous Wavelet Transform are included. Listed projects are too old to run on latest PyTorch.
Additional context
If there are some further math discussion about SP or DTCWT, I'll try to achieve myself and pull request to the torchmetrics. I'm very confused with the relationship of scaling function and wavelet function and whether it should be considered in SP or DTCWT. How does the SP or DTCWT's filters be calucated? Sorry for my pool math about the discrete and continous domain conversion.
The text was updated successfully, but these errors were encountered: