使用 Dataset 定义数据集#

这个世界上的数据集五花八门,并且总是以不同的格式(比如 png, HDF5, npy 等)分布在不同地方, 或许一些数据集的存储和组织形式已经成为了参考标准,但并不是所有数据一开始都用 MegEngine 所支持格式的进行存储, 很多时候需要我们写脚本借助一些库或框架对原始数据进行处理,并且转换成 MegEngine 中可用的数据集对象。

如何创建一个 Tensor 小节中,我们提到了 ndarray 是 Python 数据科学社区中较为通用支持的格式, 因此 MegEngine 中的数据集相关操作均以 NumPy 的 ndarray 作为处理对象 (此时假设用户已经通过一些途径将原始数据转换成了 ndarray 格式,能够用作后续数据集封装), 整个过程中数据格式不会自动转换,因此需要注意:后续如果要进行 Tensor 计算时,则需要 将 NumPy ndarray 转化为 MegEngine Tensor

参见

通过搜索引擎搜索类似 “如何用 NumPy 加载 xxx 类型数据?” 等问题会很有帮助。

备注

  • DataLoader 初始化时必须提供 dataset 参数,通过传入一个数据集对象,告知如何加载每个数据;

  • MegEngine 中可以 使用已经实现的数据集接口 (如 PascalVOC, ImageNet 等) 替用户完成一些主流数据集的获取、切分等工作。但一些时候这些实现不能满足需求,或者我们需要使用自己采集和标注好的数据集, 因此在使用 DataLoader 之前,通常需要将要用到的数据集人为地封装。

数据集类型#

根据样本的访问方式, MegEngine 中的数据集类型可分为 Map-style 和 Iterable-style 两种:

类型

Map-style

Iterable-style

抽象基类 [1]

Dataset / ArrayDataset

StreamDataset

访问方式

支持随机访问

仅能顺序迭代

适用情景

Numpy 数组、字典、磁盘文件 [2]

生成器、来自网络的流数据

Map-style#

Dataset (也可被称为 MapDataset

MegEngine 中所有数据集的抽象基类。 对应数据集类型为 Map-style, 即表示从索引/键到数据样例的映射,具有随机访问能力。 例如使用 dataset[idx] 可以从磁盘上的文件夹中读取到第 idx 个图像及其相应的标签。 使用时需要实现 __getitem__()__len__() 协议。

下面的代码展示了如何生成一个由 0 到 5 这五个数组成的数据集(不带标签):

from megengine.data.dataset import Dataset

class CustomMapDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)
>>> data = list(range(0, 5))
>>> map_dataset = CustomMapDataset(data)
>>> print(len(map_dataset))
>>> print(map_dataset[2])
5
2

警告

  • 请注意,为了避免在加载大型数据集时试图一次性将数据加载到内存而导致 OOM(Out Of Memory), 我们建议将实际的数据读取操作实现在 __getitem__ 方法中,而不是 __init__ 方法中, 后者仅记录映射关系中的索引/键内容(可能是文件名或路径组成的列表),这可以极大程度地减少内存占用。 具体的例子可参考 举例:加载图像分类数据

  • 如果数据规模已经大到了无法将索引等 meta 信息加载进内存,则需要考虑流式获取方式;

  • 但情况并不总是唯一的。如果我们的数据集规模比较小,可以常驻在内存中, 那么就可以考虑在初始化对象时就加载好整个数据集,减少反复从硬盘或其它位置读取数据到内存的次数。 例如在不同的 Epoch 中,同一个样本会被用来训练多次,此时从内存中直接读取会更加高效。

ArrayDataset

Dataset 类的进一步封装,适用于 NumPy ndarray 数据,无需实现 __getitem__()__len__() 协议。

下面的代码展示了如何生成随机一个具有 100 个样本,每张样本为 32 x 32 像素的 RGB 图片的数据集(标签为随机值) 这也是我们在处理图像时经常遇到的 (N, C, H, W) 格式:

import numpy as np
from megengine.data.dataset import ArrayDataset

data = np.random.random((100, 3, 32, 32))
target = np.random.random((100, 1))
dataset = ArrayDataset(data, target)
>>> print(len(dataset))
>>> print(type(dataset[0]), len(dataset[0]))
>>> print(dataset[0][0].shape)
100
<class 'tuple'> 2
(3, 32, 32)
ConcatDataset

由多个数据集组成的数据集。 这个数据集用于将多个映射式(map-style)组合为一个新的数据集。

下面的代码展示了如何将多个 ArrayDataset 组合成一个 ConcatDataset 数据集:

import numpy as np
from megengine.data.dataset import ArrayDataset, ConcatDataset

