alr.training.supervised_trainer

Classses

Trainer

class alr.training.supervised_trainer.Trainer(model: torch.nn.modules.module.Module, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimiser: str, patience: Optional[int] = None, reload_best: Optional[bool] = False, lr_scheduler: Optional[str] = None, lr_scheduler_kwargs: Optional[dict] = None, device: Union[str, torch.device, None] = None, *args, **kwargs)[source]

Bases: object

Parameters:
  • model (torch.nn.Module) – module object
  • loss (Callable) – should be a function fn that takes preds and targets and returns a singleton tensor with the loss value: loss = fn(preds, targets). E.g. F.nll_loss.
  • optimiser (str) – a string that corresponds to the type of optimiser to use. Must be an optimiser from torch.optim (case sensitive). E.g. ‘Adam’.
  • patience (int, optional) – if not None, then validation accuracy will be used to determine when to stop.
  • reload_best (bool, optional) – patience must be non-None if this is set to True: reloads the best model according to validation accuracy at the end of training.
  • lr_scheduler (str, optional) – a string that corresponds to the type of learning rate scheduler in torch.optim.lr_scheduler.
  • lr_scheduler_kwargs (dict, optional) – arguments to the constructor of lr_scheduler
  • device (str, None, torch.device) – device type.
  • *args (Any, optional) – arguments to be passed into the optimiser.
  • **kwargs (Any, optional) – keyword arguments to be passed into the optimiser.
evaluate(data_loader: torch.utils.data.dataloader.DataLoader) → dict[source]
fit(train_loader: torch.utils.data.dataloader.DataLoader, val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, epochs: Optional[int] = 1, callbacks: Optional[Sequence[Callable]] = None) → Dict[str, list][source]

Functions