coco数据集(yoloV5格式)中生成子类数据集

从coco数据集(yoloV5格式)中生成子类数据集。

import os
from tqdm import tqdm

names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
        'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
        'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
        'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
        'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
        'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
        'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
        'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
        'hair drier', 'toothbrush']

# 类别	图片数量	标注框数量
per_class_size = {"person" : (64115, 262465),
"bicycle" : (3252, 7113),
"car" : (12251, 43867),
"motorcycle" : (3502, 8725),
"airplane" : (2986, 5135),
"bus" : (3952, 6069),
"train" : (3588, 4571),
"truck" : (6127, 9973),
"boat" : (3025, 10759),
"traffic light" : (4139, 12884),
"fire hydrant" : (1711, 1865),
"stop sign" : (1734, 1983),
"parking meter" : (705, 1285),
"bench" : (5570, 9838),
"bird" : (3237, 10806),
"cat" : (4114, 4768),
"dog" : (4385, 5508),
"horse" : (2941, 6587),
"sheep" : (1529, 9509),
"cow" : (1968, 8147),
"elephant" : (2143, 5513),
"bear" : (960, 1294),
"zebra" : (1916, 5303),
"giraffe" : (2546, 5131),
"backpack" : (5528, 8720),
"umbrella" : (3968, 11431),
"handbag" : (6841, 12354),
"tie" : (3810, 6496),
"suitcase" : (2402, 6192),
"frisbee" : (2184, 2682),
"skis" : (3082, 6646),
"snowboard" : (1654, 2685),
"sports ball" : (4262, 6347),
"kite" : (2261, 9076),
"baseball bat" : (2506, 3276),
"baseball glove" : (2629, 3747),
"skateboard" : (3476, 5543),
"surfboard" : (3486, 6126),
"tennis racket" : (3394, 4812),
"bottle" : (8501, 24342),
"wine glass" : (2533, 7913),
"cup" : (9189, 20650),
"fork" : (3555, 5479),
"knife" : (4326, 7770),
"spoon" : (3529, 6165),
"bowl" : (7111, 14358),
"banana" : (2243, 9458),
"apple" : (1586, 5851),
"sandwich" : (2365, 4373),
"orange" : (1699, 6399),
"broccoli" : (1939, 7308),
"carrot" : (24, 51719),
"hot dog" : (11, 8426),
"pizza" : (3166, 5821),
"donut" : (1523, 7179),
"cake" : (2925, 6353),
"chair" : (12774, 38491),
"couch" : (4423, 5779),
"potted plant" : (4452, 8652),
"bed" : (3682, 4192),
"dining table" : (11837, 15714),
"toilet" : (3353, 4157),
"tv" : (4561, 5805),
"laptop" : (3524, 4970),
"mouse" : (1876, 2262),
"remote" : (3076, 5703),
"keyboard" : (2115, 2855),
"cell phone" : (4803, 6434),
"microwave" : (1547, 1673),
"oven" : (2877, 3334),
"toaster" : (217, 225),
"sink" : (4678, 5610),
"refrigerator" : (2360, 2637),
"book" : (5332, 24715),
"clock" : (4659, 6334),
"vase" : (3593, 6613),
"scissors" : (947, 1481),
"teddy bear" : (16, 6087),
"hair drier" : (189, 198),
"toothbrush" : (1007, 1954),}

SKIP_NONE_IMG = True

per_class_size = sorted(per_class_size.items(), key=lambda d:d[1][0], reverse = True)
sub_names = [key[0] for key in per_class_size[:14]]
sub_names_index_dict = {key: i for i, key in enumerate(sub_names)}
print(sub_names_index_dict)

coco2017_root_dir = "/home/dzhang/data/mscoco/"
coco2017_coco_train_dir = os.path.join(coco2017_root_dir, "train2017")
coco2017_coco_val_dir = os.path.join(coco2017_root_dir, "val2017")
coco2017_train_img_dir = os.path.join(coco2017_coco_train_dir, "images")
coco2017_val_img_dir = os.path.join(coco2017_coco_val_dir, "images")
sub_coco_root_dir = coco2017_root_dir.replace("mscoco", "mscoco_%dobj" % (len(sub_names)))
sub_coco_train_dir = os.path.join(sub_coco_root_dir, "train2017")
sub_coco_val_dir = os.path.join(sub_coco_root_dir, "val2017")
sub_coco_train_img_dir = os.path.join(sub_coco_train_dir, "images")
sub_coco_val_img_dir = os.path.join(sub_coco_val_dir, "images")
sub_coco_train_label_dir = os.path.join(sub_coco_train_dir, "labels")
sub_coco_val_label_dir = os.path.join(sub_coco_val_dir, "labels")
if not os.path.exists(sub_coco_root_dir):
    os.system("mkdir -p %s" % (sub_coco_train_label_dir))
    os.system("mkdir -p %s" % (sub_coco_val_label_dir))
    os.system("ln -s %s %s" % (coco2017_train_img_dir, sub_coco_train_img_dir))
    os.system("ln -s %s %s" % (coco2017_val_img_dir, sub_coco_val_img_dir))


def create_sub(file_path, root_dir, root_img_dir, root_label_dir):
    lines = None
    with open(os.path.join(coco2017_root_dir, file_path)) as f:
        lines = f.readlines()
    lines = [line.strip() for line in lines]

    sub_lines = []
    for line in tqdm(lines):
        img_path = line.split("/")[-1]
        label_file_path = line.replace("images", "labels/").replace(".jpg", ".txt")
        label_lines = None
        with open(label_file_path) as f:
            label_lines = f.readlines()
        label_lines = [line.strip() for line in label_lines]
        sub_label_lines=[]
        for label_str in label_lines:
            label = label_str.split(" ")
            c = int(label[0])
            if names[c] not in sub_names:
                continue
            label[0] = "%d" % (sub_names_index_dict[names[c]])
            label = " ".join(label)
            sub_label_lines.append(label)
        if len(sub_label_lines) == 0 and SKIP_NONE_IMG:
            continue
        img_abs_path = "%s" % (os.path.join(root_img_dir, img_path))
        label_abs_path = "%s" % (os.path.join(root_label_dir, img_path.replace(".jpg", ".txt")))
        sub_lines.append(img_abs_path)
        with open(label_abs_path, "w") as f:
            for line in sub_label_lines:
                f.write("%s\n" % (line))
    with open(os.path.join(root_dir, file_path), "w") as f:
        for line in sub_lines:
            f.write("%s\n" % (line))
create_sub("train2017.txt", sub_coco_root_dir, sub_coco_train_img_dir, sub_coco_train_label_dir)
create_sub("val2017.txt", sub_coco_root_dir, sub_coco_val_img_dir, sub_coco_val_label_dir)