Source code for speeq.trainers.trainers

"""This module contains different trainer classes, some of which utilize
distributed data parallelism (DDP), as well as a launch_training_job function.

Trainers:

- BaseTrainer: A basic trainer module.
- BaseDistTrainer: A basic distributed data parallel trainer module that is a subclass of BaseTrainer.
- CTCTrainer: A trainer module for CTC-based models that is a subclass of BaseTrainer.
- DistCTCTrainer: A trainer module for CTC models that utilizes distributed data parallelism, which is a subclass of both BaseDistTrainer and CTCTrainer.
- Seq2SeqTrainer: A trainer module for Seq2Seq models that is a subclass of BaseTrainer.
- DistSeq2SeqTrainer: A trainer module for Seq2Seq models that utilizes distributed data parallelism, which is a subclass of both BaseDistTrainer and Seq2SeqTrainer.
- TransducerTrainer: A trainer module for transducer-based models that is a subclass of BaseTrainer.
- DistTransducerTrainer: A trainer module for transducer models that utilizes distributed data parallelism, which is a subclass of both BaseDistTrainer and TransducerTrainer.


Function:

- launch_training_job: A function that launches a training job for a given configuration of trainer, data, and model objects. It takes in three arguments: trainer_config which is an object containing the configuration for the trainer, data_config which is an object containing the configuration for the data, and model_config which is an object containing the configuration for the model. The function returns None.
"""
import os
import time
from functools import partial
from math import inf
from pathlib import Path
from typing import Tuple, Union

import torch
from torch import Tensor
from torch.distributed import ReduceOp, all_reduce, barrier, init_process_group
from torch.multiprocessing import spawn
from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from tqdm import tqdm

from speeq.config import ASRDataConfig, ModelConfig, TrainerConfig
from speeq.constants import HistoryKeys, LogCategories
from speeq.interfaces import IDataLoader, IScheduler, ITrainer
from speeq.utils.loggers import ILogger
from speeq.utils.utils import get_key_tag, has_bnorm

from .decorators import export_ckpt, step_log


