| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import contextlib |
| import math |
| from collections import defaultdict |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
|
|
| class SamePad(nn.Module): |
| def __init__(self, kernel_size, causal=False): |
| super().__init__() |
| if causal: |
| self.remove = kernel_size - 1 |
| else: |
| self.remove = 1 if kernel_size % 2 == 0 else 0 |
|
|
| def forward(self, x): |
| if self.remove > 0: |
| x = x[:, :, : -self.remove] |
| return x |
|
|
|
|
| class TransposeLast(nn.Module): |
| def __init__(self, deconstruct_idx=None, tranpose_dim=-2): |
| super().__init__() |
| self.deconstruct_idx = deconstruct_idx |
| self.tranpose_dim = tranpose_dim |
|
|
| def forward(self, x): |
| if self.deconstruct_idx is not None: |
| x = x[self.deconstruct_idx] |
| return x.transpose(self.tranpose_dim, -1) |
|
|
|
|
| class Swish(nn.Module): |
| def __init__(self): |
| super(Swish, self).__init__() |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| return inputs * inputs.sigmoid() |
|
|
|
|
| class GLU(nn.Module): |
| def __init__(self, dim: int) -> None: |
| super(GLU, self).__init__() |
| self.dim = dim |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| outputs, gate = inputs.chunk(2, dim=self.dim) |
| return outputs * gate.sigmoid() |
|
|
|
|
| class ResidualConnectionModule(nn.Module): |
| def __init__( |
| self, |
| module: nn.Module, |
| module_factor: float = 1.0, |
| input_factor: float = 1.0, |
| ): |
| super(ResidualConnectionModule, self).__init__() |
| self.module = module |
| self.module_factor = module_factor |
| self.input_factor = input_factor |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) |
|
|
|
|
| class Linear(nn.Module): |
| def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: |
| super(Linear, self).__init__() |
| self.linear = nn.Linear(in_features, out_features, bias=bias) |
| nn.init.xavier_uniform_(self.linear.weight) |
| if bias: |
| nn.init.zeros_(self.linear.bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.linear(x) |
|
|
|
|
| class View(nn.Module): |
| def __init__(self, shape: tuple, contiguous: bool = False): |
| super(View, self).__init__() |
| self.shape = shape |
| self.contiguous = contiguous |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.contiguous: |
| x = x.contiguous() |
|
|
| return x.view(*self.shape) |
|
|
|
|
| class Transpose(nn.Module): |
| def __init__(self, shape: tuple): |
| super(Transpose, self).__init__() |
| self.shape = shape |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x.transpose(*self.shape) |
|
|
|
|
| class FeedForwardModule(nn.Module): |
| def __init__( |
| self, |
| encoder_dim: int = 512, |
| expansion_factor: int = 4, |
| dropout_p: float = 0.1, |
| ) -> None: |
| super(FeedForwardModule, self).__init__() |
| self.sequential = nn.Sequential( |
| nn.LayerNorm(encoder_dim), |
| Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), |
| Swish(), |
| nn.Dropout(p=dropout_p), |
| Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), |
| nn.Dropout(p=dropout_p), |
| ) |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| return self.sequential(inputs) |
|
|
|
|
| class DepthwiseConv1d(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| stride: int = 1, |
| padding: int = 0, |
| bias: bool = False, |
| ) -> None: |
| super(DepthwiseConv1d, self).__init__() |
| assert ( |
| out_channels % in_channels == 0 |
| ), "out_channels should be constant multiple of in_channels" |
| self.conv = nn.Conv1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| groups=in_channels, |
| stride=stride, |
| padding=padding, |
| bias=bias, |
| ) |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| return self.conv(inputs) |
|
|
|
|
| class PointwiseConv1d(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| stride: int = 1, |
| padding: int = 0, |
| bias: bool = True, |
| ) -> None: |
| super(PointwiseConv1d, self).__init__() |
| self.conv = nn.Conv1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=1, |
| stride=stride, |
| padding=padding, |
| bias=bias, |
| ) |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| return self.conv(inputs) |
|
|
|
|
| class ConformerConvModule(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| kernel_size: int = 31, |
| expansion_factor: int = 2, |
| dropout_p: float = 0.1, |
| ) -> None: |
| super(ConformerConvModule, self).__init__() |
| assert ( |
| kernel_size - 1 |
| ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" |
| assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" |
|
|
| self.sequential = nn.Sequential( |
| nn.LayerNorm(in_channels), |
| Transpose(shape=(1, 2)), |
| PointwiseConv1d( |
| in_channels, |
| in_channels * expansion_factor, |
| stride=1, |
| padding=0, |
| bias=True, |
| ), |
| GLU(dim=1), |
| DepthwiseConv1d( |
| in_channels, |
| in_channels, |
| kernel_size, |
| stride=1, |
| padding=(kernel_size - 1) // 2, |
| ), |
| nn.BatchNorm1d(in_channels), |
| Swish(), |
| PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), |
| nn.Dropout(p=dropout_p), |
| ) |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| return self.sequential(inputs).transpose(1, 2) |
|
|
|
|
| class FramewiseConv2dSubampling(nn.Module): |
| def __init__(self, out_channels: int, subsample_rate: int = 2) -> None: |
| super(FramewiseConv2dSubampling, self).__init__() |
| assert subsample_rate in {2, 4}, "subsample_rate should be 2 or 4" |
| self.subsample_rate = subsample_rate |
| self.cnn = nn.Sequential( |
| nn.Conv2d(1, out_channels, kernel_size=3, stride=2), |
| nn.ReLU(), |
| nn.Conv2d( |
| out_channels, |
| out_channels, |
| kernel_size=3, |
| stride=(2 if subsample_rate == 4 else 1, 2), |
| padding=(0 if subsample_rate == 4 else 1, 0), |
| ), |
| nn.ReLU(), |
| ) |
|
|
| def forward( |
| self, inputs: torch.Tensor, input_lengths: torch.LongTensor |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| |
| if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0: |
| inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0) |
| outputs = self.cnn(inputs.unsqueeze(1)) |
| batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size() |
|
|
| outputs = outputs.permute(0, 2, 1, 3) |
| outputs = outputs.contiguous().view( |
| batch_size, subsampled_lengths, channels * sumsampled_dim |
| ) |
|
|
| if self.subsample_rate == 4: |
| output_lengths = (((input_lengths - 1) >> 1) - 1) >> 1 |
| else: |
| output_lengths = input_lengths >> 1 |
|
|
| return outputs, output_lengths |
|
|
|
|
| class PatchwiseConv2dSubampling(nn.Module): |
| def __init__( |
| self, |
| mel_dim: int, |
| out_channels: int, |
| patch_size_time: int = 16, |
| patch_size_freq: int = 16, |
| ) -> None: |
| super(PatchwiseConv2dSubampling, self).__init__() |
|
|
| self.mel_dim = mel_dim |
| self.patch_size_time = patch_size_time |
| self.patch_size_freq = patch_size_freq |
|
|
| self.proj = nn.Conv2d( |
| 1, |
| out_channels, |
| kernel_size=(patch_size_time, patch_size_freq), |
| stride=(patch_size_time, patch_size_freq), |
| padding=0, |
| ) |
| self.cnn = nn.Sequential( |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| ) |
|
|
| @property |
| def subsample_rate(self) -> int: |
| return self.patch_size_time * self.patch_size_freq // self.mel_dim |
|
|
| def forward( |
| self, inputs: torch.Tensor, input_lengths: torch.LongTensor |
| ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| assert ( |
| inputs.shape[2] == self.mel_dim |
| ), "inputs.shape[2] should be equal to mel_dim" |
|
|
| |
| outputs = self.proj(inputs.unsqueeze(1)) |
| outputs = self.cnn(outputs) |
| |
| outputs = outputs.flatten(2, 3).transpose(1, 2) |
| |
|
|
| output_lengths = ( |
| input_lengths |
| // self.patch_size_time |
| * (self.mel_dim // self.patch_size_freq) |
| ) |
|
|
| return outputs, output_lengths |
|
|
|
|
| class RelPositionalEncoding(nn.Module): |
| def __init__(self, d_model: int, max_len: int = 10000) -> None: |
| super(RelPositionalEncoding, self).__init__() |
| self.d_model = d_model |
| self.pe = None |
| self.extend_pe(torch.tensor(0.0).expand(1, max_len)) |
|
|
| def extend_pe(self, x: torch.Tensor) -> None: |
| if self.pe is not None: |
| if self.pe.size(1) >= x.size(1) * 2 - 1: |
| if self.pe.dtype != x.dtype or self.pe.device != x.device: |
| self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
| return |
|
|
| pe_positive = torch.zeros(x.size(1), self.d_model) |
| pe_negative = torch.zeros(x.size(1), self.d_model) |
| position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| * -(math.log(10000.0) / self.d_model) |
| ) |
| pe_positive[:, 0::2] = torch.sin(position * div_term) |
| pe_positive[:, 1::2] = torch.cos(position * div_term) |
| pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
| pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
|
|
| pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| pe_negative = pe_negative[1:].unsqueeze(0) |
| pe = torch.cat([pe_positive, pe_negative], dim=1) |
| self.pe = pe.to(device=x.device, dtype=x.dtype) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| self.extend_pe(x) |
| pos_emb = self.pe[ |
| :, |
| self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1), |
| ] |
| return pos_emb |
|
|
|
|
| class RelativeMultiHeadAttention(nn.Module): |
| def __init__( |
| self, |
| d_model: int = 512, |
| num_heads: int = 16, |
| dropout_p: float = 0.1, |
| ): |
| super(RelativeMultiHeadAttention, self).__init__() |
| assert d_model % num_heads == 0, "d_model % num_heads should be zero." |
| self.d_model = d_model |
| self.d_head = int(d_model / num_heads) |
| self.num_heads = num_heads |
| self.sqrt_dim = math.sqrt(self.d_head) |
|
|
| self.query_proj = Linear(d_model, d_model) |
| self.key_proj = Linear(d_model, d_model) |
| self.value_proj = Linear(d_model, d_model) |
| self.pos_proj = Linear(d_model, d_model, bias=False) |
|
|
| self.dropout = nn.Dropout(p=dropout_p) |
| self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) |
| self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) |
| torch.nn.init.xavier_uniform_(self.u_bias) |
| torch.nn.init.xavier_uniform_(self.v_bias) |
|
|
| self.out_proj = Linear(d_model, d_model) |
|
|
| def forward( |
| self, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| pos_embedding: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| batch_size = value.size(0) |
|
|
| query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) |
| key = ( |
| self.key_proj(key) |
| .view(batch_size, -1, self.num_heads, self.d_head) |
| .permute(0, 2, 1, 3) |
| ) |
| value = ( |
| self.value_proj(value) |
| .view(batch_size, -1, self.num_heads, self.d_head) |
| .permute(0, 2, 1, 3) |
| ) |
| pos_embedding = self.pos_proj(pos_embedding).view( |
| batch_size, -1, self.num_heads, self.d_head |
| ) |
|
|
| content_score = torch.matmul( |
| (query + self.u_bias).transpose(1, 2), key.transpose(2, 3) |
| ) |
| pos_score = torch.matmul( |
| (query + self.v_bias).transpose(1, 2), |
| pos_embedding.permute(0, 2, 3, 1), |
| ) |
| pos_score = self._relative_shift(pos_score) |
|
|
| score = (content_score + pos_score) / self.sqrt_dim |
|
|
| if mask is not None: |
| mask = mask.unsqueeze(1) |
| score.masked_fill_(mask, -1e9) |
|
|
| attn = F.softmax(score, -1) |
| attn = self.dropout(attn) |
|
|
| context = torch.matmul(attn, value).transpose(1, 2) |
| context = context.contiguous().view(batch_size, -1, self.d_model) |
|
|
| return self.out_proj(context), attn |
|
|
| def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor: |
| batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() |
| zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) |
| padded_pos_score = torch.cat([zeros, pos_score], dim=-1) |
|
|
| padded_pos_score = padded_pos_score.view( |
| batch_size, num_heads, seq_length2 + 1, seq_length1 |
| ) |
| pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)[ |
| :, :, :, : seq_length2 // 2 + 1 |
| ] |
|
|
| return pos_score |
|
|
|
|
| class MultiHeadedSelfAttentionModule(nn.Module): |
| def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1): |
| super(MultiHeadedSelfAttentionModule, self).__init__() |
| self.positional_encoding = RelPositionalEncoding(d_model) |
| self.layer_norm = nn.LayerNorm(d_model) |
| self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) |
| self.dropout = nn.Dropout(p=dropout_p) |
|
|
| def forward( |
| self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| batch_size = inputs.size(0) |
| pos_embedding = self.positional_encoding(inputs) |
| pos_embedding = pos_embedding.repeat(batch_size, 1, 1) |
|
|
| inputs = self.layer_norm(inputs) |
| outputs, attn = self.attention( |
| inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask |
| ) |
|
|
| return self.dropout(outputs), attn |
|
|
|
|
| class ConformerBlock(nn.Module): |
| def __init__( |
| self, |
| encoder_dim: int = 512, |
| attention_type: str = "mhsa", |
| num_attention_heads: int = 8, |
| mamba_d_state: int = 16, |
| mamba_d_conv: int = 4, |
| mamba_expand: int = 2, |
| mamba_bidirectional: bool = True, |
| feed_forward_expansion_factor: int = 4, |
| conv_expansion_factor: int = 2, |
| feed_forward_dropout_p: float = 0.1, |
| attention_dropout_p: float = 0.1, |
| conv_dropout_p: float = 0.1, |
| conv_kernel_size: int = 31, |
| half_step_residual: bool = True, |
| transformer_style: bool = False, |
| ): |
| super(ConformerBlock, self).__init__() |
|
|
| self.transformer_style = transformer_style |
| self.attention_type = attention_type |
|
|
| if half_step_residual and not transformer_style: |
| self.feed_forward_residual_factor = 0.5 |
| else: |
| self.feed_forward_residual_factor = 1 |
|
|
| assert attention_type in ["mhsa", "mamba"] |
| if attention_type == "mhsa": |
| attention = MultiHeadedSelfAttentionModule( |
| d_model=encoder_dim, |
| num_heads=num_attention_heads, |
| dropout_p=attention_dropout_p, |
| ) |
|
|
| self.ffn_1 = FeedForwardModule( |
| encoder_dim=encoder_dim, |
| expansion_factor=feed_forward_expansion_factor, |
| dropout_p=feed_forward_dropout_p, |
| ) |
| self.attention = attention |
| if not transformer_style: |
| self.conv = ConformerConvModule( |
| in_channels=encoder_dim, |
| kernel_size=conv_kernel_size, |
| expansion_factor=conv_expansion_factor, |
| dropout_p=conv_dropout_p, |
| ) |
| self.ffn_2 = FeedForwardModule( |
| encoder_dim=encoder_dim, |
| expansion_factor=feed_forward_expansion_factor, |
| dropout_p=feed_forward_dropout_p, |
| ) |
| self.layernorm = nn.LayerNorm(encoder_dim) |
|
|
| def forward( |
| self, x: torch.Tensor |
| ) -> Tuple[torch.Tensor, Dict[str, Union[torch.Tensor, None]]]: |
| |
| ffn_1_out = self.ffn_1(x) |
| x = ffn_1_out * self.feed_forward_residual_factor + x |
|
|
| |
| if not isinstance(self.attention, MultiHeadedSelfAttentionModule): |
| |
| attn_out = self.attention(x) |
| attn = None |
| else: |
| attn_out, attn = self.attention(x) |
| x = attn_out + x |
|
|
| if self.transformer_style: |
| x = self.layernorm(x) |
| return x, { |
| "ffn_1": ffn_1_out, |
| "attn": attn, |
| "conv": None, |
| "ffn_2": None, |
| } |
|
|
| |
| conv_out = self.conv(x) |
| x = conv_out + x |
|
|
| |
| ffn_2_out = self.ffn_2(x) |
| x = ffn_2_out * self.feed_forward_residual_factor + x |
| x = self.layernorm(x) |
|
|
| other = { |
| "ffn_1": ffn_1_out, |
| "attn": attn, |
| "conv": conv_out, |
| "ffn_2": ffn_2_out, |
| } |
|
|
| return x, other |
|
|
|
|
| class ConformerEncoder(nn.Module): |
| def __init__(self, cfg): |
| super(ConformerEncoder, self).__init__() |
|
|
| self.cfg = cfg |
| self.framewise_subsample = None |
| self.patchwise_subsample = None |
| self.framewise_in_proj = None |
| self.patchwise_in_proj = None |
| assert ( |
| cfg.use_framewise_subsample or cfg.use_patchwise_subsample |
| ), "At least one subsampling method should be used" |
| if cfg.use_framewise_subsample: |
| self.framewise_subsample = FramewiseConv2dSubampling( |
| out_channels=cfg.conv_subsample_channels, |
| subsample_rate=cfg.conv_subsample_rate, |
| ) |
| self.framewise_in_proj = nn.Sequential( |
| Linear( |
| cfg.conv_subsample_channels * (((cfg.input_dim - 1) // 2 - 1) // 2), |
| cfg.encoder_dim, |
| ), |
| nn.Dropout(p=cfg.input_dropout_p), |
| ) |
| if cfg.use_patchwise_subsample: |
| self.patchwise_subsample = PatchwiseConv2dSubampling( |
| mel_dim=cfg.input_dim, |
| out_channels=cfg.conv_subsample_channels, |
| patch_size_time=cfg.patch_size_time, |
| patch_size_freq=cfg.patch_size_freq, |
| ) |
| self.patchwise_in_proj = nn.Sequential( |
| Linear( |
| cfg.conv_subsample_channels, |
| cfg.encoder_dim, |
| ), |
| nn.Dropout(p=cfg.input_dropout_p), |
| ) |
| assert not cfg.use_framewise_subsample or ( |
| cfg.conv_subsample_rate == self.patchwise_subsample.subsample_rate |
| ), ( |
| f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate" |
| f"({self.patchwise_subsample.subsample_rate})" |
| ) |
|
|
| self.framewise_norm, self.patchwise_norm = None, None |
| if getattr(cfg, "subsample_normalization", False): |
| if cfg.use_framewise_subsample: |
| self.framewise_norm = nn.LayerNorm(cfg.encoder_dim) |
| if cfg.use_patchwise_subsample: |
| self.patchwise_norm = nn.LayerNorm(cfg.encoder_dim) |
|
|
| self.conv_pos = None |
| if getattr(cfg, "conv_pos", False): |
| num_pos_layers = cfg.conv_pos_depth |
| k = max(3, cfg.conv_pos_width // num_pos_layers) |
| self.conv_pos = nn.Sequential( |
| TransposeLast(), |
| *[ |
| nn.Sequential( |
| nn.Conv1d( |
| cfg.encoder_dim, |
| cfg.encoder_dim, |
| kernel_size=k, |
| padding=k // 2, |
| groups=cfg.conv_pos_groups, |
| ), |
| SamePad(k), |
| TransposeLast(), |
| nn.LayerNorm(cfg.encoder_dim, elementwise_affine=False), |
| TransposeLast(), |
| nn.GELU(), |
| ) |
| for _ in range(num_pos_layers) |
| ], |
| TransposeLast(), |
| ) |
| self.conv_pos_post_ln = nn.LayerNorm(cfg.encoder_dim) |
|
|
| self.layers = nn.ModuleList( |
| [ |
| ConformerBlock( |
| encoder_dim=cfg.encoder_dim, |
| attention_type=cfg.attention_type, |
| num_attention_heads=cfg.num_attention_heads, |
| mamba_d_state=cfg.mamba_d_state, |
| mamba_d_conv=cfg.mamba_d_conv, |
| mamba_expand=cfg.mamba_expand, |
| mamba_bidirectional=cfg.mamba_bidirectional, |
| feed_forward_expansion_factor=cfg.feed_forward_expansion_factor, |
| conv_expansion_factor=cfg.conv_expansion_factor, |
| feed_forward_dropout_p=cfg.feed_forward_dropout_p, |
| attention_dropout_p=cfg.attention_dropout_p, |
| conv_dropout_p=cfg.conv_dropout_p, |
| conv_kernel_size=cfg.conv_kernel_size, |
| half_step_residual=cfg.half_step_residual, |
| transformer_style=getattr(cfg, "transformer_style", False), |
| ) |
| for _ in range(cfg.num_layers) |
| ] |
| ) |
|
|
| def count_parameters(self) -> int: |
| """Count parameters of encoder""" |
| return sum([p.numel() for p in self.parameters() if p.requires_grad]) |
|
|
| def update_dropout(self, dropout_p: float) -> None: |
| """Update dropout probability of encoder""" |
| for name, child in self.named_children(): |
| if isinstance(child, nn.Dropout): |
| child.p = dropout_p |
|
|
| def forward( |
| self, |
| inputs: torch.Tensor, |
| input_lengths: Optional[torch.Tensor] = None, |
| return_hidden: bool = False, |
| freeze_input_layers: bool = False, |
| target_layer: Optional[int] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, List[torch.Tensor]]]: |
| if input_lengths is None: |
| input_lengths = torch.full( |
| (inputs.size(0),), |
| inputs.size(1), |
| dtype=torch.long, |
| device=inputs.device, |
| ) |
|
|
| with torch.no_grad() if freeze_input_layers else contextlib.ExitStack(): |
| frame_feat, patch_feat = None, None |
| if self.framewise_subsample is not None: |
| frame_feat, frame_lengths = self.framewise_subsample( |
| inputs, input_lengths |
| ) |
| frame_feat = self.framewise_in_proj(frame_feat) |
| if self.framewise_norm is not None: |
| frame_feat = self.framewise_norm(frame_feat) |
|
|
| if self.patchwise_subsample is not None: |
| patch_feat, patch_lengths = self.patchwise_subsample( |
| inputs, input_lengths |
| ) |
| patch_feat = self.patchwise_in_proj(patch_feat) |
| if self.patchwise_norm is not None: |
| patch_feat = self.patchwise_norm(patch_feat) |
|
|
| if frame_feat is not None and patch_feat is not None: |
| min_len = min(frame_feat.size(1), patch_feat.size(1)) |
| frame_feat = frame_feat[:, :min_len] |
| patch_feat = patch_feat[:, :min_len] |
|
|
| features = frame_feat + patch_feat |
| output_lengths = ( |
| frame_lengths |
| if frame_lengths.max().item() < patch_lengths.max().item() |
| else patch_lengths |
| ) |
| elif frame_feat is not None: |
| features = frame_feat |
| output_lengths = frame_lengths |
| else: |
| features = patch_feat |
| output_lengths = patch_lengths |
|
|
| if self.conv_pos is not None: |
| features = features + self.conv_pos(features) |
| features = self.conv_pos_post_ln(features) |
|
|
| layer_results = defaultdict(list) |
|
|
| outputs = features |
| for i, layer in enumerate(self.layers): |
| outputs, other = layer(outputs) |
| if return_hidden: |
| layer_results["hidden_states"].append(outputs) |
| for k, v in other.items(): |
| layer_results[k].append(v) |
|
|
| if target_layer is not None and i == target_layer: |
| break |
|
|
| return outputs, output_lengths, layer_results |
|
|