import warnings
from abc import ABC, abstractmethod
from typing import Optional, Callable, Sequence
import numpy as np
import torch
import torch.distributions as dist
import torch.utils.data as torchdata
from alr.utils._type_aliases import _DeviceType
_BayesianCallable = Callable[[torch.Tensor], torch.Tensor]
def _xlogy(x, y):
res = x * torch.log(y)
res[y == 0] = 0.0
assert torch.isfinite(res).all()
return res
[docs]class AcquisitionFunction(ABC):
"""
A base class for all acquisition functions. All subclasses should
override the `__call__` method.
"""
[docs] @abstractmethod
def __call__(self, X_pool: torchdata.Dataset, b: int) -> np.array:
"""
Given unlabelled data pool `X_pool`, return the best `b`
points for labelling by an oracle, where the best points
are determined by this acquisition function and its parameters.
:param X_pool: Unlabelled dataset
:type X_pool: `torch.utils.data.Dataset`
:param b: number of points to acquire
:type b: int
:return: array of indices to `X_pool`.
:rtype: `np.array`
"""
pass
[docs]class RandomAcquisition(AcquisitionFunction):
"""
Implements random acquisition. Uniformly sample `b` indices.
"""
[docs] def __call__(self, X_pool: torchdata.Dataset, b: int) -> np.array:
return np.random.choice(len(X_pool), b, replace=False)
[docs]class BALD(AcquisitionFunction):
def __init__(
self,
pred_fn: _BayesianCallable,
subset: Optional[int] = -1,
device: _DeviceType = None,
debug: Optional[bool] = False,
**data_loader_params,
):
r"""
Implements `BALD <https://arxiv.org/abs/1112.5745>`_.
.. math::
\begin{align}
-\sum_c\left(\frac{1}{T}\sum_t\hat{p}^t_c \right)
log \left( \frac{1}{T}\sum_t\hat{p}^t_c \right) +
\frac{1}{T}\sum_{c,t}\hat{p}^t_c log \hat{p}^t_c
\end{align}
where :math:`\hat{p}^t_c` is the softmax output of class :math:`c`
on the :math:`t^{th}` stochastic iteration.
.. code:: python
model = MCDropout(...)
bald = BALD(eval_fwd_exp(model), subset=-1, device=device,
batch_size=512, pin_memory=True,
num_workers=2)
bald(X_pool, b=10)
:param pred_fn: A callable that returns a tensor of shape :math:`K \times N \times C` where
:math:`K` is the number of inference samples,
:math:`N` is the number of instances,
and :math:`C` is the number of classes.
**This function should return probabilities, not *log* probabilities!**
:type pred_fn: `Callable`
:param subset: Size of the subset of `X_pool`. Use -1 to denote the entire pool.
:type subset: int, optional
:param device: Move data to specified device when passing input data into `pred_fn`.
:type device: `None`, `str`, `torch.device`
:param debug: Save additional information to recent_score (requires more space).
:type debug: `bool`, optional
:param data_loader_params: params to be passed into `DataLoader` when
iterating over `X_pool`.
.. warning::
Do not set `shuffle=True` in `data_loader_params`! The indices will be
incorrect if the `DataLoader` object shuffles `X_pool`!
"""
self._pred_fn = pred_fn
self._device = device
self._subset = subset
self._dl_params = data_loader_params
# store recent scores
self.recent_score = None
self._debug = debug
assert not self._dl_params.get("shuffle", False)
[docs] def __call__(self, X_pool: torchdata.Dataset, b: int) -> np.array:
pool_size = len(X_pool)
idxs = np.arange(pool_size)
if self._subset != -1:
r = min(self._subset, pool_size)
assert b <= r, "Can't acquire more points that pool size"
if b == r:
return idxs
idxs = np.random.choice(pool_size, r, replace=False)
X_pool = torchdata.Subset(X_pool, idxs)
dl = torchdata.DataLoader(X_pool, **self._dl_params)
with torch.no_grad():
mc_preds: torch.Tensor = torch.cat(
[self._pred_fn(x.to(self._device) if self._device else x) for x in dl],
dim=1,
)
mc_preds = mc_preds.double()
assert mc_preds.size()[1] == pool_size
mean_mc_preds = mc_preds.mean(dim=0)
H = -(_xlogy(mean_mc_preds, mean_mc_preds)).sum(dim=1)
E = (_xlogy(mc_preds, mc_preds)).sum(dim=2).mean(dim=0)
I = (H + E).cpu()
assert torch.isfinite(I).all()
assert I.shape == (pool_size,)
result = torch.argsort(I, descending=True).numpy()
if self._debug:
confidence, argmax = mean_mc_preds.max(dim=1)
confidence, argmax = confidence.cpu().numpy(), argmax.cpu().numpy()
self.recent_score = {
"average_entropy": -E.cpu().numpy(),
"predictive_entropy": H.cpu().numpy(),
"bald_score": I.numpy(),
"confidence": confidence,
"class": argmax,
}
else:
self.recent_score = I.numpy()
return idxs[result[:b]]
[docs]class ICAL(AcquisitionFunction):
def __init__(
self,
pred_fn: _BayesianCallable,
kernel_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
subset: Optional[int] = 200,
greedy_acquire: Optional[int] = 1,
use_one_hot: Optional[bool] = True,
sample_softmax: Optional[bool] = True,
device: _DeviceType = None,
**data_loader_params,
):
r"""
Implements 'normal' `ICAL <https://arxiv.org/abs/2002.07916>`_. :math:`R` points
are randomly drawn from the pool and the average of the candidate batch's kernels
is used instead. Thus, the dependency measure reduces to :math:`d = 2`.
.. math::
\frac{1}{|\mathcal{R}|} d\text{HSIC}(\displaystyle\sum_{x'\in\mathcal{R}} k^{x'},
\frac{1}{B} \displaystyle\sum_{i = 1}^{B} k^{x_i})
.. code:: python
model = MCDropout(...)
ical = ICAL(eval_fwd_exp(model), device=device,
batch_size=512,
pin_memory=True, num_workers=2)
ical(X_pool, b=10)
:param pred_fn: A callable that returns a tensor of shape :math:`K \times N \times C` where
:math:`K` is the number of inference samples,
:math:`N` is the number of instances,
and :math:`C` is the number of classes.
**This function should return probabilities, not *log* probabilities!**
:type pred_fn: `Callable`
:param kernel_fn: Kernel function, see static methods of :class:`ICAL`. Defaults to
weighted a rational quadratic kernel. This is the default kernel in the paper.
:type kernel_fn: Callable[[torch.Tensor], torch.Tensor]], optional
:param subset: Normal ICAL uses a subset of `X_pool`. `subset` specifies the
size of this subset (:math:`|\mathcal{R}|` in the paper).
Use -1 to denote the entire pool.
:type subset: int, optional
:param greedy_acquire: how many points to acquire at once in each acquisition step.
:type greedy_acquire: int, optional
:param use_one_hot: use one_hot_encoding when calculating kernel matrix. This is the
default behaviour in the paper.
:type use_one_hot: bool, optional
:param sample_softmax: sample the softmax probabilities. If this is `True`, then
`use_one_hot` is automatically overriden to be `True`. This is the default
behaviour in the paper.
:type sample_softmax: bool, optional
:param device: Move data to specified device when passing input data into `pred_fn`.
:type device: `None`, `str`, `torch.device`
:param data_loader_params: params to be passed into `DataLoader` when
iterating over `X_pool`.
.. warning::
Do not set `shuffle=True` in `data_loader_params`! The indices will be
incorrect if the `DataLoader` object shuffles `X_pool`!
"""
self._r = subset
self._pred_fn = pred_fn
self._dl_params = data_loader_params
self._device = device
self._l = greedy_acquire
self._use_oh = True if sample_softmax else use_one_hot
self._sample_softmax = sample_softmax
if kernel_fn is None:
self._kernel = ICAL.rational_quadratic()
else:
self._kernel = kernel_fn
assert not self._dl_params.get("shuffle", False)
assert subset != 0
[docs] def __call__(self, X_pool: torchdata.Dataset, b: int) -> np.array:
l = self._l
pool_size = len(X_pool)
r = self._r if self._r != -1 else pool_size
dl = torchdata.DataLoader(X_pool, **self._dl_params)
with torch.no_grad():
mc_preds = torch.cat(
[self._pred_fn(x.to(self._device) if self._device else x) for x in dl],
dim=1,
)
mc_preds = mc_preds.detach_()
n_forward, pool_size, C = mc_preds.size()
if self._sample_softmax:
assert self._use_oh
cat_dist = dist.categorical.Categorical(
mc_preds.view(n_forward * pool_size, -1)
)
# mc_preds is now a vector of sampled class idx
mc_preds = cat_dist.sample([1])[0]
assert mc_preds.size() == (n_forward * pool_size,)
if self._use_oh:
if not self._sample_softmax:
mc_preds = mc_preds.view(n_forward * pool_size, -1).argmax(dim=-1)
assert mc_preds.size() == (n_forward * pool_size,)
mc_preds = torch.eye(C)[mc_preds].view( # shape [N * B x C]
n_forward, pool_size, C
) # shape [N x B x C]
assert mc_preds.size() == (n_forward, pool_size, C)
kernel_matrices = self._kernel(mc_preds)
assert kernel_matrices.size() == (n_forward, n_forward, pool_size)
# [Pool_size x N x N]
kernel_matrices = kernel_matrices.permute(2, 0, 1)
# indices of points current in batch (a possible maximum of b by the
# end of the iteration)
batch_idxs = []
while len(batch_idxs) < b:
# always re-sample subset (what if we don't?)
random_subset = np.random.choice(pool_size, size=r, replace=False)
# a la theorem 2 - it suggested sum but we're using mean here - shouldn't make a difference
pool_kernel = kernel_matrices[random_subset].mean(0) # [N x N]
# normal ICAL uses average batch kernels
batch_kernels = (
kernel_matrices + kernel_matrices[batch_idxs].sum(0, keepdim=True)
) / (
len(batch_idxs) + 1
) # [Pool_size x N x N]
scores = self._dHSIC(
torch.cat(
[
# TODO: can remove repeat?: potentially expensive!
pool_kernel.unsqueeze(0)
.repeat(batch_kernels.size(0), 1, 1)
.unsqueeze(-1),
batch_kernels.unsqueeze(-1),
],
dim=-1,
) # [Pool_size x N x N x 2]
)
assert scores.size() == (pool_size,)
# mask chosen indices
scores[batch_idxs] = -np.inf
# greedily take top l scores
idxs = torch.argsort(scores, descending=True)
for idx in idxs[:l]:
batch_idxs.append(idx.item())
# greedily taking top l might sometimes acquire extra points if
# b is not divisible by l, hence, truncate the output
return np.array(batch_idxs[:b])
[docs] @staticmethod
def rational_quadratic(
alphas: Optional[Sequence[float]] = (0.2, 0.5, 1, 2, 5),
weights: Optional[Sequence[float]] = None,
) -> Callable:
def _rational_quadratic(x: torch.Tensor) -> torch.Tensor:
"""
:param x: tensor of shape [N x M x C]
:return: tensor of shape [N x N x M]
"""
N, M, _ = x.size()
_alphas = x.new_tensor(alphas).view(-1, 1, 1, 1)
if weights:
_weights = x.new_tensor(weights)
else:
_weights = x.new_tensor(1.0 / _alphas.size(0)).repeat(_alphas.size(0))
assert _weights.size(0) == _alphas.size(0)
distances = (x.unsqueeze(0) - x.unsqueeze(1)).pow_(2).sum(-1)
assert distances.size() == (N, N, M)
distances = distances.unsqueeze_(0) # 1 N N M
# TODO: is logspace really necessary?
log = torch.log1p(distances / (2 * _alphas))
assert torch.isfinite(log).all()
res = torch.einsum("w,wijk->ijk", _weights, torch.exp(-_alphas * log))
assert torch.isfinite(res).all()
return res
return _rational_quadratic
@staticmethod
def _dHSIC(x: torch.Tensor) -> torch.Tensor:
r"""
Computes HSIC for d-variables in a batch of size :math:`K`.
.. note::
While the values are computed in logspace for numerical stability,
the returned value is casted back to its original space.
:param x: tensor of shape :math:`K \times N \times N \times D` where:
* :math:`K` is the batch size
* :math:`N` is the number of samples in each variable
* :math:`D` is the number of variables
:return: dHSIC scores, a tensor of shape :math:`K`.
"""
K, N, N2, D = x.size()
assert N == N2
# trivial case, definition 2.6 https://arxiv.org/pdf/1603.00285.pdf
if N < 2 * D:
warnings.warn(
f"The number of samples is lesser than twice "
f"the number of variables in dHISC. Trivial "
f"case of 0; this may or may not be intended."
)
return x.new_zeros(size=(K,))
# https://github.com/NiklasPfister/dHSIC/blob/master/dHSIC/R/dhsic.R
# logspace
x = torch.log(x)
logn = np.log(N)
term1 = torch.sum(x, dim=-1).logsumexp(dim=(1, 2)) - 2 * logn
term2 = torch.logsumexp(x, dim=(1, 2)).sum(dim=-1) - (2 * D * logn)
term3 = (
torch.logsumexp(x, dim=1).sum(dim=-1).logsumexp(dim=-1)
+ np.log(2)
- (D + 1) * logn
)
assert term1.size() == term2.size() == term3.size() == (K,)
# subtract max for numerical stabilisation
term_max = torch.stack([term1, term2, term3], dim=0).max(dim=0)[0]
assert term_max.size() == (K,)
res = (
(term1 - term_max).exp_()
+ (term2 - term_max).exp_()
- (term3 - term_max).exp_()
)
res *= term_max.exp_()
assert torch.isfinite(res).all()
return res
def _bald_score(pred_fn, dataloader, device):
# for research debugging only
with torch.no_grad():
mc_preds: torch.Tensor = torch.cat(
[pred_fn(x.to(device) if device else x) for x, _ in dataloader], dim=1
)
mc_preds = mc_preds.double()
mean_mc_preds = mc_preds.mean(dim=0)
H = -(_xlogy(mean_mc_preds, mean_mc_preds)).sum(dim=1)
E = (_xlogy(mc_preds, mc_preds)).sum(dim=2).mean(dim=0)
I = (H + E).cpu()
assert torch.isfinite(I).all()
return I.numpy()
[docs]class BatchBALD(AcquisitionFunction):
def __init__(
self,
pred_fn: _BayesianCallable,
device: _DeviceType = None,
num_samples: int = 10_000,
**data_loader_params,
):
self._pred_fn = pred_fn
self._device = device
self._dl_params = data_loader_params
self._num_samples = num_samples
# store recent scores
self.recent_score = None
assert not self._dl_params.get("shuffle", False)
[docs] def __call__(self, X_pool: torchdata.Dataset, b: int) -> np.array:
from batchbald_redux.batchbald import get_batchbald_batch
dl = torchdata.DataLoader(X_pool, **self._dl_params)
with torch.no_grad():
mc_preds_K_N_C: torch.Tensor = torch.cat(
[self._pred_fn(x.to(self._device) if self._device else x) for x in dl],
dim=1,
)
mc_preds_N_K_C = mc_preds_K_N_C.double().permute((1, 0, 2))
assert mc_preds_N_K_C.size()[0] == len(X_pool)
candidate_batch = get_batchbald_batch(
mc_preds_N_K_C,
batch_size=b,
num_samples=self._num_samples,
dtype=torch.double,
device=self._device,
)
self.recent_score = candidate_batch.scores
return np.array(candidate_batch.indices)