alr.training.supervised_trainer¶
Classses¶
Trainer¶
-
class
alr.training.supervised_trainer.Trainer(model: torch.nn.modules.module.Module, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimiser: str, patience: Optional[int] = None, reload_best: Optional[bool] = False, lr_scheduler: Optional[str] = None, lr_scheduler_kwargs: Optional[dict] = None, device: Union[str, torch.device, None] = None, *args, **kwargs)[source]¶ Bases:
objectParameters: - model (torch.nn.Module) – module object
- loss (Callable) – should be a function fn that takes preds and targets and returns a singleton tensor with the loss value: loss = fn(preds, targets). E.g. F.nll_loss.
- optimiser (str) – a string that corresponds to the type of optimiser to use. Must be an optimiser from torch.optim (case sensitive). E.g. ‘Adam’.
- patience (int, optional) – if not None, then validation accuracy will be used to determine when to stop.
- reload_best (bool, optional) – patience must be non-None if this is set to True: reloads the best model according to validation accuracy at the end of training.
- lr_scheduler (str, optional) – a string that corresponds to the type of learning rate scheduler in torch.optim.lr_scheduler.
- lr_scheduler_kwargs (dict, optional) – arguments to the constructor of lr_scheduler
- device (str, None, torch.device) – device type.
- *args (Any, optional) – arguments to be passed into the optimiser.
- **kwargs (Any, optional) – keyword arguments to be passed into the optimiser.