r"""
Useful samplers when training with small datasets
"""
import torch.utils.data as torchdata
import numpy as np
from typing import Optional, Union
from itertools import chain
def _safe_ceil(num, denom):
return (num // denom) + ((num % denom) != 0)
[docs]class EpochExtender(torchdata.Sampler):
def __init__(self, dataset: torchdata.Dataset, by: int):
super().__init__(dataset)
assert by >= 1
self._by = by
self._dataset = dataset
def __len__(self):
return len(self._dataset) * self._by
def __iter__(self):
return chain.from_iterable(
np.random.permutation(len(self._dataset)) for _ in range(self._by)
)
[docs]class RandomFixedLengthSampler(torchdata.Sampler):
# Adapted from BatchBALD Redux with modifications
# https://github.com/BlackHC/batchbald_redux/blob/110161db3208d4df1d47146a7ac76a9794d1cab7/batchbald_redux/active_learning.py#L120Args:
def __init__(
self, dataset: torchdata.Dataset, length: int, shuffle: Optional[bool] = False
):
r"""
Extends the epoch by sampling with replacement from the provided dataset
until `length` samples are drawn. The number of samples in
one epoch is `max(length, len(dataset))`.
In other words, if `len(dataset)` > `length`, then this sampler
behaves exactly like a `SequentialSampler` if `shuffle` is `False`, and
like a `RandomSampler` if `shuffle` is `True`. The random state is affected
by numpy's seed.
Args:
dataset (torch.utils.data.Dataset): dataset object
length (int): the target length to achieve.
shuffle (bool, optional): shuffle the indices if `len(dataset)` > `length`. This
parameter is ignored otherwise (default = `False`). The random state depends on
numpy's RNG.
"""
super().__init__(dataset)
assert length > 0, "What are you trying to pull?"
self._dataset = dataset
self._length = length
self._shuffle = shuffle
def __iter__(self):
if self._length > len(self._dataset):
return iter(np.random.permutation(self._length) % len(self._dataset))
else:
if self._shuffle:
return iter(np.random.permutation(len(self._dataset)))
return iter(range(len(self._dataset)))
def __len__(self):
return max(self._length, len(self._dataset))
[docs]class MinLabelledSampler(torchdata.BatchSampler):
def __init__(
self,
labelled: torchdata.Dataset,
pseudo_labelled: torchdata.Dataset,
batch_size: int,
min_labelled: Union[int, float],
):
r"""
Given labelled and pseudo_labelled datasets, returns a batch sampler that always yields
exactly `min_labelled` points from the `labelled` dataset. If there is not enough
points from `labelled`, then the points are recycled. Note that all the data points
are shuffled. Note, the concatenated dataset is assumed to have labelled followed by
pseudo_labelled, i.e. `torch.utils.data.ConcatDataset((labelled, pseudo_labelled))`.
Args:
labelled (torch.utils.data.Dataset): labelled dataset
pseudo_labelled (torch.utils.data.Dataset): pseudo_labelled dataset
batch_size (int): batch size
min_labelled (int, float): min number of points that comes from `labelled`. Must be smaller
than `batch_size`. If `float` is provided, then this argument is treated as a proportion.
"""
min_labelled = (
min_labelled
if type(min_labelled) == int
else round(min_labelled * batch_size + 0.5)
)
assert batch_size > min_labelled
self._labelled = labelled
self._pseudo_labelled = pseudo_labelled
self._batch_size = batch_size
self._min_labelled = min_labelled
self._unlabelled_batch_size = batch_size - min_labelled
# because ignite 0.3.0 :(
self.sampler = None
def __len__(self):
# return round(len(self._pseudo_labelled) / self._unlabelled_batch_size + .5)
# equivalent to the above, but safe from floating point rounding errors
return _safe_ceil(len(self._pseudo_labelled), self._unlabelled_batch_size)
def __iter__(self):
num_unlabelled = self._batch_size - self._min_labelled
labelled_indices = np.random.permutation(len(self) * self._min_labelled) % len(
self._labelled
)
unlabelled_indices = np.random.permutation(len(self._pseudo_labelled))
for i in range(len(self)):
r1 = labelled_indices[i * self._min_labelled : (i + 1) * self._min_labelled]
r2 = unlabelled_indices[
i * num_unlabelled : (i + 1) * num_unlabelled
] + len(self._labelled)
yield np.r_[r1, r2]