Source code for speeq.models.transducers
"""The transducer module provides implementations for different models used
in speech recognition based on the transducer architecture.
Classes:
- RNNTransducer: An implementation of the RNN transducer model.
- ConformerTransducer: An implementation of the Conformer transducer model.
- ContextNet: An implementation of the ContextNet transducer model.
- VGGTransformerTransducer: An implementation of the VGGTransformer transducer model with truncated self attention.
"""
from typing import List, Tuple, Union
import torch
from torch import Tensor, nn
from speeq.constants import (
DECODER_OUT_KEY,
ENC_OUT_KEY,
HIDDEN_STATE_KEY,
PREDS_KEY,
PREV_HIDDEN_STATE_KEY,
SPEECH_IDX_KEY,
)
from .decoders import TransducerRNNDecoder, TransformerTransducerDecoder
from .encoders import (
ConformerEncoder,
ContextNetEncoder,
RNNEncoder,
TransformerTransducerEncoder,
VGGTransformerEncoder,
)
class _BaseTransducer(nn.Module):
def __init__(self, feat_size: int, n_classes: int) -> None:
super().__init__()
self.join_net = nn.Linear(in_features=feat_size, out_features=n_classes)
def forward(
self,
speech: Tensor,
speech_mask: Tensor,
text: Tensor,
text_mask: Tensor,
*args,
**kwargs
) -> Tuple[Tensor, Tensor, Tensor]:
"""Passes the input to the model
Args:
speech (Tensor): The input speech of shape [B, M, d]
speech_mask (Tensor): The speech mask of shape [B, M]
text (Tensor): The text input of shape [B, N]
text_mask (Tensor): The text mask of shape [B, N]
Returns:
Tuple[Tensor, Tensor, Tensor]: A tuple of 3 tensors where the first
is the predictions of shape [B, M, N, C], the last two tensor are
the speech and text length of shape [B]
"""
speech, speech_len = self.encoder(speech, speech_mask)
text, text_len = self.decoder(text, text_mask)
result = self._join(encoder_out=speech, deocder_out=text)
speech_len, text_len = (
speech_len.to(speech.device),
text_len.to(speech.device),
)
return result, speech_len, text_len
def _join(self, encoder_out: Tensor, deocder_out: Tensor, training=True) -> Tensor:
if training:
encoder_out = encoder_out.unsqueeze(-2)
deocder_out = deocder_out.unsqueeze(1)
result = encoder_out + deocder_out
result = self.join_net(result)
return result
def predict(self, x: Tensor, mask: Tensor, state: dict) -> dict:
if ENC_OUT_KEY not in state:
state[ENC_OUT_KEY], _ = self.encoder(x, mask)
state[SPEECH_IDX_KEY] = 0
state[HIDDEN_STATE_KEY] = None
last_hidden_state = state[HIDDEN_STATE_KEY]
state = self.decoder.predict(state)
speech_idx = state[SPEECH_IDX_KEY]
out = (
state[DECODER_OUT_KEY]
+ state[ENC_OUT_KEY][:, speech_idx : speech_idx + 1, :]
)
out = self.join_net(out)
out = torch.nn.functional.log_softmax(out, dim=-1)
out = torch.argmax(out, dim=-1)
state[PREDS_KEY] = torch.cat([state[PREDS_KEY], out], dim=-1)
state[PREV_HIDDEN_STATE_KEY] = last_hidden_state
return state
[docs]class RNNTransducer(_BaseTransducer):
"""Implements the RNN transducer model proposed in
https://arxiv.org/abs/1211.3711
Args:
in_features (int): The input feature size.
n_classes (int): The number of classes/vocabulary.
emb_dim (int): The embedding layer's size.
n_layers (int): The number of the RNN layers in the encoder.
n_dec_layers (int): The number of RNNs in the decoder (predictor).
hidden_size (int): The hidden size of the RNN layers.
bidirectional (bool): A flag indicating if the rnn is bidirectional or not.
rnn_type (str): The RNN type.
p_dropout (float): The dropout rate.
"""
def __init__(
self,
in_features: int,
n_classes: int,
emb_dim: int,
n_layers: int,
n_dec_layers: int,
hidden_size: int,
bidirectional: bool,
rnn_type: str,
p_dropout: float,
) -> None:
super().__init__(feat_size=hidden_size, n_classes=n_classes)
self.encoder = RNNEncoder(
in_features=in_features,
hidden_size=hidden_size,
bidirectional=bidirectional,
n_layers=n_layers,
p_dropout=p_dropout,
rnn_type=rnn_type,
)
self.decoder = TransducerRNNDecoder(
vocab_size=n_classes,
emb_dim=emb_dim,
hidden_size=hidden_size,
rnn_type=rnn_type,
n_layers=n_dec_layers,
)
[docs]class ConformerTransducer(RNNTransducer):
"""Implements the conformer transducer model proposed in
https://arxiv.org/abs/2005.08100
Args:
d_model (int): The model dimension.
n_conf_layers (int): The number of conformer blocks.
n_dec_layers (int): The number of RNNs in the decoder (predictor).
ff_expansion_factor (int): The feed-forward expansion factor.
h (int): The number of attention heads.
kernel_size (int): The convolution module kernel size.
ss_kernel_size (int): The subsampling layer kernel size.
ss_stride (int): The subsampling layer stride size.
ss_num_conv_layers (int): The number of subsampling convolutional layers.
in_features (int): The input/speech feature size.
res_scaling (float): The residual connection multiplier.
n_classes (int): The number of classes/vocabulary.
emb_dim (int): The embedding layer's size.
rnn_type (str): The RNN type it has to be one of rnn, gru or lstm.
p_dropout (float): The dropout rate.
"""
def __init__(
self,
d_model: int,
n_conf_layers: int,
n_dec_layers: int,
ff_expansion_factor: int,
h: int,
kernel_size: int,
ss_kernel_size: int,
ss_stride: int,
ss_num_conv_layers: int,
in_features: int,
res_scaling: float,
n_classes: int,
emb_dim: int,
rnn_type: str,
p_dropout: float,
) -> None:
super().__init__(
in_features,
n_classes,
emb_dim,
1,
n_dec_layers,
d_model,
False,
rnn_type,
p_dropout,
)
self.encoder = ConformerEncoder(
d_model=d_model,
n_conf_layers=n_conf_layers,
ff_expansion_factor=ff_expansion_factor,
h=h,
kernel_size=kernel_size,
ss_kernel_size=ss_kernel_size,
ss_stride=ss_stride,
ss_num_conv_layers=ss_num_conv_layers,
in_features=in_features,
res_scaling=res_scaling,
p_dropout=p_dropout,
)
[docs]class ContextNet(_BaseTransducer):
"""Implements the ContextNet transducer model proposed in
https://arxiv.org/abs/2005.03191
Args:
in_features (int): The input feature size.
n_classes (int): The number of classes/vocabulary.
emb_dim (int): The embedding layer's size.
n_layers (int): The number of ContextNet blocks.
n_dec_layers (int): The number of RNNs in the decoder (predictor).
n_sub_layers (Union[int, List[int]]): The number of convolutional
layers per block. If list is passed, it has to be of length equal to `n_layers`.
stride (Union[int, List[int]]): The stride of the last convolutional
layers per block. If list is passed, it has to be of length equal to
`n_layers`.
out_channels (Union[int, List[int]]): The channels size of the
convolutional layers per block. If list is passed, it has to be of
length equal to `n_layers`.
kernel_size (int): The convolutional layers kernel size.
reduction_factor (int): The feature reduction size of the Squeeze-and-excitation module.
rnn_type (str): The RNN type it has to be one of rnn, gru or lstm.
"""
def __init__(
self,
in_features: int,
n_classes: int,
emb_dim: int,
n_layers: int,
n_dec_layers: int,
n_sub_layers: Union[int, List[int]],
stride: Union[int, List[int]],
out_channels: Union[int, List[int]],
kernel_size: int,
reduction_factor: int,
rnn_type: str,
) -> None:
super().__init__(
out_channels[-1] if isinstance(out_channels, list) else out_channels,
n_classes,
)
self.encoder = ContextNetEncoder(
in_features=in_features,
n_layers=n_layers,
n_sub_layers=n_sub_layers,
stride=stride,
out_channels=out_channels,
kernel_size=kernel_size,
reduction_factor=reduction_factor,
)
self.decoder = TransducerRNNDecoder(
vocab_size=n_classes,
emb_dim=emb_dim,
hidden_size=out_channels[-1]
if isinstance(out_channels, list)
else out_channels,
rnn_type=rnn_type,
n_layers=n_dec_layers,
)
[docs]class VGGTransformerTransducer(RNNTransducer):
"""Implements the Transformer-Transducer model as described in
https://arxiv.org/abs/1910.12977
Args:
in_features (int): The input feature size.
n_classes (int): The number of classes/vocabulary.
emb_dim (int): The embedding layer's size.
n_layers (int): The number of transformer encoder layers with truncated
self attention.
n_dec_layers (int): The number of RNNs in the decoder (predictor).
rnn_type (str): The RNN type.
n_vgg_blocks (int): The number of VGG blocks to use.
n_conv_layers_per_vgg_block (List[int]): A list of integers that specifies the number
of convolution layers in each block.
kernel_sizes_per_vgg_block (List[List[int]]): A list of lists that contains the
kernel size for each layer in each block. The length of the outer list
should match `n_vgg_blocks`, and each inner list should be the same length
as the corresponding block's number of layers.
n_channels_per_vgg_block (List[List[int]]): A list of lists that contains the
number of channels for each convolution layer in each block. This argument
should also have length equal to `n_vgg_blocks`, and each sublist should
have length equal to the number of layers in the corresponding block.
vgg_pooling_kernel_size (List[int]): A list of integers that specifies the size
of the max pooling layer in each block. The length of this list should be
equal to `n_vgg_blocks`.
d_model (int): The model dimensionality.
ff_size (int): The feed forward inner layer dimensionality.
h (int): The number of heads in the attention mechanism.
joint_size (int): The joint layer feature size (denoted as do in the paper).
left_size (int): The size of the left window that each time step is
allowed to look at.
right_size (int): The size of the right window that each time step is
allowed to look at.
p_dropout (float): The dropout rate.
masking_value (float, optional): The value to use for masking padded
elements. Defaults to -1e15.
"""
def __init__(
self,
in_features: int,
n_classes: int,
emb_dim: int,
n_layers: int,
n_dec_layers: int,
rnn_type: str,
n_vgg_blocks: int,
n_conv_layers_per_vgg_block: List[int],
kernel_sizes_per_vgg_block: List[List[int]],
n_channels_per_vgg_block: List[List[int]],
vgg_pooling_kernel_size: List[int],
d_model: int,
ff_size: int,
h: int,
joint_size: int,
left_size: int,
right_size: int,
p_dropout: float,
masking_value: int = -1e15,
) -> None:
super().__init__(
in_features=in_features,
n_classes=n_classes,
emb_dim=emb_dim,
n_layers=1,
n_dec_layers=n_dec_layers,
hidden_size=d_model,
bidirectional=False,
rnn_type=rnn_type,
p_dropout=p_dropout,
)
self.encoder = VGGTransformerEncoder(
in_features=in_features,
n_layers=n_layers,
n_vgg_blocks=n_vgg_blocks,
n_conv_layers_per_vgg_block=n_conv_layers_per_vgg_block,
kernel_sizes_per_vgg_block=kernel_sizes_per_vgg_block,
n_channels_per_vgg_block=n_channels_per_vgg_block,
vgg_pooling_kernel_size=vgg_pooling_kernel_size,
d_model=d_model,
ff_size=ff_size,
h=h,
left_size=left_size,
right_size=right_size,
masking_value=masking_value,
)
self.join_net = nn.Sequential(
nn.Linear(in_features=d_model, out_features=joint_size),
nn.ReLU(),
nn.Linear(in_features=joint_size, out_features=n_classes),
)
[docs]class TransformerTransducer(nn.Module):
"""Implements the Transformer-Transducer model as described in
https://arxiv.org/abs/2002.02562
Args:
in_features (int): The input feature size.
n_classes (int): The number of classes/vocabulary.
n_layers (int): The number of transformer encoder layers with truncated
self attention.
n_dec_layers (int): The number of layers in the decoder (predictor).
d_model (int): The model dimensionality.
ff_size (int): The feed forward inner layer dimensionality.
h (int): The number of heads in the attention mechanism.
joint_size (int): The joint layer feature size.
enc_left_size (int): The size of the left window that each time step is
allowed to look at in the encoder.
enc_right_size (int): The size of the right window that each time step is
allowed to look at in the encoder.
dec_left_size (int): The size of the left window that each time step is
allowed to look at in the decoder.
dec_right_size (int): The size of the right window that each time step is
allowed to look at in the decoder.
p_dropout (float): The dropout rate.
stride (int): The stride of the convolution layer in the prenet. Default 1.
kernel_size (int): The kernel size of the convolution layer in the prenet. Default 1.
masking_value (float, optional): The value to use for masking padded
elements. Defaults to -1e15.
"""
def __init__(
self,
in_features: int,
n_classes: int,
n_layers: int,
n_dec_layers: int,
d_model: int,
ff_size: int,
h: int,
joint_size: int,
enc_left_size: int,
enc_right_size: int,
dec_left_size: int,
dec_right_size: int,
p_dropout: float,
stride: int = 1,
kernel_size: int = 1,
masking_value: int = -1e15,
) -> None:
super().__init__()
self.encoder = TransformerTransducerEncoder(
in_features=in_features,
n_layers=n_layers,
d_model=d_model,
ff_size=ff_size,
h=h,
left_size=enc_left_size,
right_size=enc_right_size,
p_dropout=p_dropout,
stride=stride,
kernel_size=kernel_size,
masking_value=masking_value,
)
self.decoder = TransformerTransducerDecoder(
vocab_size=n_classes,
n_layers=n_dec_layers,
d_model=d_model,
ff_size=ff_size,
h=h,
left_size=dec_left_size,
right_size=dec_right_size,
p_dropout=p_dropout,
masking_value=masking_value,
)
self.audio_fc = nn.Linear(in_features=d_model, out_features=joint_size)
self.text_fc = nn.Linear(in_features=d_model, out_features=joint_size)
self.tanh = nn.Tanh()
self.join_net = nn.Linear(in_features=joint_size, out_features=n_classes)
def _join(self, encoder_out: Tensor, deocder_out: Tensor, training=True) -> Tensor:
if training:
encoder_out = encoder_out.unsqueeze(-2)
deocder_out = deocder_out.unsqueeze(1)
result = encoder_out + deocder_out
result = self.tanh(result)
result = self.join_net(result)
return result
[docs] def forward(
self,
speech: Tensor,
speech_mask: Tensor,
text: Tensor,
text_mask: Tensor,
*args,
**kwargs
) -> Tuple[Tensor, Tensor, Tensor]:
"""Passes the input to the model
Args:
speech (Tensor): The input speech of shape [B, M, d]
speech_mask (Tensor): The speech mask of shape [B, M]
text (Tensor): The text input of shape [B, N]
text_mask (Tensor): The text mask of shape [B, N]
Returns:
Tuple[Tensor, Tensor, Tensor]: A tuple of 3 tensors where the first
is the predictions of shape [B, M, N, C], the last two tensor are
the speech and text length of shape [B]
"""
speech, speech_len = self.encoder(speech, speech_mask)
text, text_len = self.decoder(text, text_mask)
speech = self.audio_fc(speech)
text = self.text_fc(text)
result = self._join(encoder_out=speech, deocder_out=text)
speech_len, text_len = (
speech_len.to(speech.device),
text_len.to(speech.device),
)
return result, speech_len, text_len