alr.data.datasets¶
Classses¶
DataDescription¶
Dataset¶
-
class
alr.data.datasets.Dataset[source]¶ Bases:
enum.EnumAn enum class that provides convenient data retrieval.
Example
>>> train, test = Dataset.MNIST.get() >>> train_load = torch.utils.data.DataLoader(train, batch_size=32)
-
CIFAR10= 'CIFAR10'¶
-
CIFAR100= 'CIFAR100'¶
-
CINIC10= 'CINIC10'¶
-
EMNISTBalanced= 'EMNISTBalanced'¶
-
EMNISTMerge= 'EMNISTMerge'¶
-
FashionMNIST= 'FashionMNIST'¶
-
MNIST= 'MNIST'¶
-
RepeatedMNIST= 'RepeatedMNIST'¶
-
about¶ - n_class
- width
- height
- channels
Returns: information about this dataset Return type: DataDescriptionType: Returns information about this dataset including
-
get(root: Optional[str] = 'data', raw: Optional[bool] = False, augmentation: Optional[bool] = False) → Tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset][source]¶ Return (train, test) tuple of datasets.
Parameters: - root (str, optional) – root path where data will be read from or downloaded to
- raw (bool, optional) – if True, then training set will not be transformed (i.e. no normalisation, ToTensor, etc.); note, the test set WILL be transformed.
- augmentation (bool, optional) – whether to add standard augmentation: horizontal flips and random cropping.
Returns: a 2-tuple of (train, test) datasets
Return type: tuple
-
get_augmentation¶
-
get_fixed(root: Optional[str] = 'data', which: Optional[int] = 0, raw: Optional[bool] = False) → Tuple[torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset, torch.utils.data.dataset.Dataset][source]¶ Returns a fixed train, pool, and test datasets. This is only used for experiments.
Parameters: - root (str, optional) – root path where data will be read from or downloaded to.
- which (int, optional) – there are multiple possible sets of fixed points for a given dataset. This argument specifies which of the multiple possible ones to choose from.
- raw (bool, optional) – similar to
get(), train will not contain any transform whatsoever. (Test will still have ToTensor and Normalisation.)
Returns: A tuple of train, pool, and test datasets.
Return type: tuple
-
model¶ Returns a canonical model architecture for a given dataset.
Returns: a pytorch model Return type: torch.nn.Module
-
normalisation_params¶ Returns a tuple of channel mean and standard deviation of 0-1-scaled inputs. I.e. the input is assumed to be in the range of 0-1.
Returns: a 2-tuple of mean and standard deviation Return type: tuple
-
MNISTNet¶
-
class
alr.data.datasets.MNISTNet[source]¶ Bases:
torch.nn.modules.module.Module-
forward(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
CIFAR10Net¶
-
class
alr.data.datasets.CIFAR10Net(num_classes=10, drop_prob=0.5)[source]¶ Bases:
torch.nn.modules.module.ModuleCNN from Mean Teacher paper # taken from: https://github.com/EricArazo/PseudoLabeling/blob/2fbbbd3ca648cae453e3659e2e2ed44f71be5906/utils_pseudoLab/ssl_networks.py
-
forward(x)[source]¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-