alr.training.utils¶
Classses¶
EarlyStopper¶
-
class
alr.training.utils.EarlyStopper(model: torch.nn.modules.module.Module, patience: int, trainer: ignite.engine.engine.Engine, key: Optional[str] = 'acc', mode: Optional[str] = 'max')[source]¶ Bases:
object-
attach(engine: ignite.engine.engine.Engine)[source]¶ Attach an early stopper to engine that will terminate the provided trainer when the predetermined metric does not improve for patience epochs.
Parameters: engine (ignite.engine.Engine) – this is expected to be a validation evaluator. The key metric will be extracted and the best will be used. Returns: None Return type: NoneType
-
PLPredictionSaver¶
-
class
alr.training.utils.PLPredictionSaver(log_dir: str, compact: Optional[bool] = True, pred_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = <function PLPredictionSaver.<lambda>>, onehot_target: Optional[bool] = False)[source]¶ Bases:
objectParameters: - () (onehot_target) – duh
- () – save what you need (compact) instead of saving all predictions (huge files)
- () – typically used to exponentiate model’s output predictions
- () – set to True if the target label is a distribution (i.e. argmax should be called on it to get the class); leave as false if targets are ints.