alr.training.plmixup_ensemble

Classses

PLMixupEnsembleTrainer

class alr.training.plmixup_ensemble.PLMixupEnsembleTrainer(models: List[torch.nn.modules.module.Module], optimiser: str, train_transform: Callable, test_transform: Callable, optimiser_kwargs: dict, loader_kwargs: dict, rfls_length: int, log_dir: Optional[str] = None, alpha: Optional[float] = 1.0, min_labelled: Union[float, int, None] = 16, num_classes: Optional[int] = 10, data_augmentation: Optional[Callable] = None, batch_size: Optional[int] = 100, patience: Union[Tuple[int, int], int, None] = (5, 25), lr_patience: Optional[int] = 10, device: Union[str, torch.device, None] = None)[source]

Bases: object

evaluate(data_loader: torch.utils.data.dataloader.DataLoader) → dict[source]
fit(train: torch.utils.data.dataset.Dataset, val: torch.utils.data.dataset.Dataset, pool: torch.utils.data.dataset.Dataset, epochs: Optional[Tuple[int, int]] = (50, 400))[source]

Ensemble

class alr.training.plmixup_ensemble.Ensemble(models: list, return_log: bool = False)[source]

Bases: object

__call__(x)[source]

Call self as a function.

eval()[source]
evaluate(loader, device)[source]
forward(x)[source]
get_preds(x)[source]
save_weights(prefix: str)[source]
train(mode=True)[source]

Functions

plmixup_train

alr.training.plmixup_ensemble.plmixup_train(loader, model, optimiser, alpha, device)[source]