# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# `DistributedSampler` below taken from Detectron 2, licensed under Apache 2.0.
# Changes:
# - Docstring to match the rest of the library
# - Calls to other subroutines which do not exist in DIRECT.
"""Module containing all sampler logic."""
import itertools
import logging
import math
import random
from typing import List, Optional
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
from direct.utils import chunks, communication
[docs]
class DistributedSampler(Sampler):
"""In training, we only care about the "infinite stream" of training data.
So this sampler produces an infinite stream of indices and all workers cooperate to correctly shuffle the indices
and sample different indices. The samplers in each worker effectively produces `indices[worker_id::num_workers]`
where `indices` is an infinite stream of indices consisting of `shuffle(range(size)) + shuffle(range(size)) + ...`
(if shuffle is True) or `range(size) + range(size) + ...` (if shuffle is False)
"""
def __init__(
self,
size: int,
shuffle: bool = True,
seed: Optional[int] = None,
):
"""
Parameters
----------
size: int
Size of underlying dataset.
shuffle: bool
If true, the indices will be shuffled.
seed: int
Initial seed of the shuffle, must be the same across all workers!
"""
super().__init__(data_source=None)
self._size = size
if self._size <= 0:
raise AssertionError
self._shuffle = shuffle
if seed is None:
seed = communication.shared_random_seed()
self._seed = int(seed)
self._rank = communication.get_rank()
self._world_size = communication.get_world_size()
def __iter__(self):
start = self._rank
yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
def _infinite_indices(self):
g = torch.Generator()
g.manual_seed(self._seed)
while True:
if self._shuffle:
yield from torch.randperm(self._size, generator=g)
else:
yield from torch.arange(self._size)
[docs]
class DistributedSequentialSampler(Sampler):
"""Sequential Sampler that restricts data loading to a subset of the dataset.
It is useful during evaluation. It is especially useful when combined with
`torch.nn.parallel.DistributedDataParallel`. Such that each process gets a subpart of the dataset.
Notes
=====
Dataset is assumed to be of constant size.
"""
def __init__(
self,
dataset,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
limit_number_of_volumes: bool = None,
):
super().__init__(dataset)
if num_replicas is None:
num_replicas = communication.get_world_size()
if rank is None:
rank = communication.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
filenames = list(self.dataset.volume_indices.keys()) # This is an OrderedDict
if limit_number_of_volumes:
filenames = filenames[:limit_number_of_volumes]
chunked_filenames = list(chunks(filenames, self.num_replicas))
filenames = chunked_filenames[self.rank]
# Create volume indices for this sampler.
self.volume_indices = {k: self.dataset.volume_indices[k] for k in filenames}
# Collect the indices belonging to these filenames.
self.indices = []
if self.rank < len(chunked_filenames): # Otherwise there is nothing to fill.
for filename in filenames:
self.indices.extend(list(self.dataset.volume_indices[filename]))
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
[docs]
class BatchVolumeSampler(Sampler):
"""Wraps another sampler to yield a mini-batch of indices which all belong to the same volume. This can mean that
some batches have less samples then the requested batch size.
Based on Pytorch 1.5.1 BatchSampler:
https://pytorch.org/docs/1.5.1/_modules/torch/utils/data/sampler.html#BatchSampler
"""
def __init__(self, sampler: Sampler, batch_size: int):
super().__init__(sampler) # type: ignore
if not isinstance(sampler, Sampler):
raise ValueError(f"Sampler should be an instance of torch.utils.data.Sampler, but got sampler={sampler}.")
self.sampler = sampler
self.batch_size = batch_size
# Create a reverse lookup when we need to switch to a new batch
end_of_volume = []
self.__num_batches = 0
for filename in self.sampler.volume_indices: # type: ignore
curr_slice = self.sampler.volume_indices[filename] # type: ignore
end_of_volume.append(curr_slice.stop)
num_indices = curr_slice.stop - curr_slice.start
self.__num_batches += math.ceil(num_indices / batch_size)
self.end_of_volume = iter(end_of_volume[1:])
self._next_value = end_of_volume[0]
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if (len(batch) == self.batch_size) or (idx == self._next_value - 1):
yield batch
batch = []
if idx == self._next_value - 1:
try:
self._next_value = next(self.end_of_volume)
except StopIteration:
pass
if len(batch) > 0:
yield batch
def __len__(self):
return self.__num_batches
[docs]
class ConcatDatasetBatchSampler(Sampler):
"""This sampler takes a ConcatDataset and samples complete batches of one of the underlying datasets randomly based
on the total size of the dataset.
Based on Pytorch 1.5.1 BatchSampler:
https://pytorch.org/docs/1.5.1/_modules/torch/utils/data/sampler.html#BatchSampler
"""
def __init__(self, datasets: List, batch_size: int, seed: Optional[int] = None):
super().__init__(datasets)
self.logger = logging.getLogger(type(self).__name__)
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or batch_size <= 0:
raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}")
self.datasets = datasets
self.seed = seed
self.samplers = [DistributedSampler(len(_), shuffle=True, seed=seed) for _ in datasets]
self.batch_size = batch_size
self.weights = np.asarray([len(_) for _ in datasets])
self.cumulative_sizes = self.cumsum(datasets)
self.logger.info(
"Sampling batches with weights %s with cumulative sizes %s.", self.weights, self.cumulative_sizes
)
self._batch_samplers = [
self.batch_sampler(sampler, 0 if idx == 0 else self.cumulative_sizes[idx - 1])
for idx, sampler in enumerate(self.samplers)
]
[docs]
def batch_sampler(self, sampler, sampler_offset):
batch = []
for batch_idx in sampler:
batch.append(batch_idx + sampler_offset)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch
def __next__(self):
iterator_idx = random.choices(range(len(self.weights)), weights=self.weights / self.weights.sum())[0]
return next(self._batch_samplers[iterator_idx])
def __iter__(self):
return self
[docs]
@staticmethod
def cumsum(sequence):
# From Pytorch 1.5.1: torch.utils.data.dataset.ConcatDataset
r, s = [], 0
for e in sequence:
curr_len = len(e)
r.append(curr_len + s)
s += curr_len
return r
def __len__(self):
# This does not make sense for this sampler.
raise ValueError("length does not make sense for ConcatDatasetBatchSampler.")