From 2495a5beaf7d6ec0802a109dc4c7a092e2228106 Mon Sep 17 00:00:00 2001 From: Alejandro Gaston Alvarez Franceschi Date: Wed, 15 Nov 2023 16:39:57 +0100 Subject: [PATCH] Fixes --- .../mil/frontend/torch/test/test_torch_ops.py | 20 ++++--- coremltools/converters/mil/mil/input_type.py | 2 +- .../mil/mil/ops/defs/complex_dialect_ops.py | 6 +-- .../passes/defs/lower_complex_dialect_ops.py | 54 ++++++++++--------- 4 files changed, 42 insertions(+), 40 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index aea09e0b9..b3fb470da 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -9504,9 +9504,8 @@ def test_stft(self, compute_unit, backend, input_shape, complex, n_fft, hop_leng class STFTModel(torch.nn.Module): def forward(self, x): applied_window = window(win_length) if window and win_length else None - x = torch.complex(x, x) if complex else x x = torch.stft( - x, + torch.complex(x, x) if complex else x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, @@ -9534,28 +9533,26 @@ class TestISTFT(TorchBaseTest): compute_units, backends, [(1, 32, 9), (32, 9), (3, 32, 9)], # input shape - [False, True], # complex [16], # n_fft [None, 4, 5], # hop_length [None, 16, 9], # win_length [None, torch.hann_window], # window [None, False, True], # center - ["constant", "reflect", "replicate"], # pad mode [False, True], # normalized [None, False, True], # onesided [None, 60], # length + [False, True], # return_complex ) ) - def test_istft(self, compute_unit, backend, input_shape, complex, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): - if complex and onesided: - pytest.skip("Onesided stft not possible for complex inputs") + def test_istft(self, compute_unit, backend, input_shape, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex): + if return_complex and onesided: + pytest.skip("Complex output is incompatible with onesided") class ISTFTModel(torch.nn.Module): def forward(self, x): applied_window = window(win_length) if window and win_length else None - x = torch.complex(x, x) x = torch.istft( - x, + torch.complex(x, x), n_fft=n_fft, hop_length=hop_length, win_length=win_length, @@ -9564,8 +9561,9 @@ def forward(self, x): normalized=normalized, onesided=onesided, length=length, - return_complex=True) - x = torch.stack([torch.real(x), torch.imag(x)], dim=0) + return_complex=return_complex) + if return_complex: + x = torch.stack([torch.real(x), torch.imag(x)], dim=0) return x TorchBaseTest.run_compare_torch( diff --git a/coremltools/converters/mil/mil/input_type.py b/coremltools/converters/mil/mil/input_type.py index 97f674974..51d5e9351 100644 --- a/coremltools/converters/mil/mil/input_type.py +++ b/coremltools/converters/mil/mil/input_type.py @@ -251,7 +251,7 @@ class TensorInputType(_InputType): class conv(Operation): input_spec = InputSpec( x=TensorInputType(type_domain="T"), - weight=TensorInputType(type_domain="U"), + weight=TensorInputType(type_domain="T"), ) type_domains = { diff --git a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py index 44d262d25..ea9c13ce4 100644 --- a/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/ops/defs/complex_dialect_ops.py @@ -893,6 +893,7 @@ class complex_istft(Operation): Attributes ---------- + V: complex64 T: fp32, complex64 References @@ -901,7 +902,7 @@ class complex_istft(Operation): """ input_spec = InputSpec( - input=TensorInputType(type_domain="T"), + input=TensorInputType(type_domain="V"), n_fft=TensorInputType(const=True, type_domain=types.int32), hop_length=TensorInputType(const=True, optional=True, type_domain=types.int32), win_length=TensorInputType(const=True, optional=True, type_domain=types.int32), @@ -912,7 +913,7 @@ class complex_istft(Operation): ) type_domains = { - "T": (types.fp32, types.complex64), + "V": types.complex64, } def default_inputs(self): @@ -937,7 +938,6 @@ def type_inference(self): output_shape += [self.length] return types.tensor(output_type, tuple(output_shape)) - n_frames = self.input.shape[-1] output_shape = self.n_fft.val + self.hop_length.val * (n_frames - 1) diff --git a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py index 2c37c9b1e..0e14f7249 100644 --- a/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py +++ b/coremltools/converters/mil/mil/passes/defs/lower_complex_dialect_ops.py @@ -325,7 +325,7 @@ def _stft( We can write STFT in terms of convolutions with a DFT kernel. At the end: * The real part output is: cos_base * input_real + sin_base * input_imag - * The imaginary part output is: - (sin_base * input_real - cos_base * input_imag) + * The imaginary part output is: cos_base * input_imag - sin_base * input_real Adapted from: https://github.com/adobe-research/convmelspec/blob/main/convmelspec/mil.py """ hop_length = hop_length or mb.floor_div(x=n_fft, y=4, before_op=before_op) @@ -358,12 +358,13 @@ def _stft( if input_imaginary: signal_imaginary = mb.expand_dims(x=input_imaginary, axes=(1,), before_op=before_op) - # conv with DFT kernel across the input signal - # The DFT matrix is obtained with the equation e^(2pi/N i) but the definition is: - # DFT(x) => X[k] = Σx[n]*e^(-2kpi/N i) - # If x is complex then x[n]=(a+i*b) - # So the real part = Σ(a*cos(2kpi/N)+b*sin(2kpi/N)) - # So the imag part = Σ(b*cos(2kpi/N)-a*sin(2kpi/N)) + # Convolve the DFT kernel with the input signal + # DFT(x[n]) --> X[k] = Σx[n]*e^(-2π*n/N*k), then if x is complex x[n]=(a[n]+i*b[n]) + # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)+b[n]*sin(2π*n/N*k)) + # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)-a[n]*sin(2π*n/N*k)) + # But because our DFT matrix is obtained with the conjugate --> e^(2π*n/N*k): + # real(X[k]) = Σ(a[n]*cos(2π*n/N*k)-b[n]*sin(2π*n/N*k)) + # imag(X[k]) = Σ(b[n]*cos(2π*n/N*k)+a[n]*sin(2π*n/N*k)) cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) if input_imaginary: @@ -372,11 +373,11 @@ def _stft( # add everything together if input_imaginary: - real_result = mb.add(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) - imag_result = mb.sub(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) + real_result = mb.sub(x=cos_windows_real, y=sin_windows_imag, before_op=before_op) + imag_result = mb.add(x=cos_windows_imag, y=sin_windows_real, before_op=before_op) else: real_result = cos_windows_real - imag_result = mb.sub(x=0., y=sin_windows_real, before_op=before_op) + imag_result = sin_windows_real # reduce the rank of the output if should_increase_rank: @@ -417,10 +418,10 @@ def _istft( # By default, use the entire frame win_length = win_length or n_fft - input_shape = mb.shape(x=x, before_op=before_op) + input_shape = mb.shape(x=input_real, before_op=before_op) n_frames = input_shape.val[-1] fft_size = input_shape.val[-2] - # expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) + expected_output_signal_len = n_fft.val + hop_length.val * (n_frames - 1) is_onesided = onesided.val if onesided else fft_size != n_fft cos_base, sin_base = _calculate_dft_matrix(n_fft, onesided=is_onesided, before_op=before_op) @@ -447,14 +448,13 @@ def _istft( signal_real = mb.mul(x=signal_real, y=multiplier, before_op=before_op) signal_imaginary = mb.mul(x=signal_imaginary, y=multiplier, before_op=before_op) - # Conv with DFT kernel across the input signal - # We can describe the IDFT in terms of DFT just by swapping the input and output + # Convolve the DFT kernel with the input signal + # We can describe the IDFT in terms of DFT just by swapping the input and output. # ref: https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Expressing_the_inverse_DFT_in_terms_of_the_DFT - # So IDFT(x) = (1/N) * swap(DFT(swap(x))) - # and DFT(x) = X[k] = Σx[n]*e^(-2kpi/N i) but we are using the conjugate e^(2kpi/N i) - # If x is complex then x[n]=(a+i*b) - # then real part = (1/N)*Σ(a*cos(2kpi/N)+b*sin(2kpi/N)) - # then imag part = (1/N)*Σ(b*cos(2kpi/N)-a*sin(2kpi/N)) + # IDFT(X[K]) --> x[n]=(1/N)*swap(DFT(swap(X[k]))), and K=N + # So using the definition in stft function, we get: + # real(x[n]) = Σ(a[k]*cos(2π*k/K*n)+b[k]*sin(2π*k/K*n)) + # imag(x[n]) = Σ(b[k]*cos(2π*k/K*n)-a[k]*sin(2π*k/K*n)) cos_windows_real = mb.conv(x=signal_real, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) sin_windows_real = mb.conv(x=signal_real, weight=sin_base, strides=hop_size, pad_type='valid', before_op=before_op) cos_windows_imag = mb.conv(x=signal_imaginary, weight=cos_base, strides=hop_size, pad_type='valid', before_op=before_op) @@ -750,17 +750,21 @@ def _lower_complex_istft(op: Operation): is_complex = types.is_complex(op.input.dtype) # check parameters for validity + if is_complex: + raise ValueError("Only complex inputs are allowed") if op.win_length and op.win_length.val > op.n_fft.val: raise ValueError("Window length must be less than or equal to n_fft") - if is_complex and op.onesided and op.onesided.val: - raise ValueError("Onesided is only valid for real inputs") + if op.return_complex and op.onesided and op.onesided.val: + raise ValueError("Complex output is not compatible with onesided") real, imag = _istft( - op.input.real if is_complex else op.input, - op.input.imag if is_complex else None, - op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, before_op=op) + op.input.real, op.input.imag, + op.n_fft, op.hop_length, op.win_length, op.window, op.normalized, op.onesided, op.length, before_op=op) - return _wrap_complex_output(op.outputs[0], real, imag) + if op.return_complex: + return _wrap_complex_output(op.outputs[0], real, imag) + else + return real @LowerComplex.register_lower_func(op_type="complex_shape")