from pathlib import Path
from typing import Union
from torch.utils.tensorboard import SummaryWriter
from speeq.interfaces import ILogger
from .utils import clear
[docs]class TBLogger(ILogger):
def __init__(
self,
log_dir: Union[str, Path],
n_logs: int,
clear_screen: bool,
*args,
**kwargs,
) -> None:
super().__init__()
self.writer = SummaryWriter(log_dir)
self.__counters = dict()
self.n_logs = n_logs
self.clear_screen = clear_screen
print("Started!")
[docs] def log_step(self, key: str, category: str, value: Union[int, float]) -> None:
tag = f"{key}/{category}"
if tag in self.__counters:
self.__counters[tag] += 1
counter = self.__counters[tag]
else:
self.__counters[tag] = 0
counter = 0
self.writer.add_scalar(tag, value, global_step=counter)
[docs] def log(self, history: dict):
logs = {
key: value[-self.n_logs :]
for key, value in history.items()
if isinstance(value, list)
}
if self.clear_screen is True:
clear() # cleaning the screen up
print(logs)
[docs]def get_logger(
name: str, log_dir: Union[str, Path], n_logs: int, *args, **kwargs
) -> ILogger:
if name in "tb":
return TBLogger(log_dir=log_dir, n_logs=n_logs, *args, **kwargs)
raise NotImplementedError