data1 = np.random.randint(0, 255, size=(100, 3, 32, 32), dtype=np.uint8)
data2 = np.random.randint(0, 255, size=(100, 3, 32, 32), dtype=np.uint8)
label1 = np.random.randint(0, 10, size=(100,), dtype=int)
label2 = np.random.randint(0, 10, size=(100,), dtype=int)

dataset1 = ArrayDataset(data1, label1)
dataset2 = ArrayDataset(data2, label2)
dataset = ConcatDataset([dataset1, dataset2])
>>> print(len(dataset))
>>> print(type(dataset[0]), len(dataset[0]))
>>> print(dataset[0][0].shape)
200
<class 'tuple'> 2
(3, 32, 32)

Iterable-style#

StreamDataset (也可被称为 IterableDataset

Iterable-style 数据集,适用于流式数据,即迭代式地访问数据, 例如使用 iter(dataset) 可以返回从数据库、远程服务器甚至实时生成的日志中读取的数据流, DataLoader 将使用 next 来不断获取数据。

这种类型的数据集特别适用于:

  • 随机读取成本过高,或者数据规模太大,无法支持随机访问的情况;

  • 批量大小情况实际取决于获取数据的情况,即必须根据流数据才能判断当前批是否已经完整。

使用时需要实现 __iter__() 协议。

下面的代码展示了如何生成一个由 0 到 5 这五个数组成的数据集(不带标签):

from megengine.data.dataset import StreamDataset

class CustomIterableDataset(StreamDataset):
    def __init__(self, data):
        self.data = data

    def __iter__(self):
        return iter(self.data)
>>> data = list(range(0, 5))
>>> iter_dataset = CustomIterableDataset(data)
>>> it = iter(iter_dataset)
>>> print(type(it))
list_iterator
>>> print(next(it))
0
>>> print(next(it))
1

显然,流式数据集不支持获取长度以及通过索引值进行随机访问:

>>> iterable_dataset[0]
AssertionError: can not get item from StreamDataset by index
>>> len(iterable_dataset)
AssertionError: StreamDataset does not have length

这个例子无法体现 StreamDataset 加载的真实需求情景,但方便和 MapDataset 进行对比。

为何设计流式数据集#

Map-style

>>> for data in map_dataset:
...     print(data)
0
1
2
3
4

Iterable-style

>>> for data in iter_dataset:
...     print(data)
0
1
2
3
4

根据上面的例子可以发现,使用相同的原始 List 数据来生成两种类型的数据集并迭代访问, Map-style 和 Iterable-style 的数据集都能够返回相同的结果,那么区别在哪里呢? 从高层视角看,每次 DataLoader 从 Map-style 数据集中返回批数据时, 它都会先对数据索引进行采样得到一批索引 idx, 并使用 map_dataset[idx] 获得批数据. 相反,对于 Iterable-style 数据集,DataLoader 不断地调用 next(it) 来顺序获取下一个数据, 直到它获取到一个完整的批次。这也是为什么我们说 Iterable-style 数据集更适合将数据提供给顺序模型。

参见

参考 使用 Sampler 定义抽样规则 ,了解如何从样本容量为 N 的数据集中得到长度为 B 的一批索引。

使用已经实现的数据集接口#

dataset 子模块中,除了提供了一些抽象基类待用户自定义子类进行实现, 还提供了一些基于主流数据集封装好的接口,比如常被用于教学和练习用途的 MNIST 数据集:

>>> from megengine.data.dataset import MNIST
>>> train_set = MNIST(root="path/to/data/", train=True, download=False)
>>> test_set = MNIST(root="path/to/data/", train=False, download=False)

借助于封装好的接口,我们可以快速的获取 MNIST 数据集的训练集 train_set 和测试集 test_set , 其中 download 参数可以控制是否要从数据集官方提供的地址进行下载。更多细节请参考 API 文档。

警告

这些数据集都是从它们自己的官方站点进行下载的,MegEngine 不提供镜像或加速服务。

备注

  • 一些数据集由于许可协议中的规定将不提供原始数据的下载接口(如 ImageNet ),需手动下载;

  • 下载速度受到网络环境和带宽的影响,用户也可以选择使用其它的脚本或工具下载原始数据集;

  • 这些数据集接口源码是非常不错的参考,对于帮助用户学习如何设计数据集接口会很有帮助。

如何添加新的数据集#

目前 MegEngine 中提供了一些常见的主流数据集接口,也欢迎用户为我们提供更多的接口实现。

但目前我们还没有提供明确的设计规格和要求,因此建议用户先尝试与官方维护人员进行交流。