alr.training.pseudo_label_trainer

Classses

WraparoundLoader

class alr.training.pseudo_label_trainer.WraparoundLoader(ds: torch.utils.data.dataloader.DataLoader)[source]

Bases: object

Annealer

class alr.training.pseudo_label_trainer.Annealer(step: Optional[int] = 0, T1: Optional[int] = 100, T2: Optional[int] = 700, alpha: Optional[float] = 3.0, step_interval: Optional[int] = 50)[source]

Bases: object

attach(engine: ignite.engine.engine.Engine)[source]
step(_)[source]
weight

VanillaPLTrainer

class alr.training.pseudo_label_trainer.VanillaPLTrainer(model: torch.nn.modules.module.Module, labelled_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], unlabelled_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimiser: str, use_soft_labels: Optional[bool] = False, patience: Optional[int] = None, reload_best: Optional[bool] = False, track_pl_metrics: Optional[str] = None, T1: Optional[int] = 0, T2: Optional[int] = 40, step_interval: Optional[int] = 50, device: Union[str, torch.device, None] = None, *args, **kwargs)[source]

Bases: object

A vanilla pseudo-label training object.

Parameters:
  • model (torch.nn.Module) – module object
  • labelled_loss (Callable) – should be a function fn that takes preds and targets and returns a singleton tensor with the loss value: loss = fn(preds, targets). E.g. F.nll_loss.
  • unlabelled_loss (Callable) – similar to labelled_loss, but will be used on the pool dataset instead.
  • optimiser (str) – a string that corresponds to the type of optimiser to use. Must be an optimiser from torch.optim (case sensitive). E.g. ‘Adam’.
  • use_soft_labels (bool, optional) – if True, then unlabelled_loss is presumed to be a function that calculates the loss of soft-labels instead. Examples of such loss functions are soft_nll_loss() or soft_cross_entropy().
  • patience (int, optional) – if not None, then validation accuracy will be used to determine when to stop.
  • reload_best (bool, optional) – patience must be non-None if this is set to True: reloads the best model according to validation accuracy at the end of training.
  • track_pl_metrics (str, optional) – If a string is provided, then the training procedure will save raw predictions and targets on the pool dataset at the end of each epoch in stage 1 and while it is training during stage 2 in this directory. In other words, this lets you rack the quality of pseudo-labels as it is training in stage 2.
  • T1 (int, optional) – when the weight coefficient starts kicking in. 0 implies it immediately starts taking effect since; this is probably what you want – the model is already warm-started.
  • T2 (int, optional) – when the weight coefficient starts plateauing. For example, if step_interval is 50 and the number of iterations per epoch is 200, then there will be a total of 4 steps per epoch. If T2 is 40, then on the 10th epoch onwards, the coefficient is maxed out and plateaus at 3.
  • step_interval (int, optional) – how often should the annealer increment a step (counted in number of iterations, not epochs); this value is related to T1 and T2.
  • device (str, None, torch.device) – device type.
  • *args (Any, optional) – arguments to be passed into the optimiser.
  • **kwargs (Any, optional) – keyword arguments to be passed into the optimiser.
evaluate(data_loader: torch.utils.data.dataloader.DataLoader) → dict[source]
fit(train_loader: torch.utils.data.dataloader.DataLoader, pool_loader: torch.utils.data.dataloader.DataLoader, val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, epochs: Union[int, Sequence[int]] = 1) → Dict[str, Dict[str, list]][source]

Functions

soft_nll_loss

alr.training.pseudo_label_trainer.soft_nll_loss(preds: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]

Calculates the soft negative log-likelihood loss

Parameters:
  • preds (torch.Tensor) – predictions. This is expected to be log-softmax scores.
  • target (torch.Tensor) – target. This is expected to be log-softmax scores.
Returns:

a singleton tensor with the loss value

Return type:

torch.Tensor

soft_cross_entropy

alr.training.pseudo_label_trainer.soft_cross_entropy(logits: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]

Calculates the soft cross entropy loss. This combines log-softmax with soft_nll_loss.

Parameters:
  • logits (torch.Tensor) – predictions. This is expected to be logits.
  • target (torch.Tensor) – target. This is expected to be logits.
Returns:

a singleton tensor with the loss value

Return type:

torch.Tensor

create_semisupervised_trainer

alr.training.pseudo_label_trainer.create_semisupervised_trainer(model: torch.nn.modules.module.Module, optimiser, lloss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], uloss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], annealer: alr.training.pseudo_label_trainer.Annealer, train_iterable: alr.training.pseudo_label_trainer.WraparoundLoader, pl_saver: Optional[alr.training.utils.PLPredictionSaver] = None, use_soft_labels: bool = False, device: Union[str, torch.device, None] = None)[source]