# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import pickle
import tarfile
from typing import Tuple
import numpy as np
from ....logger import get_logger
from .meta_vision import VisionDataset
from .utils import _default_dataset_root, load_raw_data_from_url
logger = get_logger(__name__)
[文档]class CIFAR10(VisionDataset):
r""":class:`~.Dataset` for CIFAR10 meta data."""
url_path = "http://www.cs.utoronto.ca/~kriz/"
raw_file_name = "cifar-10-python.tar.gz"
raw_file_md5 = "c58f30108f718f92721af3b95e74349a"
raw_file_dir = "cifar-10-batches-py"
train_batch = [
"data_batch_1",
"data_batch_2",
"data_batch_3",
"data_batch_4",
"data_batch_5",
]
test_batch = ["test_batch"]
meta_info = {"name": "batches.meta"}
def __init__(
self,
root: str = None,
train: bool = True,
download: bool = True,
timeout: int = 500,
):
super().__init__(root, order=("image", "image_category"))
self.timeout = timeout
# process the root path
if root is None:
self.root = self._default_root
if not os.path.exists(self.root):
os.makedirs(self.root)
else:
self.root = root
if not os.path.exists(self.root):
if download:
logger.debug(
"dir %s does not exist, will be automatically created",
self.root,
)
os.makedirs(self.root)
else:
raise ValueError("dir %s does not exist" % self.root)
self.target_file = os.path.join(self.root, self.raw_file_dir)
# check existence of target pickle dir, if exists load the
# pickle file no matter what download is set
if os.path.exists(self.target_file):
if train:
self.arrays = self.bytes2array(self.train_batch)
else:
self.arrays = self.bytes2array(self.test_batch)
else:
if download:
self.download()
if train:
self.arrays = self.bytes2array(self.train_batch)
else:
self.arrays = self.bytes2array(self.test_batch)
else:
raise ValueError(
"dir does not contain target file %s, please set download=True"
% (self.target_file)
)
def __getitem__(self, index: int) -> Tuple:
return tuple(array[index] for array in self.arrays)
def __len__(self) -> int:
return len(self.arrays[0])
@property
def _default_root(self):
return os.path.join(_default_dataset_root(), self.__class__.__name__)
@property
def meta(self):
meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"])
with open(meta_path, "rb") as f:
meta = pickle.load(f, encoding="bytes")
return meta
[文档] def download(self):
url = self.url_path + self.raw_file_name
load_raw_data_from_url(
url, self.raw_file_name, self.raw_file_md5, self.root, self.timeout
)
self.process()
[文档] def untar(self, file_path, dirs):
assert file_path.endswith(".tar.gz")
logger.debug("untar file %s to %s", file_path, dirs)
t = tarfile.open(file_path)
t.extractall(path=dirs)
[文档] def bytes2array(self, filenames):
data = []
label = []
for filename in filenames:
path = os.path.join(self.root, self.raw_file_dir, filename)
logger.debug("unpickle file %s", path)
with open(path, "rb") as fo:
dic = pickle.load(fo, encoding="bytes")
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
data.extend(list(batch_data[..., [2, 1, 0]]))
label.extend(dic[b"labels"])
label = np.array(label, dtype=np.int32)
return (data, label)
[文档] def process(self):
logger.info("process raw data ...")
self.untar(os.path.join(self.root, self.raw_file_name), self.root)
[文档]class CIFAR100(CIFAR10):
r""":class:`~.Dataset` for CIFAR100 meta data."""
url_path = "http://www.cs.utoronto.ca/~kriz/"
raw_file_name = "cifar-100-python.tar.gz"
raw_file_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
raw_file_dir = "cifar-100-python"
train_batch = ["train"]
test_batch = ["test"]
meta_info = {"name": "meta"}
@property
def meta(self):
meta_path = os.path.join(self.root, self.raw_file_dir, self.meta_info["name"])
with open(meta_path, "rb") as f:
meta = pickle.load(f, encoding="bytes")
return meta
[文档] def bytes2array(self, filenames):
data = []
fine_label = []
coarse_label = []
for filename in filenames:
path = os.path.join(self.root, self.raw_file_dir, filename)
logger.debug("unpickle file %s", path)
with open(path, "rb") as fo:
dic = pickle.load(fo, encoding="bytes")
batch_data = dic[b"data"].reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
data.extend(list(batch_data[..., [2, 1, 0]]))
fine_label.extend(dic[b"fine_labels"])
coarse_label.extend(dic[b"coarse_labels"])
fine_label = np.array(fine_label, dtype=np.int32)
coarse_label = np.array(coarse_label, dtype=np.int32)
return data, fine_label, coarse_label