# -*- coding: utf-8 -*-
import gzip
import os
import struct
from typing import Tuple
import numpy as np
from tqdm import tqdm
from ....logger import get_logger
from .meta_vision import VisionDataset
from .utils import _default_dataset_root, load_raw_data_from_url
logger = get_logger(__name__)
[文档]class MNIST(VisionDataset):
r"""MNIST dataset.
The MNIST_ database (Modified National Institute of Standards and Technology database)
is a large database of handwritten digits that is commonly used for training various image processing systems.
The database is also widely used for training and testing in the field of machine learning.
It was created by "re-mixing" the samples from `NIST`_'s original datasets.
Furthermore, the black and white images from NIST were normalized to fit into a 28x28 pixel
bounding box and anti-aliased, which introduced grayscale levels.
The MNIST database contains 60,000 training images and 10,000 testing images.
The above introduction comes from `MNIST database - Wikipedia
<https://en.wikipedia.org/wiki/MNIST_database>`_.
Args:
root: Path for MNIST dataset downloading or loading. If it's ``None``,
it will be set to ``~/.cache/megengine`` (the default root path).
train: If ``True``, use traning dataset; Otherwise use the test set.
download: If ``True``, downloads the dataset from the internet and puts it in ``root`` directory.
If dataset is already downloaded, it is not downloaded again.
Returns:
The MNIST :class:`~.Dataset` that can work with :class:`~.DataLoader`.
Example:
>>> from megengine.data.dataset import MNIST # doctest: +SKIP
>>> mnist = MNIST("/data/datasets/MNIST") # Set the root path # doctest: +SKIP
>>> image, label = mnist[0] # doctest: +SKIP
>>> image.shape # doctest: +SKIP
(28, 28, 1)
.. versionchanged:: 1.11 The original URL has been updated to a mirror URL
*"Please refrain from accessing these files from automated scripts with high frequency. Make copies!"*
As requested by the original provider of the MNIST dataset,
now the dataset will be downloaded from the mirror site:
https://ossci-datasets.s3.amazonaws.com/mnist/
.. seealso::
* MNIST dataset is used in :ref:`megengine-quick-start` tutorial as an example.
* You can find a lot of machine learning projects using MNIST dataset on the internet.
.. _MNIST: http://yann.lecun.com/exdb/mnist/
.. _NIST: https://www.nist.gov/data
"""
url_path = "https://ossci-datasets.s3.amazonaws.com/mnist/"
raw_file_name = [
"train-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz",
]
raw_file_md5 = [
"f68b3c2dcbeaaa9fbdd348bbdeb94873",
"d53e105ee54ea40749a09fcbcd1e9432",
"9fb629c4189551a2d022fa330f9573f3",
"ec29112dd5afa0611ce80d1b7f02629c",
]
def __init__(
self, root: str = None, train: bool = True, download: bool = True,
):
super().__init__(root, order=("image", "image_category"))
# process the root path
if root is None:
self.root = self._default_root
if not os.path.exists(self.root):
os.makedirs(self.root)
else:
self.root = root
if not os.path.exists(self.root):
if download:
logger.debug(
"dir %s does not exist, will be automatically created",
self.root,
)
os.makedirs(self.root)
else:
raise ValueError("dir %s does not exist" % self.root)
if self._check_raw_files():
self.process(train)
elif download:
self.download()
self.process(train)
else:
raise ValueError(
"root does not contain valid raw files, please set download=True"
)
def __getitem__(self, index: int) -> Tuple:
return tuple(array[index] for array in self.arrays)
def __len__(self) -> int:
return len(self.arrays[0])
@property
def _default_root(self):
return os.path.join(_default_dataset_root(), self.__class__.__name__)
@property
def meta(self):
return self._meta_data
def _check_raw_files(self):
return all(
[
os.path.exists(os.path.join(self.root, path))
for path in self.raw_file_name
]
)
def download(self):
for file_name, md5 in zip(self.raw_file_name, self.raw_file_md5):
url = self.url_path + file_name
load_raw_data_from_url(url, file_name, md5, self.root)
def process(self, train):
# load raw files and transform them into meta data and datasets Tuple(np.array)
logger.info("process the raw files of %s set...", "train" if train else "test")
if train:
meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[0])
)
meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[1])
)
else:
meta_data_images, images = parse_idx3(
os.path.join(self.root, self.raw_file_name[2])
)
meta_data_labels, labels = parse_idx1(
os.path.join(self.root, self.raw_file_name[3])
)
self._meta_data = {
"images": meta_data_images,
"labels": meta_data_labels,
}
self.arrays = (images, labels.astype(np.int32))
def parse_idx3(idx3_file):
# parse idx3 file to meta data and data in numpy array (images)
logger.debug("parse idx3 file %s ...", idx3_file)
assert idx3_file.endswith(".gz")
with gzip.open(idx3_file, "rb") as f:
bin_data = f.read()
# parse meta data
offset = 0
fmt_header = ">iiii"
magic, imgs, height, width = struct.unpack_from(fmt_header, bin_data, offset)
meta_data = {"magic": magic, "imgs": imgs, "height": height, "width": width}
# parse images
image_size = height * width
offset += struct.calcsize(fmt_header)
fmt_image = ">" + str(image_size) + "B"
images = []
bar = tqdm(total=meta_data["imgs"], ncols=80)
for image in struct.iter_unpack(fmt_image, bin_data[offset:]):
images.append(np.array(image, dtype=np.uint8).reshape((height, width, 1)))
bar.update()
bar.close()
return meta_data, images
def parse_idx1(idx1_file):
# parse idx1 file to meta data and data in numpy array (labels)
logger.debug("parse idx1 file %s ...", idx1_file)
assert idx1_file.endswith(".gz")
with gzip.open(idx1_file, "rb") as f:
bin_data = f.read()
# parse meta data
offset = 0
fmt_header = ">ii"
magic, imgs = struct.unpack_from(fmt_header, bin_data, offset)
meta_data = {"magic": magic, "imgs": imgs}
# parse labels
offset += struct.calcsize(fmt_header)
fmt_image = ">B"
labels = np.empty(imgs, dtype=int)
bar = tqdm(total=meta_data["imgs"], ncols=80)
for i, label in enumerate(struct.iter_unpack(fmt_image, bin_data[offset:])):
labels[i] = label[0]
bar.update()
bar.close()
return meta_data, labels