Source code for speeq.trainers.criterions

"""
Contains different loss functions used in various speech recognition models.

- CTCLoss: Connectionist Temporal Classification loss function.
- CrossEntropyLoss: Cross-entropy loss function.
- NLLLoss: Negative log-likelihood loss function.
- RNNTLoss: Recurrent Neural Network Transducer loss function.

"""
from typing import Tuple

from torch import Tensor, nn
from torchaudio import transforms


[docs]class CTCLoss(nn.CTCLoss): """The CTC loss. Args: blank_id (int): The blank id. reduction (str, optional): Specifies the reduction to apply to the output. Default to "mean". zero_infinity (bool, optional): Whether to zero infinite losses and the associated gradients. Default: False Infinite losses mainly occur when the inputs are too short to be aligned to the targets. """ def __init__( self, blank_id: int, reduction="mean", zero_infinity=False, *args, **kwargs ): super().__init__( blank=blank_id, reduction=reduction, zero_infinity=zero_infinity )
[docs]def remove_positionals(input: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Removes the SOS from the target and EOS prediction from the input Args: input (Tensor): The input tensor of shape [B, M, C]. target (Tensor): The target tensor of shape [B, C] Returns: Tuple[Tensor, Tensor]: The input and target. """ input = input[:, :-1, :] input = input.contiguous() target = target[:, 1:] target = target.contiguous() return input, target
[docs]def get_flatten_results(input: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Flatten the results by making the input that of shape [B, M, C] to be of shape [B * M, C] and the target of shape [B, M] to be of shape [B * M] Args: input (Tensor): The predictions of shape [B, M, C]. target (Tensor): The target tensor of shape [B, M] Returns: Tuple[Tensor, Tensor]: Atuple of the flatten results. """ target = target.view(-1) input = input.view(-1, input.shape[-1]) return input, target
[docs]class CrossEntropyLoss(nn.CrossEntropyLoss): """computes the cross entropy loss between input logits and target. Args: pad_id (int): The padding id. reduction (str, optional): Specifies the reduction to apply to the output. Default to "mean". label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount of smoothing when computing the loss. Default 0.0. """ def __init__( self, pad_id: int, reduction="mean", label_smoothing=0.0, *args, **kwargs ) -> None: super().__init__( ignore_index=pad_id, reduction=reduction, label_smoothing=label_smoothing )
[docs] def forward(self, input, target, *args, **kwargs): # input of shape [B, M, C] # target of shape [B, M] input, target = remove_positionals(input, target) input, target = get_flatten_results(input, target) return super().forward(input, target)
[docs]class NLLLoss(nn.NLLLoss): """computes the negative log likelihood loss. Args: pad_id (int): The padding id. reduction (str, optional): Specifies the reduction to apply to the output. Default to "mean". """ def __init__(self, pad_id: int, reduction="mean", *args, **kwargs) -> None: super().__init__(ignore_index=pad_id, reduction=reduction)
[docs] def forward(self, input, target, *args, **kwargs): # input of shape [B, M, C] # target of shape [B, M] input, target = remove_positionals(input, target) input, target = get_flatten_results(input, target) return super().forward(input, target)
[docs]class RNNTLoss(transforms.RNNTLoss): """computes the RNNT loss. Args: blank_id (int): The blank id. reduction (str, optional): Specifies the reduction to apply to the output. Default to "mean". """ def __init__(self, blank_id: int, reduction="mean", *args, **kwargs) -> None: super().__init__(blank=blank_id, reduction=reduction)
[docs] def forward( self, logits: Tensor, logits_len: Tensor, targets: Tensor, target_len: Tensor ) -> Tensor: # logits of shape [B, Ts, Tt, C] # target of shape [B, Tt] and start with SOS targets = targets[:, 1:] targets = targets.contiguous() target_len = target_len - 1 return super().forward( logits=logits, logit_lengths=logits_len, targets=targets, target_lengths=target_len, )