alr.training.ephemeral_trainer

Classses

PseudoLabelManager

class alr.training.ephemeral_trainer.PseudoLabelManager(pool: alr.data.UnlabelledDataset, model: torch.nn.modules.module.Module, threshold: float, init_pseudo_labelled: Optional[torch.utils.data.dataset.Dataset] = None, log_dir: Optional[str] = None, device: Union[str, torch.device, None] = None, **kwargs)[source]

Bases: object

attach(engine: ignite.engine.engine.Engine)[source]

PseudoLabelCollector

class alr.training.ephemeral_trainer.PseudoLabelCollector(threshold: float, log_dir: Optional[str] = None, pred_transform: Callable[[torch.Tensor], torch.Tensor] = <function PseudoLabelCollector.<lambda>>)[source]

Bases: object

attach(engine: ignite.engine.engine.Engine, batch_size: int, output_transform=<function PseudoLabelCollector.<lambda>>)[source]
Parameters:
  • engine (Engine) – ignite engine object
  • batch_size (int) – engine’s batch size
  • output_transform (Callable) – if engine.state.output is not (preds, target), then output_transform should return aforementioned tuple.
Returns:

None

Return type:

NoneType

global_step_from_engine(engine: ignite.engine.engine.Engine)[source]

EphemeralTrainer

class alr.training.ephemeral_trainer.EphemeralTrainer(model: alr.ALRModel, pool: alr.data.UnlabelledDataset, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimiser: str, threshold: float, min_labelled: Union[float, int, None] = None, random_fixed_length_sampler_length: Optional[int] = None, log_dir: Optional[str] = None, patience: Union[int, tuple, None] = None, reload_best: Optional[bool] = False, lr_scheduler: Optional[str] = None, lr_scheduler_kwargs: Optional[dict] = None, init_pseudo_label_dataset: Optional[torch.utils.data.dataset.Dataset] = None, device: Union[str, torch.device, None] = None, pool_loader_kwargs: Optional[dict] = None, *args, **kwargs)[source]

Bases: object

evaluate(data_loader: torch.utils.data.dataloader.DataLoader) → dict[source]
fit(train_loader: torch.utils.data.dataloader.DataLoader, val_loader: Optional[torch.utils.data.dataloader.DataLoader] = None, iterations: Optional[int] = 1, epochs: Optional[int] = 1)[source]

Functions

create_pseudo_label_trainer

alr.training.ephemeral_trainer.create_pseudo_label_trainer(model: alr.ALRModel, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimiser: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, pseudo_label_manager: alr.training.ephemeral_trainer.PseudoLabelManager, rfls_len: Optional[int] = None, min_labelled: Union[float, int, None] = None, patience: Optional[int] = None, reload_best: Optional[bool] = None, epochs: Optional[int] = 1, lr_scheduler: Optional[str] = None, lr_scheduler_kwargs: Optional[dict] = None, device: Union[str, torch.device, None] = None, *args, **kwargs)[source]