# -*- coding: utf-8 -*-
# BSD 3-Clause License
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
# ---------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved.
# ---------------------------------------------------------------------
import os
from typing import Dict, List, Tuple
import cv2
import numpy as np
from .meta_vision import VisionDataset
from .utils import is_img
[文档]class ImageFolder(VisionDataset):
r"""ImageFolder is a class for loading image data and labels from a organized folder.
The folder is expected to be organized as followed: root/cls/xxx.img_ext
Labels are indices of sorted classes in the root directory.
Args:
root: root directory of an image folder.
loader: a function used to load image from path,
if ``None``, default function that loads
images with PIL will be called.
check_valid_func: a function used to check if files in folder are
expected image files, if ``None``, default function
that checks file extensions will be called.
class_name: if ``True``, return class name instead of class index.
"""
def __init__(self, root: str, check_valid_func=None, class_name: bool = False):
super().__init__(root, order=("image", "image_category"))
self.root = root
if check_valid_func is not None:
self.check_valid = check_valid_func
else:
self.check_valid = is_img
self.class_name = class_name
self.class_dict = self.collect_class()
self.samples = self.collect_samples()
[文档] def collect_samples(self) -> List:
samples = []
directory = os.path.expanduser(self.root)
for key in sorted(self.class_dict.keys()):
d = os.path.join(directory, key)
if not os.path.isdir(d):
continue
for r, _, filename in sorted(os.walk(d, followlinks=True)):
for name in sorted(filename):
path = os.path.join(r, name)
if self.check_valid(path):
if self.class_name:
samples.append((path, key))
else:
samples.append((path, self.class_dict[key]))
return samples
[文档] def collect_class(self) -> Dict:
classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
classes.sort()
return {classes[i]: np.int32(i) for i in range(len(classes))}
def __getitem__(self, index: int) -> Tuple:
path, label = self.samples[index]
img = cv2.imread(path, cv2.IMREAD_COLOR)
return img, label
def __len__(self):
return len(self.samples)