Source code for speeq.data.padders

"""The `padders` module provides two classes for padding input sequences:
`DynamicPadder` and `StaticPadder`.

`DynamicPadder` pads an input sequence along a specified dimension to match
the maximum sequence length, while `StaticPadder` is a subclass of
`DynamicPadder` that also allows the user to specify the maximum length
of the sequence to pad to.

Both classes have a `pad` method that accepts an input tensor and the maximum
length to pad to, and returns the padded tensor and the length of the padding added.


Usage:


.. code-block:: python

    import torch
    from speeq.data.padders import DynamicPadder, StaticPadder

    # create a dummy input
    input_tensor = torch.randn(1, 3, 7)

    # Example usage of DynamicPadder
    dynamic_padder = DynamicPadder(dim=1, pad_val=0)
    padded_tensor, padding_length = dynamic_padder.pad(input_tensor, max_len=10)

    # Example usage of StaticPadder
    static_padder = StaticPadder(dim=1, pad_val=0, max_len=10)
    padded_tensor, padding_length = static_padder.pad(input_tensor)

"""
from typing import Tuple, Union

import torch
from torch import Tensor

from speeq.interfaces import IPadder


[docs]class DynamicPadder(IPadder): """Pads the input sequence across a dim for the maximum length Args: dim (int): The dimension to do the padding across. pad_val (Union[int, Tensor, float]): The padding value that will be used to fill the padding sequence. left_pad (int): The side to pad the padding sequence to. """ def __init__( self, dim: int, pad_val: Union[int, Tensor, float], left_pad=False, *args, **kwargs ) -> None: super().__init__() self.dim = dim self.left_pad = left_pad self.pad_val = pad_val
[docs] def pad(self, x: Tensor, max_len: int) -> Tuple[Tensor, int]: """Pads the input tensor to match the specified maximum length along the pre-defined dimension. Args: x (Tensor): The input tensor to be padded. max_len (int): The maximum length to pad the input tensor to. Returns: Tuple[Tensor, int]: A tuple containing the padded tensor and the length of the padding added. """ seq_len = x.shape[self.dim] pad_len = max_len - seq_len assert pad_len >= 0 if pad_len == 0: return x, pad_len pad = torch.zeros( *x.shape[: self.dim], pad_len, *x.shape[1 + self.dim :], dtype=x.dtype ).to(x.device) pad = pad + self.pad_val if self.left_pad: x = torch.cat([pad, x], dim=self.dim) else: x = torch.cat([x, pad], dim=self.dim) return x, pad_len
[docs]class StaticPadder(DynamicPadder): """A subclass of `DynamicPadder` that pads an input sequence to match a pre-defined maximum length along a specified dimension. Args: dim (int): The dimension to pad across. pad_val (Union[int, Tensor, float]): The value used to fill the padded sequence. max_len (int): The maximum length of the sequence to pad to. left_pad (int): The side to which the sequence will be padded. """ def __init__( self, dim: int, pad_val: Union[int, Tensor, float], max_len: int, left_pad=False, *args, **kwargs ) -> None: super().__init__(dim, pad_val, left_pad) self.max_len = max_len
[docs] def pad(self, x: Tensor, *args, **kwargs): """Pads the input tensor to match the specified maximum length along the pre-defined dimension. Args: x (Tensor): The input tensor to be padded. Returns: Tuple[Tensor, int]: A tuple containing the padded tensor and the length of the padding added. """ return super().pad(x, self.max_len)