# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
import einops
import hydra
import torch
from lightning.pytorch import Trainer
from omegaconf import DictConfig
from nemo.collections.audio.models.audio_to_audio import AudioToAudioModel
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import AudioSignal, LengthsType, LossType, NeuralType
from nemo.utils import logging
__all__ = [
'EncMaskDecAudioToAudioModel',
'ScoreBasedGenerativeAudioToAudioModel',
'PredictiveAudioToAudioModel',
'SchroedingerBridgeAudioToAudioModel',
'FlowMatchingAudioToAudioModel',
]
[docs]
class EncMaskDecAudioToAudioModel(AudioToAudioModel):
"""Class for encoder-mask-decoder audio processing models.
The model consists of the following blocks:
- encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform)
- mask_estimator: estimates a mask used by signal processor
- mask_processor: mask-based signal processor, combines the encoded input and the estimated mask
- decoder: transforms processor output into the time domain (synthesis transform)
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
# Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
self.world_size = 1
if trainer is not None:
self.world_size = trainer.world_size
super().__init__(cfg=cfg, trainer=trainer)
self.sample_rate = self._cfg.sample_rate
# Setup processing modules
self.encoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.encoder)
self.mask_estimator = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_estimator)
self.mask_processor = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_processor)
self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder)
if 'mixture_consistency' in self._cfg:
logging.debug('Using mixture consistency')
self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency)
else:
logging.debug('Mixture consistency not used')
self.mixture_consistency = None
# Setup augmentation
if hasattr(self.cfg, 'channel_augment') and self.cfg.channel_augment is not None:
logging.debug('Using channel augmentation')
self.channel_augmentation = EncMaskDecAudioToAudioModel.from_config_dict(self.cfg.channel_augment)
else:
logging.debug('Channel augmentation not used')
self.channel_augmentation = None
# Setup optional Optimization flags
self.setup_optimization_flags()
@property
def input_types(self) -> Dict[str, NeuralType]:
return {
"input_signal": NeuralType(
('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)
), # multi-channel format, channel dimension can be 1 for single-channel audio
"input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
return {
"output_signal": NeuralType(
('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)
), # multi-channel format, channel dimension can be 1 for single-channel audio
"output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
[docs]
@typecheck()
def forward(self, input_signal, input_length=None):
"""
Forward pass of the model.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
Returns:
Output signal `output` in the time domain and the length of the output signal `output_length`.
"""
batch_length = input_signal.size(-1)
# Encoder
encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)
# Mask estimator
mask, _ = self.mask_estimator(input=encoded, input_length=encoded_length)
# Mask-based processor in the encoded domain
processed, processed_length = self.mask_processor(input=encoded, input_length=encoded_length, mask=mask)
# Mixture consistency
if self.mixture_consistency is not None:
processed = self.mixture_consistency(mixture=encoded, estimate=processed)
# Decoder
processed, processed_length = self.decoder(input=processed, input_length=processed_length)
# Trim or pad the estimated signal to match input length
processed = self.match_batch_length(input=processed, batch_length=batch_length)
return processed, processed_length
# PTL-specific methods
def training_step(self, batch, batch_idx):
if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
input_signal, input_length, target_signal, _ = batch
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Apply channel augmentation
if self.training and self.channel_augmentation is not None:
input_signal = self.channel_augmentation(input=input_signal)
# Process input
processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)
# Calculate the loss
loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)
# Logs
self.log('train_loss', loss)
self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
# Return loss
return loss
[docs]
def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
input_signal, input_length, target_signal, _ = batch
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Process input
processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)
# Calculate the loss
loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)
# Update metrics
if hasattr(self, 'metrics') and tag in self.metrics:
# Update metrics for this (tag, dataloader_idx)
for name, metric in self.metrics[tag][dataloader_idx].items():
metric.update(preds=processed_signal, target=target_signal, input_length=input_length)
# Log global step
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
# Return loss
return {f'{tag}_loss': loss}
[docs]
@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
results = []
return results
[docs]
class PredictiveAudioToAudioModel(AudioToAudioModel):
"""This models aims to directly estimate the coefficients
in the encoded domain by applying a neural model.
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg=cfg, trainer=trainer)
self.sample_rate = self._cfg.sample_rate
# Setup processing modules
self.encoder = self.from_config_dict(self._cfg.encoder)
self.decoder = self.from_config_dict(self._cfg.decoder)
# Neural estimator
self.estimator = self.from_config_dict(self._cfg.estimator)
# Normalization
self.normalize_input = self._cfg.get('normalize_input', False)
# Term added to the denominator to improve numerical stability
self.eps = self._cfg.get('eps', 1e-8)
# Setup optional Optimization flags
self.setup_optimization_flags()
logging.debug('Initialized %s', self.__class__.__name__)
logging.debug('\tnormalize_input: %s', self.normalize_input)
logging.debug('\teps: %s', self.eps)
@property
def input_types(self) -> Dict[str, NeuralType]:
return {
"input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
"input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
return {
"output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
"output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
[docs]
@typecheck()
def forward(self, input_signal, input_length=None):
"""Forward pass of the model.
Args:
input_signal: time-domain signal
input_length: valid length of each example in the batch
Returns:
Output signal `output` in the time domain and the length of the output signal `output_length`.
"""
batch_length = input_signal.size(-1)
if self.normalize_input:
# max for each example in the batch
norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
# scale input signal
input_signal = input_signal / (norm_scale + self.eps)
# Encoder
encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)
# Backbone
estimated, estimated_length = self.estimator(input=encoded, input_length=encoded_length)
# Decoder
output, output_length = self.decoder(input=estimated, input_length=estimated_length)
if self.normalize_input:
# rescale to the original scale
output = output * norm_scale
# Trim or pad the estimated signal to match input length
output = self.match_batch_length(input=output, batch_length=batch_length)
return output, output_length
# PTL-specific methods
def training_step(self, batch, batch_idx):
if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
input_signal, input_length, target_signal, _ = batch
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Estimate the signal
output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)
# Calculate the loss
loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length)
# Logs
self.log('train_loss', loss)
self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
return loss
[docs]
def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
input_signal, input_length, target_signal, _ = batch
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Estimate the signal
output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)
# Prepare output
loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length)
# Update metrics
if hasattr(self, 'metrics') and tag in self.metrics:
# Update metrics for this (tag, dataloader_idx)
for name, metric in self.metrics[tag][dataloader_idx].items():
metric.update(preds=output_signal, target=target_signal, input_length=input_length)
# Log global step
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
return {f'{tag}_loss': loss}
[docs]
class ScoreBasedGenerativeAudioToAudioModel(AudioToAudioModel):
"""This models is using a score-based diffusion process to generate
an encoded representation of the enhanced signal.
The model consists of the following blocks:
- encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform)
- estimator: neural model, estimates a score for the diffusion process
- sde: stochastic differential equation (SDE) defining the forward and reverse diffusion process
- sampler: sampler for the reverse diffusion process, estimates coefficients of the target signal
- decoder: transforms sampler output into the time domain (synthesis transform)
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg=cfg, trainer=trainer)
self.sample_rate = self._cfg.sample_rate
# Setup processing modules
self.encoder = self.from_config_dict(self._cfg.encoder)
self.decoder = self.from_config_dict(self._cfg.decoder)
# Neural score estimator
self.estimator = self.from_config_dict(self._cfg.estimator)
# SDE
self.sde = self.from_config_dict(self._cfg.sde)
# Sampler
if 'sde' in self._cfg.sampler:
raise ValueError('SDE should be defined in the model config, not in the sampler config')
if 'score_estimator' in self._cfg.sampler:
raise ValueError('Score estimator should be defined in the model config, not in the sampler config')
self.sampler = hydra.utils.instantiate(self._cfg.sampler, sde=self.sde, score_estimator=self.estimator)
# Normalization
self.normalize_input = self._cfg.get('normalize_input', False)
# Metric evaluation
self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics')
if self.max_utts_evaluation_metrics is not None:
logging.warning(
'Metrics will be evaluated on first %d examples of the evaluation datasets.',
self.max_utts_evaluation_metrics,
)
# Term added to the denominator to improve numerical stability
self.eps = self._cfg.get('eps', 1e-8)
# Setup optional Optimization flags
self.setup_optimization_flags()
logging.debug('Initialized %s', self.__class__.__name__)
logging.debug('\tnormalize_input: %s', self.normalize_input)
logging.debug('\teps: %s', self.eps)
@property
def input_types(self) -> Dict[str, NeuralType]:
return {
"input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
"input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
return {
"output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
"output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@typecheck()
@torch.inference_mode()
def forward(self, input_signal, input_length=None):
"""Forward pass of the model.
Forward pass of the model aplies the following steps:
- encoder to obtain the encoded representation of the input signal
- sampler to generate the estimated coefficients of the target signal
- decoder to transform the sampler output into the time domain
Args:
input_signal: Tensor that represents a batch of time-domain audio signals,
of shape [B, C, T]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, contains the individual lengths of the audio sequences.
Returns:
Output `output_signal` in the time domain and the length of the output signal `output_length`.
"""
batch_length = input_signal.size(-1)
if self.normalize_input:
# max for each example in the batch
norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
# scale input signal
input_signal = input_signal / (norm_scale + self.eps)
# Encoder
encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)
# Sampler
generated, generated_length = self.sampler(
prior_mean=encoded, score_condition=encoded, state_length=encoded_length
)
# Decoder
output, output_length = self.decoder(input=generated, input_length=generated_length)
if self.normalize_input:
# rescale to the original scale
output = output * norm_scale
# Trim or pad the estimated signal to match input length
output = self.match_batch_length(input=output, batch_length=batch_length)
return output, output_length
@typecheck(
input_types={
"target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_length": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"loss": NeuralType(None, LossType()),
},
)
def _step(self, target_signal, input_signal, input_length=None):
"""Randomly generate a time step for each example in the batch, estimate
the score and calculate the loss value.
Note that this step does not include sampler.
"""
batch_size = target_signal.size(0)
if self.normalize_input:
# max for each example in the batch
norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
# scale input signal
input_signal = input_signal / (norm_scale + self.eps)
# scale the target signal
target_signal = target_signal / (norm_scale + self.eps)
# Apply encoder to both target and the input
input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length)
target_enc, _ = self.encoder(input=target_signal, input_length=input_length)
# Generate random time steps
sde_time = self.sde.generate_time(size=batch_size, device=input_enc.device)
# Get the mean and the variance of the perturbation kernel
pk_mean, pk_std = self.sde.perturb_kernel_params(state=target_enc, prior_mean=input_enc, time=sde_time)
# Generate a random sample from a standard normal distribution
z_norm = torch.randn_like(input_enc)
# Prepare perturbed data
perturbed_enc = pk_mean + pk_std * z_norm
# Score is conditioned on the perturbed data and the input
estimator_input = torch.cat([perturbed_enc, input_enc], dim=-3)
# Estimate the score using the neural estimator
# SDE time is used to inform the estimator about the current time step
# Note:
# - some implementations use `score = -self._raw_dnn_output(x, t, y)`
# - this seems to be unimportant, and is an artifact of transfering code from the original Song's repo
score_est, score_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=sde_time)
# Score loss weighting as in Section 4.2 in http://arxiv.org/abs/1907.05600
score_est = score_est * pk_std
score_ref = -z_norm
# Score matching loss on the normalized scores
loss = self.loss(estimate=score_est, target=score_ref, input_length=score_len)
return loss
# PTL-specific methods
def training_step(self, batch, batch_idx):
if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
input_signal, input_length, target_signal, _ = batch
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Calculate the loss
loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length)
# Logs
self.log('train_loss', loss)
self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
return loss
[docs]
def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
input_signal, input_length, target_signal, _ = batch
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Calculate loss
loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length)
# Update metrics
update_metrics = False
if self.max_utts_evaluation_metrics is None:
# Always update if max is not configured
update_metrics = True
# Number of examples to process
num_examples = input_signal.size(0) # batch size
else:
# Check how many examples have been used for metric calculation
first_metric_name = next(iter(self.metrics[tag][dataloader_idx]))
num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples
# Update metrics if some examples were not processed
update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics
# Number of examples to process
num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0))
if update_metrics:
# Generate output signal
output_signal, _ = self.forward(
input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples]
)
# Update metrics
if hasattr(self, 'metrics') and tag in self.metrics:
# Update metrics for this (tag, dataloader_idx)
for name, metric in self.metrics[tag][dataloader_idx].items():
metric.update(
preds=output_signal,
target=target_signal[:num_examples, ...],
input_length=input_length[:num_examples],
)
# Log global step
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
return {f'{tag}_loss': loss}
[docs]
class FlowMatchingAudioToAudioModel(AudioToAudioModel):
"""This models uses a flow matching process to generate
an encoded representation of the enhanced signal.
The model consists of the following blocks:
- encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform)
- estimator: neural model, estimates a score for the diffusion process
- flow: ordinary differential equation (ODE) defining a flow and a vector field.
- sampler: sampler for the inference process, estimates coefficients of the target signal
- decoder: transforms sampler output into the time domain (synthesis transform)
- ssl_pretrain_masking: if it is defined, perform the ssl pretrain masking for self reconstruction in the training process
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg=cfg, trainer=trainer)
self.sample_rate = self._cfg.sample_rate
# Setup processing modules
self.encoder = self.from_config_dict(self._cfg.encoder)
self.decoder = self.from_config_dict(self._cfg.decoder)
# Neural estimator
self.estimator = self.from_config_dict(self._cfg.estimator)
# Flow
self.flow = self.from_config_dict(self._cfg.flow)
# Sampler
self.sampler = hydra.utils.instantiate(self._cfg.sampler, estimator=self.estimator)
# probability that the conditional input will be feed into the
# estimator in the training stage
self.p_cond = self._cfg.get('p_cond', 1.0)
# Self-Supervised Pretraining
if self._cfg.get('ssl_pretrain_masking') is not None:
logging.debug('SSL-pretrain_masking is found and will be initialized')
self.ssl_pretrain_masking = self.from_config_dict(self._cfg.ssl_pretrain_masking)
else:
self.ssl_pretrain_masking = None
# Normalization
self.normalize_input = self._cfg.get('normalize_input', False)
# Metric evaluation
self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics')
if self.max_utts_evaluation_metrics is not None:
logging.warning(
'Metrics will be evaluated on first %d examples of the evaluation datasets.',
self.max_utts_evaluation_metrics,
)
# Regularization
self.eps = self._cfg.get('eps', 1e-8)
# Setup optional Optimization flags
self.setup_optimization_flags()
logging.debug('Initialized %s', self.__class__.__name__)
logging.debug('\tdoing SSL-pretraining: %s', (self.ssl_pretrain_masking is not None))
logging.debug('\tp_cond: %s', self.p_cond)
logging.debug('\tnormalize_input: %s', self.normalize_input)
logging.debug('\tloss: %s', self.loss)
logging.debug('\teps: %s', self.eps)
@property
def input_types(self) -> Dict[str, NeuralType]:
return {
"input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
"input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
return {
"output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
"output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@typecheck()
@torch.inference_mode()
def forward(self, input_signal, input_length=None):
"""Forward pass of the model to generate samples from the target distribution.
This is used for inference mode only, and it explicitly disables SSL masking to the input.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
Returns:
Output signal `output` in the time domain and the length of the output signal `output_length`.
"""
return self.forward_internal(input_signal=input_signal, input_length=input_length, enable_ssl_masking=False)
@typecheck(
input_types={
"input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
},
output_types={
"output_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
},
)
@torch.inference_mode()
def forward_eval(self, input_signal, input_length=None):
"""Forward pass of the model to generate samples from the target distribution.
This is used for eval mode only, and it enables SSL masking to the input.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
Returns:
Output signal `output` in the time domain and the length of the output signal `output_length`.
"""
return self.forward_internal(input_signal=input_signal, input_length=input_length, enable_ssl_masking=True)
@torch.inference_mode()
def forward_internal(self, input_signal, input_length=None, enable_ssl_masking=False):
"""Internal forward pass of the model.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
enable_ssl_masking: Whether to enable SSL masking of the input. If using SSL pretraining, masking
is applied to the input signal. If not using SSL pretraining, masking is not applied.
Returns:
Output signal `output` in the time domain and the length of the output signal `output_length`.
"""
batch_length = input_signal.size(-1)
if self.normalize_input:
# max for each example in the batch
norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
# scale input signal
input_signal = input_signal / (norm_scale + self.eps)
# Encoder
encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)
# Conditional input
if self.p_cond == 0:
# The model is trained without the conditional input
encoded = torch.zeros_like(encoded)
elif enable_ssl_masking and self.ssl_pretrain_masking is not None:
# Masking for self-supervised pretraining
encoded = self.ssl_pretrain_masking(input_spec=encoded, length=encoded_length)
# Initial process state
init_state = torch.randn_like(encoded) * self.flow.sigma_start
# Sampler
generated, generated_length = self.sampler(
state=init_state, estimator_condition=encoded, state_length=encoded_length
)
# Decoder
output, output_length = self.decoder(input=generated, input_length=generated_length)
if self.normalize_input:
# rescale to the original scale
output = output * norm_scale
# Trim or pad the estimated signal to match input length
output = self.match_batch_length(input=output, batch_length=batch_length)
return output, output_length
@typecheck(
input_types={
"target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_length": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"loss": NeuralType(None, LossType()),
},
)
def _step(self, target_signal, input_signal, input_length=None):
batch_size = target_signal.size(0)
if self.normalize_input:
# max for each example in the batch
norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
# scale input signal
input_signal = input_signal / (norm_scale + self.eps)
# scale the target signal
target_signal = target_signal / (norm_scale + self.eps)
# Apply encoder to both target and the input
input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length)
target_enc, _ = self.encoder(input=target_signal, input_length=input_length)
# Self-Supervised Pretraining
if self.ssl_pretrain_masking is not None:
input_enc = self.ssl_pretrain_masking(input_spec=input_enc, length=input_enc_len)
# Drop off conditional inputs (input_enc) with (1 - p_cond) probability.
# The dropped conditions will be set to zeros
keep_conditions = einops.rearrange((torch.rand(batch_size) < self.p_cond).float(), 'B -> B 1 1 1')
input_enc = input_enc * keep_conditions.to(input_enc.device)
x_start = torch.zeros_like(input_enc)
time = self.flow.generate_time(batch_size=batch_size).to(device=input_enc.device)
sample = self.flow.sample(time=time, x_start=x_start, x_end=target_enc)
# we want to get a vector field estimate given current state
# at training time, current state is sampled from the conditional path
# the vector field model is also conditioned on input signal
estimator_input = torch.cat([sample, input_enc], dim=-3)
# Estimate the vector using the neural estimator
estimate, estimate_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=time)
conditional_vector_field = self.flow.vector_field(time=time, x_start=x_start, x_end=target_enc, point=sample)
return self.loss(estimate=estimate, target=conditional_vector_field, input_length=input_enc_len)
# PTL-specific methods
def training_step(self, batch, batch_idx):
if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch.get('target_signal', input_signal.clone())
else:
input_signal, input_length, target_signal, _ = batch
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, "B T -> B 1 T")
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, "B T -> B 1 T")
# Calculate the loss
loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length)
# Logs
self.log('train_loss', loss)
self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
return loss
[docs]
def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch.get('target_signal', input_signal.clone())
else:
input_signal, input_length, target_signal, _ = batch
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Calculate loss
loss = self._step(
target_signal=target_signal,
input_signal=input_signal,
input_length=input_length,
)
# Update metrics
update_metrics = False
if self.max_utts_evaluation_metrics is None:
# Always update if max is not configured
update_metrics = True
# Number of examples to process
num_examples = input_signal.size(0) # batch size
else:
# Check how many examples have been used for metric calculation
first_metric_name = next(iter(self.metrics[tag][dataloader_idx]))
num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples
# Update metrics if some examples were not processed
update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics
# Number of examples to process
num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0))
if update_metrics:
# Generate output signal
output_signal, _ = self.forward_eval(
input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples]
)
# Update metrics
if hasattr(self, 'metrics') and tag in self.metrics:
# Update metrics for this (tag, dataloader_idx)
for name, metric in self.metrics[tag][dataloader_idx].items():
metric.update(
preds=output_signal,
target=target_signal[:num_examples, ...],
input_length=input_length[:num_examples],
)
# Log global step
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
return {f'{tag}_loss': loss}
[docs]
class SchroedingerBridgeAudioToAudioModel(AudioToAudioModel):
"""This models is using a Schrödinger Bridge process to generate
an encoded representation of the enhanced signal.
The model consists of the following blocks:
- encoder: transforms input audio signal into an encoded representation (analysis transform)
- estimator: neural model, estimates the coefficients for the SB process
- noise_schedule: defines the path between the clean and noisy signals
- sampler: sampler for the reverse process, estimates coefficients of the target signal
- decoder: transforms sampler output into the time domain (synthesis transform)
References:
Schrödinger Bridge for Generative Speech Enhancement, https://arxiv.org/abs/2407.16074
"""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
super().__init__(cfg=cfg, trainer=trainer)
self.sample_rate = self._cfg.sample_rate
# Setup processing modules
self.encoder = self.from_config_dict(self._cfg.encoder)
self.decoder = self.from_config_dict(self._cfg.decoder)
# Neural estimator
self.estimator = self.from_config_dict(self._cfg.estimator)
self.estimator_output = self._cfg.estimator_output
# Noise schedule
self.noise_schedule = self.from_config_dict(self._cfg.noise_schedule)
# Sampler
self.sampler = hydra.utils.instantiate(
self._cfg.sampler,
noise_schedule=self.noise_schedule,
estimator=self.estimator,
estimator_output=self.estimator_output,
)
# Normalization
self.normalize_input = self._cfg.get('normalize_input', False)
# Metric evaluation
self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics')
if self.max_utts_evaluation_metrics is not None:
logging.warning(
'Metrics will be evaluated on first %d examples of the evaluation datasets.',
self.max_utts_evaluation_metrics,
)
# Loss in the encoded domain
if 'loss_encoded' in self._cfg:
self.loss_encoded = self.from_config_dict(self._cfg.loss_encoded)
self.loss_encoded_weight = self._cfg.get('loss_encoded_weight', 1.0)
else:
self.loss_encoded = None
self.loss_encoded_weight = 0.0
# Loss in the time domain
if 'loss_time' in self._cfg:
self.loss_time = self.from_config_dict(self._cfg.loss_time)
self.loss_time_weight = self._cfg.get('loss_time_weight', 1.0)
else:
self.loss_time = None
self.loss_time_weight = 0.0
if self.loss is not None and (self.loss_encoded is not None or self.loss_time is not None):
raise ValueError('Either ``loss`` or ``loss_encoded`` and ``loss_time`` should be defined, not both.')
# Term added to the denominator to improve numerical stability
self.eps = self._cfg.get('eps', 1e-8)
# Setup optional optimization flags
self.setup_optimization_flags()
logging.debug('Initialized %s', self.__class__.__name__)
logging.debug('\testimator_output: %s', self.estimator_output)
logging.debug('\tnormalize_input: %s', self.normalize_input)
logging.debug('\tloss: %s', self.loss)
logging.debug('\tloss_encoded: %s', self.loss_encoded)
logging.debug('\tloss_encoded_weight: %s', self.loss_encoded_weight)
logging.debug('\tloss_time: %s', self.loss_time)
logging.debug('\tloss_time_weight: %s', self.loss_time_weight)
logging.debug('\teps: %s', self.eps)
@property
def input_types(self) -> Dict[str, NeuralType]:
# time-domain input
return {
"input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
"input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
# time-domain output
return {
"output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
"output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
}
@typecheck()
@torch.inference_mode()
def forward(self, input_signal, input_length=None):
"""Forward pass of the model.
Forward pass of the model consists of the following steps
- encoder to obtain the encoded representation of the input signal
- sampler to generate the estimated coefficients of the target signal
- decoder to transform the estimated output into the time domain
Args:
input_signal: Tensor that represents a batch of time-domain audio signals,
of shape [B, C, T]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, contains the individual lengths of the audio sequences.
Returns:
Output `output_signal` in the time domain and the length of the output signal `output_length`.
"""
batch_length = input_signal.size(-1)
if self.normalize_input:
# max for each example in the batch
norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
# scale input signal
input_signal = input_signal / (norm_scale + self.eps)
# Encoder
encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)
# Sampler
generated, generated_length = self.sampler(
prior_mean=encoded, estimator_condition=encoded, state_length=encoded_length
)
# Decoder
output, output_length = self.decoder(input=generated, input_length=generated_length)
if self.normalize_input:
# rescale to the original scale
output = output * norm_scale
# Trim or pad the estimated signal to match input length
output = self.match_batch_length(input=output, batch_length=batch_length)
return output, output_length
@typecheck(
input_types={
"target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
"input_length": NeuralType(tuple('B'), LengthsType()),
},
output_types={
"loss": NeuralType(None, LossType()),
"loss_encoded": NeuralType(None, LossType()),
"loss_time": NeuralType(None, LossType()),
},
)
def _step(self, target_signal, input_signal, input_length=None):
"""Randomly generate time step for each example in the batch, run neural estimator
to estimate the target and calculate the loss.
"""
batch_size = target_signal.size(0)
if self.normalize_input:
# max for each example in the batch
norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
# scale input signal
input_signal = input_signal / (norm_scale + self.eps)
# scale the target signal
target_signal = target_signal / (norm_scale + self.eps)
# Apply encoder to both target and the input
# For example, if the encoder is STFT, then _enc is the complex-valued STFT of the corresponding signal
input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length)
target_enc, _ = self.encoder(input=target_signal, input_length=input_length)
# Generate random time steps
process_time = self.noise_schedule.generate_time(size=batch_size, device=input_enc.device)
# Prepare necessary info from the noise schedule
alpha_t, alpha_bar_t, alpha_t_max = self.noise_schedule.get_alphas(time=process_time)
sigma_t, sigma_bar_t, sigma_t_max = self.noise_schedule.get_sigmas(time=process_time)
# Marginal distribution
weight_target = alpha_t * sigma_bar_t**2 / (sigma_t_max**2 + self.eps)
weight_input = alpha_bar_t * sigma_t**2 / (sigma_t_max**2 + self.eps)
# view weights as [B, C, D, T]
weight_target = weight_target.view(-1, 1, 1, 1)
weight_input = weight_input.view(-1, 1, 1, 1)
# mean
mean_x = weight_target * target_enc + weight_input * input_enc
# standard deviation
std_x = alpha_t * sigma_bar_t * sigma_t / (sigma_t_max + self.eps)
# view as [B, C, D, T]
std_x = std_x.view(-1, 1, 1, 1)
# Generate a random sample from a standard normal distribution
z_norm = torch.randn_like(input_enc)
# Generate a random sample from the marginal distribution
x_t = mean_x + std_x * z_norm
# Estimator is conditioned on the generated sample and the original input (prior)
estimator_input = torch.cat([x_t, input_enc], dim=-3)
# Neural estimator
# Estimator input is the same data type as the encoder output
# For example, if the encoder is STFT, then the estimator input and output are complex-valued coefficients
estimate, estimate_len = self.estimator(
input=estimator_input, input_length=input_enc_len, condition=process_time
)
# Prepare output target and calculate loss
if self.estimator_output == 'data_prediction':
if self.loss is not None:
# Single loss in the encoded domain
loss = self.loss(estimate=estimate, target=target_enc, input_length=estimate_len)
loss_encoded = loss_time = None
else:
# Weighted loss between encoded and time domain
loss = 0.0
# Loss in the encoded domain
if self.loss_encoded is not None:
# Loss between the estimate and the target in the encoded domain
loss_encoded = self.loss_encoded(estimate=estimate, target=target_enc, input_length=estimate_len)
# Weighting
loss += self.loss_encoded_weight * loss_encoded
else:
loss_encoded = None
# Loss in the time domain
if self.loss_time is not None:
# Convert the estimate to the time domain
with typecheck.disable_checks():
# Note: stimate is FloatType, decoder requires SpectrogramType
estimate_signal, _ = self.decoder(input=estimate, input_length=estimate_len)
# Match estimate length
batch_length = input_signal.size(-1)
estimate_signal = self.match_batch_length(input=estimate_signal, batch_length=batch_length)
# Loss between the estimate and the target in the time domain
loss_time = self.loss_time(
estimate=estimate_signal, target=target_signal, input_length=input_length
)
# Weighting
loss += self.loss_time_weight * loss_time
else:
loss_time = None
else:
raise NotImplementedError(f'Output type {self.estimator_output} is not implemented')
return loss, loss_encoded, loss_time
# PTL-specific methods
def training_step(self, batch, batch_idx):
if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
input_signal, input_length, target_signal, _ = batch
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Calculate the loss
loss, loss_encoded, loss_time = self._step(
target_signal=target_signal, input_signal=input_signal, input_length=input_length
)
# Logs
self.log('train_loss', loss)
self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
if loss_encoded is not None:
self.log('train_loss_encoded', loss_encoded)
if loss_time is not None:
self.log('train_loss_time', loss_time)
return loss
[docs]
def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
if isinstance(batch, dict):
# lhotse batches are dictionaries
input_signal = batch['input_signal']
input_length = batch['input_length']
target_signal = batch['target_signal']
else:
input_signal, input_length, target_signal, _ = batch
# For consistency, the model uses multi-channel format, even if the channel dimension is 1
if input_signal.ndim == 2:
input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
if target_signal.ndim == 2:
target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')
# Calculate loss
loss, *_ = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length)
# Update metrics
update_metrics = False
if self.max_utts_evaluation_metrics is None:
# Always update if max is not configured
update_metrics = True
# Number of examples to process
num_examples = input_signal.size(0) # batch size
else:
# Check how many examples have been used for metric calculation
first_metric_name = next(iter(self.metrics[tag][dataloader_idx]))
num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples
# Update metrics if some examples were not processed
update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics
# Number of examples to process
num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0))
if update_metrics:
# Generate output signal
output_signal, _ = self.forward(
input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples]
)
# Update metrics
if hasattr(self, 'metrics') and tag in self.metrics:
# Update metrics for this (tag, dataloader_idx)
for name, metric in self.metrics[tag][dataloader_idx].items():
metric.update(
preds=output_signal,
target=target_signal[:num_examples, ...],
input_length=input_length[:num_examples],
)
# Log global step
self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))
return {f'{tag}_loss': loss}