ArrayDataset

class ArrayDataset(*arrays)[源代码]

适用于 Numpy ndarray 数据的 Dataset 类。

需要一个或多个 NumPy 数组来初始化数据集,且表示样本数的维数应当一致。

参数

Arrays (dataset and labels) – the datas and labels to be returned iteratively.

返回

A set of raw data and corresponding label.

返回类型

Tuple

实际案例

from megengine.data.dataset import ArrayDataset
from megengine.data.dataloader import DataLoader
from megengine.data.sampler import SequentialSampler

rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8)
label = np.random.randint(0, 10, size=(sample_num,), dtype=int)
dataset = ArrayDataset(rand_data, label)
seque_sampler = SequentialSampler(dataset, batch_size=2)

dataloader = DataLoader(
    dataset,
    sampler = seque_sampler,
    num_workers=3,
)

for step, data in enumerate(dataloader):
    print(data)