ArrayDataset¶
- class ArrayDataset(*arrays)[源代码]¶
适用于 Numpy ndarray 数据的 Dataset 类。
需要一个或多个 NumPy 数组来初始化数据集,且表示样本数的维数应当一致。
- 参数
Arrays (dataset and labels) – the datas and labels to be returned iteratively.
- 返回
A set of raw data and corresponding label.
- 返回类型
Tuple
实际案例
from megengine.data.dataset import ArrayDataset from megengine.data.dataloader import DataLoader from megengine.data.sampler import SequentialSampler rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8) label = np.random.randint(0, 10, size=(sample_num,), dtype=int) dataset = ArrayDataset(rand_data, label) seque_sampler = SequentialSampler(dataset, batch_size=2) dataloader = DataLoader( dataset, sampler = seque_sampler, num_workers=3, ) for step, data in enumerate(dataloader): print(data)