pytorch使用(一)处理并加载自己的数据
阅读原文时间:2021年04月20日阅读:1

pytorch使用:目录


pytorch使用(一)数据处理

个人认为,数据处理或许是在完成一篇论文中最耗费时间的,特别是大多情况下,需要在很多个库上做实验。

pytorch官方支持很多库,使用torchvision来完成数据的处理,点这里可以看到支持的库并不是很多。在这里,我将结合一个实例说明如何使用pytorch来处理自己的数据,任务是一个分析双臂运动的,检测6个关节点的运动。输入是连续三帧的检测结果以及计算的光流,也就是$3*6+2*2=22$张heatmap,输出是中间帧的检测结果,也就是6张heatmap。

把原始数据处理为模型使用的数据需要3步:transforms.Compose() torchvision.datasets torch.utils.data.DataLoader()分别可以理解为数据处理格式的定义、数据处理和数据加载。

1. 数据预处理torchvision.transforms

pytorch使用torchvision.transforms实现数据的预处理,包括中心化(torchvision.transforms.CenterCrop)、随机剪切(torchvision.transforms.RandomCrop)、正则化、图片变为Tensor、tensor变为图片等,建议整体浏览一下这一部分的官方手册,非常有用,数据处理很方便。

先转换为张量,然后正则化:

import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
#img = transform(img)

2. 数据读取,构建Dataset子类

参考:http://blog.csdn.net/victoriaw/article/details/72356453

如果想要使用自己的数据,则必须自己构建一个torch.utils.data.Dataset的子类去读取数据。我们的将数据列表放在train.txttest.txt中,将不同类型的数据的路径放在path.txt中,所以在类的init函数中有path_file和 list_file两个变量

在定义torch.utils.data.Dataset的子类时,必须重载的两个函数是lengetitem:
- len返回数据集的大小
- getitem实现数据集的下标索引,返回对应的图像和标记(不一定非得返回图像和标记,返回元组的长度可以是任意长,这由网络需要的数据决定)。

末尾有自己写的一个Dataset子类的定义文件。

3. 数据加载

torch.utils.data.DataLoader()函数,合成数据并且提供迭代访问。主要由两部分组成:
- dataset(Dataset)。输入加载的数据,就是上面的MyDataset的实现。
- batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等参数,介绍几个比较常用的,这些在官方网站都有:

- batch-size。样本每个batch的大小,默认为1。
- shuffle。是否打乱数据,默认为False。
- num_workers。数据分为几个线程处理默认为0。
- sampler。定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。默认为False

使用:

import torch
from datagen import MyDataset

trainset = MyDataset(path_file=pathFile,list_file=trainList,numJoints = 6,type=False)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)
testset = MyDataset(path_file=pathFile,list_file=testList,numJoints = 6,type=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=8)

以下是定义class MyDataset文件datagen.py, 其中有__init__(self, path_file, list_file,numJoints,type)__getitem__(self, idx)__len__(self)三个函数,__getitem__返回一个(22,256,256)的输入和一个(6,256,256)的标签。
'''
Load data
'''

import numpy as np
from PIL import Image
#import cv2

import torch
import torch.utils.data as data
import torchvision.transforms as transforms

class MyDataset(data.Dataset):

    def __init__(self, path_file, list_file,numJoints,type):
        '''
        Args:
          path_file: (str) heatmap and optical file location
          list_file: (str) path to index file.
          numJoints: (int) number of joints
          type: (boolean) use pose flow(true) or optical flow(false)
        '''

        self.numJoints = numJoints

        # read heatmap and optical path
        with open(path_file) as f:
            paths = f.readlines()

        for path in paths:
            splited = path.strip().split()
            if splited[0]=='resPath':
                self.resPath = splited[1]
            elif splited[0]=='gtPath':
                self.gtPath = splited[1]
            elif splited[0]=='opticalFlowPath':
                self.opticalFlowPath = splited[1]
            elif splited[0]=='poseFlowPath':
                self.poseFlowPath = splited[1]
        if type:
            self.flowPath = self.poseFlowPath
        else:
            self.flowPath = self.opticalFlowPath


        #read list
        with open(list_file) as f:
            self.list = f.readlines()
            self.num_samples = len(self.list)

def __getitem__(self, idx):
    '''
    load heatmaps and optical flow and encode it to a 22 channels input and 6 channels output
    :param idx: (int) image index
    :return:
        input: a 22 channel input which integrate 2 optical flow and heatmaps of 3 image
        output: the ground truth

    '''

    input = []
    output = []
    # load heatmaps of 3 image
    for im in range(3):
        for map in range(6):
            curResPath = self.resPath + self.list[idx].rstrip('\n') + str(im + 1) + '/' + str(map + 1) + '.bmp'
            heatmap = Image.open(curResPath)
            heatmap.load()
            heatmap = np.asarray(heatmap, dtype='float') / 255
            input.append(heatmap)
    # load 2 flow
    for flow in range(2):
        curFlowXPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowx/' + str(flow + 1) + '.jpg'
        flowX = Image.open(curFlowXPath)
        flowX.load()
        flowX = np.asarray(flowX, dtype='float')
        curFlowYPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowy/' + str(flow + 1) + '.jpg'
        flowY = Image.open(curFlowYPath)
        flowY.load()
        flowY = np.asarray(flowY, dtype='float')
        input.append(flowX)
        input.append(flowY)
    # load groundtruth
    for map in range(6):
        curgtPath = self.resPath + self.list[idx].rstrip('\n') + str(2) + '/' + str(map + 1) + '.bmp'
        heatmap = Image.open(curResPath)
        heatmap.load()
        heatmap = np.asarray(heatmap, dtype='float') / 255
        output.append(heatmap)

    input = torch.Tensor(input)
    output = torch.Tensor(output)

    return input,output



def __len__(self):
    return self.num_samples

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器

你可能感兴趣的文章