"""Batch sampler from MinkLoc method.
Code adopted from repository: https://github.com/jac99/MinkLocMultimodal, MIT License
"""
import logging
from typing import Iterator, List, Optional
import numpy as np
import torch
from numpy.random import default_rng
from torch.utils.data import Sampler
from opr.datasets.base import BasePlaceRecognitionDataset
[docs]
class BatchSampler(Sampler):
"""Sampler returning list of indices to form a mini-batch.
Samples elements in groups consisting of k=2 similar elements (positives)
Batch has the following structure: item1_1, ..., item1_k, item2_1, ... item2_k, itemn_1, ..., itemn_k
"""
# TODO: refactor this class to be more readable
# TODO: separate private members from public members
is_batches_generated: bool = False
def __init__(
self,
dataset: BasePlaceRecognitionDataset,
batch_size: int,
batch_size_limit: Optional[int] = None,
batch_expansion_rate: Optional[float] = None,
max_batches: Optional[int] = None,
positives_per_group: int = 2,
seed: Optional[int] = None,
drop_last: bool = True,
) -> None:
"""Sampler returning list of indices to form a mini-batch.
Note:
The dynamic batch size option is implemented. You can read more about it
in the MinkLoc paper: https://arxiv.org/abs/2011.04530
Args:
dataset (BasePlaceRecognitionDataset): Dataset from which to sample.
batch_size (int): Initial batch size.
batch_size_limit (int, optional): Maximum batch size if dynamic batch sizing
is enabled (see MinkLoc paper for details). Defaults to None.
batch_expansion_rate (float, optional): Batch expansion rate if dynamic batch sizing
is enabled (see MinkLoc paper for details). Defaults to None.
max_batches (int, optional): Maximum number of batches to generate in epoch. If None, then
no limit will be applied. Defaults to None.
positives_per_group (int): Number of positive elements to sample in group. Defaults to 2.
seed (int, optional): Random seed. Defaults to None.
drop_last (bool): If True, the sampler will drop the last batch if its size would be less
than batch_size. Defaults to True.
Raises:
ValueError: If batch_size_limit is not specified when batch_expansion_rate is specified.
ValueError: If batch_expansion_rate is less or equal to 1.0.
ValueError: If batch_size_limit is less or equal to batch_size.
ValueError: If positives_per_group is less than 2.
"""
self.logger = logging.getLogger(self.__class__.__name__)
if batch_expansion_rate is not None:
if batch_size_limit is None:
raise ValueError("batch_size_limit must be specified if batch_expansion_rate is specified")
if batch_expansion_rate <= 1.0:
raise ValueError("batch_expansion_rate must be greater than 1.0")
if batch_size_limit <= batch_size:
raise ValueError("batch_size_limit must be greater than batch_size")
self.batch_size = batch_size
self.batch_size_limit = batch_size_limit
self.batch_expansion_rate = batch_expansion_rate
self.max_batches = max_batches
self.drop_last = drop_last
self.dataset = dataset
if positives_per_group < 2:
raise ValueError("positives_per_group must be greater or equal to 2")
self.positives_per_group = positives_per_group
if self.batch_size < 2 * self.positives_per_group:
self.batch_size = 2 * self.positives_per_group
self.logger.warning(f"Batch too small. Batch size increased to {self.batch_size}.")
elif self.batch_size % self.positives_per_group != 0:
self.batch_size = self.batch_size - (self.batch_size % self.positives_per_group)
self.logger.warning(
"Batch size must be divisible by number of positives per group. "
f"Batch size decreased to {self.batch_size} "
f"(positives_per_group={self.positives_per_group}).",
)
if self.batch_size_limit is not None and (self.batch_size_limit % self.positives_per_group != 0):
self.batch_size_limit = self.batch_size_limit - (self.batch_size_limit % self.positives_per_group)
self.logger.warning(
"Batch size limit must be divisible by number of positives per group. "
f"Batch size limit decreased to {self.batch_size_limit} "
f"(positives_per_group={self.positives_per_group}).",
)
if self.batch_size > self.batch_size_limit:
raise ValueError("batch_size must be less or equal to batch_size_limit")
self.batch_idx: List[List[int]] = [] # Index of elements in each batch (re-generated every epoch)
self.elems_ndx = np.arange(len(self.dataset)) # array of indexes
self.rng = default_rng(seed=seed)
self.generate_batches() # generate initial batches list (to make __len__ work properly)
def __iter__(self) -> Iterator[List[int]]: # noqa: D105
if not self.is_batches_generated:
self.generate_batches() # re-generate batches on every epoch
for batch in self.batch_idx:
yield batch
self.is_batches_generated = False
def __len__(self) -> int: # noqa: D105
return len(self.batch_idx)
[docs]
def expand_batch(self) -> None:
"""Batch expansion method. See MinkLoc paper for details about dynamic batch sizing."""
if self.batch_expansion_rate is None or self.batch_size_limit is None:
self.logger.warning("Dynamic batch sizing is disabled but 'expand_batch' method was called.")
return
if self.batch_size >= self.batch_size_limit:
return
old_batch_size = self.batch_size
self.batch_size = int(self.batch_size * self.batch_expansion_rate)
# ensure that it is still divisible by number of positives per group:
self.batch_size = self.batch_size - (self.batch_size % self.positives_per_group)
# but if batch_expansion_rate is small - we may decrease it back to previous batch size:
if self.batch_size == old_batch_size:
self.batch_size += self.positives_per_group # smallest possible step
# then check if it is smaller than the limit
self.batch_size = min(self.batch_size, self.batch_size_limit)
self.logger.info(f"=> Batch size increased from: {old_batch_size} to {self.batch_size}")
self.generate_batches()
[docs]
def generate_batches(self) -> None: # noqa: C901 # TODO: refactor to reduce complexity
"""Generate training/evaluation batches."""
# batch_idx holds indexes of elements in each batch as a list of lists
self.batch_idx = []
unused_elements_ndx = np.copy(self.elems_ndx)
current_batch: List[int] = []
while True:
if len(current_batch) >= self.batch_size or len(unused_elements_ndx) == 0:
# Flush out batch, when it has a desired size, or a smaller batch, when there's no more
# elements to process
if len(current_batch) >= 2 * self.positives_per_group:
# Ensure there're at least two groups of similar elements, otherwise, it would not be possible
# to find negative examples in the batch
if len(current_batch) % self.positives_per_group != 0:
raise ValueError("Batch size must be divisible by number of positives per group.")
if self.drop_last and len(current_batch) < self.batch_size:
# Drop last batch if it is smaller than batch_size
break
self.batch_idx.append(current_batch)
current_batch = []
if (self.max_batches is not None) and (len(self.batch_idx) >= self.max_batches):
break
if len(unused_elements_ndx) == 0:
break
# Add k similar elements to the batch
selected_element = self.rng.choice(unused_elements_ndx)
unused_elements_ndx = np.delete(
unused_elements_ndx, np.argwhere(unused_elements_ndx == selected_element)
)
positives = self.dataset.positives_index[selected_element].numpy()
if len(positives) < (self.positives_per_group - 1):
# we need at least k-1 positive examples
continue
unused_positives = [e for e in positives if e in unused_elements_ndx]
used_positives = [e for e in positives if e not in unused_elements_ndx]
# If there're unused elements similar to selected_element, sample from them
# otherwise sample from all similar elements
current_batch += [selected_element]
for _ in range(self.positives_per_group - 1):
if len(unused_positives) > 0:
pos_i = self.rng.choice(len(unused_positives))
another_positive = unused_positives.pop(pos_i)
unused_elements_ndx = np.delete(
unused_elements_ndx, np.argwhere(unused_elements_ndx == another_positive)
)
else:
pos_i = self.rng.choice(len(used_positives))
another_positive = used_positives.pop(pos_i)
current_batch += [another_positive]
for batch in self.batch_idx:
if len(batch) % self.positives_per_group != 0:
raise ValueError(f"Incorrect bach size: {len(batch)}")
self.is_batches_generated = True
[docs]
class DistributedBatchSamplerWrapper(Sampler):
"""Wrapper for BatchSampler that supports distributed batch sampling."""
def __init__(
self, sampler: BatchSampler, num_replicas: Optional[int] = None, rank: Optional[int] = None
) -> None:
"""Wrapper for BatchSampler that supports distributed batch sampling.
Args:
sampler (BatchSampler): BatchSampler instance to wrap.
num_replicas (int, optional): Number of processes participating in distributed training.
If None, then torch.distributed.get_world_size() will be used. Defaults to None.
rank (int, optional): Process rank. If None, then torch.distributed.get_rank() will be used.
Defaults to None.
Raises:
ValueError: If sampler has drop_last=False.
RuntimeError: If distributed package is not available.
ValueError: If rank is out of range [0, num_replicas-1].
ValueError: If batch size is not divisible by the number of replicas.
"""
self.sampler = sampler
if not self.sampler.drop_last:
raise ValueError(
"DistributedBatchSamplerWrapper currently requires sampler to have drop_last=True"
)
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
self.num_replicas = num_replicas
self.rank = rank
self.global_batch_size = self.sampler.batch_size
if self.global_batch_size % self.num_replicas != 0:
raise ValueError("Batch size should be divisible by the number of replicas")
self.local_batch_size = self.global_batch_size // self.num_replicas
self.start_end_indices = self.local_batch_size * self.rank, self.local_batch_size * (self.rank + 1)
def __iter__(self) -> Iterator[List[int]]: # noqa: D105
start_idx, end_idx = self.start_end_indices
if not self.sampler.is_batches_generated:
self.sampler.generate_batches() # re-generate batches on every epoch
for batch in self.sampler.batch_idx:
yield batch[start_idx:end_idx]
self.is_batches_generated = False
def __len__(self) -> int: # noqa: D105
return len(self.sampler)