Source code for speeq.models.decoders

"""This module contains various pre-implemented decoders used in differnet models.

Classes:

- GlobAttRNNDecoder: Implements a RNN decoder with global attention mechanism.
- LocationAwareAttDecoder: Implements a RNN decoder with location aware attention mechanism.
- TransducerRNNDecoder: Implements a simple RNN decoder with an embedding layer and a single RNN layer.
- TransformerDecoder: Implements a transformer decoder.
- TransformerTransducerDecoder: Implements a Transformer-Transducer decoder.
"""
from typing import Tuple, Union

import torch
from torch import Tensor, nn

from speeq.constants import DECODER_OUT_KEY, ENC_OUT_KEY, HIDDEN_STATE_KEY, PREDS_KEY

from .layers import (
    GlobalMulAttention,
    LocAwareGlobalAddAttention,
    PositionalEmbedding,
    PredModule,
    SpeechTransformerDecLayer,
    TransformerDecLayer,
    TransformerTransducerLayer,
)


[docs]class GlobAttRNNDecoder(nn.Module): """Implements RNN decoder with global attention. Args: embed_dim (int): The size of the embedding. hidden_size (int): The size of the RNN hidden state. n_layers (int): The number of RNN layers. n_classes (int): The number of output classes. pred_activation (Module): An instance of an activation function. teacher_forcing_rate (float): The teacher forcing rate. Default 0.0 rnn_type (str): The RNN type it has to be one of rnn, gru or lstm. Default 'rnn'. """ def __init__( self, embed_dim: int, hidden_size: int, n_layers: int, n_classes: int, pred_activation: nn.Module, teacher_forcing_rate: float = 0.0, rnn_type: str = "rnn", ) -> None: super().__init__() self.emb = nn.Embedding(num_embeddings=n_classes, embedding_dim=embed_dim) from .registry import RNN_REGISTRY self.rnn_layers = nn.ModuleList( [ RNN_REGISTRY[rnn_type]( input_size=hidden_size, hidden_size=hidden_size, batch_first=True, bidirectional=False, ) for i in range(n_layers) ] ) self.fc_layers = nn.ModuleList( [ nn.Linear( in_features=hidden_size + embed_dim if i == 0 else 2 * hidden_size, out_features=hidden_size, ) for i in range(n_layers) ] ) self.att_layers = nn.ModuleList( [ GlobalMulAttention(enc_feat_size=hidden_size, dec_feat_size=hidden_size) for _ in range(n_layers) ] ) self.pred_net = PredModule( in_features=hidden_size, n_classes=n_classes, activation=pred_activation ) self.hidden_size = hidden_size self.n_classes = n_classes self.is_lstm = rnn_type == "lstm" self.teacher_forcing_rate = teacher_forcing_rate def _apply_teacher_forcing(self, y: Tensor, preds: Tensor) -> Tensor: # y of shape [B, 1] # preds of shape [B, 1] """Applies teacher forcing on the decoder's input. Args: y (Tensor): The original target labels. preds (Tensor): The latest prediction. Returns: Tensor: The new decoder input tensor. """ mask = torch.rand(y.shape[0]) <= self.teacher_forcing_rate mask = mask.to(y.device) mask = mask.unsqueeze(dim=-1) return mask * y + (~mask) * preds def _init_hidden_state(self, batch_size, device): if self.is_lstm: return ( torch.zeros(1, batch_size, self.hidden_size).to(device), torch.zeros(1, batch_size, self.hidden_size).to(device), ) return torch.zeros(1, batch_size, self.hidden_size).to(device)
[docs] def forward( self, h: Union[Tensor, Tuple[Tensor, Tensor], None], enc_out: Tensor, enc_mask: Tensor, dec_inp: Tensor, *args, **kwargs ) -> Tensor: """Decodes the input regressivly. Args: h (Union[Tensor, Tuple[Tensor, Tensor], None]): The last hidden state of the encoder. If not provided, set as None. Its shape is [1, B, hidden_size]. enc_out (Tensor): The encoder output tensor of shape [B, M, h]. enc_mask (Tensor): The encoder mask tensor of shape [B, M], where True denotes data positions and False denotes padding ones. dec_inp (Tensor): The decoder input tensor of shape [B, M_dec]. Returns: Tensor: A tensor of shape [B, M_dec, C], representing the output of the forward pass. """ batch_size, max_len = dec_inp.shape if h is None: h = self._init_hidden_state(batch_size=batch_size, device=dec_inp.device) results = None out = self.emb(dec_inp[:, 0:1]) h = [h] * len(self.rnn_layers) for i in range(max_len): layers = enumerate(zip(self.fc_layers, self.rnn_layers, self.att_layers)) for j, (fc, rnn, att) in layers: h_ = h[j] if self.is_lstm: (h_, c_) = h_ h_ = h_.permute(1, 0, 2) out = torch.cat([out, h_], dim=-1) out = fc(out) out = att(key=enc_out, query=out, mask=enc_mask) out, h[j] = rnn(out, h[j]) out = self.pred_net(out) results = out if results is None else torch.cat([results, out], dim=1) y = torch.argmax(out, dim=-1) if self.teacher_forcing_rate > 0: y = self._apply_teacher_forcing(y=dec_inp[:, i : i + 1], preds=y) out = self.emb(y) return results
[docs] def predict(self, state: dict) -> Tuple[Tensor, dict, Tensor]: enc_out = state[ENC_OUT_KEY] preds = state[PREDS_KEY] # [B, M] h = state[HIDDEN_STATE_KEY] last_pred = preds[:, -1:] if isinstance(h, list) is False: # for the first prediction iteration h = [h] * len(self.rnn_layers) out = self.emb(last_pred) layers = enumerate(zip(self.fc_layers, self.rnn_layers, self.att_layers)) for i, (fc, rnn, att) in layers: h_ = h[i] if self.is_lstm: (h_, c_) = h_ h_ = h_.permute(1, 0, 2) out = torch.cat([out, h_], dim=-1) out = fc(out) out = att(key=enc_out, query=out, mask=None) out, h[i] = rnn(out, h[i]) out = self.pred_net(out) state[PREDS_KEY] = torch.cat( [state[PREDS_KEY], torch.argmax(out, dim=-1)], dim=-1 ) state[HIDDEN_STATE_KEY] = h return state
[docs]class LocationAwareAttDecoder(GlobAttRNNDecoder): """Implements RNN decoder with location aware attention. Args: embed_dim (int): The embedding size. hidden_size (int): The RNN hidden size. n_layers (int): The number of RNN layers. n_classes (int): The number of classes. pred_activation (Module): An activation function instance. kernel_size (int): The attention kernel size. activation (str): The activation function to use. it can be either softmax or sigmax. inv_temperature (Union[float, int]): The inverse temperature value. Default 1. teacher_forcing_rate (float): The teacher forcing rate. Default 0.0 rnn_type (str): The RNN type it has to be one of rnn, gru or lstm. Default 'rnn'. """ def __init__( self, embed_dim: int, hidden_size: int, n_layers: int, n_classes: int, pred_activation: nn.Module, kernel_size: int, activation: str, inv_temperature: Union[float, int] = 1, teacher_forcing_rate: float = 0.0, rnn_type: str = "rnn", ) -> None: super().__init__( embed_dim=embed_dim, hidden_size=hidden_size, n_layers=n_layers, n_classes=n_classes, pred_activation=pred_activation, teacher_forcing_rate=teacher_forcing_rate, rnn_type=rnn_type, ) self.att_layers = nn.ModuleList( [ LocAwareGlobalAddAttention( enc_feat_size=hidden_size, dec_feat_size=hidden_size, kernel_size=kernel_size, activation=activation, inv_temperature=inv_temperature, ) for _ in range(n_layers) ] )
[docs] def forward( self, h: Union[Tensor, Tuple[Tensor, Tensor], None], enc_out: Tensor, enc_mask: Tensor, dec_inp: Tensor, *args, **kwargs ) -> Tensor: """Runs the forward pass on the input. Args: h (Union[Tensor, Tuple[Tensor, Tensor], None]): The last hidden state of the encoder if provided, which is of shape [1, B_enc, hidden_size]. enc_out (Tensor): The encoder outputs of shape [B, M, h]. enc_mask (Tensor): The encoder mask of shape [B, M], which is True for the data positions and False for the padding ones. dec_inp (Tensor): The decoder input of shape [B, M_dec]. Returns: A tensor of shape [B, M_dec, C], which is the output of the LocationAwareAttDecoder module. """ batch_size, max_len = dec_inp.shape results = None if h is None: h = self._init_hidden_state(batch_size=batch_size, device=dec_inp.device) alpha = torch.zeros(batch_size, 1, enc_out.shape[1]).to(enc_out.device) out = self.emb(dec_inp[:, 0:1]) h = [h] * len(self.rnn_layers) for i in range(max_len): for j, (fc, rnn, att) in enumerate( zip(self.fc_layers, self.rnn_layers, self.att_layers) ): h_ = h[j] if self.is_lstm: (h_, c_) = h_ h_ = h_.permute(1, 0, 2) out = torch.cat([out, h_], dim=-1) out = fc(out) out, alpha = att(key=enc_out, query=out, alpha=alpha, mask=enc_mask) out, h[j] = rnn(out, h[j]) out = self.pred_net(out) results = out if results is None else torch.cat([results, out], dim=1) y = torch.argmax(out, dim=-1) if self.teacher_forcing_rate > 0: y = self._apply_teacher_forcing(y=dec_inp[:, i : i + 1], preds=y) out = self.emb(y) return results
[docs] def predict(self, state: dict) -> Tuple[Tensor, dict, Tensor]: alpha_key = "alpha" enc_out = state[ENC_OUT_KEY] batch_size, _, hidden_size = enc_out.shape last_pred = state[PREDS_KEY][:, -1:] h = state[HIDDEN_STATE_KEY] alpha = state.get(alpha_key, torch.zeros(batch_size, 1, hidden_size)) alpha = alpha.to(enc_out.device) if isinstance(h, list) is False: h = [h] * len(self.rnn_layers) out = self.emb(last_pred) for i, (fc, rnn, att) in enumerate( zip(self.fc_layers, self.rnn_layers, self.att_layers) ): h_ = h[i] if self.is_lstm: (h_, c_) = h_ h_ = h_.permute(1, 0, 2) out = torch.cat([out, h_], dim=-1) out = fc(out) out, alpha = att(key=enc_out, query=out, alpha=alpha, mask=None) out, h[i] = rnn(out, h[i]) out = self.pred_net(out) state[PREDS_KEY] = torch.cat( [state[PREDS_KEY], torch.argmax(out, dim=-1)], dim=-1 ) state[HIDDEN_STATE_KEY] = h state[alpha_key] = alpha return state
[docs]class TransducerRNNDecoder(nn.Module): """Builds a simple RNN-decoder that contains embedding layer and a single RNN layer Args: vocab_size (int): The vocabulary size. emb_dim (int): The embedding dimension. hidden_size (int): The RNN's hidden size. rnn_type (str): The RNN type it has to be one of rnn, gru or lstm. Default 'rnn'. n_layers (int): The number of RNN layers to use. Default 1. """ def __init__( self, vocab_size: int, emb_dim: int, hidden_size: int, rnn_type: str, n_layers: int = 1, ) -> None: super().__init__() self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=emb_dim) from .registry import PACKED_RNN_REGISTRY self.layers = nn.ModuleList( [ PACKED_RNN_REGISTRY[rnn_type]( input_size=emb_dim if i == 0 else hidden_size, hidden_size=hidden_size, batch_first=True, enforce_sorted=False, bidirectional=False, ) for i in range(n_layers) ] )
[docs] def forward( self, x: Tensor, mask: Tensor, ) -> Tuple[Tensor, Tensor]: """Runs the input tensor through the RNN transducer decoder. Args: x (Tensor): The input tensor of shape [B, M]. mask (Tensor): The encoder mask of shape [B, M]. It is True for data positions and False for padding ones. Returns: Tuple[Tensor, Tensor]: A tuple containing two tensors. The first tensor is the output tensor of shape [B, M, hidden_size] and the second tensor is the length tensor of shape [B], representing the actual length of each input sequence in the batch. """ lengths = mask.sum(dim=-1).cpu() out = self.emb(x) for rnn in self.layers: out, _, lens = rnn(out, lengths) return out, lens
[docs] def predict(self, state: dict) -> dict: h = state[HIDDEN_STATE_KEY] if not isinstance(h, list): h = [h] * len(self.layers) last_pred = state[PREDS_KEY][:, -1:] lens = torch.ones(last_pred.shape[0], dtype=torch.long) out = self.emb(last_pred) for i, rnn in enumerate(self.layers): out, h[i], _ = rnn(out, lens, h[i]) state[HIDDEN_STATE_KEY] = h state[DECODER_OUT_KEY] = out return state
[docs]class TransformerDecoder(nn.Module): """Implements the transformer decoder as described in https://arxiv.org/abs/1706.03762 Args: n_classes (int): The number of classes the model will predict. n_layers (int): The number of decoder layers. d_model (int): The model dimensionality. ff_size (int): The dimensionality of the feed-forward inner layer. h (int): The number of attention heads. pred_activation (Module): An activation function instance. masking_value (int): The attentin masking value. Default -1e15 """ def __init__( self, n_classes: int, n_layers: int, d_model: int, ff_size: int, h: int, pred_activation: nn.Module, masking_value: int = -1e15, ) -> None: super().__init__() self.emb = PositionalEmbedding(vocab_size=n_classes, embed_dim=d_model) self.layers = nn.ModuleList( [ TransformerDecLayer( d_model=d_model, ff_size=ff_size, h=h, masking_value=masking_value ) for _ in range(n_layers) ] ) self.pred_net = PredModule( in_features=d_model, n_classes=n_classes, activation=pred_activation )
[docs] def forward( self, enc_out: Tensor, enc_mask: Union[Tensor, None], dec_inp: Tensor, dec_mask: Union[Tensor, None], *args, **kwargs ) -> Tensor: """Passes the inputs through the transformer decoder. Args: enc_out (Tensor): The output tensor of the encoder of shape [B, M_enc, d]. enc_mask (Union[Tensor, None]): The encoder mask of shape [B, M_enc], which is True for the data positions and False for the padding ones. dec_inp (Tensor): The input tensor of the decoder of shape [B, M_dec]. dec_mask (Union[Tensor, None]): The decoder mask of shape [B, M_dec], which is True for the data positions and False for the padding ones. Returns: Tensor: The output tensor of the transformer decoder of shape [B, M_dec, C]. """ out = self.emb(dec_inp) for layer in self.layers: out = layer( enc_out=enc_out, enc_mask=enc_mask, dec_inp=out, dec_mask=dec_mask ) out = self.pred_net(out) return out
[docs] def predict(self, state: dict) -> dict: preds = state[PREDS_KEY] out = self.emb(preds) for layer in self.layers: out = layer( enc_out=state[ENC_OUT_KEY], enc_mask=None, dec_inp=out, dec_mask=None ) out = self.pred_net(out[:, -1:, :]) last_pred = torch.argmax(out, dim=-1) state[PREDS_KEY] = torch.cat([state[PREDS_KEY], last_pred], dim=-1) return state
[docs]class SpeechTransformerDecoder(TransformerDecoder): """Implements the speech transformer decoder as described in https://ieeexplore.ieee.org/document/8462506 Args: n_classes (int): The number of classes the model will predict. n_layers (int): The number of decoder layers. d_model (int): The model dimensionality. ff_size (int): The dimensionality of the feed-forward inner layer. h (int): The number of attention heads. pred_activation (Module): An activation function instance. masking_value (int): The attentin masking value. Default -1e15 """ def __init__( self, n_classes: int, n_layers: int, d_model: int, ff_size: int, h: int, pred_activation: nn.Module, masking_value: int = -1e15, ) -> None: super().__init__( n_classes, n_layers, d_model, ff_size, h, pred_activation, masking_value ) self.layers = nn.ModuleList( [ SpeechTransformerDecLayer( d_model=d_model, ff_size=ff_size, h=h, masking_value=masking_value ) for _ in range(n_layers) ] ) self.layer_norm = nn.LayerNorm(normalized_shape=d_model)
[docs] def forward( self, enc_out: Tensor, enc_mask: Union[Tensor, None], dec_inp: Tensor, dec_mask: Union[Tensor, None], *args, **kwargs ) -> Tensor: """Passes the inputs through the speech transformer decoder. Args: enc_out (Tensor): The output tensor of the encoder of shape [B, M_enc, d]. enc_mask (Union[Tensor, None]): The encoder mask of shape [B, M_enc], which is True for the data positions and False for the padding ones. dec_inp (Tensor): The input tensor of the decoder of shape [B, M_dec]. dec_mask (Union[Tensor, None]): The decoder mask of shape [B, M_dec], which is True for the data positions and False for the padding ones. Returns: Tensor: The output tensor of the speech transformer decoder of shape [B, M_dec, C]. """ out = self.emb(dec_inp) for layer in self.layers: out = layer( enc_out=enc_out, enc_mask=enc_mask, dec_inp=out, dec_mask=dec_mask ) out = self.layer_norm(out) out = self.pred_net(out) return out
[docs] def predict(self, state: dict) -> dict: preds = state[PREDS_KEY] out = self.emb(preds) for layer in self.layers: out = layer( enc_out=state[ENC_OUT_KEY], enc_mask=None, dec_inp=out, dec_mask=None ) out = self.layer_norm(out) out = self.pred_net(out[:, -1:, :]) last_pred = torch.argmax(out, dim=-1) state[PREDS_KEY] = torch.cat([state[PREDS_KEY], last_pred], dim=-1) return state
[docs]class TransformerTransducerDecoder(nn.Module): """Implements the Transformer-Transducer decoder with relative truncated multi-head self attention as described in https://arxiv.org/abs/2002.02562 Args: vocab_size (int): The vocabulary size. n_layers (int): The number of transformer encoder layers with truncated self attention and relative positional encoding. d_model (int): The model dimensionality. ff_size (int): The feed forward inner layer dimensionality. h (int): The number of heads in the attention mechanism. left_size (int): The size of the left window that each time step is allowed to look at. right_size (int): The size of the right window that each time step is allowed to look at. p_dropout (float): The dropout rate. masking_value (float, optional): The value to use for masking padded elements. Defaults to -1e15. """ def __init__( self, vocab_size: int, n_layers: int, d_model: int, ff_size: int, h: int, left_size: int, right_size: int, p_dropout: float, masking_value: int = -1e15, ) -> None: super().__init__() self.emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) self.enc_layers = nn.ModuleList( [ TransformerTransducerLayer( d_model=d_model, ff_size=ff_size, h=h, left_size=left_size, right_size=right_size, p_dropout=p_dropout, masking_value=masking_value, ) for _ in range(n_layers) ] )
[docs] def forward(self, x: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]: """Passes the input `x` through the decoder layers. Args: x (Tensor): The input tensor of shape [B, M] mask (Tensor): The input boolean mask of shape [B, M], where it's True if there is no padding. Returns: Tuple[Tensor, Tensor]: A tuple where the first element is the encoded text of shape [B, M, d_model] and the second element is the lengths of shape [B]. """ lengths = mask.sum(dim=-1) out = self.emb(x) for layer in self.enc_layers: out = layer(out, mask) return out, lengths