alr.training.ephemeral_trainer¶
Classses¶
PseudoLabelManager¶
-
class
alr.training.ephemeral_trainer.PseudoLabelManager(pool: alr.data.UnlabelledDataset, model: torch.nn.modules.module.Module, threshold: float, init_pseudo_labelled: Optional[torch.utils.data.dataset.Dataset] = None, log_dir: Optional[str] = None, device: Union[str, torch.device, None] = None, **kwargs)[source]¶ Bases:
object
PseudoLabelCollector¶
-
class
alr.training.ephemeral_trainer.PseudoLabelCollector(threshold: float, log_dir: Optional[str] = None, pred_transform: Callable[[torch.Tensor], torch.Tensor] = <function PseudoLabelCollector.<lambda>>)[source]¶ Bases:
object-
attach(engine: ignite.engine.engine.Engine, batch_size: int, output_transform=<function PseudoLabelCollector.<lambda>>)[source]¶ Parameters: - engine (Engine) – ignite engine object
- batch_size (int) – engine’s batch size
- output_transform (Callable) – if engine.state.output is not (preds, target), then output_transform should return aforementioned tuple.
Returns: None
Return type: NoneType
-
EphemeralTrainer¶
-
class
alr.training.ephemeral_trainer.EphemeralTrainer(model: alr.ALRModel, pool: alr.data.UnlabelledDataset, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimiser: str, threshold: float, min_labelled: Union[float, int, None] = None, random_fixed_length_sampler_length: Optional[int] = None, log_dir: Optional[str] = None, patience: Union[int, tuple, None] = None, reload_best: Optional[bool] = False, lr_scheduler: Optional[str] = None, lr_scheduler_kwargs: Optional[dict] = None, init_pseudo_label_dataset: Optional[torch.utils.data.dataset.Dataset] = None, device: Union[str, torch.device, None] = None, pool_loader_kwargs: Optional[dict] = None, *args, **kwargs)[source]¶ Bases:
object
Functions¶
create_pseudo_label_trainer¶
-
alr.training.ephemeral_trainer.create_pseudo_label_trainer(model: alr.ALRModel, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optimiser: str, train_loader: torch.utils.data.dataloader.DataLoader, val_loader: torch.utils.data.dataloader.DataLoader, pseudo_label_manager: alr.training.ephemeral_trainer.PseudoLabelManager, rfls_len: Optional[int] = None, min_labelled: Union[float, int, None] = None, patience: Optional[int] = None, reload_best: Optional[bool] = None, epochs: Optional[int] = 1, lr_scheduler: Optional[str] = None, lr_scheduler_kwargs: Optional[dict] = None, device: Union[str, torch.device, None] = None, *args, **kwargs)[source]¶