[docs]class BaseTrainer(ITrainer): """Builds the basic trainer module Args: optimizer (Union[Optimizer, IScheduler]): The optimizer or the wrapped optimizer that will be used during the training. criterion (Module): The loss fucntion that will be used during the training process. model (Module): The model. train_loader (ILoader): The loader for the training data. test_loader (ILoader): The loader for the testing data. epochs (int): The number of epochs. log_steps_frequency (int): The frequency at which to log results. logger (ILogger): The logger to be used. outdir (Union[str, Path]): The directory to save checkpoints. grad_acc_steps (int): The number of steps to accumulate gradients over. Default 1. grad_clip_thresh (Union[None, float]): The maximum norm of the gradients. Default None. grad_clip_norm_type (float): The type of p-norm used. Default 2.0. history (dict): The training history, if available. Default {}. """ def __init__( self, optimizer: Union[Optimizer, IScheduler], criterion: Module, model: Module, train_loader: IDataLoader, test_loader: IDataLoader, epochs: int, log_steps_frequency: int, logger: ILogger, outdir: Union[str, Path], grad_acc_steps: int = 1, grad_clip_thresh: Union[None, float] = None, grad_clip_norm_type: float = 2.0, history: dict = {}, ) -> None: super().__init__() self.optimizer = optimizer self.criterion = criterion self.model = model self.train_loader = train_loader self.test_loader = test_loader self.epochs = epochs self.log_steps_frequency = log_steps_frequency self.logger = logger self.outdir = outdir self.grad_clip_thresh = grad_clip_thresh self.grad_clip_norm_type = grad_clip_norm_type self.history = history self.counter = 1 self.grad_acc_steps = grad_acc_steps if HistoryKeys.min_loss.value not in self.history: self.history[HistoryKeys.min_loss.value] = inf
[docs] def backward_pass(self, loss: Tensor) -> None: """This method performs a backward pass on the model parameters to update them based on the provided loss tensor. Args: loss (Tensor): The loss tensor. """ loss = loss / self.grad_acc_steps loss.backward() if self.grad_clip_thresh is not None: torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=self.grad_clip_thresh, norm_type=self.grad_clip_norm_type, ) if self.counter % self.grad_acc_steps == 0: self.optimizer.step() self.optimizer.zero_grad()
[docs] def fit(self): """Fits the model on the training data.""" for _ in range(self.epochs): self.train() self.logger.log(self.history)
[docs] def inline_log(self, key: str, category: str, value: int): tag = get_key_tag(key=key, category=category) if tag in self.history: self.history[tag].append(value) else: self.history[tag] = [value] self.logger.log_step(key, category, value)
[docs] @step_log(key=HistoryKeys.train_loss.value, category=LogCategories.batches.value) def train_step(self, batch: Tuple[Tensor]) -> float: """This method represents a single step in the training process. It performs a forward pass, calculates the loss, and then performs a backward pass to update the model parameters. Args: batch (Tuple[Tensor]): The input batch to be processed. Returns: float: The loss value for this step. """ loss = self.forward_pass(batch) self.backward_pass(loss) return loss.item()
[docs] @step_log(key=HistoryKeys.train_loss.value, category=LogCategories.epochs.value) def train(self) -> float: """The main training loop, where the function iterate over the training examples and perform forward and backward pass. Returns: float: The average loss over all training examples. """ self.model.train() total_loss = 0.0 for i, batch in enumerate(tqdm(self.train_loader)): loss = self.train_step(batch) total_loss += loss if self.counter % self.log_steps_frequency == 0: self.test() self.model.train() self.inline_log( key=HistoryKeys.train_loss.value, category=LogCategories.steps.value, value=total_loss / (i + 1), ) self.counter += 1 return total_loss / len(self.train_loader)
[docs] @export_ckpt(key=HistoryKeys.test_loss.value, category=LogCategories.steps.value) @step_log(key=HistoryKeys.test_loss.value, category=LogCategories.steps.value) @torch.no_grad() def test(self) -> float: """Performing a model test on the testing data Returns: float: The average test loss. """ self.model.eval() total_loss = 0.0 for batch in self.test_loader: loss = self.forward_pass(batch) total_loss += loss.item() total_loss /= len(self.test_loader) return total_loss
@property def is_master(self): return True
[docs]class BaseDistTrainer(BaseTrainer): """Builds the basic distributed data parallel trainer module Args: optimizer (Union[Optimizer, IScheduler]): The optimizer or the wrapped optimizer that will be used during the training. criterion (Module): The loss fucntion that will be used during the training process. model (Module): The model. train_loader (ILoader): The loader for the training data. test_loader (ILoader): The loader for the testing data. epochs (int): The number of epochs. log_steps_frequency (int): The frequency at which to log results. logger (ILogger): The logger to be used. outdir (Union[str, Path]): The directory to save checkpoints. rank (int): The process index. world_size (int): The number of nodes/processes. dist_address (str): The address of the master node. dist_port (int): The port of the master node. dist_backend (str): The backend used for DDP. grad_acc_steps (int): The number of steps to accumulate gradients over. Default 1. grad_clip_thresh (Union[None, float]): The maximum norm of the gradients. Default None. grad_clip_norm_type (float): The type of p-norm used. Default 2.0. history (dict): The training history, if available. Default {}. """ def __init__( self, optimizer: Union[Optimizer, IScheduler], criterion: Module, model: Module, train_loader: IDataLoader, test_loader: IDataLoader, epochs: int, log_steps_frequency: int, logger: ILogger, outdir: Union[str, Path], rank: int, world_size: int, dist_address: str, dist_port: int, dist_backend: str, grad_acc_steps: int = 1, grad_clip_thresh: Union[None, float] = None, grad_clip_norm_type: float = 2.0, history={}, ) -> None: BaseTrainer.__init__( self, optimizer=optimizer, criterion=criterion, model=model, train_loader=train_loader, test_loader=test_loader, epochs=epochs, log_steps_frequency=log_steps_frequency, logger=logger, outdir=outdir, grad_acc_steps=grad_acc_steps, grad_clip_thresh=grad_clip_thresh, grad_clip_norm_type=grad_clip_norm_type, history=history, ) self.rank = rank self.world_size = world_size self.dist_port = dist_port self.dist_address = dist_address self.dist_backend = dist_backend self.init_dist() self.has_bnorm = has_bnorm(self.model) self.model.to(f"cuda:{rank}") self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) self.model = DistributedDataParallel(self.model, device_ids=[self.rank])
[docs] def init_dist(self): """initialize the distributed training process""" os.environ["MASTER_ADDR"] = self.dist_address os.environ["MASTER_PORT"] = str(self.dist_port) init_process_group( backend=self.dist_backend, init_method=self.dist_address, world_size=self.world_size, rank=self.rank, )
@property def is_master(self): return self.rank == 0 def _all_reduce_loss(self, total_loss: float, counter: int) -> Tensor: total = torch.tensor([total_loss / counter]).cuda(self.rank) all_reduce(total, op=ReduceOp.SUM) return total / self.world_size
[docs] def backward_pass(self, loss: Tensor) -> None: """This method performs a backward pass on the model parameters to update them based on the provided loss tensor. Args: loss (Tensor): The loss tensor. """ loss = loss / self.grad_acc_steps loss.backward() if self.grad_clip_thresh is not None: torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm=self.grad_clip_thresh, norm_type=self.grad_clip_norm_type, ) if self.counter % self.grad_acc_steps == 0: for param in self.model.parameters(): param.grad.data /= self.world_size self.optimizer.step() self.optimizer.zero_grad()
[docs] @step_log(key=HistoryKeys.train_loss.value, category=LogCategories.epochs.value) def train(self) -> float: """The main training loop that run on one of the processes, where the function iterate over the training examples and perform forward and backward pass. Returns: float: The average loss over all training examples from all processes. """ self.model.train() total_loss = 0.0 for i, batch in enumerate(tqdm(self.train_loader)): loss = self.train_step(batch) total_loss += loss if self.counter % self.log_steps_frequency == 0: total = self._all_reduce_loss(total_loss, i + 1) if self.is_master or self.has_bnorm is True: """The extra condition to solve a dummy issue caused when we have DDP with batch norm, it works only if the evaluation is done on all nodes!, the link below is similar issue discuss.pytorch.org/t/validation-hangs-up-when-using-ddp-and-syncbatchnorm/104831 """ self.inline_log( key=HistoryKeys.train_loss.value, category=LogCategories.steps.value, value=total.item(), ) self.test() self.model.train() if self.has_bnorm is False: barrier() self.counter += 1 return self._all_reduce_loss(total_loss, len(self.train_loader)).item()
[docs] def fit(self): """Fits the model on the training data, and logs the results on the master node only. """ for _ in range(self.epochs): self.train() if self.is_master or self.has_bnorm is True: """The extra condition to solve a dummy issue caused when we have DDP with batch norm, it works only if the evaluation is done on all nodes! the link below is similar issue discuss.pytorch.org/t/validation-hangs-up-when-using-ddp-and-syncbatchnorm/104831 """ self.logger.log(self.history) barrier()
[docs]class CTCTrainer(BaseTrainer): """A trainer module for CTC-based models. Args: optimizer (Union[Optimizer, IScheduler]): The optimizer or the wrapped optimizer that will be used during the training. criterion (Module): The loss fucntion that will be used during the training process. model (Module): The model. train_loader (ILoader): The loader for the training data. test_loader (ILoader): The loader for the testing data. epochs (int): The number of epochs. log_steps_frequency (int): The frequency at which to log results. logger (ILogger): The logger to be used. outdir (Union[str, Path]): The directory to save checkpoints. grad_acc_steps (int): The number of steps to accumulate gradients over. Default 1. grad_clip_thresh (Union[None, float]): The maximum norm of the gradients. Default None. grad_clip_norm_type (float): The type of p-norm used. Default 2.0. history (dict): The training history, if available. Default {}. """ def __init__( self, optimizer: Union[Optimizer, IScheduler], criterion: Module, model: Module, train_loader: IDataLoader, test_loader: IDataLoader, epochs: int, log_steps_frequency: int, device: str, logger: ILogger, outdir: Union[str, Path], grad_acc_steps: int = 1, grad_clip_thresh: Union[None, float] = None, grad_clip_norm_type: float = 2.0, history: dict = {}, ) -> None: BaseTrainer.__init__( self, optimizer=optimizer, criterion=criterion, model=model, train_loader=train_loader, test_loader=test_loader, epochs=epochs, log_steps_frequency=log_steps_frequency, logger=logger, outdir=outdir, grad_acc_steps=grad_acc_steps, grad_clip_thresh=grad_clip_thresh, grad_clip_norm_type=grad_clip_norm_type, history=history, ) self.device = device self.model.to(device)
[docs] def forward_pass(self, batch: Tuple[Tensor]) -> Tensor: batch = [item.to(self.device) for item in batch] [speech, speech_mask, text, text_mask] = batch preds, lengths = self.model(speech, speech_mask) # preds of shape [T, B, C] loss = self.criterion(preds, text, lengths, text_mask.sum(dim=-1)) return loss
[docs]class DistCTCTrainer(BaseDistTrainer, CTCTrainer): """A trainer module for CTC models that utilizes distributed data parallelism. Args: optimizer (Union[Optimizer, IScheduler]): The optimizer or the wrapped optimizer that will be used during the training. criterion (Module): The loss fucntion that will be used during the training process. model (Module): The model. train_loader (ILoader): The loader for the training data. test_loader (ILoader): The loader for the testing data. epochs (int): The number of epochs. log_steps_frequency (int): The frequency at which to log results. logger (ILogger): The logger to be used. outdir (Union[str, Path]): The directory to save checkpoints. rank (int): The process index. world_size (int): The number of nodes/processes. dist_address (str): The address of the master node. dist_port (int): The port of the master node. dist_backend (str): The backend used for DDP. grad_acc_steps (int): The number of steps to accumulate gradients over. Default 1. grad_clip_thresh (Union[None, float]): The maximum norm of the gradients. Default None. grad_clip_norm_type (float): The type of p-norm used. Default 2.0. history (dict): The training history, if available. Default {}. """ def __init__( self, optimizer: Union[Optimizer, IScheduler], criterion: Module, model: Module, train_loader: IDataLoader, test_loader: IDataLoader, epochs: int, logger: ILogger, outdir: Union[str, Path], log_steps_frequency: int, rank: int, world_size: int, dist_address: int, dist_port: int, dist_backend: str, grad_acc_steps: int = 1, grad_clip_thresh: Union[None, float] = None, grad_clip_norm_type: float = 2.0, history: dict = {}, ) -> None: CTCTrainer.__init__( self, optimizer=optimizer, criterion=criterion, model=model, train_loader=train_loader, test_loader=test_loader, epochs=epochs, log_steps_frequency=log_steps_frequency, device=f"cuda:{rank}", logger=logger, outdir=outdir, grad_acc_steps=grad_acc_steps, grad_clip_thresh=grad_clip_thresh, grad_clip_norm_type=grad_clip_norm_type, history=history, ) BaseDistTrainer.__init__( self, optimizer=optimizer, criterion=criterion, model=model, train_loader=train_loader, test_loader=test_loader, epochs=epochs, log_steps_frequency=log_steps_frequency, logger=logger, outdir=outdir, rank=rank, world_size=world_size, dist_address=dist_address, dist_port=dist_port, dist_backend=dist_backend, grad_acc_steps=grad_acc_steps, grad_clip_thresh=grad_clip_thresh, grad_clip_norm_type=grad_clip_norm_type, history=history, )
[docs]class Seq2SeqTrainer(BaseTrainer): """A trainer module for Seq2Seq models. Args: optimizer (Union[Optimizer, IScheduler]): The optimizer or the wrapped optimizer that will be used during the training. criterion (Module): The loss fucntion that will be used during the training process. model (Module): The model. train_loader (ILoader): The loader for the training data. test_loader (ILoader): The loader for the testing data. epochs (int): The number of epochs. log_steps_frequency (int): The frequency at which to log results. logger (ILogger): The logger to be used. outdir (Union[str, Path]): The directory to save checkpoints. grad_acc_steps (int): The number of steps to accumulate gradients over. Default 1. grad_clip_thresh (Union[None, float]): The maximum norm of the gradients. Default None. grad_clip_norm_type (float): The type of p-norm used. Default 2.0. history (dict): The training history, if available. Default {}. """ def __init__( self, optimizer: Union[Optimizer, IScheduler], criterion: Module, model: Module, train_loader: IDataLoader, test_loader: IDataLoader, epochs: int, log_steps_frequency: int, device: str, logger: ILogger, outdir: Union[str, Path], grad_acc_steps: int = 1, grad_clip_thresh: Union[None, float] = None, grad_clip_norm_type: float = 2.0, history: dict = {}, ) -> None: BaseTrainer.__init__( self, optimizer=optimizer, criterion=criterion, model=model, train_loader=train_loader, test_loader=test_loader, epochs=epochs, log_steps_frequency=log_steps_frequency, logger=logger, outdir=outdir, grad_acc_steps=grad_acc_steps, grad_clip_thresh=grad_clip_thresh, grad_clip_norm_type=grad_clip_norm_type, history=history, ) self.device = device self.model.to(device)
[docs] def forward_pass(self, batch: Tuple[Tensor]) -> Tensor: batch = [item.to(self.device) for item in batch] [speech, speech_mask, text, text_mask] = batch preds = self.model(speech, speech_mask, text, text_mask) loss = self.criterion(preds, text, text_mask) return loss
[docs]class DistSeq2SeqTrainer(BaseDistTrainer, Seq2SeqTrainer): """A trainer module for Seq2Seq models that utilizes distributed data parallelism. Args: optimizer (Union[Optimizer, IScheduler]): The optimizer or the wrapped optimizer that will be used during the training. criterion (Module): The loss fucntion that will be used during the training process. model (Module): The model. train_loader (ILoader): The loader for the training data. test_loader (ILoader): The loader for the testing data. epochs (int): The number of epochs. log_steps_frequency (int): The frequency at which to log results. logger (ILogger): The logger to be used. outdir (Union[str, Path]): The directory to save checkpoints. rank (int): The process index. world_size (int): The number of nodes/processes. dist_address (str): The address of the master node. dist_port (int): The port of the master node. dist_backend (str): The backend used for DDP. grad_acc_steps (int): The number of steps to accumulate gradients over. Default 1. grad_clip_thresh (Union[None, float]): The maximum norm of the gradients. Default None. grad_clip_norm_type (float): The type of p-norm used. Default 2.0. history (dict): The training history, if available. Default {}. """ def __init__( self, optimizer: Union[Optimizer, IScheduler], criterion: Module, model: Module, train_loader: IDataLoader, test_loader: IDataLoader, epochs: int, logger: ILogger, outdir: Union[str, Path], log_steps_frequency: int, rank: int, world_size: int, dist_address: int, dist_port: int, dist_backend: str, grad_acc_steps: int = 1, grad_clip_thresh: Union[None, float] = None, grad_clip_norm_type: float = 2.0, history: dict = {}, ) -> None: Seq2SeqTrainer.__init__( self, optimizer=optimizer, criterion=criterion, model=model, train_loader=train_loader, test_loader=test_loader, epochs=epochs, log_steps_frequency=log_steps_frequency, device=f"cuda:{rank}", logger=logger, outdir=outdir, grad_acc_steps=grad_acc_steps, grad_clip_thresh=grad_clip_thresh, grad_clip_norm_type=grad_clip_norm_type, history=history, ) BaseDistTrainer.__init__( self, optimizer=optimizer, criterion=criterion, model=model, train_loader=train_loader, test_loader=test_loader, epochs=epochs, log_steps_frequency=log_steps_frequency, logger=logger, outdir=outdir, rank=rank, world_size=world_size, dist_address=dist_address, dist_port=dist_port, dist_backend=dist_backend, grad_acc_steps=grad_acc_steps, grad_clip_thresh=grad_clip_thresh, grad_clip_norm_type=grad_clip_norm_type, history=history, )
[docs]class TransducerTrainer(BaseTrainer): """A trainer module for transducer-based models. Args: optimizer (Union[Optimizer, IScheduler]): The optimizer or the wrapped optimizer that will be used during the training. criterion (Module): The loss fucntion that will be used during the training process. model (Module): The model. train_loader (ILoader): The loader for the training data. test_loader (ILoader): The loader for the testing data. epochs (int): The number of epochs. log_steps_frequency (int): The frequency at which to log results. logger (ILogger): The logger to be used. outdir (Union[str, Path]): The directory to save checkpoints. grad_acc_steps (int): The number of steps to accumulate gradients over. Default 1. grad_clip_thresh (Union[None, float]): The maximum norm of the gradients. Default None. grad_clip_norm_type (float): The type of p-norm used. Default 2.0. history (dict): The training history, if available. Default {}. """ def __init__( self, optimizer: Union[Optimizer, IScheduler], criterion: Module, model: Module, train_loader: IDataLoader, test_loader: IDataLoader, epochs: int, log_steps_frequency: int, device: str, logger: ILogger, outdir: Union[str, Path], grad_acc_steps: int = 1, grad_clip_thresh: Union[None, float] = None, grad_clip_norm_type: float = 2.0, history: dict = {}, ) -> None: BaseTrainer.__init__( self, optimizer=optimizer, criterion=criterion, model=model, train_loader=train_loader, test_loader=test_loader, epochs=epochs, log_steps_frequency=log_steps_frequency, logger=logger, outdir=outdir, grad_acc_steps=grad_acc_steps, grad_clip_thresh=grad_clip_thresh, grad_clip_norm_type=grad_clip_norm_type, history=history, ) self.device = device self.model.to(device)
[docs] def forward_pass(self, batch: Tuple[Tensor]) -> Tensor: """This method conducts a forward pass on the CTC model. Args: batch (Tuple[Tensor]): The input batch containing the speech, speech length, text, and text length tensors, in that order. Returns: Tensor: A tensor representing the loss. """ batch = [item.to(self.device) for item in batch] [speech, speech_mask, text, text_mask] = batch preds, speech_len, text_len = self.model(speech, speech_mask, text, text_mask) text, speech_len, text_len = (text.int(), speech_len.int(), text_len.int()) loss = self.criterion(preds, speech_len, text, text_len) return loss
[docs]class DistTransducerTrainer(BaseDistTrainer, TransducerTrainer): """A trainer module for transducer models that utilizes distributed data parallelism. Args: optimizer (Union[Optimizer, IScheduler]): The optimizer or the wrapped optimizer that will be used during the training. criterion (Module): The loss fucntion that will be used during the training process. model (Module): The model. train_loader (ILoader): The loader for the training data. test_loader (ILoader): The loader for the testing data. epochs (int): The number of epochs. log_steps_frequency (int): The frequency at which to log results. logger (ILogger): The logger to be used. outdir (Union[str, Path]): The directory to save checkpoints. rank (int): The process index. world_size (int): The number of nodes/processes. dist_address (str): The address of the master node. dist_port (int): The port of the master node. dist_backend (str): The backend used for DDP. grad_acc_steps (int): The number of steps to accumulate gradients over. Default 1. grad_clip_thresh (Union[None, float]): The maximum norm of the gradients. Default None. grad_clip_norm_type (float): The type of p-norm used. Default 2.0. history (dict): The training history, if available. Default {}. """ def __init__( self, optimizer: Union[Optimizer, IScheduler], criterion: Module, model: Module, train_loader: IDataLoader, test_loader: IDataLoader, epochs: int, logger: ILogger, outdir: Union[str, Path], log_steps_frequency: int, rank: int, world_size: int, dist_address: int, dist_port: int, dist_backend: str, grad_acc_steps: int = 1, grad_clip_thresh: Union[None, float] = None, grad_clip_norm_type: float = 2.0, history: dict = {}, ) -> None: TransducerTrainer.__init__( self, optimizer=optimizer, criterion=criterion, model=model, train_loader=train_loader, test_loader=test_loader, epochs=epochs, log_steps_frequency=log_steps_frequency, device=f"cuda:{rank}", logger=logger, outdir=outdir, grad_acc_steps=grad_acc_steps, grad_clip_thresh=grad_clip_thresh, grad_clip_norm_type=grad_clip_norm_type, history=history, ) BaseDistTrainer.__init__( self, optimizer=optimizer, criterion=criterion, model=model, train_loader=train_loader, test_loader=test_loader, epochs=epochs, log_steps_frequency=log_steps_frequency, logger=logger, outdir=outdir, rank=rank, world_size=world_size, dist_address=dist_address, dist_port=dist_port, dist_backend=dist_backend, grad_acc_steps=grad_acc_steps, grad_clip_thresh=grad_clip_thresh, grad_clip_norm_type=grad_clip_norm_type, history=history, )
def _run_trainer( rank: int, world_size: int, trainer_config: TrainerConfig, data_config: ASRDataConfig, model_config: ModelConfig, ) -> None: if rank != 0: # To make sure the master node created any dependancies # This can be replaced if we pass the rank to the # factories depend on the master node time.sleep(5) from .registry import get_asr_trainer trainer = get_asr_trainer( rank=rank, world_size=world_size, trainer_config=trainer_config, data_config=data_config, model_config=model_config, ) trainer.fit()
[docs]def launch_training_job( trainer_config: object, data_config: object, model_config: object ) -> None: """Launches ASR training job by constructing a trainer from the passed configuration and run it on single or multiple GPUS. Args: trainer_config (object): Trainer configuration object. data_config (object): Data configuration object. model_config (object): Model configuration object. """ trainer_launcher = partial( _run_trainer, trainer_config=trainer_config, data_config=data_config, model_config=model_config, ) if trainer_config.dist_config is None: trainer_launcher(rank=0, world_size=1) else: world_size = trainer_config.dist_config.n_gpus spawn( trainer_launcher, nprocs=trainer_config.dist_config.n_gpus, args=(world_size,), )