from torch.optim.lr_scheduler import ReduceLROnPlateau
from alr.training.progress_bar.ignite_progress_bar import ProgressBar
from alr.training.samplers import RandomFixedLengthSampler, MinLabelledSampler
import torch.utils.data as torchdata
import torch
from typing import Optional, Tuple, Callable, Union
from torch import nn
from alr.utils._type_aliases import _DeviceType
from alr.training.utils import EarlyStopper, PLPredictionSaver
from alr.utils import _map_device
import numpy as np
from contextlib import contextmanager
from ignite.engine import Engine, Events, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from torch.nn import functional as F
from enum import Enum
class DataMarker(Enum):
PSEUDO_LABELLED = "pseudo_labelled"
LABELLED = "labelled"
class PseudoLabelledDataset(torchdata.Dataset):
class IndexMarker(torchdata.Dataset):
"""
Wraps a regular dataset such that it returns
the data, its index, and a mark when indexed.
This helps the training process identify which instances
are pseudo-labelled and what their indices were so we can
update the next iteration's pseduo-labels.
"""
def __init__(self, dataset: torchdata.Dataset, mark):
self.dataset = dataset
self.mark = mark
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# returns (x, y), idx, mark
return self.dataset[idx], idx, self.mark
def __init__(
self,
dataset: torchdata.Dataset,
mark: DataMarker,
transform: Callable[[torch.Tensor], torch.Tensor],
augmentation: Optional[Callable[[torch.Tensor], torch.Tensor]] = lambda x: x,
target_transform: Optional[Callable] = lambda x: x,
):
self.dataset = PseudoLabelledDataset.IndexMarker(dataset, mark)
self._augmentation = augmentation
self._transform = transform
self._with_metadata = True
self._new_targets = None
self._target_transform = target_transform
self._original_labels = False
self.label_history = []
def __getitem__(self, idx):
(img_raw, target), idx, mark = self.dataset[idx]
# override target
if self._new_targets is not None and (not self._original_labels):
target = self._new_targets[idx]
img_aug = self._augmentation(img_raw)
img_raw, img_aug = map(self._transform, [img_raw, img_aug])
if self._with_metadata:
return img_raw, img_aug, self._target_transform(target), idx, mark
return img_aug, self._target_transform(target)
def __len__(self):
return len(self.dataset)
@contextmanager
def original_labels(self):
if not self._original_labels:
self._original_labels = True
yield self
self._original_labels = False
else:
yield self
@contextmanager
def no_fluff(self):
if self._with_metadata:
self._with_metadata = False
yield self
self._with_metadata = True
else:
yield self
@contextmanager
def no_augmentation(self):
if self._augmentation:
store = self._augmentation
self._augmentation = None
yield self
self._augmentation = store
else:
yield self
def override_targets(self, new_targets: torch.Tensor):
r"""
Overrides the target classes for this dataset.
Args:
new_targets: tensor of shape [N,], where len(self) == N (i.e. your pseudo-labels)
Returns: None
"""
assert new_targets.size(0) == len(self.dataset)
# new_targets = [N x C]
self.label_history.append(new_targets)
self._new_targets = new_targets
@property
def override_accuracy(self):
# computes the accuracy of the psuedo-labels after calling override_targets
assert self._new_targets is not None
correct = 0
for i in range(len(self)):
overridden_target = self._new_targets[i]
original_target = self.dataset[i][0][-1]
correct += overridden_target.argmax(dim=-1).item() == original_target
return correct / len(self)
# from https://github.com/facebookresearch/mixup-cifar10/blob/master/train.py#L119
[docs]def mixup(
x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0, device: _DeviceType = None
):
"""Returns mixed inputs, pairs of targets, and lambda"""
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
index = torch.randperm(batch_size)
if device:
index = index.to(device)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
[docs]def reg_nll_loss(coef: Optional[Tuple[float, float]] = (0.8, 0.4)):
def _reg_nll_loss(pred: torch.Tensor, target: torch.Tensor):
C = target.size(-1)
prob = pred.exp()
# heuristic: empirical mean of mini-batch
prob_avg = prob.mean(dim=0)
# uniform prior
prior = target.new_ones(C) / C
# standard cross entropy loss: H[target, pred]
ce_loss = -torch.mean(torch.sum(target * pred, dim=1))
# prior loss: KL(prior || empirical mean) = sum c=1..C of prior * log[prior/emp. mean]
# note, this is simplified, the full prior loss is:
# sum(prior * log[prior] - prior * log[prob_avg])
# but since the first term is a constant, we drop it.
prior_loss = -torch.sum(prior * torch.log(prob_avg))
# entropy loss: neg. mean of sum c=1..C of p(y=c|x)log[p(y=c|x)]
entropy_loss = -torch.mean(torch.sum(prob * pred, dim=1))
return ce_loss + coef[0] * prior_loss + coef[1] * entropy_loss
return _reg_nll_loss
[docs]def reg_mixup_loss(coef: Optional[Tuple[float, float]] = (0.8, 0.4)):
def _reg_mixup_loss(
pred: torch.Tensor, y1: torch.Tensor, y2: torch.Tensor, lamb: int
):
"""
pred is log_softmax,
y1 and y2 are softmax probabilities
"""
C = y1.size(-1)
assert y2.size(-1) == C
# NxC
prob = pred.exp()
# C
prob_avg = prob.mean(dim=0)
prior = y2.new_ones(C) / C
# term1, term2, [1,]
term1 = -torch.mean(torch.sum(y1 * pred, dim=1))
term2 = -torch.mean(torch.sum(y2 * pred, dim=1))
mixup_loss = lamb * term1 + (1 - lamb) * term2
prior_loss = -torch.sum(prior * torch.log(prob_avg))
entropy_loss = -torch.mean(torch.sum(prob * pred, dim=1))
return mixup_loss + coef[0] * prior_loss + coef[1] * entropy_loss
return _reg_mixup_loss
[docs]def create_warmup_trainer(model: nn.Module, optimiser, device: _DeviceType = None):
def _step(engine: Engine, batch):
model.train()
# prepare batch
x, y = batch
x, y = _map_device([x, y], device)
# predict, loss, optimise
pred = model(x)
loss = reg_nll_loss()(pred, y)
optimiser.zero_grad()
loss.backward()
optimiser.step()
return loss.item()
return Engine(_step)
[docs]class PLMixupTrainer:
def __init__(
self,
model: nn.Module,
optimiser: str,
train_transform: Callable,
test_transform: Callable,
optimiser_kwargs: dict,
loader_kwargs: dict,
rfls_length: int,
log_dir: Optional[str] = None,
alpha: Optional[float] = 1.0,
min_labelled: Optional[Union[int, float]] = 16,
num_classes: Optional[int] = 10,
data_augmentation: Optional[Callable] = None,
batch_size: Optional[int] = 100,
patience: Optional[Union[Tuple[int, int], int]] = (5, 25),
lr_patience: Optional[int] = 10,
device: _DeviceType = None,
):
# for now, assume model returns logsoftmax - ceebs.
self._model = model
self._train_transform = train_transform
self._test_transform = test_transform
self._data_augmentation = data_augmentation
self._optim_kwargs = optimiser_kwargs
self._optimiser = optimiser
self._device = device
self._batch_size = batch_size
self._patience = patience
self._loader_kwargs = loader_kwargs
self._rfls_length = rfls_length
self._min_labelled = min_labelled
self._num_classes = num_classes
self._alpha = alpha
self._lr_patience = lr_patience
self._log_dir = log_dir
self.soft_label_history = None
# like soft_label_history but on augmented input to induce noise
self.soft_augmented_label_history = None
def _instantiate_optimiser(self):
return getattr(torch.optim, self._optimiser)(
self._model.parameters(), **self._optim_kwargs
)
[docs] def fit(
self,
train: torchdata.Dataset,
val: torchdata.Dataset,
pool: torchdata.Dataset,
epochs: Optional[Tuple[int, int]] = (50, 400),
):
if isinstance(self._patience, int):
pat1 = pat2 = self._patience
else:
pat1, pat2 = self._patience[0], self._patience[1]
history = {
"val_loss": [],
"val_acc": [],
"override_acc": [],
}
optimiser = self._instantiate_optimiser()
train = PseudoLabelledDataset(
train,
mark=DataMarker.LABELLED,
transform=self._train_transform,
augmentation=self._data_augmentation,
target_transform=onehot_transform(self._num_classes),
)
pool = PseudoLabelledDataset(
pool,
mark=DataMarker.PSEUDO_LABELLED,
transform=self._train_transform,
augmentation=self._data_augmentation,
)
val = PseudoLabelledDataset(
val,
mark=DataMarker.LABELLED,
transform=self._test_transform,
)
val._with_metadata = False
train_loader = torchdata.DataLoader(
train,
batch_size=self._batch_size,
sampler=RandomFixedLengthSampler(train, self._rfls_length, shuffle=True),
**self._loader_kwargs,
)
pool_loader = torchdata.DataLoader(
pool, batch_size=512, shuffle=False, **self._loader_kwargs
)
val_loader = torchdata.DataLoader(
val, batch_size=512, shuffle=False, **self._loader_kwargs
)
pbar = ProgressBar(desc=lambda _: "Stage 1")
# warm up
with train.no_fluff():
val_eval = create_supervised_evaluator(
self._model,
metrics={"acc": Accuracy(), "loss": Loss(F.nll_loss)},
device=self._device,
)
trainer = create_warmup_trainer(
self._model,
optimiser=optimiser,
device=self._device,
)
es = EarlyStopper(
self._model, patience=pat1, trainer=trainer, key="acc", mode="max"
)
es.attach(val_eval)
@trainer.on(Events.EPOCH_COMPLETED)
def _log(e: Engine):
metrics = val_eval.run(val_loader).metrics
acc, loss = metrics["acc"], metrics["loss"]
pbar.log_message(
f"\tStage 1 epoch {e.state.epoch}/{e.state.max_epochs} "
f"[val] acc, loss = "
f"{acc:.4f}, {loss:.4f}"
)
history["val_acc"].append(acc)
history["val_loss"].append(loss)
pbar.attach(trainer)
trainer.run(train_loader, max_epochs=epochs[0])
es.reload_best()
# pseudo-label points
with pool.no_augmentation():
with pool.no_fluff():
pseudo_labels = []
with torch.no_grad():
self._model.eval()
for x, _ in pool_loader:
x = x.to(self._device)
# model outputs logsoftmax, use .exp() here to get probs
pseudo_labels.append(self._model(x).exp().detach().cpu())
pool.override_targets(torch.cat(pseudo_labels))
plab_acc = pool.override_accuracy
pbar.log_message(f"\t*End of stage 1*: overridden labels' acc: {plab_acc}")
history["override_acc"].append(plab_acc)
# start training with PL
full_dataset = torchdata.ConcatDataset((train, pool))
fds_loader = torchdata.DataLoader(
full_dataset,
batch_sampler=MinLabelledSampler(
train,
pool,
batch_size=self._batch_size,
min_labelled=self._min_labelled,
),
**self._loader_kwargs,
)
val_eval = create_supervised_evaluator(
self._model,
metrics={"acc": Accuracy(), "loss": Loss(F.nll_loss)},
device=self._device,
)
optimiser = self._instantiate_optimiser()
scheduler = ReduceLROnPlateau(
optimiser,
mode="max",
factor=0.1,
patience=self._lr_patience,
verbose=True,
min_lr=1e-3,
)
trainer = create_plmixup_trainer(
self._model,
optimiser,
pool,
alpha=self._alpha,
num_classes=self._num_classes,
log_dir=self._log_dir,
device=self._device,
)
es = EarlyStopper(
self._model, patience=pat2, trainer=trainer, key="acc", mode="max"
)
es.attach(val_eval)
pbar = ProgressBar(desc=lambda _: "Stage 2")
soft_augmented_label_history = []
@trainer.on(Events.EPOCH_COMPLETED)
def _log(e: Engine):
metrics = val_eval.run(val_loader).metrics
acc, loss = metrics["acc"], metrics["loss"]
pbar.log_message(
f"\tEpoch {e.state.epoch}/{e.state.max_epochs} "
f"[val] acc, loss = "
f"{acc:.4f}, {loss:.4f}"
)
history["val_acc"].append(acc)
history["val_loss"].append(loss)
history["override_acc"].append(pool.override_accuracy)
scheduler.step(acc)
with pool.no_fluff():
pseudo_labels = []
with torch.no_grad():
self._model.eval()
for x, _ in pool_loader:
x = x.to(self._device)
# add (softmax) probability, hence .exp()
pseudo_labels.append(self._model(x).exp().detach().cpu())
soft_augmented_label_history.append(torch.cat(pseudo_labels))
pbar.attach(trainer)
trainer.run(fds_loader, max_epochs=epochs[1])
es.reload_best()
soft_label_history = pool.label_history
if trainer.state.epoch != trainer.state.max_epochs:
soft_label_history = soft_label_history[:-pat2]
soft_augmented_label_history = soft_augmented_label_history[:-pat2]
self.soft_label_history = torch.stack(soft_label_history, dim=0)
self.soft_augmented_label_history = torch.stack(
soft_augmented_label_history, dim=0
)
return history
[docs] def evaluate(self, data_loader: torchdata.DataLoader) -> dict:
evaluator = create_supervised_evaluator(
self._model,
metrics={"acc": Accuracy(), "loss": Loss(F.nll_loss)},
device=self._device,
)
return evaluator.run(data_loader).metrics
[docs]def create_plmixup_trainer(model, optimiser, pool, alpha, num_classes, log_dir, device):
def _step(engine: Engine, batch):
model.train()
img_raw, img_aug, target, idx, mark = batch
img_raw, img_aug, target, idx, mark = _map_device(
[img_raw, img_aug, target, idx, mark], device
)
xp, y1, y2, lamb = mixup(img_aug, target, alpha=alpha)
preds = model(xp)
loss = reg_mixup_loss()(preds, y1, y2, lamb)
optimiser.zero_grad()
loss.backward()
optimiser.step()
return loss.item(), img_raw, img_aug, target, idx, mark
e = Engine(_step)
PLUpdater(
model, pool, log_dir=log_dir, num_class=num_classes, device=device
).attach(e)
return e
[docs]class PLUpdater:
def __init__(
self,
model: nn.Module,
pool: PseudoLabelledDataset,
log_dir: str,
num_class: int,
device=None,
):
self._pseudo_labels = torch.empty(size=(len(pool), num_class))
self._model = model
self._pool = pool
self._log_dir = log_dir
self._device = device
self._sanity_check_mask = torch.zeros(len(pool), dtype=torch.bool)
[docs] def attach(self, engine: Engine):
engine.add_event_handler(Events.ITERATION_COMPLETED, self._on_iteration_end)
engine.add_event_handler(Events.EPOCH_COMPLETED, self._on_epoch_end)
def _on_iteration_end(self, engine: Engine):
# after iteration ended
_, img_raw, img_aug, target, idx, mark = engine.state.output
with torch.no_grad():
self._model.eval()
pld_mask = mark == DataMarker.PSEUDO_LABELLED
# unaugmented, raw, pseudo-labelled images
pld_img = img_raw[pld_mask]
# get *softmax* predictions -- exponentiate the output!
new_pld = self._model(pld_img).exp().detach().cpu()
mask = idx[pld_mask]
self._pseudo_labels[mask] = new_pld
self._sanity_check_mask[mask] = 1
def _on_epoch_end(self, engine: Engine):
# sanity check!
# assert that all unlabelled data has been pseudo-labeled in this epoch
assert (
self._sanity_check_mask.all()
), "Some instances in pool are not pseudo-labelled. Something went wrong."
# reset mask
self._sanity_check_mask = torch.zeros(
self._pseudo_labels.size(0), dtype=torch.bool
)
self._pool.override_targets(self._pseudo_labels.clone())
if self._log_dir is not None:
# original pool labels w/o augmentation and metadata from PseudoLabelledDataset
with self._pool.no_augmentation():
with self._pool.no_fluff():
with self._pool.original_labels():
_calib_metrics(
self._model,
self._pool,
self._log_dir,
other_engine=engine,
device=self._device,
)
def _calib_metrics(
model,
ds,
log_dir,
other_engine=None,
device=None,
pred_transform=lambda x: x.exp(),
):
# given a model and dataset, runs one epoch to calculate the calibration metrics
kwargs = (
{} if not torch.cuda.is_available() else dict(num_workers=4, pin_memory=True)
)
loader = torchdata.DataLoader(ds, shuffle=False, batch_size=512, **kwargs)
save_pl_metrics = create_supervised_evaluator(model, metrics=None, device=device)
pps = PLPredictionSaver(
log_dir=log_dir,
pred_transform=pred_transform,
)
pps.attach(save_pl_metrics)
if other_engine is not None:
pps.global_step_from_engine(other_engine)
save_pl_metrics.run(loader)
class _WithTransform(torchdata.Dataset):
def __init__(self, dataset: torchdata.Dataset, transform, has_targets):
super(_WithTransform, self).__init__()
self._dataset = dataset
self._transform = transform
self._has_targets = has_targets
def __getitem__(self, idx):
if self._has_targets:
x, y = self._dataset[idx]
return self._transform(x), y
# (x,) only
return self._transform(self._dataset[idx])
def __len__(self):
return len(self._dataset)
def dataset_transform_functor(transform, has_targets: bool = False):
def _trans(dataset: torchdata.Dataset) -> torchdata.Dataset:
return _WithTransform(dataset, transform, has_targets)
return _trans
# Edited on Sat Jan 2 21:17:41 GMT 2021
# assigning to a new function name for backward-compatibility
# because naming functions is hard.
temp_ds_transform = dataset_transform_functor