基于Jittor框架实现LSGAN图像生成对抗网络
阅读原文时间:2021年07月18日阅读:1

基于Jittor框架实现LSGAN图像生成对抗网络

生成对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。GAN模型由生成器(Generator)和判别器(Discriminator)两个部分组成。在训练过程中,生成器的目标就是尽量生成真实的图片去欺骗判别器。而判别器的目标就是尽量把生成器生成的图片和真实的图片分别开来。这样,生成器和判别器构成了一个动态的“博弈过程”。许多相关的研究工作表明GAN能够产生效果非常真实的生成效果。

使用Jittor框架实现了一种经典GAN模型LSGANLSGAN将GAN的目标函数由交叉熵损失替换成最小二乘损失,以此拒绝了标准GAN生成的图片质量不高以及训练过程不稳定这两个缺陷。通过LSGAN的实现介绍了Jittor数据加载、模型定义、模型训练的使用方法。

LSGAN论文:https://arxiv.org/abs/1611.04076

使用两种数据集进行LSGAN的训练,分别是Jittor自带的数据集MNIST,和用户构建的数据集CelebA。您可以通过以下链接下载CelebA数据集。

使用Jittor自带的MNIST数据加载器方法如下。使用jittor.transform可以进行数据归一化及数据增强,这里通过transform将图片归一化到[0,1]区间,并resize到标准大小112*112。。通过set_attrs函数可以修改数据集的相关参数,如batch_sizeshuffletransform等。

from jittor.dataset.mnist import MNIST

import jittor.transform as transform

transform = transform.Compose([

    

    

])

train_loader = MNIST (train=True, transform=transform)

        

val_loader = MNIST (train=False, transform = transform)

        

使用用户构建的CelebA数据集方法如下,通过通用数据加载器jittor.dataset.dataset.ImageFolder,输入数据集路径即可构建用户数据集。

from jittor.dataset.dataset import ImageFolder

import jittor.transform as transform

transform = transform.Compose([

    

    

])

train_dir = './data/celebA_train'

train_loader = ImageFolder(train_dir)

        

val_dir = './data/celebA_eval'

val_loader = ImageFolder(val_dir)

        

2.1.网络结构

使用LSGAN进行图像生成,下图为LSGAN论文给出的网络架构图,其中(a)为生成器,(b)为判别器。生成器网络输入一个1024维的向量,生成分辨率为112*112的图像;判别器网络输入112*112的图像,输出一个数字表示输入图像为真实图像的可信程度。

受到VGG模型的启发,生成器在与DCGAN的结构基础上在前两个反卷积层之后增加了两个步长=1的反卷积层。除使用最小二乘损失函数外判别器的结构与DCGAN中的结构相同。与DCGAN相同,生成器和判别器分别使用了ReLU激活函数和LeakyReLU激活函数。

下面将介绍如何使用Jittor定义一个网络模型。定义模型需要继承基类jittor.Module,并实现__init__execute函数。__init__函数在模型声明时会被调用,用于进行模型内部op或其他模型的声明及参数的初始化。该模型初始化时输入参数dim表示训练图像的通道数,对于MNIST数据集dim为1,对于CelebA数据集dim为3。

execute函数在网络前向传播时会被调用,用于定义前向传播的计算图,通过autograd机制在训练时Jittor会自动构建反向计算图。

import jittor as jt

from jittor import nn, Module

class generator(Module):

    

        

        

        

        

        

        

        

        

        

        

        

        

        

        

        

        

        

        

    

        

        

        

        

        

        

        

        

        

class discriminator(nn.Module):

    

        

        

        

        

        

        

        

        

        

        

    

        

        

        

        

        

        

        

2.2.损失函数

损失函数采用最小二乘损失函数,其中判别器损失函数如下。其中x为真实图像,z为服从正态分布的1024维向量,a取值为1,b取值为0。

生成器损失函数如下。其中z为服从正态分布的1024维向量,c取值为1。

具体实现如下,x为生成器的输出值,b表示该图像是否希望被判别为真。

def ls_loss(x, b):

    

    

    

    

        

    

        

3.1.参数设定

参数设定如下。

# 通过use_cuda设置在GPU上进行训练

jt.flags.use_cuda = 1

# 批大小

batch_size = 128

# 学习率

lr = 0.0002

# 训练轮数

train_epoch = 50

# 训练图像标准大小

img_size = 112

# Adam_优化器参数_

betas = (0.5,0.999)

# 数据集图像通道数,MNIST为1,CelebA为3

dim = 1 if task=="MNIST" else 3

3.2.模型、优化器声明

分别声明生成器和判别器,并使用Adam作为优化器。

# 生成器

G = generator (dim)

# 判别器

D = discriminator (dim)

# 生成器优化器

G_optim = nn.Adam(G.parameters(), lr, betas=betas)

# 判别器优化器

D_optim = nn.Adam(D.parameters(), lr, betas=betas)

3.3.训练

for epoch in range(train_epoch):

    

         

        

        

        

        

        

        

        

        

        

        

        

        

        

        

        

        

        

        

            

            

4.1.生成结果

分别使用MNISTCelebA数据集进行了50个epoch的训练。训练完成后各随机采样了25张图像,结果如下。

4.2.速度对比

使用Jittor与主流的深度学习框架PyTorch进行了训练速度的对比,下表为PyTorch(是/否打开benchmark)及Jittor在两种数据集上进行1次训练迭带的使用时间。得益于Jittor特有的元算子融合技术,其训练速度比PyTorch快了40%~55%。