go networking vuejs2 split proxy sms pmp视频教程 less使用 jquery的点击事件 android小程序源代码 bootstrap中文api文档 matlab取实部 mysql时间戳转日期 mysql配置远程连接 python参考手册 python的开发工具 python插件 python函数返回 java入门代码 java自定义异常 梦幻西游手游助手 mac画图工具 gunzip 桌面数字时钟 斑驳纹理 蜘蛛皮肤 js取余数 内存条有什么用 lol世界第一 饥荒黄油 maya骨骼绑定教程 cad打散 杨辉三角python ucs怎么用 dnf女柔道加点 大家来goldwave使用教程 yy打不开 千千静听老版本下载 方正小标宋gbk 大图打印
当前位置: 首页 > 学习教程  > 编程语言

pytorchYOLOV4 训练数据生成

2020/7/24 11:21:12 文章标签: 测试文章如有侵权请发送至邮箱809451989@qq.com投诉后文章立即删除

目录

step1:目录结构

step2:运行create.py 生成目标文件目录

step3:运行trans.py 得到训练所需的txt文件


本文参考darknet 生成训练数据的博文中的代码来修改得到 ,原文:https://blog.csdn.net/dcrmg/article/details/81296520,感谢原作者。

pytorch版本yolov4源码地址,感谢源码作者:

https://github.com/Tianxiaomo/pytorch-YOLOv4

上一篇博客写的是使用pytorch yolov4 的步骤可参考:https://blog.csdn.net/h649070/article/details/107492649

step1:目录结构

dataset

----------trainImage

----------validateImage

----------trainImageXML

----------validateImageXML

----------create.py

----------trans.py

step2:运行create.py 生成目标文件目录

create.py 文件代码如下

在数据库同级目录下直接运行

# -*- coding: utf-8 -*-
'''
作者:-牧野-
来源:CSDN
原文:https://blog.csdn.net/dcrmg/article/details/81296520
版权声明:本文为博主原创文章,转载请附上博文链接!
'''
import os
import shutil

def listname(path,idtxtpath):
    filelist = os.listdir(path)  # 该文件夹下所有的文件(包括文件夹)
    filelist.sort()
    f = open(idtxtpath, 'w')
    for files in filelist:  # 遍历所有文件
        Olddir = os.path.join(path, files)  # 原来的文件路径
        if os.path.isdir(Olddir):  # 如果是文件夹则跳过
            continue
        f.write(files)
        f.write('\n')
    f.close()
 
savepath = os.getcwd()
imgidtxttrainpath = savepath+"/trainImageId.txt"
imgidtxtvalpath = savepath + "/validateImageId.txt"

imgtrainpath = os.path.join(os.getcwd(),'trainImage')
imgvalpath = os.path.join(os.getcwd(),'validateImage')
listname(imgtrainpath,imgidtxttrainpath)
listname(imgvalpath,imgidtxtvalpath)

print ("trainImageId.txt && validateImageId.txt have been created!")

step3:运行trans.py 得到训练所需的txt文件

在同级目录下运行,代码如下:


import xml.etree.ElementTree as ET
import pickle
import string
import os
import shutil
from os import listdir, getcwd
from os.path import join
import cv2

sets = [('2012', 'train')]

classes = ['class1','class2']


wd = getcwd()
out_val_file = open(os.path.join(wd,'valset.txt'), 'w') 
out_train_file = open(os.path.join(wd,'trainset.txt'), 'w') 

def convert(size, box):
    dw = 1. / size[0]
    dh = 1. / size[1]
    x = (box[0] + box[1]) / 2.0
    y = (box[2] + box[3]) / 2.0
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x * dw
    w = w * dw
    y = y * dh
    h = h * dh
    return (x, y, w, h)


def convert_annotation(image_id, flag, savepath):
    if flag == 0:
        in_file = open(savepath + '/trainImageXML/%s.xml' % (os.path.splitext(image_id)[0]),'r',encoding='utf-8')
        # out_file = open(savepath + '/trainImage/%s.txt' % (os.path.splitext(image_id)[0]), 'w')
        print(in_file)
        tree = ET.parse(in_file)
        root = tree.getroot()
        size = root.find('size')

        img = cv2.imread('./trainImage/' + str(image_id))
        h = img.shape[0]
        w = img.shape[1]
        out_train_file.write(savepath + '/trainImage/%s.jpg' % (os.path.splitext(image_id)[0]))
        for obj in root.iter('object'):
            difficult = obj.find('difficult').text
            cls = obj.find('name').text
            if cls not in classes or int(difficult) == 1:
                continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            # b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
            #      float(xmlbox.find('ymax').text))
            b = (int(xmlbox.find('xmin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymin').text),
                 int(xmlbox.find('ymax').text))
            # bb = convert((w, h), b)
            bb = b
            out_train_file.write(" " + ",".join([str(a) for a in bb]) + ',' + str(cls_id))
        out_train_file.write('\n')
    elif flag == 1:
        in_file = open(savepath + '/validateImageXML/%s.xml' % (os.path.splitext(image_id)[0]),'r',encoding='utf-8')
        # out_file = open(savepath + '/validateImage/%s.txt' % (os.path.splitext(image_id)[0]), 'w')

        tree = ET.parse(in_file)
        root = tree.getroot()
        size = root.find('size')

        img = cv2.imread('./validateImage/' + str(image_id))
        h = img.shape[0]
        w = img.shape[1]
        out_val_file.write(savepath + '/validateImage/%s.jpg' % (os.path.splitext(image_id)[0]))
        for obj in root.iter('object'):
            difficult = obj.find('difficult').text
            cls = obj.find('name').text
            if cls not in classes or int(difficult) == 1:
                continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            # b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
            #      float(xmlbox.find('ymax').text))
            b = (int(xmlbox.find('xmin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymin').text),
                 int(xmlbox.find('ymax').text))
            # bb = convert((w, h), b)
            bb = b
            out_val_file.write(" " + ",".join([str(a) for a in bb]) + ',' + str(cls_id))
        out_val_file.write('\n')
    return



for year, image_set in sets:
    savepath = wd#os.getcwd();

    idtxt = savepath + "/validateImageId.txt"
    pathtxt = savepath + "/validateImagePath.txt"
    image_ids = open(idtxt).read().strip().split()
    list_file = open(pathtxt, 'w')
    s = '\xef\xbb\xbf'
    for image_id in image_ids:
        nPos = image_id.find(s)
        if nPos >= 0:
            image_id = image_id[3:]
        list_file.write('%s/validateImage/%s\n' % (wd, image_id))
        print(image_id)
        convert_annotation(image_id, 1, savepath)
    list_file.close()

    idtxt = savepath + "/trainImageId.txt"
    pathtxt = savepath + "/trainImagePath.txt"
    image_ids = open(idtxt).read().strip().split()
    list_file = open(pathtxt, 'w')
    s = '\xef\xbb\xbf'
    for image_id in image_ids:
        nPos = image_id.find(s)
        if nPos >= 0:
            image_id = image_id[3:]
        list_file.write('%s/trainImage/%s\n' % (wd, image_id))
        print(image_id)
        convert_annotation(image_id, 0, savepath)
    list_file.close()


out_train_file.close()
out_val_file.close()

最后生成的valset.txt 和 trainset.txt即为最终使用的文件了;

转载请留言


本文链接: http://www.dtmao.cc/news_show_50408.shtml

附件下载

相关教程

    暂无相关的数据...

共有条评论 网友评论

验证码: 看不清楚?