Source code for alr.training.ephemeral_trainer

from collections import defaultdict

from alr import ALRModel
from alr.data import RelabelDataset, PseudoLabelDataset, UnlabelledDataset
from alr.training.utils import EarlyStopper, PLPredictionSaver
from alr.utils._type_aliases import _DeviceType, _Loss_fn
from typing import Optional, Callable, Union

import numpy as np
import torch
from torch import nn
import torch.utils.data as torchdata
from ignite.engine import create_supervised_evaluator, Events, Engine
from ignite.metrics import Accuracy, Loss
from alr.training.progress_bar.ignite_progress_bar import ProgressBar
from alr.training import Trainer
from alr.training.samplers import MinLabelledSampler, RandomFixedLengthSampler
import copy


[docs]class PseudoLabelManager: def __init__( self, pool: UnlabelledDataset, model: nn.Module, threshold: float, init_pseudo_labelled: Optional[torchdata.Dataset] = None, log_dir: Optional[str] = None, device: _DeviceType = None, **kwargs, ): bs = kwargs.pop("batch_size", 1024) shuffle = kwargs.pop("shuffle", False) assert not shuffle self._pool = pool self._loader = torchdata.DataLoader( pool, batch_size=bs, shuffle=shuffle, **kwargs ) self._model = model self._log_dir = log_dir self._device = device self._threshold = threshold self.acquired_sizes = [] # keep a copy of the latest pseudo-labelled dataset self._init_pld = init_pseudo_labelled self.history = [init_pseudo_labelled]
[docs] def attach(self, engine: Engine): engine.add_event_handler(Events.STARTED, self._initialise) # could also be EPOCH_COMPLETED since there's only one iteration in each epoch engine.add_event_handler(Events.ITERATION_COMPLETED, self._load_labels)
def _load_labels(self, engine: Engine): evaluator = create_supervised_evaluator( self._model, metrics=None, device=self._device ) plc = PseudoLabelCollector( self._threshold, log_dir=self._log_dir, ) plc.attach(evaluator, batch_size=self._loader.batch_size) plc.global_step_from_engine(engine) evaluator.run(self._loader) indices, pseudo_labels = ( evaluator.state.pl_indices.cpu().numpy(), evaluator.state.pl_plabs.cpu().numpy(), ) self.acquired_sizes.append(indices.shape[0]) if indices.shape[0]: confident_points = torchdata.Subset(copy.copy(self._pool), indices) if self._pool.debug: # pool returns target labels too pld = RelabelDataset(confident_points, pseudo_labels) else: pld = PseudoLabelDataset(confident_points, pseudo_labels) else: pld = None engine.state.pseudo_labelled_dataset = pld self.history.append(pld) def _initialise(self, engine: Engine): engine.state.pseudo_labelled_dataset = self._init_pld
[docs]class PseudoLabelCollector: def __init__( self, threshold: float, log_dir: Optional[str] = None, pred_transform: Callable[[torch.Tensor], torch.Tensor] = lambda x: x.exp(), ): self._indices = [] self._plabs = [] self._pred_transform = pred_transform self._output_transform = lambda x: x self._thresh = threshold self._targets = [] self._preds = [] if log_dir: self._saver = PLPredictionSaver(log_dir, pred_transform=pred_transform) else: self._saver = None self._batch_size = None def _parse(self, engine: Engine): preds, targets = self._output_transform(engine.state.output) # state.iteration starts with 1 iteration = engine.state.iteration - 1 offset = iteration * self._batch_size with torch.no_grad(): preds = self._pred_transform(preds) preds_max, plabs = torch.max(preds, dim=-1) mask = torch.nonzero(preds_max >= self._thresh).flatten() if mask.shape[0]: # plabs = [N,] self._plabs.append(plabs[mask]) self._indices.append(mask + offset) def _flush(self, engine: Engine): if self._indices and self._plabs: engine.state.pl_indices = torch.cat(self._indices) engine.state.pl_plabs = torch.cat(self._plabs) else: engine.state.pl_indices = torch.Tensor([]) engine.state.pl_plabs = torch.Tensor([]) self._indices = [] self._plabs = []
[docs] def attach(self, engine: Engine, batch_size: int, output_transform=lambda x: x): r""" Args: engine (Engine): ignite engine object batch_size (int): engine's batch size output_transform (Callable): if engine.state.output is not (preds, target), then output_transform should return aforementioned tuple. Returns: NoneType: None """ engine.add_event_handler(Events.ITERATION_COMPLETED, self._parse) engine.add_event_handler(Events.COMPLETED, self._flush) self._output_transform = output_transform self._batch_size = batch_size if self._saver: self._saver.attach(engine, output_transform=output_transform)
[docs] def global_step_from_engine(self, engine: Engine): if self._saver: self._saver.global_step_from_engine(engine)
def _update_dataloader_batch_sampler( loader: torchdata.DataLoader, pseudo_labelled_dataset: torchdata.Dataset, min_labelled: Union[float, int], ): labelled_dataset = loader.dataset # attributes that usually go in dataloader's constructor attrs = [k for k in loader.__dict__.keys() if not k.startswith("_")] drop = [ "dataset", "sampler", "batch_sampler", "dataset_kind", "batch_size", "shuffle", "drop_last", ] kwargs = {k: getattr(loader, k) for k in attrs if k not in drop} kwargs["dataset"] = torchdata.ConcatDataset( (labelled_dataset, pseudo_labelled_dataset) ) kwargs["batch_sampler"] = MinLabelledSampler( labelled_dataset, pseudo_labelled_dataset, loader.batch_size, min_labelled=min_labelled, ) return torchdata.DataLoader(**kwargs) def _update_dataloader( loader: torchdata.DataLoader, dataset: torchdata.Dataset, sampler: Optional[torchdata.Sampler] = None, ): # attributes that usually go in dataloader's constructor attrs = [k for k in loader.__dict__.keys() if not k.startswith("_")] drop = ["dataset", "sampler", "batch_sampler", "dataset_kind"] kwargs = {k: getattr(loader, k) for k in attrs if k not in drop} if not isinstance( loader.sampler, ( torchdata.SequentialSampler, torchdata.RandomSampler, RandomFixedLengthSampler, ), ): raise ValueError( f"Only sequential, random, and random fixed length samplers " f"are supported in _update_dataloader" ) kwargs["dataset"] = dataset # Sequential and Random will be automatically determined if sampler is None (depending on shuffle) kwargs["sampler"] = sampler return torchdata.DataLoader(**kwargs)
[docs]def create_pseudo_label_trainer( model: ALRModel, loss: _Loss_fn, optimiser: str, train_loader: torchdata.DataLoader, val_loader: torchdata.DataLoader, pseudo_label_manager: PseudoLabelManager, rfls_len: Optional[int] = None, min_labelled: Optional[Union[float, int]] = None, patience: Optional[int] = None, reload_best: Optional[bool] = None, epochs: Optional[int] = 1, lr_scheduler: Optional[str] = None, lr_scheduler_kwargs: Optional[dict] = None, device: _DeviceType = None, *args, **kwargs, ): assert ( not min_labelled or not rfls_len ), "rfls_len and min_labelled are mutually exclusive" def _step(engine: Engine, _): model.reset_weights() # update loader accordingly: if pld is not none, concatenate them new_loader = train_loader pld = engine.state.pseudo_labelled_dataset if pld is not None: train_ds = torchdata.ConcatDataset((train_loader.dataset, pld)) # update dataloader's dataset attribute if rfls_len: new_loader = _update_dataloader( train_loader, train_ds, RandomFixedLengthSampler(train_ds, length=rfls_len, shuffle=True), ) elif min_labelled: new_loader = _update_dataloader_batch_sampler( train_loader, pld, min_labelled ) else: new_loader = _update_dataloader(train_loader, train_ds) # begin supervised training trainer = Trainer( model, loss, optimiser, patience, reload_best, lr_scheduler=lr_scheduler, lr_scheduler_kwargs=lr_scheduler_kwargs, device=device, *args, **kwargs, ) history = trainer.fit( new_loader, val_loader=val_loader, epochs=epochs, ) # if early stopping was applied w/ patience, then the actual train acc and loss should be # -patience from the final loss/acc UNLESS we reached the maximum number of epochs. if patience and len(history["train_loss"]) != epochs: return history["train_loss"][-patience], history["train_acc"][-patience] return history["train_loss"][-1], history["train_acc"][-1] e = Engine(_step) pseudo_label_manager.attach(e) return e
[docs]class EphemeralTrainer: def __init__( self, model: ALRModel, pool: UnlabelledDataset, loss: _Loss_fn, optimiser: str, threshold: float, min_labelled: Optional[Union[float, int]] = None, random_fixed_length_sampler_length: Optional[int] = None, log_dir: Optional[str] = None, patience: Optional[Union[int, tuple]] = None, reload_best: Optional[bool] = False, lr_scheduler: Optional[str] = None, lr_scheduler_kwargs: Optional[dict] = None, init_pseudo_label_dataset: Optional[torchdata.Dataset] = None, device: _DeviceType = None, pool_loader_kwargs: Optional[dict] = None, *args, **kwargs, ): assert ( not min_labelled or not random_fixed_length_sampler_length ), "random_fixed_length_sampler_length and min_labelled are mutually exclusive" self._pool = pool self._model = model self._loss = loss self._optimiser = optimiser self._patience = patience self._reload_best = reload_best self._device = device self._args = args self._kwargs = kwargs self._threshold = threshold self._log_dir = log_dir self._pool_loader_kwargs = ( pool_loader_kwargs if pool_loader_kwargs is not None else {} ) self._min_labelled = min_labelled self._rfls_len = random_fixed_length_sampler_length self._lr_scheduler = lr_scheduler self._lr_scheduler_kwargs = ( lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {} ) self._init_pseudo_label_dataset = init_pseudo_label_dataset self.last_pseudo_label_dataset = None
[docs] def fit( self, train_loader: torchdata.DataLoader, val_loader: Optional[torchdata.DataLoader] = None, iterations: Optional[int] = 1, epochs: Optional[int] = 1, ): if self._patience: if isinstance(self._patience, int): pat1 = pat2 = self._patience else: pat1, pat2 = self._patience else: pat1 = pat2 = None if self._patience and val_loader is None: raise ValueError( "If patience is specified, then val_loader must be provided in .fit()." ) val_evaluator = create_supervised_evaluator( self._model, metrics={"acc": Accuracy(), "loss": Loss(self._loss)}, device=self._device, ) history = defaultdict(list) pbar = ProgressBar(desc=lambda _: "Ephemeral") def _log_metrics(engine: Engine): # train_loss and train_acc are moving averages of the last epoch # in the supervised training loop train_loss, train_acc = engine.state.output history[f"train_loss"].append(train_loss) history[f"train_acc"].append(train_acc) pbar.log_message( f"Eph. iteration {engine.state.epoch}/{engine.state.max_epochs}\n" f"\ttrain acc = {train_acc}, train loss = {train_loss}" ) if val_loader is None: return # job done # val loader - save to history and print metrics. Also, add handlers to # evaluator (e.g. early stopping, model checkpointing that depend on val_acc) metrics = val_evaluator.run(val_loader).metrics history[f"val_acc"].append(metrics["acc"]) history[f"val_loss"].append(metrics["loss"]) pbar.log_message( f"\tval acc = {metrics['acc']}, val loss = {metrics['loss']}" ) pseudo_label_manager = PseudoLabelManager( pool=self._pool, model=self._model, threshold=self._threshold, init_pseudo_labelled=self._init_pseudo_label_dataset, log_dir=self._log_dir, device=self._device, **self._pool_loader_kwargs, ) trainer = create_pseudo_label_trainer( model=self._model, loss=self._loss, optimiser=self._optimiser, train_loader=train_loader, val_loader=val_loader, pseudo_label_manager=pseudo_label_manager, rfls_len=self._rfls_len, min_labelled=self._min_labelled, patience=pat1, reload_best=self._reload_best, epochs=epochs, lr_scheduler=self._lr_scheduler, lr_scheduler_kwargs=self._lr_scheduler_kwargs, device=self._device, *self._args, **self._kwargs, ) # output of trainer are running averages of train_loss and train_acc (from the # last epoch of the supervised trainer) pbar.attach(trainer) if val_loader is not None and self._patience: es = EarlyStopper(self._model, pat2, trainer, key="acc", mode="max") es.attach(val_evaluator) trainer.add_event_handler(Events.EPOCH_COMPLETED, _log_metrics) trainer.run( range(iterations), max_epochs=iterations, epoch_length=1, ) if val_loader is not None and self._patience and self._reload_best: es.reload_best() history["train_size"] = np.array(pseudo_label_manager.acquired_sizes) + len( train_loader.dataset ) if trainer.state.epoch == epochs: self.last_pseudo_label_dataset = pseudo_label_manager.history[-2] else: self.last_pseudo_label_dataset = pseudo_label_manager.history[-(pat2 + 2)] return history
[docs] def evaluate(self, data_loader: torchdata.DataLoader) -> dict: evaluator = create_supervised_evaluator( self._model, metrics={"acc": Accuracy(), "loss": Loss(self._loss)}, device=self._device, ) return evaluator.run(data_loader).metrics