Source code for nemo.collections.asr.modules.conv_asr

# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import List, Optional, Set, Union

import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import MISSING, DictConfig, ListConfig, OmegaConf

from nemo.collections.asr.parts.submodules.jasper import (
    JasperBlock,
    MaskedConv1d,
    ParallelBlock,
    SqueezeExcite,
    init_weights,
    jasper_activations,
)
from nemo.collections.asr.parts.submodules.tdnn_attention import (
    AttentivePoolLayer,
    StatsPoolLayer,
    TDNNModule,
    TDNNSEModule,
)
from nemo.collections.asr.parts.utils import adapter_utils
from nemo.core.classes.common import typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import AccessMixin, adapter_mixins
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types import (
    AcousticEncodedRepresentation,
    LengthsType,
    LogitsType,
    LogprobsType,
    NeuralType,
    SpectrogramType,
)
from nemo.utils import logging

__all__ = ['ConvASRDecoder', 'ConvASREncoder', 'ConvASRDecoderClassification']


[docs] class ConvASREncoder(NeuralModule, Exportable, AccessMixin): """ Convolutional encoder for ASR models. With this class you can implement JasperNet and QuartzNet models. Based on these papers: https://arxiv.org/pdf/1904.03288.pdf https://arxiv.org/pdf/1910.10261.pdf """ def _prepare_for_export(self, **kwargs): m_count = 0 for name, m in self.named_modules(): if isinstance(m, MaskedConv1d): m.use_mask = False m_count += 1 Exportable._prepare_for_export(self, **kwargs) logging.warning(f"Turned off {m_count} masked convolutions")
[docs] def input_example(self, max_batch=1, max_dim=8192): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ device = next(self.parameters()).device input_example = torch.randn(max_batch, self._feat_in, max_dim, device=device) lens = torch.full(size=(input_example.shape[0],), fill_value=max_dim, device=device) return tuple([input_example, lens])
@property def input_types(self): """Returns definitions of module input ports.""" return OrderedDict( { "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), "length": NeuralType(tuple('B'), LengthsType()), } ) @property def output_types(self): """Returns definitions of module output ports.""" return OrderedDict( { "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), } ) def __init__( self, jasper, activation: str, feat_in: int, normalization_mode: str = "batch", residual_mode: str = "add", norm_groups: int = -1, conv_mask: bool = True, frame_splicing: int = 1, init_mode: Optional[str] = 'xavier_uniform', quantize: bool = False, ): super().__init__() if isinstance(jasper, ListConfig): jasper = OmegaConf.to_container(jasper) activation = jasper_activations[activation]() # If the activation can be executed in place, do so. if hasattr(activation, 'inplace'): activation.inplace = True feat_in = feat_in * frame_splicing self._feat_in = feat_in residual_panes = [] encoder_layers = [] self.dense_residual = False self._subsampling_factor = 1 for layer_idx, lcfg in enumerate(jasper): dense_res = [] if lcfg.get('residual_dense', False): residual_panes.append(feat_in) dense_res = residual_panes self.dense_residual = True groups = lcfg.get('groups', 1) separable = lcfg.get('separable', False) heads = lcfg.get('heads', -1) residual_mode = lcfg.get('residual_mode', residual_mode) se = lcfg.get('se', False) se_reduction_ratio = lcfg.get('se_reduction_ratio', 8) se_context_window = lcfg.get('se_context_size', -1) se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest') kernel_size_factor = lcfg.get('kernel_size_factor', 1.0) stride_last = lcfg.get('stride_last', False) future_context = lcfg.get('future_context', -1) encoder_layers.append( JasperBlock( feat_in, lcfg['filters'], repeat=lcfg['repeat'], kernel_size=lcfg['kernel'], stride=lcfg['stride'], dilation=lcfg['dilation'], dropout=lcfg['dropout'], residual=lcfg['residual'], groups=groups, separable=separable, heads=heads, residual_mode=residual_mode, normalization=normalization_mode, norm_groups=norm_groups, activation=activation, residual_panes=dense_res, conv_mask=conv_mask, se=se, se_reduction_ratio=se_reduction_ratio, se_context_window=se_context_window, se_interpolation_mode=se_interpolation_mode, kernel_size_factor=kernel_size_factor, stride_last=stride_last, future_context=future_context, quantize=quantize, layer_idx=layer_idx, ) ) feat_in = lcfg['filters'] self._subsampling_factor *= ( int(lcfg['stride'][0]) if isinstance(lcfg['stride'], List) else int(lcfg['stride']) ) self._feat_out = feat_in self.encoder = torch.nn.Sequential(*encoder_layers) self.apply(lambda x: init_weights(x, mode=init_mode)) self.max_audio_length = 0
[docs] @typecheck() def forward(self, audio_signal, length): self.update_max_sequence_length(seq_length=audio_signal.size(2), device=audio_signal.device) s_input, length = self.encoder(([audio_signal], length)) if length is None: return s_input[-1] return s_input[-1], length
[docs] def update_max_sequence_length(self, seq_length: int, device): """ Find global max audio length across all nodes in distributed training and update the max_audio_length """ if torch.distributed.is_initialized(): global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) # Update across all ranks in the distributed system torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX) seq_length = global_max_len.int().item() if seq_length > self.max_audio_length: if seq_length < 5000: seq_length = seq_length * 2 elif seq_length < 10000: seq_length = seq_length * 1.5 self.max_audio_length = seq_length device = next(self.parameters()).device seq_range = torch.arange(0, self.max_audio_length, device=device) if hasattr(self, 'seq_range'): self.seq_range = seq_range else: self.register_buffer('seq_range', seq_range, persistent=False) # Update all submodules for name, m in self.named_modules(): if isinstance(m, MaskedConv1d): m.update_masked_length(self.max_audio_length, seq_range=self.seq_range) elif isinstance(m, SqueezeExcite): m.set_max_len(self.max_audio_length, seq_range=self.seq_range)
@property def subsampling_factor(self) -> int: return self._subsampling_factor
class ParallelConvASREncoder(NeuralModule, Exportable): """ Convolutional encoder for ASR models with parallel blocks. CarneliNet can be implemented with this class. """ def _prepare_for_export(self): m_count = 0 for m in self.modules(): if isinstance(m, MaskedConv1d): m.use_mask = False m_count += 1 logging.warning(f"Turned off {m_count} masked convolutions") def input_example(self, max_batch=1, max_dim=256): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ input_example = torch.randn(max_batch, self._feat_in, max_dim).to(next(self.parameters()).device) return tuple([input_example]) @property def disabled_deployment_input_names(self): """Implement this method to return a set of input names disabled for export""" return set(["length"]) @property def disabled_deployment_output_names(self): """Implement this method to return a set of output names disabled for export""" return set(["encoded_lengths"]) def save_to(self, save_path: str): pass @classmethod def restore_from(cls, restore_path: str): pass @property def input_types(self): """Returns definitions of module input ports.""" return OrderedDict( { "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), "length": NeuralType(tuple('B'), LengthsType()), } ) @property def output_types(self): """Returns definitions of module output ports.""" return OrderedDict( { "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), } ) def __init__( self, jasper, activation: str, feat_in: int, normalization_mode: str = "batch", residual_mode: str = "add", norm_groups: int = -1, conv_mask: bool = True, frame_splicing: int = 1, init_mode: Optional[str] = 'xavier_uniform', aggregation_mode: Optional[str] = None, quantize: bool = False, ): super().__init__() if isinstance(jasper, ListConfig): jasper = OmegaConf.to_container(jasper) activation = jasper_activations[activation]() feat_in = feat_in * frame_splicing self._feat_in = feat_in residual_panes = [] encoder_layers = [] self.dense_residual = False for lcfg in jasper: dense_res = [] if lcfg.get('residual_dense', False): residual_panes.append(feat_in) dense_res = residual_panes self.dense_residual = True groups = lcfg.get('groups', 1) separable = lcfg.get('separable', False) heads = lcfg.get('heads', -1) residual_mode = lcfg.get('residual_mode', residual_mode) se = lcfg.get('se', False) se_reduction_ratio = lcfg.get('se_reduction_ratio', 8) se_context_window = lcfg.get('se_context_size', -1) se_interpolation_mode = lcfg.get('se_interpolation_mode', 'nearest') kernel_size_factor = lcfg.get('kernel_size_factor', 1.0) stride_last = lcfg.get('stride_last', False) aggregation_mode = lcfg.get('aggregation_mode', 'sum') block_dropout = lcfg.get('block_dropout', 0.0) parallel_residual_mode = lcfg.get('parallel_residual_mode', 'sum') parallel_blocks = [] for kernel_size in lcfg['kernel']: parallel_blocks.append( JasperBlock( feat_in, lcfg['filters'], repeat=lcfg['repeat'], kernel_size=[kernel_size], stride=lcfg['stride'], dilation=lcfg['dilation'], dropout=lcfg['dropout'], residual=lcfg['residual'], groups=groups, separable=separable, heads=heads, residual_mode=residual_mode, normalization=normalization_mode, norm_groups=norm_groups, activation=activation, residual_panes=dense_res, conv_mask=conv_mask, se=se, se_reduction_ratio=se_reduction_ratio, se_context_window=se_context_window, se_interpolation_mode=se_interpolation_mode, kernel_size_factor=kernel_size_factor, stride_last=stride_last, quantize=quantize, ) ) if len(parallel_blocks) == 1: encoder_layers.append(parallel_blocks[0]) else: encoder_layers.append( ParallelBlock( parallel_blocks, aggregation_mode=aggregation_mode, block_dropout_prob=block_dropout, residual_mode=parallel_residual_mode, in_filters=feat_in, out_filters=lcfg['filters'], ) ) feat_in = lcfg['filters'] self._feat_out = feat_in self.encoder = torch.nn.Sequential(*encoder_layers) self.apply(lambda x: init_weights(x, mode=init_mode)) @typecheck() def forward(self, audio_signal, length=None): s_input, length = self.encoder(([audio_signal], length)) if length is None: return s_input[-1] return s_input[-1], length
[docs] class ConvASRDecoder(NeuralModule, Exportable, adapter_mixins.AdapterModuleMixin): """Simple ASR Decoder for use with CTC-based models such as JasperNet and QuartzNet Based on these papers: https://arxiv.org/pdf/1904.03288.pdf https://arxiv.org/pdf/1910.10261.pdf https://arxiv.org/pdf/2005.04290.pdf """ @property def input_types(self): return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())}) @property def output_types(self): return OrderedDict({"logprobs": NeuralType(('B', 'T', 'D'), LogprobsType())}) def __init__(self, feat_in, num_classes, init_mode="xavier_uniform", vocabulary=None, add_blank=True): super().__init__() if vocabulary is None and num_classes < 0: raise ValueError("Neither of the vocabulary and num_classes are set! At least one of them need to be set.") if num_classes <= 0: num_classes = len(vocabulary) logging.info(f"num_classes of ConvASRDecoder is set to the size of the vocabulary: {num_classes}.") if vocabulary is not None: if num_classes != len(vocabulary): raise ValueError( f"If vocabulary is specified, it's length should be equal to the num_classes. \ Instead got: num_classes={num_classes} and len(vocabulary)={len(vocabulary)}" ) self.__vocabulary = vocabulary self._feat_in = feat_in # Add 1 for blank char self._num_classes = num_classes + 1 if add_blank else num_classes self.decoder_layers = torch.nn.Sequential( torch.nn.Conv1d(self._feat_in, self._num_classes, kernel_size=1, bias=True) ) self.apply(lambda x: init_weights(x, mode=init_mode)) accepted_adapters = [adapter_utils.LINEAR_ADAPTER_CLASSPATH] self.set_accepted_adapter_types(accepted_adapters) # to change, requires running ``model.temperature = T`` explicitly self.temperature = 1.0
[docs] @typecheck() def forward(self, encoder_output): # Adapter module forward step if self.is_adapter_available(): encoder_output = encoder_output.transpose(1, 2) # [B, T, C] encoder_output = self.forward_enabled_adapters(encoder_output) encoder_output = encoder_output.transpose(1, 2) # [B, C, T] if self.temperature != 1.0: return torch.nn.functional.log_softmax( self.decoder_layers(encoder_output).transpose(1, 2) / self.temperature, dim=-1 ) return torch.nn.functional.log_softmax(self.decoder_layers(encoder_output).transpose(1, 2), dim=-1)
[docs] def input_example(self, max_batch=1, max_dim=256): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ input_example = torch.randn(max_batch, self._feat_in, max_dim).to(next(self.parameters()).device) return tuple([input_example])
def _prepare_for_export(self, **kwargs): m_count = 0 for m in self.modules(): if type(m).__name__ == "MaskedConv1d": m.use_mask = False m_count += 1 if m_count > 0: logging.warning(f"Turned off {m_count} masked convolutions") Exportable._prepare_for_export(self, **kwargs) # Adapter method overrides
[docs] def add_adapter(self, name: str, cfg: DictConfig): # Update the config with correct input dim cfg = self._update_adapter_cfg_input_dim(cfg) # Add the adapter super().add_adapter(name=name, cfg=cfg)
def _update_adapter_cfg_input_dim(self, cfg: DictConfig): cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self._feat_in) return cfg @property def vocabulary(self): return self.__vocabulary @property def num_classes_with_blank(self): return self._num_classes
class ConvASRDecoderReconstruction(NeuralModule, Exportable): """ASR Decoder for reconstructing masked regions of spectrogram""" @property def input_types(self): return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())}) @property def output_types(self): if self.apply_softmax: return OrderedDict({"out": NeuralType(('B', 'T', 'D'), LogprobsType())}) else: return OrderedDict({"out": NeuralType(('B', 'T', 'D'), AcousticEncodedRepresentation())}) def __init__( self, feat_in, feat_out, feat_hidden, stride_layers=0, non_stride_layers=0, kernel_size=11, init_mode="xavier_uniform", activation="relu", stride_transpose=True, apply_softmax=False, ): super().__init__() if ((stride_layers + non_stride_layers) > 0) and (kernel_size < 3 or kernel_size % 2 == 0): raise ValueError("Kernel size in this decoder needs to be >= 3 and odd when using at least 1 conv layer.") activation = jasper_activations[activation]() self.feat_in = feat_in self.feat_out = feat_out self.feat_hidden = feat_hidden self.decoder_layers = [nn.Conv1d(self.feat_in, self.feat_hidden, kernel_size=1, bias=True)] for i in range(stride_layers): self.decoder_layers.append(activation) if stride_transpose: self.decoder_layers.append( nn.ConvTranspose1d( self.feat_hidden, self.feat_hidden, kernel_size, stride=2, padding=(kernel_size - 3) // 2 + 1, output_padding=1, bias=True, groups=self.feat_hidden, ) ) else: self.decoder_layers.append( nn.Conv1d( self.feat_hidden, self.feat_hidden, kernel_size, stride=2, padding=(kernel_size - 1) // 2, bias=True, groups=self.feat_hidden, ) ) self.decoder_layers.append(nn.Conv1d(self.feat_hidden, self.feat_hidden, kernel_size=1, bias=True)) self.decoder_layers.append(nn.BatchNorm1d(self.feat_hidden, eps=1e-3, momentum=0.1)) for i in range(non_stride_layers): self.decoder_layers.append(activation) self.decoder_layers.append( nn.Conv1d( self.feat_hidden, self.feat_hidden, kernel_size, bias=True, groups=self.feat_hidden, padding=kernel_size // 2, ) ) self.decoder_layers.append(nn.Conv1d(self.feat_hidden, self.feat_hidden, kernel_size=1, bias=True)) self.decoder_layers.append(nn.BatchNorm1d(self.feat_hidden, eps=1e-3, momentum=0.1)) self.decoder_layers.append(activation) self.decoder_layers.append(nn.Conv1d(self.feat_hidden, self.feat_out, kernel_size=1, bias=True)) self.decoder_layers = nn.Sequential(*self.decoder_layers) self.apply_softmax = apply_softmax self.apply(lambda x: init_weights(x, mode=init_mode)) @typecheck() def forward(self, encoder_output): out = self.decoder_layers(encoder_output).transpose(-2, -1) if self.apply_softmax: out = torch.nn.functional.log_softmax(out, dim=-1) return out def input_example(self, max_batch=1, max_dim=256): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ input_example = torch.randn(max_batch, self._feat_in, max_dim).to(next(self.parameters()).device) return tuple([input_example]) def _prepare_for_export(self, **kwargs): m_count = 0 for m in self.modules(): if type(m).__name__ == "MaskedConv1d": m.use_mask = False m_count += 1 if m_count > 0: logging.warning(f"Turned off {m_count} masked convolutions") Exportable._prepare_for_export(self, **kwargs)
[docs] class ConvASRDecoderClassification(NeuralModule, Exportable): """Simple ASR Decoder for use with classification models such as JasperNet and QuartzNet Based on these papers: https://arxiv.org/pdf/2005.04290.pdf """
[docs] def input_example(self, max_batch=1, max_dim=256): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ input_example = torch.randn(max_batch, self._feat_in, max_dim).to(next(self.parameters()).device) return tuple([input_example])
@property def input_types(self): return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())}) @property def output_types(self): return OrderedDict({"logits": NeuralType(('B', 'D'), LogitsType())}) def __init__( self, feat_in: int, num_classes: int, init_mode: Optional[str] = "xavier_uniform", return_logits: bool = True, pooling_type='avg', ): super().__init__() self._feat_in = feat_in self._return_logits = return_logits self._num_classes = num_classes if pooling_type == 'avg': self.pooling = torch.nn.AdaptiveAvgPool1d(1) elif pooling_type == 'max': self.pooling = torch.nn.AdaptiveMaxPool1d(1) else: raise ValueError('Pooling type chosen is not valid. Must be either `avg` or `max`') self.decoder_layers = torch.nn.Sequential(torch.nn.Linear(self._feat_in, self._num_classes, bias=True)) self.apply(lambda x: init_weights(x, mode=init_mode))
[docs] def forward(self, encoder_output, **kwargs): batch, in_channels, timesteps = encoder_output.size() encoder_output = self.pooling(encoder_output).view(batch, in_channels) # [B, C] logits = self.decoder_layers(encoder_output) # [B, num_classes] if self._return_logits: return logits return torch.nn.functional.softmax(logits, dim=-1)
@property def num_classes(self): return self._num_classes
class ECAPAEncoder(NeuralModule, Exportable): """ Modified ECAPA Encoder layer without Res2Net module for faster training and inference which achieves better numbers on speaker diarization tasks Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf) input: feat_in: input feature shape (mel spec feature shape) filters: list of filter shapes for SE_TDNN modules kernel_sizes: list of kernel shapes for SE_TDNN modules dilations: list of dilations for group conv se layer scale: scale value to group wider conv channels (deafult:8) output: outputs : encoded output output_length: masked output lengths """ @property def input_types(self): """Returns definitions of module input ports.""" return OrderedDict( { "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), "length": NeuralType(tuple('B'), LengthsType()), } ) @property def output_types(self): """Returns definitions of module output ports.""" return OrderedDict( { "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), } ) def __init__( self, feat_in: int, filters: list, kernel_sizes: list, dilations: list, scale: int = 8, init_mode: str = 'xavier_uniform', ): super().__init__() self.layers = nn.ModuleList() self.layers.append(TDNNModule(feat_in, filters[0], kernel_size=kernel_sizes[0], dilation=dilations[0])) for i in range(len(filters) - 2): self.layers.append( TDNNSEModule( filters[i], filters[i + 1], group_scale=scale, se_channels=128, kernel_size=kernel_sizes[i + 1], dilation=dilations[i + 1], ) ) self.feature_agg = TDNNModule(filters[-1], filters[-1], kernel_sizes[-1], dilations[-1]) self.apply(lambda x: init_weights(x, mode=init_mode)) def forward(self, audio_signal, length=None): x = audio_signal outputs = [] for layer in self.layers: x = layer(x, length=length) outputs.append(x) x = torch.cat(outputs[1:], dim=1) x = self.feature_agg(x) return x, length
[docs] class SpeakerDecoder(NeuralModule, Exportable): """ Speaker Decoder creates the final neural layers that maps from the outputs of Jasper Encoder to the embedding layer followed by speaker based softmax loss. Args: feat_in (int): Number of channels being input to this module num_classes (int): Number of unique speakers in dataset emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings from 1st of this layers). Defaults to [1024,1024] pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention' Defaults to 'xvector (mean and variance)' tap (temporal average pooling: just mean) attention (attention based pooling) init_mode (str): Describes how neural network parameters are initialized. Options are ['xavier_uniform', 'xavier_normal', 'kaiming_uniform','kaiming_normal']. Defaults to "xavier_uniform". """
[docs] def input_example(self, max_batch=1, max_dim=256): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ input_example = torch.randn(max_batch, self.input_feat_in, max_dim).to(next(self.parameters()).device) return tuple([input_example])
@property def input_types(self): return OrderedDict( { "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "length": NeuralType(('B',), LengthsType(), optional=True), } ) @property def output_types(self): return OrderedDict( { "logits": NeuralType(('B', 'D'), LogitsType()), "embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()), } ) def __init__( self, feat_in: int, num_classes: int, emb_sizes: Optional[Union[int, list]] = 256, pool_mode: str = 'xvector', angular: bool = False, attention_channels: int = 128, init_mode: str = "xavier_uniform", ): super().__init__() self.angular = angular self.emb_id = 2 bias = False if self.angular else True emb_sizes = [emb_sizes] if type(emb_sizes) is int else emb_sizes self._num_classes = num_classes self.pool_mode = pool_mode.lower() if self.pool_mode == 'xvector' or self.pool_mode == 'tap': self._pooling = StatsPoolLayer(feat_in=feat_in, pool_mode=self.pool_mode) affine_type = 'linear' elif self.pool_mode == 'attention': self._pooling = AttentivePoolLayer(inp_filters=feat_in, attention_channels=attention_channels) affine_type = 'conv' shapes = [self._pooling.feat_in] for size in emb_sizes: shapes.append(int(size)) emb_layers = [] for shape_in, shape_out in zip(shapes[:-1], shapes[1:]): layer = self.affine_layer(shape_in, shape_out, learn_mean=False, affine_type=affine_type) emb_layers.append(layer) self.emb_layers = nn.ModuleList(emb_layers) self.final = nn.Linear(shapes[-1], self._num_classes, bias=bias) self.apply(lambda x: init_weights(x, mode=init_mode))
[docs] def affine_layer( self, inp_shape, out_shape, learn_mean=True, affine_type='conv', ): if affine_type == 'conv': layer = nn.Sequential( nn.BatchNorm1d(inp_shape, affine=True, track_running_stats=True), nn.Conv1d(inp_shape, out_shape, kernel_size=1), ) else: layer = nn.Sequential( nn.Linear(inp_shape, out_shape), nn.BatchNorm1d(out_shape, affine=learn_mean, track_running_stats=True), nn.ReLU(), ) return layer
[docs] @typecheck() def forward(self, encoder_output, length=None): pool = self._pooling(encoder_output, length) embs = [] for layer in self.emb_layers: pool, emb = layer(pool), layer[: self.emb_id](pool) embs.append(emb) pool = pool.squeeze(-1) if self.angular: for W in self.final.parameters(): W = F.normalize(W, p=2, dim=1) pool = F.normalize(pool, p=2, dim=1) out = self.final(pool) return out, embs[-1].squeeze(-1)
class ConvASREncoderAdapter(ConvASREncoder, adapter_mixins.AdapterModuleMixin): # Higher level forwarding def add_adapter(self, name: str, cfg: dict): for jasper_block in self.encoder: # type: adapter_mixins.AdapterModuleMixin cfg = self._update_adapter_cfg_input_dim(jasper_block, cfg) jasper_block.set_accepted_adapter_types(self.get_accepted_adapter_types()) jasper_block.add_adapter(name, cfg) def is_adapter_available(self) -> bool: return any([jasper_block.is_adapter_available() for jasper_block in self.encoder]) def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): for jasper_block in self.encoder: # type: adapter_mixins.AdapterModuleMixin jasper_block.set_enabled_adapters(name=name, enabled=enabled) def get_enabled_adapters(self) -> List[str]: names = set([]) for jasper_block in self.encoder: # type: adapter_mixins.AdapterModuleMixin names.update(jasper_block.get_enabled_adapters()) names = sorted(list(names)) return names def _update_adapter_cfg_input_dim(self, block: JasperBlock, cfg): cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=block.planes) return cfg def get_accepted_adapter_types( self, ) -> Set[type]: types = super().get_accepted_adapter_types() if len(types) == 0: self.set_accepted_adapter_types( [ adapter_utils.LINEAR_ADAPTER_CLASSPATH, ] ) types = self.get_accepted_adapter_types() return types @dataclass class JasperEncoderConfig: filters: int = MISSING repeat: int = MISSING kernel: List[int] = MISSING stride: List[int] = MISSING dilation: List[int] = MISSING dropout: float = MISSING residual: bool = MISSING # Optional arguments groups: int = 1 separable: bool = False heads: int = -1 residual_mode: str = "add" residual_dense: bool = False se: bool = False se_reduction_ratio: int = 8 se_context_size: int = -1 se_interpolation_mode: str = 'nearest' kernel_size_factor: float = 1.0 stride_last: bool = False @dataclass class ConvASREncoderConfig: _target_: str = 'nemo.collections.asr.modules.ConvASREncoder' jasper: Optional[List[JasperEncoderConfig]] = field(default_factory=list) activation: str = MISSING feat_in: int = MISSING normalization_mode: str = "batch" residual_mode: str = "add" norm_groups: int = -1 conv_mask: bool = True frame_splicing: int = 1 init_mode: Optional[str] = "xavier_uniform" @dataclass class ConvASRDecoderConfig: _target_: str = 'nemo.collections.asr.modules.ConvASRDecoder' feat_in: int = MISSING num_classes: int = MISSING init_mode: Optional[str] = "xavier_uniform" vocabulary: Optional[List[str]] = field(default_factory=list) @dataclass class ConvASRDecoderClassificationConfig: _target_: str = 'nemo.collections.asr.modules.ConvASRDecoderClassification' feat_in: int = MISSING num_classes: int = MISSING init_mode: Optional[str] = "xavier_uniform" return_logits: bool = True pooling_type: str = 'avg' """ Register any additional information """ if adapter_mixins.get_registered_adapter(ConvASREncoder) is None: adapter_mixins.register_adapter(base_class=ConvASREncoder, adapter_class=ConvASREncoderAdapter)