Source code for nemo.collections.audio.modules.transforms

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple

import torch
from einops import rearrange

from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like
from nemo.core.classes import NeuralModule, typecheck
from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType, SpectrogramType
from nemo.utils import logging


[docs] class AudioToSpectrogram(NeuralModule): """Transform a batch of input multi-channel signals into a batch of STFT-based spectrograms. Args: fft_length: length of FFT hop_length: length of hops/shifts of the sliding window power: exponent for magnitude spectrogram. Default `None` will return a complex-valued spectrogram magnitude_power: Transform magnitude of the spectrogram as x^magnitude_power. scale: Positive scaling of the spectrogram. """ def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): super().__init__() # For now, assume FFT length is divisible by two if fft_length % 2 != 0: raise ValueError(f'fft_length = {fft_length} must be divisible by 2') self.fft_length = fft_length self.hop_length = hop_length self.pad_mode = 'constant' window = torch.hann_window(self.win_length) self.register_buffer('window', window) self.num_subbands = fft_length // 2 + 1 if magnitude_power <= 0: raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') self.magnitude_power = magnitude_power if scale <= 0: raise ValueError(f'Scale needs to be positive: current value {scale}') self.scale = scale logging.debug('Initialized %s with:', self.__class__.__name__) logging.debug('\tfft_length: %s', fft_length) logging.debug('\thop_length: %s', hop_length) logging.debug('\tmagnitude_power: %s', magnitude_power) logging.debug('\tscale: %s', scale) @property def win_length(self) -> int: return self.fft_length
[docs] def stft(self, x: torch.Tensor): """Apply STFT as in torchaudio.transforms.Spectrogram(power=None) Args: x_spec: Input time-domain signal, shape (..., T) Returns: Time-domain signal ``x_spec = STFT(x)``, shape (..., F, N). """ # pack batch B, C, T = x.size() x = rearrange(x, 'B C T -> (B C) T') x_spec = torch.stft( input=x, n_fft=self.fft_length, hop_length=self.hop_length, win_length=self.win_length, window=self.window, center=True, pad_mode=self.pad_mode, normalized=False, onesided=True, return_complex=True, ) # unpack batch x_spec = rearrange(x_spec, '(B C) F N -> B C F N', B=B, C=C) return x_spec
@property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports.""" return { "input": NeuralType(('B', 'C', 'T'), AudioSignal()), "input_length": NeuralType(('B',), LengthsType(), optional=True), } @property def output_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports.""" return { "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "output_length": NeuralType(('B',), LengthsType()), }
[docs] @typecheck() def forward( self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """Convert a batch of C-channel input signals into a batch of complex-valued spectrograms. Args: input: Time-domain input signal with C channels, shape (B, C, T) input_length: Length of valid entries along the time dimension, shape (B,) Returns: Output spectrogram with F subbands and N time frames, shape (B, C, F, N) and output length with shape (B,). """ B, T = input.size(0), input.size(-1) input = input.view(B, -1, T) # STFT output (B, C, F, N) with torch.amp.autocast(input.device.type, enabled=False): output = self.stft(input.float()) if self.magnitude_power != 1: # apply power on the magnitude output = torch.pow(output.abs(), self.magnitude_power) * torch.exp(1j * output.angle()) if self.scale != 1: # apply scaling of the coefficients output = self.scale * output if input_length is not None: # Mask padded frames output_length = self.get_output_length(input_length=input_length) length_mask: torch.Tensor = make_seq_mask_like( lengths=output_length, like=output, time_dim=-1, valid_ones=False ) output = output.masked_fill(length_mask, 0.0) else: # Assume all frames are valid for all examples in the batch output_length = output.size(-1) * torch.ones(B, device=output.device).long() return output, output_length
[docs] def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor: """Get length of valid frames for the output. Args: input_length: number of valid samples, shape (B,) Returns: Number of valid frames, shape (B,) """ # centered STFT results in (T // hop_length + 1) frames for T samples (cf. torch.stft) output_length = input_length.div(self.hop_length, rounding_mode='floor').add(1).long() return output_length
[docs] class SpectrogramToAudio(NeuralModule): """Transform a batch of input multi-channel spectrograms into a batch of time-domain multi-channel signals. Args: fft_length: length of FFT hop_length: length of hops/shifts of the sliding window magnitude_power: Transform magnitude of the spectrogram as x^(1/magnitude_power). scale: Spectrogram will be scaled with 1/scale before the inverse transform. """ def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): super().__init__() # For now, assume FFT length is divisible by two if fft_length % 2 != 0: raise ValueError(f'fft_length = {fft_length} must be divisible by 2') self.fft_length = fft_length self.hop_length = hop_length window = torch.hann_window(self.win_length) self.register_buffer('window', window) self.num_subbands = fft_length // 2 + 1 if magnitude_power <= 0: raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') self.magnitude_power = magnitude_power if scale <= 0: raise ValueError(f'Scale needs to be positive: current value {scale}') self.scale = scale logging.debug('Initialized %s with:', self.__class__.__name__) logging.debug('\tfft_length: %s', fft_length) logging.debug('\thop_length: %s', hop_length) logging.debug('\tmagnitude_power: %s', magnitude_power) logging.debug('\tscale: %s', scale) @property def win_length(self) -> int: return self.fft_length
[docs] def istft(self, x_spec: torch.Tensor): """Apply iSTFT as in torchaudio.transforms.InverseSpectrogram Args: x_spec: Input complex-valued spectrogram, shape (..., F, N) Returns: Time-domain signal ``x = iSTFT(x_spec)``, shape (..., T). """ # pack batch B, C, F, N = x_spec.size() x_spec = rearrange(x_spec, 'B C F N -> (B C) F N') x = torch.istft( input=x_spec, n_fft=self.fft_length, hop_length=self.hop_length, win_length=self.win_length, window=self.window, center=True, normalized=False, onesided=True, length=None, return_complex=False, ) # unpack batch x = rearrange(x, '(B C) T -> B C T', B=B, C=C) return x
@property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports.""" return { "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), "input_length": NeuralType(('B',), LengthsType(), optional=True), } @property def output_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports.""" return { "output": NeuralType(('B', 'C', 'T'), AudioSignal()), "output_length": NeuralType(('B',), LengthsType()), }
[docs] @typecheck() def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None) -> torch.Tensor: """Convert input complex-valued spectrogram to a time-domain signal. Multi-channel IO is supported. Args: input: Input spectrogram for C channels, shape (B, C, F, N) input_length: Length of valid entries along the time dimension, shape (B,) Returns: Time-domain signal with T time-domain samples and C channels, (B, C, T) and output length with shape (B,). """ B, F, N = input.size(0), input.size(-2), input.size(-1) assert F == self.num_subbands, f'Number of subbands F={F} not matching self.num_subbands={self.num_subbands}' input = input.view(B, -1, F, N) if not input.is_complex(): raise ValueError("Expected `input` to be complex dtype.") # iSTFT output (B, C, T) with torch.amp.autocast(input.device.type, enabled=False): output = input.cfloat() if self.scale != 1: # apply 1/scale on the coefficients output = output / self.scale if self.magnitude_power != 1: # apply 1/power on the magnitude output = torch.pow(output.abs(), 1 / self.magnitude_power) * torch.exp(1j * output.angle()) output = self.istft(output) if input_length is not None: # Mask padded samples output_length = self.get_output_length(input_length=input_length) length_mask: torch.Tensor = make_seq_mask_like( lengths=output_length, like=output, time_dim=-1, valid_ones=False ) output = output.masked_fill(length_mask, 0.0) else: # Assume all frames are valid for all examples in the batch output_length = output.size(-1) * torch.ones(B, device=output.device).long() return output, output_length
[docs] def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor: """Get length of valid samples for the output. Args: input_length: number of valid frames, shape (B,) Returns: Number of valid samples, shape (B,) """ # centered STFT results in ((N-1) * hop_length) time samples for N frames (cf. torch.istft) output_length = input_length.sub(1).mul(self.hop_length).long() return output_length