alr.training.pseudo_label_trainer¶
Classses¶
WraparoundLoader¶
Annealer¶
VanillaPLTrainer¶
-
class
alr.training.pseudo_label_trainer.VanillaPLTrainer(model: torch.nn.modules.module.Module, labelled_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], unlabelled_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimiser: str, use_soft_labels: Optional[bool] = False, patience: Optional[int] = None, reload_best: Optional[bool] = False, track_pl_metrics: Optional[str] = None, T1: Optional[int] = 0, T2: Optional[int] = 40, step_interval: Optional[int] = 50, device: Union[str, torch.device, None] = None, *args, **kwargs)[source]¶ Bases:
objectA vanilla pseudo-label training object.
Parameters: - model (torch.nn.Module) – module object
- labelled_loss (Callable) – should be a function fn that takes preds and targets and returns a singleton tensor with the loss value: loss = fn(preds, targets). E.g. F.nll_loss.
- unlabelled_loss (Callable) – similar to labelled_loss, but will be used on the pool dataset instead.
- optimiser (str) – a string that corresponds to the type of optimiser to use. Must be an optimiser from torch.optim (case sensitive). E.g. ‘Adam’.
- use_soft_labels (bool, optional) – if True, then unlabelled_loss is presumed to
be a function that calculates the loss of soft-labels instead. Examples of such loss functions are
soft_nll_loss()orsoft_cross_entropy(). - patience (int, optional) – if not None, then validation accuracy will be used to determine when to stop.
- reload_best (bool, optional) – patience must be non-None if this is set to True: reloads the best model according to validation accuracy at the end of training.
- track_pl_metrics (str, optional) – If a string is provided, then the training procedure will save raw predictions and targets on the pool dataset at the end of each epoch in stage 1 and while it is training during stage 2 in this directory. In other words, this lets you rack the quality of pseudo-labels as it is training in stage 2.
- T1 (int, optional) – when the weight coefficient starts kicking in. 0 implies it immediately starts taking effect since; this is probably what you want – the model is already warm-started.
- T2 (int, optional) – when the weight coefficient starts plateauing. For example, if step_interval is 50 and the number of iterations per epoch is 200, then there will be a total of 4 steps per epoch. If T2 is 40, then on the 10th epoch onwards, the coefficient is maxed out and plateaus at 3.
- step_interval (int, optional) – how often should the annealer increment a step (counted in number of iterations, not epochs); this value is related to T1 and T2.
- device (str, None, torch.device) – device type.
- *args (Any, optional) – arguments to be passed into the optimiser.
- **kwargs (Any, optional) – keyword arguments to be passed into the optimiser.
Functions¶
soft_nll_loss¶
-
alr.training.pseudo_label_trainer.soft_nll_loss(preds: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]¶ Calculates the soft negative log-likelihood loss
Parameters: - preds (torch.Tensor) – predictions. This is expected to be log-softmax scores.
- target (torch.Tensor) – target. This is expected to be log-softmax scores.
Returns: a singleton tensor with the loss value
Return type: torch.Tensor
soft_cross_entropy¶
-
alr.training.pseudo_label_trainer.soft_cross_entropy(logits: torch.Tensor, target: torch.Tensor) → torch.Tensor[source]¶ Calculates the soft cross entropy loss. This combines log-softmax with soft_nll_loss.
Parameters: - logits (torch.Tensor) – predictions. This is expected to be logits.
- target (torch.Tensor) – target. This is expected to be logits.
Returns: a singleton tensor with the loss value
Return type: torch.Tensor
create_semisupervised_trainer¶
-
alr.training.pseudo_label_trainer.create_semisupervised_trainer(model: torch.nn.modules.module.Module, optimiser, lloss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], uloss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], annealer: alr.training.pseudo_label_trainer.Annealer, train_iterable: alr.training.pseudo_label_trainer.WraparoundLoader, pl_saver: Optional[alr.training.utils.PLPredictionSaver] = None, use_soft_labels: bool = False, device: Union[str, torch.device, None] = None)[source]¶