alr.data¶
Classses¶
UnlabelledDataset¶
-
class
alr.data.UnlabelledDataset(dataset: torch.utils.data.dataset.Dataset, label_fn: Optional[Callable[[torch.utils.data.dataset.Dataset], torch.utils.data.dataset.Dataset]] = None, debug: Optional[bool] = False)[source]¶ Bases:
torch.utils.data.dataset.DatasetA wrapper class to manage the unlabelled dataset by providing a simple interface to
label()specific points and remove from the underlying dataset. Note that it doesn’t physically remove points from dataset, but rather provide an abstraction over it to logically remove them. Furthermore, if the label_fn is not provided, this class automatically infers that the provided “unlabelled” dataset is, in fact, labelled. This is especially for benchmarking studies!Parameters: - dataset (
torch.utils.data.Dataset) – unlabelled dataset - (Callable (label_fn) – Dataset \(\rightarrow\) Dataset, optional): a function that takes an unlabelled dataset and returns another dataset that’s fully labelled. If this is not provided, then dataset should be labelled.
- debug (bool, optional) – Turn debug mode on. If True, then indexing this dataset will return both (x, y) instead of just x; this is useful for research purposes. Note, label_fn must be None otherwise an error will be raised.
-
convert_idx(idxs: numpy.array) → numpy.array[source]¶ Given a set of indices relative to the current state of UnlabelledDataset, return the true/absolute index of the original pool dataset. :param idxs: sequence of indices :type idxs: np.array
Returns: absolute index Return type: np.array
-
label(idxs: Sequence[int]) → torch.utils.data.dataset.Dataset[source]¶ Label and return points specified by idxs according to provided label_fn. These labelled points will no longer be part of this dataset. Note, however, that this is just an abstraction and the original provided dataset in the constructor will not be modified. In other words, the dataset will not lose points as a result of being labelled.
Parameters: idxs (Sequence[int]) – indices of points to label Returns: - a labelled dataset where each
- point is specified by idxs and labelled by label_fn.
Return type: torch.utils.data.Dataset
-
labelled_classes¶ Return a list of classes that were labelled by the user (label_fn).
Returns: list of classes Return type: list
-
labelled_indices¶ Returns a list of indices that were labelled in the past.
Returns: all the indices that were labelled by label()Return type: list
-
reset() → None[source]¶ Reset to initial state – all labelled points are unlabelled and introduced back into the pool.
Returns: None Return type: NoneType
-
true_labels()[source]¶ When the dataset is index within this context, it returns the label as well. It’s useful for debugging/evaluation purposes. This assumes label_fn was None – which indicates that the dataset came with labels to begin with. If label_fn is not None, then this method does nothing – the indexed datum will only contain features, as usual.
Returns: self Return type: UnlabelledDataset
- dataset (
DataManager¶
-
class
alr.data.DataManager(labelled: torch.utils.data.dataset.Dataset, unlabelled: alr.data.UnlabelledDataset, acquisition_fn: alr.acquisition.AcquisitionFunction)[source]¶ Bases:
objectA stateful data manager class
The
labelledandunlabelleddatasets are updated according to the points acquired byacquire(). acquisition_fn dictates which points should be chosen from the unlabelled pool. Similar toUnlabelledDataset, the original dataset, labelled, will not be modified as this class provides a logical abstraction. Whilst unlabelled is modified, the dataset that it is providing an abstraction over is not (seeUnlabelledDataset).Parameters: - labelled (
Dataset) – training data with labelled points - unlabelled (
UnlabelledDataset) – unlabelled pool - acquisition_fn (
AcquisitionFunction) – acquisition function
-
acquire(b: int, transform=None) → Tuple[numpy.array, torch.utils.data.dataset.Dataset][source]¶ Acquire b points from the
unlabelleddataset and adds it to thelabelleddataset.Parameters: - b (int) – number of points to acquire at once
- transform (Callable, optional) – transform the unlabelled dataset before giving it to the acquisition function. The function is expected to take and return a dataset.
Returns: - A tuple consisting of
- numpy array with indices that were selected by the acquisition function; and
- Subset-type dataset with the b points that were freshly labelled
Return type: Tuple[np.array, torch.utils.data.Dataset]
Notes
the returned numpy array of indices indexes the original pool dataset. I.e., it’s the “absolute” index relative to the original pool set.
-
append_to_labelled(dataset: torch.utils.data.dataset.Dataset)[source]¶ Logically appends given dataset to the labelled dataset. Again, this does not physically modify the provided dataset.
Parameters: dataset (torch.utils.data.Dataset) – dataset object Returns: None Return type: NoneType
-
labelled¶ The current labelled dataset after considering previous acquisitions.
Returns: labelled dataset Return type: torch.utils.data.Dataset
-
n_unlabelled¶ Current number of
unlabelledpoints.Returns: size of dataset Return type: int
-
reset() → None[source]¶ Resets the state of this data manager. All acquired points are removed from the
labelleddataset and added back into theunlabelleddataset.Returns: None Return type: NoneType
-
unlabelled¶ The current unlabelled dataset after considering previous acquisitions.
Returns: unlabelled dataset Return type: torch.utils.data.Dataset
- labelled (
PseudoLabelDataset¶
-
class
alr.data.PseudoLabelDataset(dataset: torch.utils.data.dataset.Dataset, pseudo_labels: Sequence[T_co])[source]¶ Bases:
torch.utils.data.dataset.DatasetProvides dataset with pseudo-labels. Dataset’s __getitem__ is expected to return x only (i.e. without targets). Use
RelabelDatasetif dataset is labelled.Parameters: - dataset (torch.utils.data.Dataset) – dataset object
- pseudo_labels (Sequence) – pseudo-labels
RelabelDataset¶
-
class
alr.data.RelabelDataset(dataset: torch.utils.data.dataset.Dataset, labels: Sequence[T_co])[source]¶ Bases:
torch.utils.data.dataset.DatasetOverrides dataset labels. Dataset’s __getitem__ is expected to return (x, y) (i.e. with targets). Use
PseudoLabelDatasetif dataset in unlabelled.Parameters: - dataset (torch.utils.data.Dataset) – dataset object
- labels (Sequence) – new labels
Functions¶
disable_augmentation¶
-
alr.data.disable_augmentation(dataset: alr.data.UnlabelledDataset)[source]¶ When using
UnlabelledDatasetandTransformedDataset, this function can be used inUnlabelledDataset.acquire()’s transform argument to disable augmentation before scoring with an acquisition function (without modifying dataset).Parameters: dataset (UnlabelledDataset) – unlabelled dataset to transform. Assumes dataset._dataset is of type TransformedDataset.Returns: - a new transformed unlabelled dataset that has no augmentation,
- leaving the original unlabelled dataset (dataset) untouched.
Return type: UnlabelledDataset