Source code for speeq.utils.loggers

from pathlib import Path
from typing import Union

from torch.utils.tensorboard import SummaryWriter

from speeq.interfaces import ILogger

from .utils import clear


[docs]class TBLogger(ILogger): def __init__( self, log_dir: Union[str, Path], n_logs: int, clear_screen: bool, *args, **kwargs, ) -> None: super().__init__() self.writer = SummaryWriter(log_dir) self.__counters = dict() self.n_logs = n_logs self.clear_screen = clear_screen print("Started!")
[docs] def log_step(self, key: str, category: str, value: Union[int, float]) -> None: tag = f"{key}/{category}" if tag in self.__counters: self.__counters[tag] += 1 counter = self.__counters[tag] else: self.__counters[tag] = 0 counter = 0 self.writer.add_scalar(tag, value, global_step=counter)
[docs] def log(self, history: dict): logs = { key: value[-self.n_logs :] for key, value in history.items() if isinstance(value, list) } if self.clear_screen is True: clear() # cleaning the screen up print(logs)
[docs]def get_logger( name: str, log_dir: Union[str, Path], n_logs: int, *args, **kwargs ) -> ILogger: if name in "tb": return TBLogger(log_dir=log_dir, n_logs=n_logs, *args, **kwargs) raise NotImplementedError