"""This module provides various scheduler classes for adjusting learning rates during
training, including the base `Scheduler` class and its implementation `NoamScheduler`.
The `SqueezeformerNoamScheduler` is a modified version of the `NoamScheduler` specific
to the Squeezeformer model.
Classes:
- Scheduler: Implements the base scheduler class.
- NoamScheduler: Implements the Noam scheduler.
- SqueezeformerNoamScheduler: Implements the Noam scheduler with modifications for the
Squeezeformer model.
"""
from math import sqrt
from numbers import Number
from typing import Iterable
from speeq.constants import OPTIMIZER_STATE_KEY
from speeq.interfaces import IScheduler
[docs]class Scheduler(IScheduler):
"""Implements the base scheduler class.
Args:
params (Iterable): The mdoel's parameters.
optimizer (str): The name of the optimizer ot be used.
optimizer_args (dict): The optimizer's arguments.
"""
def __init__(self, params: Iterable, optimizer: str, optimizer_args: dict) -> None:
super().__init__()
from .registry import OPTIMIZERS
self.optimizer = OPTIMIZERS[optimizer](params, **optimizer_args)
[docs] def state_dict(self):
return self.optimizer.state_dict()
[docs] def zero_grad(self) -> None:
self.optimizer.zero_grad()
def _update_lr(self) -> None:
self.counter += 1
lr = self.get_lr()
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
[docs] def step(self) -> None:
self.optimizer.step()
self._update_lr()
[docs] def load_state_dict(self, state_dict: dict) -> None:
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_KEY])
state_dict.pop(OPTIMIZER_STATE_KEY)
self.__dict__.update(state_dict)
[docs]class NoamScheduler(Scheduler):
"""Implements the noam scheduler proposed in
https://arxiv.org/abs/1706.03762
Args:
params (Iterable): The mdoel's parameters.
optimizer (str): The name of the optimizer.
optimizer_args (dict): The optimizer's arguments.
warmup_staps (int): The warmup steps.
d_model (int): The model dimension.
"""
def __init__(
self,
params,
optimizer: str,
optimizer_args: dict,
warmup_staps: int,
d_model: int,
*args,
**kwargs
) -> None:
super().__init__(
params=params, optimizer=optimizer, optimizer_args=optimizer_args
)
self.peak = 1 / sqrt(d_model)
self.counter = 0
self.warmup_staps = warmup_staps
self._update_lr()
[docs] def get_lr(self) -> float:
return self.peak * min(
1 / sqrt(self.counter), self.counter * pow(self.warmup_staps, -1.5)
)
[docs] def state_dict(self) -> dict:
return {
"peak": self.peak,
"warmup_staps": self.warmup_staps,
"counter": self.counter,
OPTIMIZER_STATE_KEY: self.optimizer.state_dict(),
}