# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple
import torch
from einops import rearrange
from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like
from nemo.core.classes import NeuralModule, typecheck
from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType, SpectrogramType
from nemo.utils import logging
[docs]
class AudioToSpectrogram(NeuralModule):
"""Transform a batch of input multi-channel signals into a batch of
STFT-based spectrograms.
Args:
fft_length: length of FFT
hop_length: length of hops/shifts of the sliding window
power: exponent for magnitude spectrogram. Default `None` will
return a complex-valued spectrogram
magnitude_power: Transform magnitude of the spectrogram as x^magnitude_power.
scale: Positive scaling of the spectrogram.
"""
def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0):
super().__init__()
# For now, assume FFT length is divisible by two
if fft_length % 2 != 0:
raise ValueError(f'fft_length = {fft_length} must be divisible by 2')
self.fft_length = fft_length
self.hop_length = hop_length
self.pad_mode = 'constant'
window = torch.hann_window(self.win_length)
self.register_buffer('window', window)
self.num_subbands = fft_length // 2 + 1
if magnitude_power <= 0:
raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}')
self.magnitude_power = magnitude_power
if scale <= 0:
raise ValueError(f'Scale needs to be positive: current value {scale}')
self.scale = scale
logging.debug('Initialized %s with:', self.__class__.__name__)
logging.debug('\tfft_length: %s', fft_length)
logging.debug('\thop_length: %s', hop_length)
logging.debug('\tmagnitude_power: %s', magnitude_power)
logging.debug('\tscale: %s', scale)
@property
def win_length(self) -> int:
return self.fft_length
[docs]
def stft(self, x: torch.Tensor):
"""Apply STFT as in torchaudio.transforms.Spectrogram(power=None)
Args:
x_spec: Input time-domain signal, shape (..., T)
Returns:
Time-domain signal ``x_spec = STFT(x)``, shape (..., F, N).
"""
# pack batch
B, C, T = x.size()
x = rearrange(x, 'B C T -> (B C) T')
x_spec = torch.stft(
input=x,
n_fft=self.fft_length,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=True,
pad_mode=self.pad_mode,
normalized=False,
onesided=True,
return_complex=True,
)
# unpack batch
x_spec = rearrange(x_spec, '(B C) F N -> B C F N', B=B, C=C)
return x_spec
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_length": NeuralType(('B',), LengthsType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"output_length": NeuralType(('B',), LengthsType()),
}
[docs]
@typecheck()
def forward(
self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert a batch of C-channel input signals
into a batch of complex-valued spectrograms.
Args:
input: Time-domain input signal with C channels, shape (B, C, T)
input_length: Length of valid entries along the time dimension, shape (B,)
Returns:
Output spectrogram with F subbands and N time frames, shape (B, C, F, N)
and output length with shape (B,).
"""
B, T = input.size(0), input.size(-1)
input = input.view(B, -1, T)
# STFT output (B, C, F, N)
with torch.amp.autocast(input.device.type, enabled=False):
output = self.stft(input.float())
if self.magnitude_power != 1:
# apply power on the magnitude
output = torch.pow(output.abs(), self.magnitude_power) * torch.exp(1j * output.angle())
if self.scale != 1:
# apply scaling of the coefficients
output = self.scale * output
if input_length is not None:
# Mask padded frames
output_length = self.get_output_length(input_length=input_length)
length_mask: torch.Tensor = make_seq_mask_like(
lengths=output_length, like=output, time_dim=-1, valid_ones=False
)
output = output.masked_fill(length_mask, 0.0)
else:
# Assume all frames are valid for all examples in the batch
output_length = output.size(-1) * torch.ones(B, device=output.device).long()
return output, output_length
[docs]
def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor:
"""Get length of valid frames for the output.
Args:
input_length: number of valid samples, shape (B,)
Returns:
Number of valid frames, shape (B,)
"""
# centered STFT results in (T // hop_length + 1) frames for T samples (cf. torch.stft)
output_length = input_length.div(self.hop_length, rounding_mode='floor').add(1).long()
return output_length
[docs]
class SpectrogramToAudio(NeuralModule):
"""Transform a batch of input multi-channel spectrograms into a batch of
time-domain multi-channel signals.
Args:
fft_length: length of FFT
hop_length: length of hops/shifts of the sliding window
magnitude_power: Transform magnitude of the spectrogram as x^(1/magnitude_power).
scale: Spectrogram will be scaled with 1/scale before the inverse transform.
"""
def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0):
super().__init__()
# For now, assume FFT length is divisible by two
if fft_length % 2 != 0:
raise ValueError(f'fft_length = {fft_length} must be divisible by 2')
self.fft_length = fft_length
self.hop_length = hop_length
window = torch.hann_window(self.win_length)
self.register_buffer('window', window)
self.num_subbands = fft_length // 2 + 1
if magnitude_power <= 0:
raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}')
self.magnitude_power = magnitude_power
if scale <= 0:
raise ValueError(f'Scale needs to be positive: current value {scale}')
self.scale = scale
logging.debug('Initialized %s with:', self.__class__.__name__)
logging.debug('\tfft_length: %s', fft_length)
logging.debug('\thop_length: %s', hop_length)
logging.debug('\tmagnitude_power: %s', magnitude_power)
logging.debug('\tscale: %s', scale)
@property
def win_length(self) -> int:
return self.fft_length
[docs]
def istft(self, x_spec: torch.Tensor):
"""Apply iSTFT as in torchaudio.transforms.InverseSpectrogram
Args:
x_spec: Input complex-valued spectrogram, shape (..., F, N)
Returns:
Time-domain signal ``x = iSTFT(x_spec)``, shape (..., T).
"""
# pack batch
B, C, F, N = x_spec.size()
x_spec = rearrange(x_spec, 'B C F N -> (B C) F N')
x = torch.istft(
input=x_spec,
n_fft=self.fft_length,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=True,
normalized=False,
onesided=True,
length=None,
return_complex=False,
)
# unpack batch
x = rearrange(x, '(B C) T -> B C T', B=B, C=C)
return x
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"input_length": NeuralType(('B',), LengthsType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"output": NeuralType(('B', 'C', 'T'), AudioSignal()),
"output_length": NeuralType(('B',), LengthsType()),
}
[docs]
@typecheck()
def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Convert input complex-valued spectrogram to a time-domain
signal. Multi-channel IO is supported.
Args:
input: Input spectrogram for C channels, shape (B, C, F, N)
input_length: Length of valid entries along the time dimension, shape (B,)
Returns:
Time-domain signal with T time-domain samples and C channels, (B, C, T)
and output length with shape (B,).
"""
B, F, N = input.size(0), input.size(-2), input.size(-1)
assert F == self.num_subbands, f'Number of subbands F={F} not matching self.num_subbands={self.num_subbands}'
input = input.view(B, -1, F, N)
if not input.is_complex():
raise ValueError("Expected `input` to be complex dtype.")
# iSTFT output (B, C, T)
with torch.amp.autocast(input.device.type, enabled=False):
output = input.cfloat()
if self.scale != 1:
# apply 1/scale on the coefficients
output = output / self.scale
if self.magnitude_power != 1:
# apply 1/power on the magnitude
output = torch.pow(output.abs(), 1 / self.magnitude_power) * torch.exp(1j * output.angle())
output = self.istft(output)
if input_length is not None:
# Mask padded samples
output_length = self.get_output_length(input_length=input_length)
length_mask: torch.Tensor = make_seq_mask_like(
lengths=output_length, like=output, time_dim=-1, valid_ones=False
)
output = output.masked_fill(length_mask, 0.0)
else:
# Assume all frames are valid for all examples in the batch
output_length = output.size(-1) * torch.ones(B, device=output.device).long()
return output, output_length
[docs]
def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor:
"""Get length of valid samples for the output.
Args:
input_length: number of valid frames, shape (B,)
Returns:
Number of valid samples, shape (B,)
"""
# centered STFT results in ((N-1) * hop_length) time samples for N frames (cf. torch.istft)
output_length = input_length.sub(1).mul(self.hop_length).long()
return output_length