megengine.data.dataset.vision.coco 源代码

# -*- coding: utf-8 -*-
# ---------------------------------------------------------------------
# Part of the following code in this file refs to maskrcnn-benchmark
# MIT License
#
# Copyright (c) 2018 Facebook
# ---------------------------------------------------------------------
import json
import os
from collections import defaultdict

import cv2
import numpy as np

from .meta_vision import VisionDataset

min_keypoints_per_image = 10


def _count_visible_keypoints(anno):
    return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)


def has_valid_annotation(anno, order):
    # if it"s empty, there is no annotation
    if len(anno) == 0:
        return False
    if "boxes" in order or "boxes_category" in order:
        if "bbox" not in anno[0]:
            return False
    if "keypoints" in order:
        if "keypoints" not in anno[0]:
            return False
        # for keypoint detection tasks, only consider valid images those
        # containing at least min_keypoints_per_image
        if _count_visible_keypoints(anno) < min_keypoints_per_image:
            return False
    return True


[文档]class COCO(VisionDataset): r"""`MS COCO <http://cocodataset.org/#home>`_ Dataset.""" supported_order = ( "image", "boxes", "boxes_category", "keypoints", # TODO: need to check # "polygons", "info", ) def __init__( self, root, ann_file, remove_images_without_annotations=False, *, order=None ): super().__init__(root, order=order, supported_order=self.supported_order) with open(ann_file, "r") as f: dataset = json.load(f) self.imgs = dict() for img in dataset["images"]: # for saving memory if "license" in img: del img["license"] if "coco_url" in img: del img["coco_url"] if "date_captured" in img: del img["date_captured"] if "flickr_url" in img: del img["flickr_url"] self.imgs[img["id"]] = img self.img_to_anns = defaultdict(list) for ann in dataset["annotations"]: # for saving memory if ( "boxes" not in self.order and "boxes_category" not in self.order and "bbox" in ann ): del ann["bbox"] if "polygons" not in self.order and "segmentation" in ann: del ann["segmentation"] self.img_to_anns[ann["image_id"]].append(ann) self.cats = dict() for cat in dataset["categories"]: self.cats[cat["id"]] = cat self.ids = list(sorted(self.imgs.keys())) # filter images without detection annotations if remove_images_without_annotations: ids = [] for img_id in self.ids: anno = self.img_to_anns[img_id] # filter crowd annotations anno = [obj for obj in anno if obj["iscrowd"] == 0] anno = [ obj for obj in anno if obj["bbox"][2] > 0 and obj["bbox"][3] > 0 ] if has_valid_annotation(anno, order): ids.append(img_id) self.img_to_anns[img_id] = anno else: del self.imgs[img_id] del self.img_to_anns[img_id] self.ids = ids self.json_category_id_to_contiguous_id = { v: i + 1 for i, v in enumerate(sorted(self.cats.keys())) } self.contiguous_category_id_to_json_id = { v: k for k, v in self.json_category_id_to_contiguous_id.items() } def __getitem__(self, index): img_id = self.ids[index] anno = self.img_to_anns[img_id] target = [] for k in self.order: if k == "image": file_name = self.imgs[img_id]["file_name"] path = os.path.join(self.root, file_name) image = cv2.imread(path, cv2.IMREAD_COLOR) target.append(image) elif k == "boxes": boxes = [obj["bbox"] for obj in anno] boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) # transfer boxes from xywh to xyxy boxes[:, 2:] += boxes[:, :2] target.append(boxes) elif k == "boxes_category": boxes_category = [obj["category_id"] for obj in anno] boxes_category = [ self.json_category_id_to_contiguous_id[c] for c in boxes_category ] boxes_category = np.array(boxes_category, dtype=np.int32) target.append(boxes_category) elif k == "keypoints": keypoints = [obj["keypoints"] for obj in anno] keypoints = np.array(keypoints, dtype=np.float32).reshape( -1, len(self.keypoint_names), 3 ) target.append(keypoints) elif k == "polygons": polygons = [obj["segmentation"] for obj in anno] polygons = [ [np.array(p, dtype=np.float32).reshape(-1, 2) for p in ps] for ps in polygons ] target.append(polygons) elif k == "info": info = self.imgs[img_id] info = [info["height"], info["width"], info["file_name"]] target.append(info) else: raise NotImplementedError return tuple(target) def __len__(self): return len(self.ids) def get_img_info(self, index): img_id = self.ids[index] img_info = self.imgs[img_id] return img_info class_names = ( "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", ) classes_originID = { "person": 1, "bicycle": 2, "car": 3, "motorcycle": 4, "airplane": 5, "bus": 6, "train": 7, "truck": 8, "boat": 9, "traffic light": 10, "fire hydrant": 11, "stop sign": 13, "parking meter": 14, "bench": 15, "bird": 16, "cat": 17, "dog": 18, "horse": 19, "sheep": 20, "cow": 21, "elephant": 22, "bear": 23, "zebra": 24, "giraffe": 25, "backpack": 27, "umbrella": 28, "handbag": 31, "tie": 32, "suitcase": 33, "frisbee": 34, "skis": 35, "snowboard": 36, "sports ball": 37, "kite": 38, "baseball bat": 39, "baseball glove": 40, "skateboard": 41, "surfboard": 42, "tennis racket": 43, "bottle": 44, "wine glass": 46, "cup": 47, "fork": 48, "knife": 49, "spoon": 50, "bowl": 51, "banana": 52, "apple": 53, "sandwich": 54, "orange": 55, "broccoli": 56, "carrot": 57, "hot dog": 58, "pizza": 59, "donut": 60, "cake": 61, "chair": 62, "couch": 63, "potted plant": 64, "bed": 65, "dining table": 67, "toilet": 70, "tv": 72, "laptop": 73, "mouse": 74, "remote": 75, "keyboard": 76, "cell phone": 77, "microwave": 78, "oven": 79, "toaster": 80, "sink": 81, "refrigerator": 82, "book": 84, "clock": 85, "vase": 86, "scissors": 87, "teddy bear": 88, "hair drier": 89, "toothbrush": 90, } keypoint_names = ( "nose", "left_eye", "right_eye", "left_ear", "right_ear", "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", "left_wrist", "right_wrist", "left_hip", "right_hip", "left_knee", "right_knee", "left_ankle", "right_ankle", )