# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABC, abstractmethod
from typing import Tuple
[文档]class Dataset(ABC):
r"""An abstract base class for all datasets.
__getitem__ and __len__ method are aditionally needed.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __getitem__(self, index):
pass
@abstractmethod
def __len__(self):
pass
[文档]class StreamDataset(Dataset):
r"""An abstract class for stream data.
__iter__ method is aditionally needed.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __iter__(self):
pass
def __getitem__(self, idx):
raise AssertionError("can not get item from StreamDataset by index")
def __len__(self):
raise AssertionError("StreamDataset does not have length")
[文档]class ArrayDataset(Dataset):
r"""ArrayDataset is a dataset for numpy array data.
One or more numpy arrays are needed to initiate the dataset.
And the dimensions represented sample number are expected to be the same.
"""
def __init__(self, *arrays):
super().__init__()
if not all(len(arrays[0]) == len(array) for array in arrays):
raise ValueError("lengths of input arrays are inconsistent")
self.arrays = arrays
def __getitem__(self, index: int) -> Tuple:
return tuple(array[index] for array in self.arrays)
def __len__(self) -> int:
return len(self.arrays[0])