【pytorch】图片分类问题处理一般数据集,使其满足torchvision.datasets.ImageFolder调用结构
阿里云国内75折 回扣 微信号:monov8 |
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6 |
torchvision.datasets.ImageFolder调用结构
对于简单的图像分类任务并不需要自己定义一个 Dataset类可以直接调用 torchvision.datasets.ImageFolder 返回训练数据与标签。
数据集应满足pytorch的格式要求即将数据集分割为训练集和测试集并将数据和标签分别放入不同的文件夹
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
同时应兼顾按比例划分训练集测试集及验证集的需求。
下面的函数将人眼睁闭数据集转换为pytorch指定的结构
原始数据集
调用代码示例
import os
import shutil
import random
class PictureClassifier(object):
def __init__(self, img_dir, target_dir, categories, train_percent, validate_percent, test_percent):
self.img_dir = img_dir
self.target_dir = target_dir
self.categories = categories
self.train_percent = train_percent
self.validate_percent = validate_percent
self.test_percent = test_percent
for category in categories:
os.makedirs(os.path.join(target_dir, 'train', category))
os.makedirs(os.path.join(target_dir, 'validate', category))
os.makedirs(os.path.join(target_dir, 'test', category))
#定义通过图片名获取标签的方法返回标签
def getLabelByFileName(self, filename):
pass
#检验被遍历对象是否为需要处理图片的方法返回true或false
def isPic(self, filename):
pass
#遍历img_dir下的所有文件逐一进行操作
def classify(self):
for root, dirs, files in os.walk(self.img_dir):
for file in files:
# 打印所有文件对象路径
# print(os.path.join(root, file))
# 该file所在的路径
# print(root)
fileName = file
if self.isPic(fileName):
label = self.getLabelByFileName(fileName)
if random.random() < self.train_percent:
shutil.copy(os.path.join(root, file), os.path.join(self.target_dir, 'train', label, file))
elif random.random() < self.validate_percent:
shutil.copy(os.path.join(root, file), os.path.join(self.target_dir, 'validate', label, file))
else:
shutil.copy(os.path.join(root, file), os.path.join(self.target_dir, 'test', label, file))
else:
continue
class MyPictureClassifier(PictureClassifier):
def __init__(self, img_dir, target_dir, categories,train_percent, validate_percent, test_percent):
super(MyPictureClassifier, self).__init__(img_dir, target_dir, categories,train_percent, validate_percent, test_percent)
def getLabelByFileName(self, filename):
#数据集第四个位置为标签名
num_str = filename.split('_')[4]
if num_str=="0":
return 'close'
else:
return 'open'
def isPic(self, filename):
return filename.endswith('.png')
# 图片所在的文件夹
img_dir = 'D:\mrlEyes_2018_01'
# 将图片转换后存放的文件夹
target_dir = 'D:\eyeDataSet'
# 类别信息
categories = ['open', 'close']
worker=MyPictureClassifier(img_dir,target_dir,categories,0.8,0.1,0.1)
worker.classify()
转换后