import numpy as np
from typing import Optional, Tuple, Callable, Union, List
import torch.utils.data as torchdata
from torch.optim.lr_scheduler import ReduceLROnPlateau
from alr.training.pl_mixup import (
mixup,
reg_mixup_loss,
PseudoLabelledDataset,
onehot_transform,
create_warmup_trainer,
DataMarker,
)
from alr.training.utils import EarlyStopper, PerformanceTracker
from alr.utils._type_aliases import _DeviceType
from alr.training.samplers import RandomFixedLengthSampler, MinLabelledSampler
from alr.utils import _map_device
from torch import nn
from torch.nn import functional as F
import torch
from ignite.engine import create_supervised_evaluator, Events
from ignite.metrics import Accuracy, Loss
from pathlib import Path
[docs]class PLMixupEnsembleTrainer:
def __init__(
self,
models: List[nn.Module],
optimiser: str,
train_transform: Callable,
test_transform: Callable,
optimiser_kwargs: dict,
loader_kwargs: dict,
rfls_length: int,
log_dir: Optional[str] = None,
alpha: Optional[float] = 1.0,
min_labelled: Optional[Union[int, float]] = 16,
num_classes: Optional[int] = 10,
data_augmentation: Optional[Callable] = None,
batch_size: Optional[int] = 100,
patience: Optional[Union[Tuple[int, int], int]] = (5, 25),
lr_patience: Optional[int] = 10,
device: _DeviceType = None,
):
self._models = models
self._train_transform = train_transform
self._test_transform = test_transform
self._data_augmentation = data_augmentation
self._optim_kwargs = optimiser_kwargs
self._optimiser = optimiser
self._device = device
self._batch_size = batch_size
self._patience = patience
self._loader_kwargs = loader_kwargs
self._rfls_length = rfls_length
self._min_labelled = min_labelled
self._num_classes = num_classes
self._alpha = alpha
self._lr_patience = lr_patience
self._log_dir = log_dir
self.soft_label_history = None
def _instantiate_optimiser(self, model: nn.Module):
return getattr(torch.optim, self._optimiser)(
model.parameters(), **self._optim_kwargs
)
[docs] def fit(
self,
train: torchdata.Dataset,
val: torchdata.Dataset,
pool: torchdata.Dataset,
epochs: Optional[Tuple[int, int]] = (50, 400),
):
# stage 1
if isinstance(self._patience, int):
pat1 = pat2 = self._patience
else:
pat1, pat2 = self._patience[0], self._patience[1]
train = PseudoLabelledDataset(
train,
mark=DataMarker.LABELLED,
transform=self._train_transform,
augmentation=self._data_augmentation,
target_transform=onehot_transform(self._num_classes),
)
pool = PseudoLabelledDataset(
pool,
mark=DataMarker.PSEUDO_LABELLED,
transform=self._train_transform,
augmentation=self._data_augmentation,
)
val = PseudoLabelledDataset(
val,
mark=DataMarker.LABELLED,
transform=self._test_transform,
)
val._with_metadata = False
train_loader = torchdata.DataLoader(
train,
batch_size=self._batch_size,
sampler=RandomFixedLengthSampler(train, self._rfls_length, shuffle=True),
**self._loader_kwargs,
)
pool_loader = torchdata.DataLoader(
pool, batch_size=512, shuffle=False, **self._loader_kwargs
)
val_loader = torchdata.DataLoader(
val, batch_size=512, shuffle=False, **self._loader_kwargs
)
models = self._models
optimisers = [self._instantiate_optimiser(m) for m in models]
history = {
"val_loss": [[] for _ in range(len(models))],
"val_acc": [[] for _ in range(len(models))],
"override_acc": [],
}
print("Commencing stage 1 ...")
with train.no_fluff():
for idx, (m, o) in enumerate(zip(models, optimisers)):
print(f"\tTraining model {idx + 1} of {len(models)}")
val_eval = create_supervised_evaluator(
m,
metrics={"acc": Accuracy(), "loss": Loss(F.nll_loss)},
device=self._device,
)
trainer = create_warmup_trainer(
m,
optimiser=o,
device=self._device,
)
es = EarlyStopper(
m, patience=pat1, trainer=trainer, key="acc", mode="max"
)
es.attach(val_eval)
@trainer.on(Events.EPOCH_COMPLETED)
def _log(_):
metrics = val_eval.run(val_loader).metrics
acc, loss = metrics["acc"], metrics["loss"]
history["val_acc"][idx].append(acc)
history["val_loss"][idx].append(loss)
trainer.run(train_loader, max_epochs=epochs[0])
es.reload_best()
print(
f"\tModel {idx + 1} of {len(models)} done, "
f"acc = {max(history['val_acc'][idx])}"
)
# pseudo-label points
plab_acc = self._override_pool_labels(pool, pool_loader)
print(f"End of stage 1: overridden labels' acc: {plab_acc}")
history["override_acc"].append(plab_acc)
# stage 2
full_dataset = torchdata.ConcatDataset((train, pool))
fds_loader = torchdata.DataLoader(
full_dataset,
batch_sampler=MinLabelledSampler(
train,
pool,
batch_size=self._batch_size,
min_labelled=self._min_labelled,
),
**self._loader_kwargs,
)
# reset optimiser
optimisers = [self._instantiate_optimiser(m) for m in models]
schedulers = [
ReduceLROnPlateau(
op,
mode="max",
factor=0.1,
patience=self._lr_patience,
verbose=True,
min_lr=1e-3,
)
for op in optimisers
]
model_tracker = [PerformanceTracker(m, pat2) for m in models]
current_epoch = 0
print("Commencing stage 2 ...")
while any(not mt.done for mt in model_tracker) and current_epoch < epochs[1]:
current_epoch += 1
for idx, (m, o, mt, scheduler) in enumerate(
zip(models, optimisers, model_tracker, schedulers)
):
if mt.done:
continue
# train model m for one epoch
plmixup_train(fds_loader, m, o, alpha=self._alpha, device=self._device)
# get val acc for model m
metrics = (
create_supervised_evaluator(
m,
metrics={"acc": Accuracy(), "loss": Loss(F.nll_loss)},
device=self._device,
)
.run(val_loader)
.metrics
)
acc, loss = metrics["acc"], metrics["loss"]
mt.step(acc)
scheduler.step(acc)
history["val_acc"][idx].append(acc)
history["val_loss"][idx].append(loss)
print(f"\tModel {idx + 1} val acc at epoch {current_epoch} = {acc:.4f}")
# reload best weights if haven't done so
for idx, mt in enumerate(model_tracker):
if mt.done and not mt.reloaded:
print(f"\tModel {idx + 1} converged, reloading weights")
mt.reload_best()
plab_acc = self._override_pool_labels(pool, pool_loader)
history["override_acc"].append(plab_acc)
print(
f"\tEpoch {current_epoch}/{epochs[1]}: "
f"mean val acc = {np.mean([h[-1] for h in history['val_acc']]):.4f}; "
f"pseudo-label acc = {plab_acc:.4f}"
)
# the last element in pool.label_history is the most accurate one to-date:
# all the individual models have (converged and) reloaded their weights
self.soft_label_history = torch.stack(pool.label_history, dim=0)
return history
def _override_pool_labels(self, pool, pool_loader):
ensemble = Ensemble(self._models)
with pool.no_augmentation():
with pool.no_fluff():
pseudo_labels = []
with torch.no_grad():
for x, _ in pool_loader:
x = x.to(self._device)
# NOTE: ensemble's forward call returns softmax probabilities
pseudo_labels.append(ensemble(x).detach().cpu())
pool.override_targets(torch.cat(pseudo_labels))
return pool.override_accuracy
[docs] def evaluate(self, data_loader: torchdata.DataLoader) -> dict:
ensemble = Ensemble(
self._models, return_log=True
) # return_log=True for the loss function
evaluator = create_supervised_evaluator(
ensemble,
metrics={"acc": Accuracy(), "loss": Loss(F.nll_loss)},
device=self._device,
)
return evaluator.run(data_loader).metrics
[docs]class Ensemble:
def __init__(self, models: list, return_log: bool = False):
# assumes models return log-softmax probabilities
self.models = models
self.return_log = return_log
[docs] def forward(self, x):
if self.return_log:
return torch.log(self.get_preds(x).mean(dim=0) + 1e-5)
return self.get_preds(x).mean(dim=0)
[docs] def get_preds(self, x):
preds = []
with torch.no_grad():
for m in self.models:
m.eval()
preds.append(m(x).exp())
return torch.stack(preds)
[docs] def evaluate(self, loader, device):
with torch.no_grad():
correct = 0
total = 0
for x, y in loader:
x, y = x.to(device), y.to(device)
correct += (self.forward(x).argmax(dim=1) == y).sum().item()
total += y.size(0)
return correct / total
[docs] def __call__(self, x):
return self.forward(x)
[docs] def eval(self):
self.train(mode=False)
[docs] def train(self, mode=True):
# should never ever be in training mode
assert not mode
[docs] def save_weights(self, prefix: str):
for mi, m in enumerate(self.models, 1):
torch.save(m.state_dict(), f"{prefix}_model_{mi}.pt")
[docs]def plmixup_train(loader, model, optimiser, alpha, device):
model.train()
for _, img_aug, target, _, _ in loader:
img_aug, target = _map_device([img_aug, target], device)
xp, y1, y2, lamb = mixup(img_aug, target, alpha=alpha)
preds = model(xp)
loss = reg_mixup_loss()(preds, y1, y2, lamb)
optimiser.zero_grad()
loss.backward()
optimiser.step()