torchvision.datasets.ImageFolder使用详解
阅读原文时间:2023年08月09日阅读:1

一、数据集组织方式
ImageFolder是一个通用的数据加载器,它要求我们以下面这种格式来组织数据集的训练、验证或者测试图片。

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

对于上面的root,假设data文件夹在.py文件的同级目录中,那么root一般都是如下这种形式:./data/train 和 ./data/valid


二、ImageFolder参数详解
 

dataset=torchvision.datasets.ImageFolder(
                       root, transform=None, 
                       target_transform=None, 
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;loader=<function default_loader>,&nbsp;
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;is_valid_file=None)

参数详解:

  • root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
  • transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
  • target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。 如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
  • loader:表示数据集加载方式,通常默认加载方式即可。
  • is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)

返回的dataset都有以下三种属性:

  • self.classes:用一个 list 保存类别名称
  • self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
  • self.imgs:保存(img-path, class) tuple的 list

三、程序案例
 

from torchvision.datasets import ImageFolder
from torchvision import transforms

#加上transforms
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
transform=transforms.Compose([
&nbsp; &nbsp; transforms.RandomCrop(180),
&nbsp; &nbsp; transforms.RandomHorizontalFlip(),
&nbsp; &nbsp; transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
&nbsp; &nbsp; normalize
])

dataset=ImageFolder('./data/train',transform=transform)

我们得到的dataset,它的结构就是[(img_data,class_id),(img_data,class_id),…],下面我们打印第一个元素:

print(dataset[0])
'''
输出:
(tensor([[[-0.5137, -0.4667, -0.4902, &nbsp;..., -0.0980, -0.0980, -0.0902],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.5922, -0.5529, -0.5059, &nbsp;..., -0.0902, -0.0980, -0.0667],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.5373, -0.5294, -0.4824, &nbsp;..., -0.0588, -0.0824, -0.0196],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;...,
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.3098, -0.3882, -0.3725, &nbsp;..., -0.4353, -0.4510, -0.4196],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.2863, -0.3647, -0.3725, &nbsp;..., -0.4431, -0.4118, -0.4196],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.3412, -0.3569, -0.3882, &nbsp;..., -0.4667, -0.4588, -0.4196]],

&nbsp; &nbsp; &nbsp; &nbsp; [[-0.6157, -0.5686, -0.5922, &nbsp;..., -0.2863, -0.2784, -0.2706],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.6941, -0.6549, -0.6078, &nbsp;..., -0.2784, -0.2784, -0.2471],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.6392, -0.6314, -0.5843, &nbsp;..., -0.2471, -0.2706, -0.2078],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;...,
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.4431, -0.5059, -0.5059, &nbsp;..., -0.5608, -0.5765, -0.5451],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.4196, -0.4824, -0.5059, &nbsp;..., -0.5686, -0.5373, -0.5451],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.4745, -0.4902, -0.5294, &nbsp;..., -0.5922, -0.5843, -0.5451]],

&nbsp; &nbsp; &nbsp; &nbsp; [[-0.6627, -0.6157, -0.6549, &nbsp;..., -0.5059, -0.5216, -0.5137],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.7412, -0.7020, -0.6706, &nbsp;..., -0.4980, -0.5216, -0.4902],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.6863, -0.6784, -0.6471, &nbsp;..., -0.4667, -0.4902, -0.4275],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;...,
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.6000, -0.6549, -0.6627, &nbsp;..., -0.6784, -0.6941, -0.6627],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.5765, -0.6314, -0.6471, &nbsp;..., -0.6863, -0.6549, -0.6627],
&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;[-0.6314, -0.6314, -0.6392, &nbsp;..., -0.7098, -0.7020, -0.6627]]]), 0)
'''

下面我们再看一下dataset的三个属性:

print(dataset.classes) &nbsp;#根据分的文件夹的名字来确定的类别
print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
'''
输出:
['cat', 'dog']
{'cat': 0, 'dog': 1}
[('./data/train\\cat\\1.jpg', 0),&nbsp;
&nbsp;('./data/train\\cat\\2.jpg', 0),&nbsp;
&nbsp;('./data/train\\dog\\1.jpg', 1),&nbsp;
&nbsp;('./data/train\\dog\\2.jpg', 1)]
'''