megengine.data.dataset.vision.voc 源代码

# -*- coding: utf-8 -*-
# ---------------------------------------------------------------------
# Part of the following code in this file refs to torchvision
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
# ---------------------------------------------------------------------
import collections.abc
import os
import xml.etree.ElementTree as ET

import cv2
import numpy as np

from .meta_vision import VisionDataset


[文档]class PascalVOC(VisionDataset): r"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.""" supported_order = ( "image", "boxes", "boxes_category", "mask", "info", ) def __init__(self, root, image_set, *, order=None): if ("boxes" in order or "boxes_category" in order) and "mask" in order: raise ValueError( "PascalVOC only supports boxes & boxes_category or mask, not both." ) super().__init__(root, order=order, supported_order=self.supported_order) if not os.path.isdir(self.root): raise RuntimeError("Dataset not found or corrupted.") self.image_set = image_set image_dir = os.path.join(self.root, "JPEGImages") if "boxes" in order or "boxes_category" in order: annotation_dir = os.path.join(self.root, "Annotations") splitdet_dir = os.path.join(self.root, "ImageSets/Main") split_f = os.path.join(splitdet_dir, image_set.rstrip("\n") + ".txt") with open(os.path.join(split_f), "r") as f: self.file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] self.annotations = [ os.path.join(annotation_dir, x + ".xml") for x in self.file_names ] assert len(self.images) == len(self.annotations) elif "mask" in order: if "aug" in image_set: mask_dir = os.path.join(self.root, "SegmentationClass_aug") else: mask_dir = os.path.join(self.root, "SegmentationClass") splitmask_dir = os.path.join(self.root, "ImageSets/Segmentation") split_f = os.path.join(splitmask_dir, image_set.rstrip("\n") + ".txt") with open(os.path.join(split_f), "r") as f: self.file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in self.file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in self.file_names] assert len(self.images) == len(self.masks) else: raise NotImplementedError self.img_infos = dict() def __getitem__(self, index): target = [] for k in self.order: if k == "image": image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) target.append(image) elif k == "boxes": anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) boxes = [obj["bndbox"] for obj in anno["annotation"]["object"]] # boxes type xyxy boxes = [ (bb["xmin"], bb["ymin"], bb["xmax"], bb["ymax"]) for bb in boxes ] boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) target.append(boxes) elif k == "boxes_category": anno = self.parse_voc_xml(ET.parse(self.annotations[index]).getroot()) boxes_category = [obj["name"] for obj in anno["annotation"]["object"]] boxes_category = [ self.class_names.index(bc) + 1 for bc in boxes_category ] boxes_category = np.array(boxes_category, dtype=np.int32) target.append(boxes_category) elif k == "mask": if "aug" in self.image_set: mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) else: mask = cv2.imread(self.masks[index], cv2.IMREAD_COLOR) mask = self._trans_mask(mask) mask = mask[:, :, np.newaxis] target.append(mask) elif k == "info": info = self.get_img_info(index, image) info = [info["height"], info["width"], info["file_name"]] target.append(info) else: raise NotImplementedError return tuple(target) def __len__(self): return len(self.images) def get_img_info(self, index, image=None): if index not in self.img_infos: if image is None: image = cv2.imread(self.images[index], cv2.IMREAD_COLOR) self.img_infos[index] = dict( height=image.shape[0], width=image.shape[1], file_name=self.file_names[index], ) return self.img_infos[index] def _trans_mask(self, mask): label = np.ones(mask.shape[:2]) * 255 for i in range(len(self.class_colors)): b, g, r = self.class_colors[i] label[ (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r) ] = i return label.astype(np.uint8) def parse_voc_xml(self, node): voc_dict = {} children = list(node) if children: def_dic = collections.defaultdict(list) for dc in map(self.parse_voc_xml, children): for ind, v in dc.items(): def_dic[ind].append(v) if node.tag == "annotation": def_dic["object"] = [def_dic["object"]] voc_dict = { node.tag: { ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items() } } if node.text: text = node.text.strip() if not children: voc_dict[node.tag] = text return voc_dict class_names = ( "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", ) class_colors = [ [0, 0, 0], # background [0, 0, 128], [0, 128, 0], [0, 128, 128], [128, 0, 0], [128, 0, 128], [128, 128, 0], [128, 128, 128], [0, 0, 64], [0, 0, 192], [0, 128, 64], [0, 128, 192], [128, 0, 64], [128, 0, 192], [128, 128, 64], [128, 128, 192], [0, 64, 0], [0, 64, 128], [0, 192, 0], [0, 192, 128], [128, 64, 0], ]