alr.data.datasets

Classses

DataDescription

class alr.data.datasets.DataDescription(n_class: int, width: int, height: int, channels: int)[source]

Bases: object

Describes the attributes of this dataset.

Dataset

class alr.data.datasets.Dataset[source]

Bases: enum.Enum

An enum class that provides convenient data retrieval.

Example

>>> train, test = Dataset.MNIST.get()
>>> train_load = torch.utils.data.DataLoader(train, batch_size=32)
CIFAR10 = 'CIFAR10'
CIFAR100 = 'CIFAR100'
CINIC10 = 'CINIC10'
EMNISTBalanced = 'EMNISTBalanced'
EMNISTMerge = 'EMNISTMerge'
FashionMNIST = 'FashionMNIST'
MNIST = 'MNIST'
RepeatedMNIST = 'RepeatedMNIST'
about
  • n_class
    • width
    • height
    • channels
Returns:information about this dataset
Return type:DataDescription
Type:Returns information about this dataset including
get(root: Optional[str] = 'data', raw: Optional[bool] = False, augmentation: Optional[bool] = False) → Tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset][source]

Return (train, test) tuple of datasets.

Parameters:
  • root (str, optional) – root path where data will be read from or downloaded to
  • raw (bool, optional) – if True, then training set will not be transformed (i.e. no normalisation, ToTensor, etc.); note, the test set WILL be transformed.
  • augmentation (bool, optional) – whether to add standard augmentation: horizontal flips and random cropping.
Returns:

a 2-tuple of (train, test) datasets

Return type:

tuple

get_augmentation
get_fixed(root: Optional[str] = 'data', which: Optional[int] = 0, raw: Optional[bool] = False) → Tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset][source]

Returns a fixed train, pool, and test datasets. This is only used for experiments.

Parameters:
  • root (str, optional) – root path where data will be read from or downloaded to.
  • which (int, optional) – there are multiple possible sets of fixed points for a given dataset. This argument specifies which of the multiple possible ones to choose from.
  • raw (bool, optional) – similar to get(), train will not contain any transform whatsoever. (Test will still have ToTensor and Normalisation.)
Returns:

A tuple of train, pool, and test datasets.

Return type:

tuple

model

Returns a canonical model architecture for a given dataset.

Returns:a pytorch model
Return type:torch.nn.Module
normalisation_params

Returns a tuple of channel mean and standard deviation of 0-1-scaled inputs. I.e. the input is assumed to be in the range of 0-1.

Returns:a 2-tuple of mean and standard deviation
Return type:tuple

MNISTNet

class alr.data.datasets.MNISTNet[source]

Bases: torch.nn.modules.module.Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

CIFAR10Net

class alr.data.datasets.CIFAR10Net(num_classes=10, drop_prob=0.5)[source]

Bases: torch.nn.modules.module.Module

CNN from Mean Teacher paper # taken from: https://github.com/EricArazo/PseudoLabeling/blob/2fbbbd3ca648cae453e3659e2e2ed44f71be5906/utils_pseudoLab/ssl_networks.py

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Functions