# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections.abc
import math
from abc import ABC, abstractmethod
from typing import Any, Generator, Iterator, List, Union
import numpy as np
import megengine.distributed as dist
[文档]class Sampler(ABC):
r"""An abstract base class for all Sampler"""
@abstractmethod
def __init__(self):
pass
[文档]class MapSampler(Sampler):
r"""Sampler for map dataset.
Args:
dataset: dataset to sample from.
batch_size: batch size for batch method.
drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False
num_samples: number of samples assigned to one rank.
world_size: number of ranks.
rank: rank id, non-negative interger within 0 and ``world_size``.
seed: seed for random operators.
"""
def __init__(
self,
dataset,
batch_size=1,
drop_last=False,
num_samples=None,
world_size=None,
rank=None,
seed=None,
):
if (
not isinstance(batch_size, int)
or isinstance(batch_size, bool)
or batch_size <= 0
):
raise ValueError(
"batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size)
)
if not isinstance(drop_last, bool):
raise ValueError(
"drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last)
)
if num_samples is not None and (
not isinstance(num_samples, int)
or isinstance(num_samples, bool)
or num_samples <= 0
):
raise ValueError(
"num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples)
)
self.batch_size = batch_size
self.dataset = dataset
self.drop_last = drop_last
if world_size is None:
world_size = dist.get_world_size() if dist.is_distributed() else 1
self.world_size = world_size
if rank is None:
rank = dist.get_rank() if dist.is_distributed() else 0
self.rank = rank
if num_samples is None:
num_samples = len(self.dataset)
self.num_samples = int(math.ceil(num_samples / self.world_size))
# Make sure seeds are the same at each rank
if seed is None and self.world_size > 1:
seed = 0
self.rng = np.random.RandomState(seed)
def __iter__(self) -> Union[Generator, Iterator]:
return self.batch()
def __len__(self) -> int:
if self.drop_last:
return self.num_samples // self.batch_size
else:
return int(math.ceil(self.num_samples / self.batch_size))
[文档] def sample(self):
r"""Return a list contains all sample indices."""
raise NotImplementedError
[文档] def scatter(self, indices) -> List:
r"""Scatter method is used for splitting indices into subset, each subset
will be assigned to a rank. Indices are evenly splitted by default.
If customized indices assignment method is needed, please rewrite this method.
"""
total_size = self.num_samples * self.world_size
# add extra indices to make it evenly divisible
indices += indices[: (total_size - len(indices))]
assert len(indices) == total_size
# subsample
indices = indices[self.rank : total_size : self.world_size]
assert len(indices) == self.num_samples
return indices
[文档] def batch(self) -> Iterator[List[Any]]:
r"""Batch method provides a batch indices generator."""
indices = list(self.sample())
# user might pass the world_size parameter without dist,
# so dist.is_distributed() should not be used
if self.world_size > 1:
indices = self.scatter(indices)
step, length = self.batch_size, len(indices)
batch_index = [indices[i : i + step] for i in range(0, length, step)]
if self.drop_last and len(batch_index[-1]) < self.batch_size:
batch_index.pop()
return iter(batch_index)
[文档]class StreamSampler(Sampler):
r"""Sampler for stream dataset.
Warning:
In the case of multiple machines, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal.
Usually, :meth:`~.StreamDataset.__iter__` can return different iterator by
``rank = dist.get_rank()``. So that they will get different data.
"""
def __init__(self, batch_size=1):
self.batch_size = batch_size
def __iter__(self):
return self
def __next__(self):
return iter(range(self.batch_size))
[文档]class SequentialSampler(MapSampler):
r"""Sample elements sequentially.
Args:
dataset: dataset to sample from.
batch_size: batch size for batch method.
drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False
indices: indice of samples.
world_size: number of ranks.
rank: rank id, non-negative interger within 0 and ``world_size``.
"""
def __init__(
self,
dataset,
batch_size=1,
drop_last=False,
indices=None,
world_size=None,
rank=None,
):
super().__init__(dataset, batch_size, drop_last, None, world_size, rank)
if indices is not None and not isinstance(indices, collections.abc.Sequence):
raise ValueError(
"indices should be None or a sequence, "
"but got indices={}".format(indices)
)
self.indices = indices
[文档] def sample(self) -> Iterator[Any]:
r"""Return a generator."""
if self.indices is None:
return iter(range(len(self.dataset)))
else:
return self.indices
[文档]class RandomSampler(MapSampler):
r"""Sample elements randomly without replacement.
Args:
dataset: dataset to sample from.
batch_size: batch size for batch method.
drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False
indices: indice of samples.
world_size: number of ranks.
rank: rank id, non-negative interger within 0 and ``world_size``.
seed: seed for random operators.
"""
def __init__(
self,
dataset,
batch_size=1,
drop_last=False,
indices=None,
world_size=None,
rank=None,
seed=None,
):
super().__init__(dataset, batch_size, drop_last, None, world_size, rank, seed)
if indices is not None and not isinstance(indices, collections.abc.Sequence):
raise ValueError(
"indices should be None or a sequence, "
"but got indices={}".format(indices)
)
self.indices = indices
[文档] def sample(self) -> List:
if self.indices is None:
return self.rng.permutation(len(self.dataset)).tolist()
else:
return self.rng.permutation(self.indices).tolist()
[文档]class ReplacementSampler(MapSampler):
r"""Sample elements randomly with replacement.
Args:
dataset: dataset to sample from.
batch_size: batch size for batch method.
drop_last: set ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch_size, then the last batch will
be smaller. Default: False
num_samples: number of samples assigned to one rank.
weights: weights for sampling indices, it could be unnormalized weights.
world_size: number of ranks.
rank: rank id, non-negative interger within 0 and ``world_size``.
seed: seed for random operators.
"""
def __init__(
self,
dataset,
batch_size=1,
drop_last=False,
num_samples=None,
weights=None,
world_size=None,
rank=None,
seed=None,
):
super().__init__(
dataset, batch_size, drop_last, num_samples, world_size, rank, seed
)
if weights is not None:
if not isinstance(weights, collections.abc.Sequence):
raise ValueError(
"weights should be None or a sequence, "
"but got weights={}".format(weights)
)
if len(weights) != len(dataset):
raise ValueError(
"len(dataset)={} should be equal to"
"len(weights)={}".format(len(dataset), len(weights))
)
self.weights = weights
if self.weights is not None:
self.weights = np.array(weights) / sum(weights)
[文档] def sample(self) -> List:
n = len(self.dataset)
if self.weights is None:
return self.rng.randint(n, size=self.num_samples).tolist()
else:
return self.rng.multinomial(n, self.weights, self.num_samples).tolist()
[文档]class Infinite(MapSampler):
r"""Infinite Sampler warper for basic sampler."""
[文档] def sample(self):
raise NotImplementedError("sample method not supported in Infinite")
def __init__(self, sampler):
self.sampler = sampler
self.sampler_iter = iter(self.sampler)
def __iter__(self):
return self
def __next__(self):
try:
index = next(self.sampler_iter)
except StopIteration:
self.sampler_iter = iter(self.sampler)
index = next(self.sampler_iter)
return index
def __len__(self):
return np.iinfo(np.int64).max