alr.training.pl_mixup

Classses

IndexMarker

PDS

PLMixupTrainer

class alr.training.pl_mixup.PLMixupTrainer(model: torch.nn.modules.module.Module, optimiser: str, train_transform: Callable, test_transform: Callable, optimiser_kwargs: dict, loader_kwargs: dict, rfls_length: int, log_dir: Optional[str] = None, alpha: Optional[float] = 1.0, min_labelled: Union[float, int, None] = 16, num_classes: Optional[int] = 10, data_augmentation: Optional[Callable] = None, batch_size: Optional[int] = 100, patience: Union[Tuple[int, int], int, None] = (5, 25), lr_patience: Optional[int] = 10, device: Union[str, torch.device, None] = None)[source]

Bases: object

evaluate(data_loader: torch.utils.data.dataloader.DataLoader) → dict[source]
fit(train: torch.utils.data.dataset.Dataset, val: torch.utils.data.dataset.Dataset, pool: torch.utils.data.dataset.Dataset, epochs: Optional[Tuple[int, int]] = (50, 400))[source]

PLUpdater

class alr.training.pl_mixup.PLUpdater(model: torch.nn.modules.module.Module, pool: alr.training.pl_mixup.PseudoLabelledDataset, log_dir: str, num_class: int, device=None)[source]

Bases: object

attach(engine: ignite.engine.engine.Engine)[source]

Functions

mixup

alr.training.pl_mixup.mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0, device: Union[str, torch.device, None] = None)[source]

Returns mixed inputs, pairs of targets, and lambda

reg_nll_loss

alr.training.pl_mixup.reg_nll_loss(coef: Optional[Tuple[float, float]] = (0.8, 0.4))[source]

reg_mixup_loss

alr.training.pl_mixup.reg_mixup_loss(coef: Optional[Tuple[float, float]] = (0.8, 0.4))[source]

onehot_transform

alr.training.pl_mixup.onehot_transform(n)[source]

create_warmup_trainer

alr.training.pl_mixup.create_warmup_trainer(model: torch.nn.modules.module.Module, optimiser, device: Union[str, torch.device, None] = None)[source]

create_plmixup_trainer

alr.training.pl_mixup.create_plmixup_trainer(model, optimiser, pool, alpha, num_classes, log_dir, device)[source]

temp_ds_transform

alr.training.pl_mixup.temp_ds_transform(transform, has_targets: bool = False)