[docs]classMNIST(VisionDataset):r"""MNIST dataset. The MNIST_ database (Modified National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems. The database is also widely used for training and testing in the field of machine learning. It was created by "re-mixing" the samples from `NIST`_'s original datasets. Furthermore, the black and white images from NIST were normalized to fit into a 28x28 pixel bounding box and anti-aliased, which introduced grayscale levels. The MNIST database contains 60,000 training images and 10,000 testing images. The above introduction comes from `MNIST database - Wikipedia <https://en.wikipedia.org/wiki/MNIST_database>`_. Args: root: Path for MNIST dataset downloading or loading. If it's ``None``, it will be set to ``~/.cache/megengine`` (the default root path). train: If ``True``, use traning dataset; Otherwise use the test set. download: If ``True``, downloads the dataset from the internet and puts it in ``root`` directory. If dataset is already downloaded, it is not downloaded again. Returns: The MNIST :class:`~.Dataset` that can work with :class:`~.DataLoader`. Example: >>> from megengine.data.dataset import MNIST # doctest: +SKIP >>> mnist = MNIST("/data/datasets/MNIST") # Set the root path # doctest: +SKIP >>> image, label = mnist[0] # doctest: +SKIP >>> image.shape # doctest: +SKIP (28, 28, 1) .. versionchanged:: 1.11 The original URL has been updated to a mirror URL *"Please refrain from accessing these files from automated scripts with high frequency. Make copies!"* As requested by the original provider of the MNIST dataset, now the dataset will be downloaded from the mirror site: https://ossci-datasets.s3.amazonaws.com/mnist/ .. seealso:: * MNIST dataset is used in :ref:`megengine-quick-start` tutorial as an example. * You can find a lot of machine learning projects using MNIST dataset on the internet. .. _MNIST: http://yann.lecun.com/exdb/mnist/ .. _NIST: https://www.nist.gov/data """url_path="https://ossci-datasets.s3.amazonaws.com/mnist/"raw_file_name=["train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz","t10k-images-idx3-ubyte.gz","t10k-labels-idx1-ubyte.gz",]raw_file_md5=["f68b3c2dcbeaaa9fbdd348bbdeb94873","d53e105ee54ea40749a09fcbcd1e9432","9fb629c4189551a2d022fa330f9573f3","ec29112dd5afa0611ce80d1b7f02629c",]def__init__(self,root:str=None,train:bool=True,download:bool=True,):super().__init__(root,order=("image","image_category"))# process the root pathifrootisNone:self.root=self._default_rootifnotos.path.exists(self.root):os.makedirs(self.root)else:self.root=rootifnotos.path.exists(self.root):ifdownload:logger.debug("dir %s does not exist, will be automatically created",self.root,)os.makedirs(self.root)else:raiseValueError("dir %s does not exist"%self.root)ifself._check_raw_files():self.process(train)elifdownload:self.download()self.process(train)else:raiseValueError("root does not contain valid raw files, please set download=True")def__getitem__(self,index:int)->Tuple:returntuple(array[index]forarrayinself.arrays)def__len__(self)->int:returnlen(self.arrays[0])@propertydef_default_root(self):returnos.path.join(_default_dataset_root(),self.__class__.__name__)@propertydefmeta(self):returnself._meta_datadef_check_raw_files(self):returnall([os.path.exists(os.path.join(self.root,path))forpathinself.raw_file_name])defdownload(self):forfile_name,md5inzip(self.raw_file_name,self.raw_file_md5):url=self.url_path+file_nameload_raw_data_from_url(url,file_name,md5,self.root)defprocess(self,train):# load raw files and transform them into meta data and datasets Tuple(np.array)logger.info("process the raw files of %s set...","train"iftrainelse"test")iftrain:meta_data_images,images=parse_idx3(os.path.join(self.root,self.raw_file_name[0]))meta_data_labels,labels=parse_idx1(os.path.join(self.root,self.raw_file_name[1]))else:meta_data_images,images=parse_idx3(os.path.join(self.root,self.raw_file_name[2]))meta_data_labels,labels=parse_idx1(os.path.join(self.root,self.raw_file_name[3]))self._meta_data={"images":meta_data_images,"labels":meta_data_labels,}self.arrays=(images,labels.astype(np.int32))
defparse_idx3(idx3_file):# parse idx3 file to meta data and data in numpy array (images)logger.debug("parse idx3 file %s ...",idx3_file)assertidx3_file.endswith(".gz")withgzip.open(idx3_file,"rb")asf:bin_data=f.read()# parse meta dataoffset=0fmt_header=">iiii"magic,imgs,height,width=struct.unpack_from(fmt_header,bin_data,offset)meta_data={"magic":magic,"imgs":imgs,"height":height,"width":width}# parse imagesimage_size=height*widthoffset+=struct.calcsize(fmt_header)fmt_image=">"+str(image_size)+"B"images=[]bar=tqdm(total=meta_data["imgs"],ncols=80)forimageinstruct.iter_unpack(fmt_image,bin_data[offset:]):images.append(np.array(image,dtype=np.uint8).reshape((height,width,1)))bar.update()bar.close()returnmeta_data,imagesdefparse_idx1(idx1_file):# parse idx1 file to meta data and data in numpy array (labels)logger.debug("parse idx1 file %s ...",idx1_file)assertidx1_file.endswith(".gz")withgzip.open(idx1_file,"rb")asf:bin_data=f.read()# parse meta dataoffset=0fmt_header=">ii"magic,imgs=struct.unpack_from(fmt_header,bin_data,offset)meta_data={"magic":magic,"imgs":imgs}# parse labelsoffset+=struct.calcsize(fmt_header)fmt_image=">B"labels=np.empty(imgs,dtype=int)bar=tqdm(total=meta_data["imgs"],ncols=80)fori,labelinenumerate(struct.iter_unpack(fmt_image,bin_data[offset:])):labels[i]=label[0]bar.update()bar.close()returnmeta_data,labels