alr.training.samplers¶
Useful samplers when training with small datasets
Classses¶
EpochExtender¶
RandomFixedLengthSampler¶
-
class
alr.training.samplers.RandomFixedLengthSampler(dataset: torch.utils.data.dataset.Dataset, length: int, shuffle: Optional[bool] = False)[source]¶ Bases:
torch.utils.data.sampler.SamplerExtends the epoch by sampling with replacement from the provided dataset until length samples are drawn. The number of samples in one epoch is max(length, len(dataset)). In other words, if len(dataset) > length, then this sampler behaves exactly like a SequentialSampler if shuffle is False, and like a RandomSampler if shuffle is True. The random state is affected by numpy’s seed.
Parameters: - dataset (torch.utils.data.Dataset) – dataset object
- length (int) – the target length to achieve.
- shuffle (bool, optional) – shuffle the indices if len(dataset) > length. This parameter is ignored otherwise (default = False). The random state depends on numpy’s RNG.
MinLabelledSampler¶
-
class
alr.training.samplers.MinLabelledSampler(labelled: torch.utils.data.dataset.Dataset, pseudo_labelled: torch.utils.data.dataset.Dataset, batch_size: int, min_labelled: Union[int, float])[source]¶ Bases:
torch.utils.data.sampler.BatchSamplerGiven labelled and pseudo_labelled datasets, returns a batch sampler that always yields exactly min_labelled points from the labelled dataset. If there is not enough points from labelled, then the points are recycled. Note that all the data points are shuffled. Note, the concatenated dataset is assumed to have labelled followed by pseudo_labelled, i.e. torch.utils.data.ConcatDataset((labelled, pseudo_labelled)).
Parameters: - labelled (torch.utils.data.Dataset) – labelled dataset
- pseudo_labelled (torch.utils.data.Dataset) – pseudo_labelled dataset
- batch_size (int) – batch size
- min_labelled (int, float) – min number of points that comes from labelled. Must be smaller than batch_size. If float is provided, then this argument is treated as a proportion.