alr.training.utils

Classses

EarlyStopper

class alr.training.utils.EarlyStopper(model: torch.nn.modules.module.Module, patience: int, trainer: ignite.engine.engine.Engine, key: Optional[str] = 'acc', mode: Optional[str] = 'max')[source]

Bases: object

attach(engine: ignite.engine.engine.Engine)[source]

Attach an early stopper to engine that will terminate the provided trainer when the predetermined metric does not improve for patience epochs.

Parameters:engine (ignite.engine.Engine) – this is expected to be a validation evaluator. The key metric will be extracted and the best will be used.
Returns:None
Return type:NoneType
reload_best()[source]

PLPredictionSaver

class alr.training.utils.PLPredictionSaver(log_dir: str, compact: Optional[bool] = True, pred_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = <function PLPredictionSaver.<lambda>>, onehot_target: Optional[bool] = False)[source]

Bases: object

Parameters:
  • () (onehot_target) – duh
  • () – save what you need (compact) instead of saving all predictions (huge files)
  • () – typically used to exponentiate model’s output predictions
  • () – set to True if the target label is a distribution (i.e. argmax should be called on it to get the class); leave as false if targets are ints.
attach(engine: ignite.engine.engine.Engine, output_transform: Callable[[...], tuple] = <function PLPredictionSaver.<lambda>>)[source]
global_step_from_engine(engine: ignite.engine.engine.Engine)[source]

PerformanceTracker

class alr.training.utils.PerformanceTracker(model: torch.nn.modules.module.Module, patience: int)[source]

Bases: object

done
reload_best()[source]
reloaded
reset()[source]
step(acc)[source]

Functions