Adversarial Examples Improve Image Recognition
阅读原文时间:2022年05月26日阅读:1

Xie C, Tan M, Gong B, et al. Adversarial Examples Improve Image Recognition.[J]. arXiv: Computer Vision and Pattern Recognition, 2019.

@article{xie2019adversarial,

title={Adversarial Examples Improve Image Recognition.},

author={Xie, Cihang and Tan, Mingxing and Gong, Boqing and Wang, Jiang and Yuille, Alan L and Le, Quoc V},

journal={arXiv: Computer Vision and Pattern Recognition},

year={2019}}

为了让网络更稳定, 作者关注

\[\arg \min_{\theta} [\mathbb{E}_{(x, y)\sim \mathbb{D}}(L(\theta,x,y)+\max_{\epsilon \in \mathbb{S}}L(\theta,x+\epsilon,y)],
\]

实际上就是一种对抗训练. 但是如果只是普通的训练样本加上对应的adversarial samples混合训练效果并不好. 于是, 作者提出这种情况的原因是Batchnorm造成的, 只需要针对普通样本和对抗样本在训练的时候分别设置不同的batchnorm模块即可.

作者认为, 普通训练样本和对抗训练样本所属的分布不同, 此时用同一个batchnorm效果不好, 所以提出在训练的时候添加一个额外的batchnorm, 专门用于为对杨训练样本使用, 而在非训练截断, 只是用普通的Batchnorm.

每一次训练步骤如下:

  1. 从普通训练样本中采样 batch \(x^c\)以及对应的标签\(y\);
  2. 根据相应算法(本文采用PGD)生成对抗训练样本\(x^a\)(额外的batchnorm);
  3. 计算损失\(L^c(\theta, x^c, y)\)(普通的batchnorm);
  4. 计算损失\(L^a(\theta, x^a, y)\)(额外的batchnorm);
  5. backward: \(L^c(\theta, x^c, y) + L^a(\theta, x^a, y)\), 并更新\(\theta\).

实验概述

数据集: ImageNet-A, ImageNet-C, Stylized-ImageNet.

5.2: AdvProp, 85.2% top-1 accuracy on ImageNet(Fig4);

在打乱的ImageNet数据集合上测试(Table4, mCE(mean corruption , lower is better));

探究adversarial attacks 强度对网络分类正确率的影响:当一个网络的“适应性"较弱的时候, 强度小反而效果好, ”适应性”较强的时候, 强度高更好(Table2);

比较AdvProp与一般的对抗训练的效果差异(Fig5);

“适应性”强的网络, AdvProp的作用越小;

AutoAugment 与 Advprop的比较(Table 6);

不同的adversarial attacks的影响(Table 7);

代码未经测试.

"""
white-box attacks:
iFGSM
PGD
"""

import torch
import torch.nn as nn

class WhiteBox:

    def __init__(self, net, epsilon:float, times:int, criterion=None):
        self.net = net
        self.epsilon = epsilon
        self.times = times
        if not criterion:
            self.criterion = nn.CrossEntropyLoss()
        else:
            self.criterion = criterion

        pass

    @staticmethod
    def calc_jacobian(loss, inp):
        jacobian = torch.autograd.grad(loss, inp, retain_graph=True)[0]
        return jacobian

    @staticmethod
    def sgn(matrix):
        return torch.sign(matrix)

    @staticmethod
    def pre(out):
        return torch.argmax(out, dim=1)

    def fgsm(self, inp, y):
        inp.requires_grad_(True)
        out = self.net(inp)
        loss = self.criterion(out, y)
        delta = self.sgn(self.calc_jacobian(loss, inp))
        flag = False
        inp_new = inp.data
        for i in range(self.times):
            inp_new = inp_new + self.epsilon * delta
            out_new = self.net(inp_new)
            if self.pre(out_new) != y:
                flag = True
                break
        return flag, inp_new

    def ifgsm(self, inps, ys):
        N = len(inps)
        adversarial_samples = []
        for i in range(N):
            flag, inp_new = self.fgsm(
                inps[[i]], ys[[i]]
            )
            if flag:
                adversarial_samples.append(inp_new)

        return torch.cat(adversarial_samples), \
               len(adversarial_samples) / N

    def pgd(self, inp, y, perturb):
        boundary_low = inp - perturb
        boundary_up = inp + perturb
        inp.requires_grad_(True)
        out = self.net(inp)
        loss = self.criterion(out, y)
        delta = self.sgn(self.calc_jacobian(loss, inp))
        flag = False
        inp_new = inp.data
        for i in range(self.times):
            inp_new = torch.clamp(
                inp_new + delta,
                boundary_low,
                boundary_up
            )
            out_new = self.net(inp_new)
            if self.pre(out_new) != y:
                flag = True
                break
        return flag, inp_new

    def ipgd(self, inps, ys, perturb):
        N = len(inps)
        adversarial_samples = []
        for i in range(N):
            flag, inp_new = self.pgd(
                inps[[i]], ys[[i]],
                perturb
            )
            if flag:
                adversarial_samples.append(inp_new)

        return torch.cat(adversarial_samples), \
               len(adversarial_samples) / N




"""
black-box attack
see Practical Black-Box Attacks against Machine Learning.
"""

import  torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class Synthetic(Dataset):

    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

class Blackbox:

    def __init__(self, oracle, substitute, data, trainer, lamb):
        self.oracle = oracle
        self.substitute = substitute
        self.data = []
        self.trainer = trainer
        self.lamb = lamb
        self.update(data)

    def update(self, data):
        labels = self.oracle(data)
        self.data.append(Synthetic(data, labels))

    def train(self):
        self.trainer(self.substitute, self.data, self.lamb)

class Trainer:

    def __init__(self,
                 lr, weight_decay,
                 batch_size, shuffle=True, **kwargs):
        """
        :param lr:  learning rate
        :param weight_decay:
        :param batch_size: batch_size for dataloader
        :param shuffle:  shuffle for dataloader
        :param kwargs:  other configs for dataloader
        """
        self.kwargs = {"batch_size":batch_size,
                       "shuffle":shuffle}
        self.kwargs.update(kwargs)
        self.criterion = nn.CrossEntropyLoss
        self.opti = self.optim(lr=lr, weight_decay=weight_decay)

    @quireone
    def optim(self, parameters, **kwargs):
        """
        quireone is decorator defined below
        :param parameters: net.parameteres()
        :param kwargs: other configs
        :return:
        """
        return torch.optim.SGD(parameters, **kwargs)

    def dataloader(self, dataset):
        return DataLoader(dataset, **self.kwargs)

    @staticmethod
    def calc_jacobian(out, inp):
        jacobian = torch.autograd.grad(out, inp, retain_graph=True)[0]
        return jacobian

    @staticmethod
    def sgn(matrix):
        return torch.sign(matrix)

    def newdata(self, outs, inps, labels, lamb):
        data = inps.data
        for i in range(len(labels)):
            out = outs[i, labels[i]]
            data += lamb * self.sgn(self.calc_jacobian(out, inps))

        return data

    def train(self, net, criterion, opti, dataloader, lamb=None,
             update=False):
        """
        :param net:
        :param criterion:
        :param opti:
        :param dataloader:
        :param lamb: lambda for update S
        :param update: if True, train will return the new data
        :return:
        """
        if update:
            assert lamb is not None, "lamb needed when updating"
            newd = torch.tensor([])
        for i, data in enumerate(dataloader):
            inps, labels = data
            inps.requires_grad_(True)
            outs = net(inps)
            loss = criterion(outs, labels)

            if update:
                new_samples = self.newdata(outs, inps, labels, lamb)
                newd = torch.cat((newd, new_samples))

            opti.zerograd()
            loss.backward()
            opti.step()
        if update:
            return newd

    def __call__(self, substitute, data, lamb):
        N = len(data)
        opti = self.opti(substitute.parameters())
        for i, item in enumerate(data):
            dataloader = self.dataloader(data)
            if i is N-1:
                return self.train(substitute, self.criterion,
                                  opti, dataloader, lamb, True)
            else:
                self.train(substitute, self.criterion,
                           opti, dataloader)

def quireone(func): #a decorator, for easy to define optimizer
    def wrapper1(*args, **kwargs):
        def wrapper2(arg):
            result = func(arg, *args, **kwargs)
            return result
        wrapper2.__doc__ = func.__doc__
        wrapper2.__name__ = func.__name__
        return wrapper2
    return wrapper1





"""
Adversarial Examples Improve Image Recognition
"""

import torch
import torch.nn as nn

class Mixturenorm1d(nn.Module):

    def __init__(self, rel, num_features:int, *args, **kwargs):
        super(Mixturenorm1d, self).__init__()
        self.norm1 = nn.BatchNorm1d(num_features,
                                    *args, **kwargs)
        self.norm2 = nn.BatchNorm1d(num_features,
                                    *args, **kwargs)
        self.rel = rel

    def forward(self, x):
        if self.rel.adv and self.rel.training:
            return self.norm2(x)
        else:
            return self.norm1(x)

    def __setattr__(self, name, value):
        """
        we should redefine the setattr method,
        or self.rel will be regard as a child of Mixturenorm.
        Hence, if we call instance.modules() or instance.children(),
        RecursionError: maximum recursion depth exceeded will be raised.
        :param name:
        :param value:
        :return:
        """
        if name is "rel":
            object.__setattr__(self, name, value)
        else:
            super(Mixturenorm1d, self).__setattr__(name, value)

class Mixturenorm2d(nn.Module):

    def __init__(self, rel, num_features:int, *args, **kwargs):
        super(Mixturenorm2d, self).__init__()
        self.norm1 = nn.BatchNorm2d(num_features,
                                    *args, **kwargs)
        self.norm2 = nn.BatchNorm2d(num_features,
                                    *args, **kwargs)
        self.rel = rel
    def forward(self, x):
        if self.rel.adv and self.rel.training:
            return self.norm2(x)
        else:
            return self.norm1(x)

    def __setattr__(self, name, value):
        if name is "rel":
            object.__setattr__(self, name, value)
        else:
            super(Mixturenorm2d, self).__setattr__(name, value)

if __name__ == "__main__":

    class Testnet(nn.Module):

        def __init__(self):
            super(Testnet, self).__init__()
            self.flag = False
            self.dense = nn.Sequential(
                nn.Linear(10, 20),
                Mixturenorm1d(self, 20),
                nn.ReLU(),
                nn.Linear(20, 1)
            )

        def forward(self, x, adv=False):
            self.adv = adv
            return self.dense(x)

    x = torch.rand((3, 10))
    test = Testnet()
    out = test(x, True)