DeepLabV3+语义分割实战
阅读原文时间:2023年07月10日阅读:1

DeepLabV3+语义分割实战

语义分割是计算机视觉的一项重要任务,本文使用Jittor框架实现了DeepLabV3+``语义分割模型。

DeepLabV3+论文:https://arxiv.org/pdf/1802.02611.pdf

完整代码:https://github.com/Jittor/deeplab-jittor

1.1 数据准备

VOC2012数据集是目标检测、语义分割等任务常用的数据集之一, 本文使用VOC数据集的2012 trainaug (train + sbd set)作为训练集,2012 val set作为测试集。

VOC数据集中的物体共包括20个前景类别:'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' 和背景类别

最终数据集的文件组织如下。

# 

根目录

|----voc_aug

|    |----datalist

|    |    |----train.txt

|    |    |----val.txt

|    |----images

|    |----annotations

1.2 数据加载

使用jittor.dataset.dataset的基类Dataset可以构造自己的数据集,需要实现__init____getitem__、函数。

  1. __init__: 定义数据路径,这里的data_root需设置为之前设定的 voc_augsplit 为 train val test 之一,表示选择训练集、验证集还是测试集。同时需要调用self.set_attr来指定数据集加载所需的参数batch_sizetotal_lenshuffle
  2. __getitem__: 返回单个item的数据。

import numpy as np

import os

from PIL import Image

import matplotlib.pyplot as plt

from jittor.dataset.dataset import Dataset, dataset_root

import jittor as jt

import os

import os.path as osp

from PIL import Image, ImageOps, ImageFilter

import numpy as np

import scipy.io as sio

import random

def fetch(image_path, label_path):

    

        

    

        

    

def scale(image, label):

    

    

    

    

    

    

    

    

def pad(image, label):

    

    

    

    

    

    

    

def crop(image, label):

    

    

    

    

    

    

    

def normalize(image, label):

    

    

    

    

    

    

    

    

def flip(image, label):

    

        

        

    

class BaseDataset(Dataset):

    

        

        

        

        

        

        

        

        

        

        

        

        

            

        

            

            

            

            

            

            

        

        

        

    

        

class TrainDataset(BaseDataset):

    

        

    

        

        

        

        

        

        

        

        

        

        

        

        

class ValDataset(BaseDataset):

    

        

    

        

        

        

        

        

        

        

        

上图为DeepLabV3+论文给出的网络架构图。本文采用ResNebackbone。输入图像尺寸为513*513

整个网络可以分成 backbone aspp decoder 三个部分。

2.1 backbonb 这里使用最常见的ResNet,作为backbone并且在ResNet的最后两次使用空洞卷积来扩大感受野,其完整定义如下:

import jittor as jt

from jittor import nn

from jittor import Module

from jittor import init

from jittor.contrib import concat, argmax_pool

import time

class Bottleneck(Module):

    

    

        

        

        

        

                               

        

        

        

        

        

        

        

    

        

        

        

        

        

        

        

        

        

        

            

        

        

        

class ResNet(Module):

    

        

        

        

        

            

            

        

            

            

        

            

        

        

        

        

        

        

        

        

        

    

        

        

            

                

                          

                

            

        

        

        

        

            

        

    

        

        

            

                

                          

                

            

        

        

                            

        

        

            

                                

        

    

        

        

        

        

        

        

        

        

        

        

def resnet50(output_stride):

    

    

def resnet101(output_stride):

    

    

2.2 ASPP

即使用不同尺寸的 dilation conv 对 backbone 得到的 feature map 进行卷积,最后 concat 并整合得到新的特征。

import jittor as jt

from jittor import nn

from jittor import Module

from jittor import init

from jittor.contrib import concat

class Single_ASPPModule(Module):

    

        

        

                                            

        

        

    

        

        

        

        

class ASPP(Module):

    

        

        

        

            

        

            

        

            

        

        

        

        

        

                                             

                                             

                                             

        

        

        

        

    

        

        

        

        

        

        

        

        

        

        

        

        

class GlobalPooling (Module):

    

        

    

        

2.3 Decoder:

Decoder 将 ASPP 的特征放大后与 ResNet 的中间特征一起 concat, 得到最后分割所用的特征。

import jittor as jt

from jittor import nn

from jittor import Module

from jittor import init

from jittor.contrib import concat

import time

class Decoder(nn.Module):

    

        

        

        

        

        

        

                                       

                                       

                                       

                                       

                                       

                                       

                                       

                                       

    

        

        

        

        

        

        

        

2.4 完整的模型整合如下: 即将以上部分通过一个类连接起来。

import jittor as jt

from jittor import nn

from jittor import Module

from jittor import init

from jittor.contrib import concat

from decoder import Decoder

from aspp import ASPP

from backbone import resnet50, resnet101

class DeepLab(Module):

    

        

        

        

        

    

        

        

        

        

        

3.1 模型训练参数设定如下:

# Learning parameters

batch_size = 8

learning_rate = 0.005

momentum = 0.9

weight_decay = 1e-4

epochs = 50

3.2 定义模型、优化器、数据加载器。

model = DeepLab(output_stride=16, num_classes=21)

optimizer = nn.SGD(model.parameters(),

                   

                   

                   

train_loader = TrainDataset(data_root='/vocdata/',

                            

                            

                            

val_loader = ValDataset(data_root='/vocdata/',

                        

                        

                        

3.3 模型训练与验证

# lr scheduler

def poly_lr_scheduler(opt, init_lr, iter, epoch, max_iter, max_epoch):

    

    

# train function

def train(model, train_loader, optimizer, epoch, init_lr):

    

    

    

        

        

        

        

        

        

# val function

# we omit evaluator code and you can

def val (model, val_loader, epoch, evaluator):

    

    

    

        

        

        

        

        

        

        

    

    

    

    

    

    

        

    

                FWIoU = {} Best Miou = {}'.format(epoch, mIoU, Acc, Acc_class, FWIoU, best_miou))

3.4 evaluator 写法:使用混淆矩阵计算 Pixel accuracy 和 mIoU。

class Evaluator(object):

    

        

        

    

        

        

    

        

        

        

    

        

                 

                 

                 

        

        

    

        

        

                    

                    

        

        

    

        

        

        

        

        

    

        

        

    

        

3.5 训练入口函数

epochs = 50

evaluator = Evaluator(21)

train_loader = TrainDataset(data_root='/voc/data/path/', split='train', batch_size=8, shuffle=True)

val_loader = ValDataset(data_root='/voc/data/path/', split='val', batch_size=1, shuffle=False)

learning_rate = 0.005

momentum = 0.9

weight_decay = 1e-4

optimizer = nn.SGD(model.parameters(), learning_rate, momentum, weight_decay)

for epoch in range (epochs):

    

    
  1. pytorch-deeplab-xception
  2. Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation