Source code for alr.training.pseudo_label_trainer

from collections import defaultdict
from typing import Optional, Dict, Sequence, Union

import numpy as np
import torch
import torch.utils.data as torchdata
from alr.training.progress_bar.ignite_progress_bar import ProgressBar
from ignite.engine import Engine, Events, create_supervised_evaluator
from ignite.metrics import Loss, Accuracy
from torch import nn

from alr.data import UnlabelledDataset
from alr.training import Trainer
from alr.training.utils import EarlyStopper, PLPredictionSaver
from alr.utils import _map_device
from alr.utils.math import cross_entropy
from alr.utils._type_aliases import _DeviceType, _Loss_fn


[docs]class WraparoundLoader: def __init__(self, ds: torchdata.DataLoader): self._ds = ds self._iter = iter(ds) def __next__(self) -> torch.Tensor: try: return next(self._iter) except StopIteration: self._iter = iter(self._ds) return next(self._iter) def __iter__(self): return self
[docs]class Annealer: def __init__( self, step: Optional[int] = 0, T1: Optional[int] = 100, T2: Optional[int] = 700, alpha: Optional[float] = 3.0, step_interval: Optional[int] = 50, ): self._step = step self._T1 = T1 self._T2 = T2 self._alpha = alpha self._step_interval = step_interval
[docs] def step(self, _): self._step += 1
@property def weight(self): if self._step < self._T1: return 0 elif self._step > self._T2: return self._alpha else: return ((self._step - self._T1) / (self._T2 - self._T1)) * self._alpha
[docs] def attach(self, engine: Engine): engine.add_event_handler( Events.ITERATION_COMPLETED(every=self._step_interval), self.step )
[docs]def soft_nll_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: r""" Calculates the soft negative log-likelihood loss Args: preds (torch.Tensor): predictions. This is expected to be log-softmax scores. target (torch.Tensor): target. This is expected to be log-softmax scores. Returns: torch.Tensor: a singleton tensor with the loss value """ # -1/N * sum_y p(y)log[p(y)] res = cross_entropy(target, preds, mode="logsoftmax").sum(dim=1).mean() assert torch.isfinite(res) return res
[docs]def soft_cross_entropy(logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: r""" Calculates the soft cross entropy loss. This combines log-softmax with `soft_nll_loss`. Args: logits (torch.Tensor): predictions. This is expected to be logits. target (torch.Tensor): target. This is expected to be logits. Returns: torch.Tensor: a singleton tensor with the loss value """ # -1/N * sum_y p(y)log[p(y)] res = cross_entropy(target, logits, mode="logits").sum(dim=1).mean() assert torch.isfinite(res) return res
[docs]def create_semisupervised_trainer( model: nn.Module, optimiser, lloss_fn: _Loss_fn, uloss_fn: _Loss_fn, annealer: Annealer, train_iterable: WraparoundLoader, pl_saver: Optional[PLPredictionSaver] = None, use_soft_labels: bool = False, device: _DeviceType = None, ): def _step(_, batch): if isinstance(batch, (list, tuple)): # don't have to map targets to GPU since we're saving it immediately x, targets = batch x = x.to(device) else: x, targets = batch.to(device), None # get pseudo-labels for this batch using eval mode with torch.no_grad(): model.eval() raw_preds = model(x) if use_soft_labels: # uloss_fn's second parameter expects a soft dist. preds = raw_preds else: # uloss_fn's second parameter expects a sequence # of class numbers preds = torch.argmax(raw_preds, dim=1) # normal forward pass on pseudo_labels model.train() u_loss = uloss_fn(model(x), preds) # normal forward pass on training data model.train() x, y = _map_device(next(train_iterable), device) l_loss = lloss_fn(model(x), y) # total loss loss = l_loss + annealer.weight * u_loss assert torch.isfinite(loss) optimiser.zero_grad() loss.backward() optimiser.step() return loss.item(), raw_preds, targets e = Engine(_step) annealer.attach(e) if pl_saver is not None: pl_saver.attach( e, output_transform=lambda x: (x[1], x[2]) # (raw_preds, targets) ) return e
[docs]class VanillaPLTrainer: def __init__( self, model: nn.Module, labelled_loss: _Loss_fn, unlabelled_loss: _Loss_fn, optimiser: str, use_soft_labels: Optional[bool] = False, patience: Optional[int] = None, reload_best: Optional[bool] = False, track_pl_metrics: Optional[str] = None, T1: Optional[int] = 0, T2: Optional[int] = 40, step_interval: Optional[int] = 50, device: _DeviceType = None, *args, **kwargs, ): r""" A vanilla pseudo-label training object. Args: model (torch.nn.Module): module object labelled_loss (Callable): should be a function `fn` that takes `preds` and `targets` and returns a singleton tensor with the loss value: `loss = fn(preds, targets)`. E.g. F.nll_loss. unlabelled_loss (Callable): similar to `labelled_loss`, but will be used on the pool dataset instead. optimiser (str): a string that corresponds to the type of optimiser to use. Must be an optimiser from `torch.optim` (case sensitive). E.g. 'Adam'. use_soft_labels (bool, optional): if `True`, then `unlabelled_loss` is presumed to be a function that calculates the loss of soft-labels instead. Examples of such loss functions are :func:`~alr.training.pseudo_label_trainer.soft_nll_loss` or :func:`~alr.training.pseudo-Label_trainer.soft_cross_entropy`. patience (int, optional): if not `None`, then validation accuracy will be used to determine when to stop. reload_best (bool, optional): patience must be non-`None` if this is set to `True`: reloads the best model according to validation accuracy at the end of training. track_pl_metrics (str, optional): If a string is provided, then the training procedure will save raw predictions and targets on the pool dataset at the *end* of each epoch in stage 1 and while it is training *during* stage 2 in this directory. In other words, this lets you rack the quality of pseudo-labels as it is training in stage 2. T1 (int, optional): when the weight coefficient starts kicking in. 0 implies it immediately starts taking effect since; this is probably what you want -- the model is already warm-started. T2 (int, optional): when the weight coefficient starts plateauing. For example, if `step_interval` is 50 and the number of iterations per epoch is 200, then there will be a total of 4 steps per epoch. If T2 is 40, then on the 10th epoch onwards, the coefficient is maxed out and plateaus at 3. step_interval (int, optional): how often should the annealer increment a step (counted in number of iterations, *not* epochs); this value is related to `T1` and `T2`. device (str, None, torch.device): device type. *args (Any, optional): arguments to be passed into the optimiser. **kwargs (Any, optional): keyword arguments to be passed into the optimiser. """ # essentials assert hasattr(torch.optim, optimiser) assert patience is None or patience > 0 assert not reload_best or patience is not None self._model = model self._lloss = labelled_loss self._uloss = unlabelled_loss self._optim = optimiser # hparams self._use_soft_labels = use_soft_labels self._patience = patience self._reload_best = reload_best self._track_pl_metrics = track_pl_metrics self._T1 = T1 self._T2 = T2 self._step_interval = step_interval self._device = device # optimiser args self._args = args self._kwargs = kwargs
[docs] def fit( self, train_loader: torchdata.DataLoader, pool_loader: torchdata.DataLoader, val_loader: Optional[torchdata.DataLoader] = None, epochs: Union[int, Sequence[int]] = 1, ) -> Dict[str, Dict[str, list]]: if self._track_pl_metrics is not None and ( not isinstance(pool_loader.dataset, UnlabelledDataset) or not pool_loader.dataset.debug ): raise ValueError( f"If track_pl_metrics is provided, then the dataset in pool_loader " f"must be of the type UnlabelledDataset with debug on." ) if self._patience and val_loader is None: raise ValueError( "If patience is specified, then val_loader must be provided in .fit()." ) if isinstance(epochs, int): epochs = (epochs, epochs) assert len(epochs) == 2 epoch1, epoch2 = epochs[0], epochs[1] callbacks = None if self._track_pl_metrics is not None: save_pl_metrics = create_supervised_evaluator( self._model, metrics=None, device=self._device ) pps = PLPredictionSaver(log_dir=(self._track_pl_metrics + "/stage1")) pps.attach(save_pl_metrics) def _save_pl_metrics(e: Engine): # epoch should be read from the engine that's training # the model, not the evaluator defined above. pps.global_step_from_engine(e) save_pl_metrics.run(pool_loader) callbacks = [_save_pl_metrics] # stage 1 supervised_trainer = Trainer( self._model, self._lloss, self._optim, patience=self._patience, reload_best=self._reload_best, device=self._device, *self._args, **self._kwargs, ) # until convergence supervised_history = supervised_trainer.fit( train_loader, val_loader, epochs=epoch1, callbacks=callbacks, ) # stage 2 pl_history = defaultdict(list) pbar = ProgressBar(desc=lambda _: "Vanilla PL trainer") train_evaluator = create_supervised_evaluator( self._model, metrics={"acc": Accuracy(), "loss": Loss(self._lloss)}, device=self._device, ) val_evaluator = create_supervised_evaluator( self._model, metrics={"acc": Accuracy(), "loss": Loss(self._lloss)}, device=self._device, ) def _log_metrics(engine: Engine): # engine = ssl engine with `pl_tracker` metrics = train_evaluator.run(train_loader).metrics pl_history["train_acc"].append(metrics["acc"]) pl_history["train_loss"].append(metrics["loss"]) pbar.log_message( f"epoch {engine.state.epoch}/{engine.state.max_epochs}\n" f"\ttrain acc = {metrics['acc']}, train loss = {metrics['loss']}" ) if val_loader is None: return # job done metrics = val_evaluator.run(val_loader).metrics pl_history["val_acc"].append(metrics["acc"]) pl_history["val_loss"].append(metrics["loss"]) pbar.log_message( f"\tval acc = {metrics['acc']}, val loss = {metrics['loss']}" ) ssl_trainer = create_semisupervised_trainer( model=self._model, optimiser=getattr(torch.optim, self._optim)( self._model.parameters(), *self._args, **self._kwargs ), lloss_fn=self._lloss, uloss_fn=self._uloss, annealer=Annealer( step=1, T1=self._T1, T2=self._T2, step_interval=self._step_interval ), train_iterable=WraparoundLoader(train_loader), pl_saver=( PLPredictionSaver(self._track_pl_metrics + "/stage2") if self._track_pl_metrics is not None else None ), use_soft_labels=self._use_soft_labels, device=self._device, ) if val_loader is not None and self._patience: es = EarlyStopper( self._model, patience=self._patience, trainer=ssl_trainer, key="acc", mode="max", ) es.attach(val_evaluator) ssl_trainer.add_event_handler(Events.EPOCH_COMPLETED, _log_metrics) pbar.attach(ssl_trainer) ssl_trainer.run( pool_loader, max_epochs=epoch2, ) if val_loader is not None and self._patience and self._reload_best: es.reload_best() return { "stage1": {k: np.array(v) for k, v in supervised_history.items()}, "stage2": {k: np.array(v) for k, v in pl_history.items()}, }
[docs] def evaluate(self, data_loader: torchdata.DataLoader) -> dict: evaluator = create_supervised_evaluator( self._model, metrics={"acc": Accuracy(), "loss": Loss(self._lloss)}, device=self._device, ) return evaluator.run(data_loader).metrics