问题记录2:关于使用torch.utils.data.TensorDataset()和torch.utils.data.DataLoader()时遇到的一些问题
阅读原文时间:2021年04月20日阅读:1

本文记录学习过程中遇到的问题、我的解决过程以及学习心得,如有错误之处,欢迎指正!

在学习用pytorch进行数据批处理的过程中用到了torch.utils.data.TensorDataset()和torch.utils.data.DataLoader()函数,练习的代码如下:

import torch
import torch.utils.data as Data

torch.manual_seed(1)    # 这句有关生成随机数,他会使得随机生成的结果是确定的


BATCH_SIZE= 5   # 设置批次训练数量

# 定义数据
x = torch.linspace(1, 10, steps=10)    # torch.linspace()线性等分向量,前两个参数是向量的开始和结束值,steps是分割出的点数,缺省值100
y = torch.linspace(10, 1, steps=10)    # x,y都是十维向量

torch_dataset = Data.TensorDataset(x, y)    # x,y对应整合进数据集,应该是一个二维数据的队列(10*2矩阵)

loader = Data.DataLoader(
    dataset=torch_dataset,      # 加载数据集
    batch_size=BATCH_SIZE,      # 批次大小
    shuffle=True,                # 是否打乱顺序训练
    num_workers=2               # 设置线程数
)

def show_batch():
    for epoch in range(3):   # 进行三轮训练
        for step, (batch_x, batch_y) in enumerate(loader):  # 每轮训练进行两批(一批5个数据)
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())

if __name__== '__main__':
    show_batch()

问题1:TensorDataset() 形参名称出错 

莫烦pytorch课程中整合数据集这一步骤用到的代码是:


torch_dataset = Data.TensorDataset(data_tensor=x, target_tensor=y)

这一句在运行过程中会报错:
TypeError: __init__() got an unexpected keyword argument 'data_tensor'

查看TensorDataset的声明:

class TensorDataset(Dataset):

    def __init__(self, *tensors):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

再查看Dataset的的声明(篇幅太长,略去代码展示)确实不存在data_tensor或者target_tensor的参数,论坛上有人说这是由于版本不同造成的。去掉形参名直接赋值即可解决问题。

问题2:show_batch()函数调用出错/对变量loader迭代时报错

DEBUG之前的代码没有定义函数show_batch,也没有if __name__=='__main__'语句,直接执行show_batch()函数中的内容。或者定义show_batch()函数并直接调用也会导致出错。错误代码:

RuntimeError: DataLoader worker (pid(s) 12384, 10160) exited unexpectedly

这类错误是由于多线程处理造成的,若要直接对loader迭代,则需要去掉对loader赋值时Data.DataLoader()函数中的num_workers=2语句,或者赋值0.

如果需要采用多线程处理的话,最好采用最上面代码的方法定义主函数并调用步骤,直接在主函数中执行show_batch()函数的执行语句同样可以实现。

问题3:如何读出TensorDataset类中的数据

为了方便理解,我采用了如下代码,希望了解各个类中数据时如何存储的:

print(x)
print(y)
print(torch_dataset)
print(loader)

'''
输出结果为:
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
tensor([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.])
<torch.utils.data.dataset.TensorDataset object at 0x0000016D635A86A0>
<torch.utils.data.dataloader.DataLoader object at 0x0000016D635A8DA0>
'''

由输出结果可见,类的数据不能直接读取,查看了pytorch中文文档中关于torch.utils.data的描述之后了解到TensorDataset通过第一个维度索引两个张量来恢复每个样本,可以通过torch_dataset[index]来读出TensorDataset中的数据,也可以通过for循环遍历数据:

for each in torch_dataset:
    print(each)

'''
输出结果:
(tensor(1.), tensor(10.))
(tensor(2.), tensor(9.))
(tensor(3.), tensor(8.))
(tensor(4.), tensor(7.))
(tensor(5.), tensor(6.))
(tensor(6.), tensor(5.))
(tensor(7.), tensor(4.))
(tensor(8.), tensor(3.))
(tensor(9.), tensor(2.))
(tensor(10.), tensor(1.))
'''

而在研究DataLoader中数据存储时无法通过遍历或者迭代实现数据读出,代码和错误信息如下:

dataiter = iter(loader)
data, labels = next(loader)
print(data, labels)

# TypeError: 'DataLoader' object is not an iterator


print(loader[0])
for each in loader:
    print(each)

# TypeError: 'DataLoader' object does not support indexing

看文档说DataLoader类是可以迭代的,但是错误信息指出DataLoader类不能索引或者迭代。

奈何本人半路出家,学艺不精…这个问题我仍然没有解决,并且暂时没有查找到相关资料,这里先挖一个坑,等解决了再来更新吧…