alr.data

Classses

UnlabelledDataset

class alr.data.UnlabelledDataset(dataset: torch.utils.data.dataset.Dataset, label_fn: Optional[Callable[[torch.utils.data.dataset.Dataset], torch.utils.data.dataset.Dataset]] = None, debug: Optional[bool] = False)[source]

Bases: torch.utils.data.dataset.Dataset

A wrapper class to manage the unlabelled dataset by providing a simple interface to label() specific points and remove from the underlying dataset. Note that it doesn’t physically remove points from dataset, but rather provide an abstraction over it to logically remove them. Furthermore, if the label_fn is not provided, this class automatically infers that the provided “unlabelled” dataset is, in fact, labelled. This is especially for benchmarking studies!

Parameters:
  • dataset (torch.utils.data.Dataset) – unlabelled dataset
  • (Callable (label_fn) – Dataset \(\rightarrow\) Dataset, optional): a function that takes an unlabelled dataset and returns another dataset that’s fully labelled. If this is not provided, then dataset should be labelled.
  • debug (bool, optional) – Turn debug mode on. If True, then indexing this dataset will return both (x, y) instead of just x; this is useful for research purposes. Note, label_fn must be None otherwise an error will be raised.
convert_idx(idxs: numpy.array) → numpy.array[source]

Given a set of indices relative to the current state of UnlabelledDataset, return the true/absolute index of the original pool dataset. :param idxs: sequence of indices :type idxs: np.array

Returns:absolute index
Return type:np.array
label(idxs: Sequence[int]) → torch.utils.data.dataset.Dataset[source]

Label and return points specified by idxs according to provided label_fn. These labelled points will no longer be part of this dataset. Note, however, that this is just an abstraction and the original provided dataset in the constructor will not be modified. In other words, the dataset will not lose points as a result of being labelled.

Parameters:idxs (Sequence[int]) – indices of points to label
Returns:
a labelled dataset where each
point is specified by idxs and labelled by label_fn.
Return type:torch.utils.data.Dataset
labelled_classes

Return a list of classes that were labelled by the user (label_fn).

Returns:list of classes
Return type:list
labelled_indices

Returns a list of indices that were labelled in the past.

Returns:all the indices that were labelled by label()
Return type:list
reset() → None[source]

Reset to initial state – all labelled points are unlabelled and introduced back into the pool.

Returns:None
Return type:NoneType
true_labels()[source]

When the dataset is index within this context, it returns the label as well. It’s useful for debugging/evaluation purposes. This assumes label_fn was None – which indicates that the dataset came with labels to begin with. If label_fn is not None, then this method does nothing – the indexed datum will only contain features, as usual.

Returns:self
Return type:UnlabelledDataset

DataManager

class alr.data.DataManager(labelled: torch.utils.data.dataset.Dataset, unlabelled: alr.data.UnlabelledDataset, acquisition_fn: alr.acquisition.AcquisitionFunction)[source]

Bases: object

A stateful data manager class

The labelled and unlabelled datasets are updated according to the points acquired by acquire(). acquisition_fn dictates which points should be chosen from the unlabelled pool. Similar to UnlabelledDataset, the original dataset, labelled, will not be modified as this class provides a logical abstraction. Whilst unlabelled is modified, the dataset that it is providing an abstraction over is not (see UnlabelledDataset).

Parameters:
acquire(b: int, transform=None) → Tuple[numpy.array, torch.utils.data.dataset.Dataset][source]

Acquire b points from the unlabelled dataset and adds it to the labelled dataset.

Parameters:
  • b (int) – number of points to acquire at once
  • transform (Callable, optional) – transform the unlabelled dataset before giving it to the acquisition function. The function is expected to take and return a dataset.
Returns:

A tuple consisting of
  1. numpy array with indices that were selected by the acquisition function; and
  2. Subset-type dataset with the b points that were freshly labelled

Return type:

Tuple[np.array, torch.utils.data.Dataset]

Notes

the returned numpy array of indices indexes the original pool dataset. I.e., it’s the “absolute” index relative to the original pool set.

append_to_labelled(dataset: torch.utils.data.dataset.Dataset)[source]

Logically appends given dataset to the labelled dataset. Again, this does not physically modify the provided dataset.

Parameters:dataset (torch.utils.data.Dataset) – dataset object
Returns:None
Return type:NoneType
labelled

The current labelled dataset after considering previous acquisitions.

Returns:labelled dataset
Return type:torch.utils.data.Dataset
n_labelled

Current number of labelled points.

Returns:size of dataset
Return type:int
n_unlabelled

Current number of unlabelled points.

Returns:size of dataset
Return type:int
reset() → None[source]

Resets the state of this data manager. All acquired points are removed from the labelled dataset and added back into the unlabelled dataset.

Returns:None
Return type:NoneType
unlabelled

The current unlabelled dataset after considering previous acquisitions.

Returns:unlabelled dataset
Return type:torch.utils.data.Dataset

PseudoLabelDataset

class alr.data.PseudoLabelDataset(dataset: torch.utils.data.dataset.Dataset, pseudo_labels: Sequence[T_co])[source]

Bases: torch.utils.data.dataset.Dataset

Provides dataset with pseudo-labels. Dataset’s __getitem__ is expected to return x only (i.e. without targets). Use RelabelDataset if dataset is labelled.

Parameters:
  • dataset (torch.utils.data.Dataset) – dataset object
  • pseudo_labels (Sequence) – pseudo-labels

RelabelDataset

class alr.data.RelabelDataset(dataset: torch.utils.data.dataset.Dataset, labels: Sequence[T_co])[source]

Bases: torch.utils.data.dataset.Dataset

Overrides dataset labels. Dataset’s __getitem__ is expected to return (x, y) (i.e. with targets). Use PseudoLabelDataset if dataset in unlabelled.

Parameters:
  • dataset (torch.utils.data.Dataset) – dataset object
  • labels (Sequence) – new labels

TransformedDataset

class alr.data.TransformedDataset(raw_dataset: torch.utils.data.dataset.Dataset, transform: Optional[list] = None, augmentation: Optional[list] = None)[source]

Bases: torch.utils.data.dataset.Dataset

Transforms and augments an untransformed and unaugmented dataset.

Functions

disable_augmentation

alr.data.disable_augmentation(dataset: alr.data.UnlabelledDataset)[source]

When using UnlabelledDataset and TransformedDataset, this function can be used in UnlabelledDataset.acquire()’s transform argument to disable augmentation before scoring with an acquisition function (without modifying dataset).

Parameters:dataset (UnlabelledDataset) – unlabelled dataset to transform. Assumes dataset._dataset is of type TransformedDataset.
Returns:
a new transformed unlabelled dataset that has no augmentation,
leaving the original unlabelled dataset (dataset) untouched.
Return type:UnlabelledDataset