其他分享
首页 > 其他分享> > 三步搞定使用Augmentor对训练数据集进行扩增

三步搞定使用Augmentor对训练数据集进行扩增

作者:互联网

文章目录

前言

在训练模型时,有时在数据量较少情况下,避免过拟合,通常会采取人为进行数据增强来达到扩充数据集的目的,下面就介绍一种使用Augmentor来扩充数据集的方法。

实现过程

程序实现过程如下:

import numpy as np, Augmentor, cv2, sys
import os
import shutil
def del_file(path):
    ls = os.listdir(path) 
    for i in ls:
        c_path = os.path.join(path, i)
        if os.path.isdir(c_path):
            del_file(c_path)
        else:
            os.remove(c_path)

def Enhancement(filePath, rate):
    index = ngFilePath.rfind("\\")
    print(index)
    dataType = filePath.find("NG")
    dataType1 = filePath.find("OK")

    if(dataType > 0):
        enhancementDir = filePath[0:index] + "\\" + 'EnhanceImg' + '\\' + 'NG'
    if (dataType1 > 0):
        enhancementDir = filePath[0:index] + "\\" + 'EnhanceImg' + '\\' + 'OK'

    showDir = filePath[0:index] + "\\" + 'showImg'
    singleDir = filePath[0:index] + "\\" + 'sigleImg'

    isExist = os.path.exists(enhancementDir);
    if not isExist:
        os.makedirs(enhancementDir)
    else:
        del_file(enhancementDir)
    isExist = os.path.exists(showDir)

    if not isExist:
        os.makedirs(showDir)
    else:
        del_file(showDir)

    isExist = os.path.exists(singleDir)
    if not isExist:
        os.makedirs(singleDir)
    else:
        del_file(singleDir)

    sourceFiles = os.listdir(filePath)
    num = len(sourceFiles)
    sourceList = list(range(num))
    for i in sourceList:
        sourceFilesName = os.path.join(filePath,sourceFiles[i])
        src = cv2.imread(sourceFilesName, 0)
        shutil.copy2(sourceFilesName, singleDir)
        p = Augmentor.Pipeline(singleDir, showDir)
        p.random_brightness(probability= 0.7, min_factor = 0.5, max_factor= 1.2)
        # p.crop_centre(probability=0.5,160, 160)
        p.resize(probability=1, width=160, height=160)
        p.random_contrast(probability= 0.5, min_factor= 0.5, max_factor= 1.2)
        p.sample(rate)
        # shutil.copy2(sourceFilesName, enhancementDir)
        enhancedImg = os.listdir(showDir)
        enhanceImgList = list(range(len(enhancedImg)))
        sampleImgList =  []
        for j in enhanceImgList:
            fileName = enhancedImg[j]
            sampleImgList.append(fileName)
        numSampleImg = list(range(len(sampleImgList)))
        for k in numSampleImg:
            fileName = os.path.join(showDir, sampleImgList[k])
            shutil.copy2(fileName, enhancementDir)

        del_file(showDir)

if __name__ == '__main__':
    ngFilePath = "E:\\IMG\\NG"
    okFilePath = "E:\\IMG\\OK"
    rate = 10
    Enhancement(okFilePath, rate)

注:
由于这里是做二分类,所以将数据分为OK和NG,这里OK文件夹里随便放了5张图片,然后对这5张图片进行数据增强。

在这里插入图片描述
运行脚本后,后自动生成三个文件夹,数据增强后的数据会自动保存在EnhanceImg文件夹下
在这里插入图片描述在这里插入图片描述

标签:index,搞定,enhancementDir,扩增,filePath,path,showDir,os,Augmentor
来源: https://blog.csdn.net/weixin_41552975/article/details/117781491