Source code for alr.training.supervised_trainer

import torch
from torch import nn
from ignite.engine import (
    Engine,
    Events,
    create_supervised_evaluator,
    create_supervised_trainer,
)
from ignite.metrics import Loss, Accuracy, RunningAverage
import torch.utils.data as torchdata
from alr.training.progress_bar.ignite_progress_bar import ProgressBar
from ignite.contrib.handlers.param_scheduler import LRScheduler
import numpy as np

from collections import defaultdict
from typing import Optional, Dict, Callable, Sequence

from alr.utils._type_aliases import _DeviceType, _Loss_fn
from alr.training.utils import EarlyStopper


[docs]class Trainer: def __init__( self, model: nn.Module, loss: _Loss_fn, optimiser: str, patience: Optional[int] = None, reload_best: Optional[bool] = False, lr_scheduler: Optional[str] = None, lr_scheduler_kwargs: Optional[dict] = None, device: _DeviceType = None, *args, **kwargs, ): r""" Args: model (torch.nn.Module): module object loss (Callable): should be a function `fn` that takes `preds` and `targets` and returns a singleton tensor with the loss value: `loss = fn(preds, targets)`. E.g. F.nll_loss. optimiser (str): a string that corresponds to the type of optimiser to use. Must be an optimiser from `torch.optim` (case sensitive). E.g. 'Adam'. patience (int, optional): if not `None`, then validation accuracy will be used to determine when to stop. reload_best (bool, optional): patience must be non-`None` if this is set to `True`: reloads the best model according to validation accuracy at the end of training. lr_scheduler (str, optional): a string that corresponds to the type of learning rate scheduler in `torch.optim.lr_scheduler`. lr_scheduler_kwargs (dict, optional): arguments to the constructor of `lr_scheduler` device (str, None, torch.device): device type. *args (Any, optional): arguments to be passed into the optimiser. **kwargs (Any, optional): keyword arguments to be passed into the optimiser. """ self._loss = loss self._optim = getattr(torch.optim, optimiser)( model.parameters(), *args, **kwargs ) self._patience = patience self._reload_best = reload_best assert patience is None or patience > 0 assert not reload_best or patience is not None self._device = device self._model = model self._lr_scheduler = lr_scheduler lr_scheduler_kwargs = {} if lr_scheduler_kwargs is None else lr_scheduler_kwargs if lr_scheduler is not None: self._lr_scheduler = getattr(torch.optim.lr_scheduler, lr_scheduler)( self._optim, **lr_scheduler_kwargs )
[docs] def fit( self, train_loader: torchdata.DataLoader, val_loader: Optional[torchdata.DataLoader] = None, epochs: Optional[int] = 1, callbacks: Optional[Sequence[Callable]] = None, ) -> Dict[str, list]: if self._patience and val_loader is None: raise ValueError( "If patience is specified, then val_loader must be provided in .fit()." ) pbar = ProgressBar(desc=lambda _: "Training") history = defaultdict(list) val_evaluator = create_supervised_evaluator( self._model, metrics={"acc": Accuracy(), "loss": Loss(self._loss)}, device=self._device, ) def _log_metrics(engine: Engine): # moving averages train_acc, train_loss = ( engine.state.metrics["train_acc"], engine.state.metrics["train_loss"], ) history[f"train_acc"].append(train_acc) history[f"train_loss"].append(train_loss) pbar.log_message( f"epoch {engine.state.epoch}/{engine.state.max_epochs}\n" f"\ttrain acc = {train_acc}, train loss = {train_loss}" ) if val_loader is None: return # job done # val loader - save to history and print metrics. Also, add handlers to # evaluator (e.g. early stopping, model checkpointing that depend on val_acc) metrics = val_evaluator.run(val_loader).metrics history[f"val_acc"].append(metrics["acc"]) history[f"val_loss"].append(metrics["loss"]) pbar.log_message( f"\tval acc = {metrics['acc']}, val loss = {metrics['loss']}" ) trainer = create_supervised_trainer( self._model, optimizer=self._optim, loss_fn=self._loss, device=self._device, output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y), ) pbar.attach(trainer) if self._lr_scheduler is not None: scheduler = LRScheduler(self._lr_scheduler) trainer.add_event_handler(Events.EPOCH_COMPLETED, scheduler) RunningAverage(Accuracy(output_transform=lambda x: (x[1], x[2]))).attach( trainer, "train_acc" ) RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "train_loss") if val_loader is not None and self._patience: es = EarlyStopper( self._model, self._patience, trainer, key="acc", mode="max" ) es.attach(val_evaluator) trainer.add_event_handler(Events.EPOCH_COMPLETED, _log_metrics) if callbacks is not None: for c in callbacks: trainer.add_event_handler(Events.EPOCH_COMPLETED, c) trainer.run( train_loader, max_epochs=epochs, ) if val_loader is not None and self._patience and self._reload_best: es.reload_best() return history
[docs] def evaluate(self, data_loader: torchdata.DataLoader) -> dict: evaluator = create_supervised_evaluator( self._model, metrics={"acc": Accuracy(), "loss": Loss(self._loss)}, device=self._device, ) return evaluator.run(data_loader).metrics