megengine.data.dataloader 源代码

# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import gc
import math
import multiprocessing
import platform
import queue
import random
import threading
import time
from typing import Callable, Union

import numpy as np

from ..device import _sh, get_default_device
from ..functional.tensor import copy
from ..logger import get_logger
from ..random.rng import _random_seed_generator
from ..tensor import Tensor
from .collator import Collator
from .dataset import Dataset, StreamDataset
from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform

try:
    import thread
except:
    import _thread as thread


logger = get_logger(__name__)


GLOBAL_TIMEOUT = 5


def raise_timeout_error():
    raise RuntimeError("dataloader timeout")


[文档]class DataLoader: r"""Provides a convenient way to iterate on a given dataset. DataLoader combines a dataset with :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`, make it flexible to get minibatch continually from a dataset. Args: dataset: dataset from which to load the minibatch. sampler: defines the strategy to sample data from the dataset. transform: defined the transforming strategy for a sampled batch. Default: None collator: defined the merging strategy for a transformed batch. Default: None num_workers: the number of sub-process to load, transform and collate the batch. ``0`` means using single-process. Default: 0 timeout: if positive, means the timeout value(second) for collecting a batch from workers. Default: 0 timeout_event: callback function triggered by timeout, default to raise runtime error. divide: define the paralleling strategy in multi-processing mode. ``True`` means one batch is divided into :attr:`num_workers` pieces, and the workers will process these pieces parallelly. ``False`` means different sub-process will process different batch. Default: False preload: Defines whether to apply the preloading strategy of dataloader, and parallelize the copy of host2device while kernal is executed to improve the loading speed. default is seted False the output will change from np.ndarry to dtype tensor. the support dtypes for preload are int,float,list[int,float],tuple[int,float],and another type is not supported. """ __initialized = False def __init__( self, dataset: Dataset, sampler: Sampler = None, transform: Transform = None, collator: Collator = None, num_workers: int = 0, timeout: int = 0, timeout_event: Callable = raise_timeout_error, divide: bool = False, preload: bool = False, ): if num_workers < 0: raise ValueError("num_workers should not be negative") if timeout < 0: raise ValueError("timeout should not be negative") if divide and num_workers <= 1: raise ValueError("divide should not be set to True when num_workers <= 1") self.dataset = dataset self.num_workers = num_workers self.timeout = timeout self.timeout_event = timeout_event self.divide = divide self.preload = preload if isinstance(dataset, StreamDataset): self.sampler = sampler if sampler else StreamSampler(batch_size=1) assert isinstance( self.sampler, StreamSampler ), "types of dataset and sampler do not match" else: assert isinstance( dataset, Dataset ), "Can not recognize this kind of dataset: %s" % type(dataset) self.sampler = ( sampler if sampler else SequentialSampler(dataset, batch_size=1, drop_last=False) ) assert isinstance( self.sampler, MapSampler ), "types of dataset and sampler do not match" if divide: if self.sampler.batch_size <= self.num_workers: raise ValueError( "batch size must not smaller than num_workers in divide mode." ) elif self.sampler.batch_size % self.num_workers: logger.warning( "batch size is not divisible by num_workers, may lose performance in divide mode." ) if transform is None: self.transform = PseudoTransform() else: self.transform = transform if collator is None: self.collator = Collator() else: self.collator = collator self.__initialized = True def __iter__(self): if platform.system() == "Windows" and self.num_workers > 0: print( "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero" ) self.num_workers = 0 if isinstance(self.dataset, StreamDataset): if not self.num_workers: return _SerialStreamDataLoaderIter(self, self.preload) else: return _ParallelStreamDataLoaderIter(self, self.preload) else: assert isinstance( self.dataset, Dataset ), "Can not recognize this kind of dataset: %s" % type(self.dataset) if not self.num_workers: return _SerialMapDataLoaderIter(self, self.preload) else: return _ParallelMapDataLoaderIter(self, self.preload) def __len__(self): return len(self.sampler)
class PreLoader: def __init__(self, preload): if preload: self.default_device = get_default_device() self.pre_load_device = self.default_device + ":" + str(_sh.get_next()) self.pre_load_device_cache = None self.preload = preload """ strategy one: load from numpy data, and generate dtype tensor """ def _load_tensor(self, batch, cached=True): if isinstance(batch, np.ndarray): device = self.pre_load_device if cached else self.default_device return Tensor(batch, device=device) elif isinstance(batch, collections.abc.Mapping): return {k: self._load_tensor(v, cached) for k, v in batch.items()} elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple return type(batch)(*(self._load_tensor(value, cached) for value in batch)) elif isinstance(batch, collections.abc.Sequence): return [self._load_tensor(value, cached) for value in batch] else: return batch """ strategy two: load from cache that is already tensor just do d2d copy """ def _load_cache(self, data): if isinstance(data, Tensor): if data.device == self.default_device: return data return copy(data, device=self.default_device) elif isinstance(data, collections.abc.Mapping): return {k: self._load_cache(v) for k, v in data.items()} elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple return type(data)(*(self._load_cache(value) for value in data)) elif isinstance(data, collections.abc.Sequence): return [self._load_cache(value) for value in data] else: return data def _swap_out_cache(self): out = self._load_cache(self.pre_load_device_cache) self.pre_load_device_cache = None # clean cache return out class _BaseMapDataLoaderIter(PreLoader): def __init__(self, loader, preload): super().__init__(preload) self.dataset = loader.dataset self.sampler = loader.sampler self.seed = _random_seed_generator().__next__() self.transform = loader.transform self.collator = loader.collator self.num_workers = loader.num_workers self.timeout = loader.timeout self.timeout_event = loader.timeout_event self.divide = loader.divide self.num_processed = 0 def _get_next_batch(self): raise NotImplementedError def __len__(self): return len(self.sampler) def __iter__(self): return self def __next__(self): if self.preload: cached = self.pre_load_device_cache if cached is None: # first and last if self.num_processed >= len(self): # last raise StopIteration elif self.num_processed == 0: # first self._try_load_tensor(cached=False) # first do the h2d out = self._swap_out_cache() self._try_load_tensor() return out else: if self.num_processed >= len(self): raise StopIteration minibatch = self._get_next_batch() self.num_processed += 1 return minibatch def _try_load_tensor(self, cached=True): if self.num_processed >= len(self): return else: self.num_processed += 1 batch = self._get_next_batch() self.pre_load_device_cache = self._load_tensor(batch, cached) class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): def __init__(self, loader, preload): super(_SerialMapDataLoaderIter, self).__init__(loader, preload) self.indices_iter = iter(self.sampler) def _get_next_batch(self): indices = next(self.indices_iter) items = [self.dataset[idx] for idx in indices] trans_items = self.transform.apply_batch(items) return self.collator.apply(trans_items) class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): __initialized = False def __init__(self, loader, preload): super(_ParallelMapDataLoaderIter, self).__init__(loader, preload) self.task_queues = [ multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) ] self.feed_batch_idx = multiprocessing.Value("i", 0) self.target_batch_idx = multiprocessing.Value("i", 0) self.shutdown_flag = multiprocessing.Value("i", 0) self.trans_data_queues = [ multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) ] # use shared-memory queue implemented by pyarrow plasma store. from .tools._queue import PlasmaShmQueue self.batch_queue = PlasmaShmQueue(maxsize=2) self.task_feeding_worker = multiprocessing.Process( target=_task_feeding_loop, args=( iter(self.sampler), self.task_queues, self.num_workers, self.divide, self.shutdown_flag, self.feed_batch_idx, ), daemon=True, ) gc.collect() self.task_feeding_worker.start() self.workers = [] for worker_id in range(self.num_workers): worker = multiprocessing.Process( target=_worker_loop, args=( self.dataset, self.task_queues[worker_id], self.trans_data_queues[worker_id], self.transform, self.seed + worker_id + 1, self.shutdown_flag, ), daemon=True, ) gc.collect() worker.start() self.workers.append(worker) if self.divide: self.data_collecting_worker = multiprocessing.Process( target=_data_gathering_loop, args=( self.trans_data_queues, self.batch_queue, self.collator, len(self), self.num_workers, self.shutdown_flag, self.target_batch_idx, ), daemon=True, ) else: self.data_collecting_worker = multiprocessing.Process( target=_data_selecting_loop, args=( self.trans_data_queues, self.batch_queue, self.collator, len(self), self.num_workers, self.shutdown_flag, self.target_batch_idx, ), daemon=True, ) gc.collect() self.data_collecting_worker.start() self.__initialized = True def _check_workers(self): # Check the status of each worker. if not self.data_collecting_worker.is_alive(): exitcode = self.data_collecting_worker.exitcode if exitcode != 0: raise RuntimeError("data collecting worker died. {}".format(exitcode)) if not self.task_feeding_worker.is_alive(): exitcode = self.task_feeding_worker.exitcode if exitcode != 0: raise RuntimeError("task feeding worker died. {}".format(exitcode)) for worker_id, worker in enumerate(self.workers): if not worker.is_alive(): exitcode = worker.exitcode if exitcode != 0: raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode)) logger.debug("all workers are alive.") def _get_next_batch(self): start_time = time.time() while True: self._check_workers() try: return self.batch_queue.get(timeout=1) except queue.Empty: logger.debug("batch queue empty!") waited_time = time.time() - start_time if self.timeout > 0: if waited_time > self.timeout: raise RuntimeError("get_next_batch timeout!") def _shutdown(self): with self.shutdown_flag.get_lock(): self.shutdown_flag.value = 1 if self.task_feeding_worker.is_alive(): self.task_feeding_worker.terminate() self.task_feeding_worker.join() if self.data_collecting_worker.is_alive(): self.data_collecting_worker.terminate() self.data_collecting_worker.join() for worker in self.workers: if worker.is_alive(): worker.terminate() worker.join() for q in self.trans_data_queues: q.cancel_join_thread() q.close() for q in self.task_queues: q.cancel_join_thread() q.close() self.batch_queue.cancel_join_thread() self.batch_queue.close() def __del__(self): if self.__initialized: self._shutdown() class _BaseStreamDataLoaderIter(PreLoader): def __init__(self, loader, preload): super().__init__(preload) self.dataset = loader.dataset self.sampler = loader.sampler self.transform = loader.transform self.collator = loader.collator self.num_workers = loader.num_workers self.timeout = loader.timeout self.timeout_event = loader.timeout_event def _get_next_batch(self): raise NotImplementedError def _process_raw_data(self, raw_data): assert len(raw_data) == 2 and isinstance( raw_data[0], bool ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched." if not raw_data[0]: data = list((x,) for x in raw_data[1]) else: data = raw_data[1] ret = [] for idx in range(len(data[0])): ret.append(tuple(e[idx] for e in data)) return ret def __iter__(self): return self def __next__(self): if self.preload: if self.pre_load_device_cache is None: self._try_load_tensor(cached=False) # load in current out = self._swap_out_cache() self._try_load_tensor() # load in cached return out else: return self._get_next_batch() def _try_load_tensor(self, cached=True): batch = self._get_next_batch() self.pre_load_device_cache = self._load_tensor(batch, cached) class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): def __init__(self, loader, preload): super().__init__(loader, preload) self.dataset_iter = iter(self.dataset) self.idx = 0 self.unused = [] def _try_get_raw_data(self, start_time): raw_data = None while not raw_data: try: if self.timeout > 0: timer = threading.Timer(self.timeout, thread.interrupt_main) timer.start() raw_data = next(self.dataset_iter) if self.timeout > 0: timer.cancel() except KeyboardInterrupt: raw_data = self.timeout_event() except: if self.timeout > 0: timer.cancel() waited_time = time.time() - start_time if waited_time > self.timeout: raw_data = self.timeout_event() return raw_data def _get_next_batch(self): ret = [] start_time = time.time() while len(ret) < self.sampler.batch_size: if len(self.unused) != 0: batch_data = self.unused else: raw_data = self._try_get_raw_data(start_time) batch_data = self._process_raw_data(raw_data) while len(batch_data) != 0 and len(ret) < self.sampler.batch_size: data = batch_data.pop() ret.append(self.transform.apply(data)) self.unused = batch_data return self.collator.apply(ret) class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): __initialized = False def __init__(self, loader, preload): super().__init__(loader, preload) self.shutdown_flag = multiprocessing.Value("i", 0) self.raw_data_queues = [ multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) ] self.trans_data_queues = [ multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers) ] # shared-memory queue implemented by pyarrow plasma store from .tools._queue import PlasmaShmQueue self.batch_queue = PlasmaShmQueue(maxsize=2) self.recieve_worker = multiprocessing.Process( target=self._worker_to_raw_data_queues, daemon=True ) gc.collect() self.recieve_worker.start() self.transform_workers = [] for worker_id in range(self.num_workers): worker = multiprocessing.Process( target=self._worker_to_trans_data_queues, args=(worker_id,), daemon=True ) gc.collect() worker.start() self.transform_workers.append(worker) self.collect_worker = multiprocessing.Process( target=self._worker_to_batch_queue, daemon=True ) gc.collect() self.collect_worker.start() self.__initialized = True def _put_raw_data_queues(self, raw_data, qidx): batch_data = self._process_raw_data(raw_data) for data in batch_data: while True: qidx = qidx % self.num_workers try: self.raw_data_queues[qidx].put(data) break except queue.Full: if self.shutdown_flag.value == 1: break logger.debug("raw data queue %d is full" % qidx) finally: qidx += 1 return qidx def _worker_to_raw_data_queues(self): dataset_iter = iter(self.dataset) qidx = 0 while True: if self.shutdown_flag.value == 1: break raw_data = next(dataset_iter) qidx = self._put_raw_data_queues(raw_data, qidx) def _worker_to_trans_data_queues(self, worker_id): while True: if self.shutdown_flag.value == 1: break try: data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT) except queue.Empty: continue trans_data = self.transform.apply(data) while True: try: self.trans_data_queues[worker_id].put(trans_data) break except queue.Full: if self.shutdown_flag.value == 1: break logger.debug("batch queue if full") def _worker_to_batch_queue(self): cnt = -1 trans_items = [] while True: if self.shutdown_flag.value == 1: break cnt += 1 queue_id = cnt % self.num_workers try: trans_item = self.trans_data_queues[queue_id].get( timeout=GLOBAL_TIMEOUT ) except queue.Empty: continue trans_items.append(trans_item) if len(trans_items) == self.sampler.batch_size: batch_data = self.collator.apply(trans_items) while True: try: self.batch_queue.put(batch_data, timeout=1) break except queue.Full: if self.shutdown_flag.value == 1: break logger.debug("batch queue is full") trans_items = [] def _check_workers(self): if not self.collect_worker.is_alive(): exitcode = self.collect_worker.exitcode if exitcode != 0: raise RuntimeError("collator worker died. {}".format(exitcode)) for worker_id, worker in enumerate(self.transform_workers): if not worker.is_alive(): exitcode = worker.exitcode if exitcode != 0: raise RuntimeError( "worker: {} died. {}".format(worker_id, exitcode) ) def _get_next_batch(self): start_time = time.time() while True: self._check_workers() try: return self.batch_queue.get(timeout=1) except queue.Empty: logger.debug("batch queue empty!") waited_time = time.time() - start_time if self.timeout > 0 and waited_time > self.timeout: self._put_raw_data_queues(self.timeout_event(), 0) def _shutdown(self): with self.shutdown_flag.get_lock(): self.shutdown_flag.value = 1 if self.recieve_worker.is_alive(): self.recieve_worker.terminate() self.recieve_worker.join() if self.collect_worker.is_alive(): self.collect_worker.terminate() self.collect_worker.join() for worker in self.transform_workers: if worker.is_alive(): worker.terminate() worker.join() for q in self.raw_data_queues: q.cancel_join_thread() q.close() for q in self.trans_data_queues: q.cancel_join_thread() q.close() self.batch_queue.cancel_join_thread() self.batch_queue.close() def __del__(self): if self.__initialized: self._shutdown() def _task_feeding_loop( indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx ): # Feed the indices into the task queues while True: if shutdown_flag.value == 1: break batch_idx = feed_batch_idx.value try: indices = next(indices_iter) except StopIteration: break if divide: # make sure all task_queues is ready for put while any([q.full() for q in task_queues]): if shutdown_flag.value == 1: return # divide into small pieces, feed to different workers. sub_num = math.ceil(len(indices) / num_workers) for worker_id in range(num_workers): sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num] task_queues[worker_id].put((batch_idx, sub_indices)) else: # distribute tasks to different workers uniformly. target_id = batch_idx % num_workers while task_queues[target_id].full(): if shutdown_flag.value == 1: return task_queues[target_id].put((batch_idx, indices)) with feed_batch_idx.get_lock(): feed_batch_idx.value += 1 def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag): # Get dataset items and do the transform random.seed(seed) np.random.seed(seed) while True: if shutdown_flag.value == 1: break try: batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT) except queue.Empty: continue if len(indices) > 0: items = [dataset[idx] for idx in indices] trans_items = transform.apply_batch(items) else: # in case of incomplete last batch trans_items = () while True: try: trans_data_queue.put((batch_idx, trans_items), timeout=1) break except queue.Full: if shutdown_flag.value == 1: break logger.debug("batch part queue is full!") def _data_gathering_loop( trans_data_queues, batch_queue, collator, length, num_workers, shutdown_flag, target_idx, ): # Gathering the small pieces of batch data into full batch data while True: if shutdown_flag.value == 1: break target_batch_idx = target_idx.value if target_batch_idx >= length: break full_trans_items = [] for worker_id in range(num_workers): while True: try: batch_idx, trans_items = trans_data_queues[worker_id].get( timeout=GLOBAL_TIMEOUT ) break except queue.Empty: if shutdown_flag.value == 1: break logger.debug( "worker:{} data queue get timeout! target batch idx:{}".format( worker_id, target_batch_idx ) ) if batch_idx != target_batch_idx: raise RuntimeError( "Unexperted batch_idx in data gathering loop. worker_id:{}.".format( worker_id ) ) else: full_trans_items.extend(trans_items) # Merge different parts into a batch. full_batch = collator.apply(full_trans_items) while True: try: batch_queue.put(full_batch, timeout=1) break except queue.Full: if shutdown_flag.value == 1: break logger.debug("batch queue is full!") with target_idx.get_lock(): target_idx.value += 1 batch_queue.disconnect_client() def _data_selecting_loop( trans_data_queues, batch_queue, collator, length, num_workers, shutdown_flag, target_idx, ): # Make sure that batch is generated exactly with the same order as generated indices while True: if shutdown_flag.value == 1: break target_batch_idx = target_idx.value if target_batch_idx >= length: break target_worker_id = target_batch_idx % num_workers while True: try: batch_idx, trans_items = trans_data_queues[target_worker_id].get( timeout=GLOBAL_TIMEOUT ) batch_data = collator.apply(trans_items) break except queue.Empty: if shutdown_flag.value == 1: break logger.debug( "worker:{} data queue get timeout! target batch idx:{}".format( target_worker_id, target_batch_idx ) ) if batch_idx != target_batch_idx: raise RuntimeError( "batch_idx {} mismatch the target_batch_idx {}".format( batch_idx, target_batch_idx ) ) while True: try: batch_queue.put(batch_data, timeout=1) break except queue.Full: if shutdown_flag.value == 1: break logger.debug("batch queue is full!") with target_idx.get_lock(): target_idx.value += 1 batch_queue.disconnect_client()