resnet18训练自定义数据集
阅读原文时间:2023年07月16日阅读:1

目录结构

dogsData.py

import json

import torch
import os, glob
import random, csv

from PIL import Image
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms
from torchvision.transforms import InterpolationMode

class Dogs(Dataset):

def \_\_init\_\_(self, root, resize, mode):  
    super().\_\_init\_\_()  
    self.root = root  
    self.resize = resize  
    self.nameLable = {}  
    for name in sorted(os.listdir(os.path.join(root))):  
        if not os.path.isdir(os.path.join(root, name)):  
            continue  
        self.nameLable\[name\] = len(self.nameLable.keys())

    if not os.path.exists(os.path.join(self.root, 'label.txt')):  
        with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f:  
            f.write(json.dumps(self.nameLable, ensure\_ascii=False))

    # print(self.nameLable)  
    self.images, self.labels = self.load\_csv('images.csv')  
    # print(self.labels)

    if mode == 'train':  
        self.images = self.images\[:int(0.8\*len(self.images))\]  
        self.labels = self.labels\[:int(0.8\*len(self.labels))\]  
    elif mode == 'val':  
        self.images = self.images\[int(0.8\*len(self.images)):int(0.9\*len(self.images))\]  
        self.labels = self.labels\[int(0.8\*len(self.labels)):int(0.9\*len(self.labels))\]  
    else:  
        self.images = self.images\[int(0.9\*len(self.images)):\]  
        self.labels = self.labels\[int(0.9\*len(self.labels)):\]

def load\_csv(self, filename):

    if not os.path.exists(os.path.join(self.root, filename)):  
        images = \[\]  
        for name in self.nameLable.keys():  
            images += glob.glob(os.path.join(self.root, name, '\*.png'))  
            images += glob.glob(os.path.join(self.root, name, '\*.jpg'))  
            images += glob.glob(os.path.join(self.root, name, '\*.jpeg'))  
        # print(len(images))

        random.shuffle(images)  
        with open(os.path.join(self.root, filename), mode='w', newline='') as f:  
            writer = csv.writer(f)  
            for img in images:  
                name = img.split(os.sep)\[-2\]  
                label = self.nameLable\[name\]  
                writer.writerow(\[img, label\])  
        print('csv write succesful')

    images, labels = \[\], \[\]  
    with open(os.path.join(self.root, filename)) as f:  
        reader = csv.reader(f)  
        for row in reader:  
            img, label = row  
            label = int(label)  
            images.append(img)  
            labels.append(label)

    assert len(images) == len(labels)

    return images, labels

def denormalize(self, x\_hat):  
    mean = \[0.485, 0.456, 0.406\]  
    std = \[0.229, 0.224, 0.225\]  
    # x\_hot = (x-mean)/std  
    # x = x\_hat \* std = mean  
    # x : \[c, w, h\]  
    # mean \[3\] => \[3, 1, 1\]  
    mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)  
    std = torch.tensor(std).unsqueeze(1).unsqueeze(1)

    x = x\_hat \* std + mean  
    return x

def \_\_len\_\_(self):  
    return len(self.images)

def \_\_getitem\_\_(self, idx):  
    # print(idx, len(self.images), len(self.labels))  
    img, label = self.images\[idx\], self.labels\[idx\]

    # 将字符串路径转换为tensor数据  
    # print(self.resize, type(self.resize))  
    tf = transforms.Compose(\[  
        lambda x: Image.open(x).convert('RGB'),  
        transforms.Resize((int(self.resize\*1.25), int(self.resize\*1.25))),  
        transforms.RandomRotation(15),  
        transforms.CenterCrop(self.resize),  
        transforms.ToTensor(),  
        transforms.Normalize(mean=\[0.485, 0.456, 0.406\], std=\[0.229, 0.224, 0.225\])  
    \])  
    img = tf(img)

    label = torch.tensor(label)

    return img, label

def main():

import visdom  
import time

viz = visdom.Visdom()

# func1 通用  
db = Dogs('Images\_Data\_Dog', 224, 'train')  
# 取一张  
# x,y = next(iter(db))  
# print(x.shape, y)  
# viz.image(db.denormalize(x), win='sample\_x', opts=dict(title='sample\_x'))

# 取一个batch  
loader = DataLoader(db, batch\_size=32, shuffle=True, num\_workers=8)  
print(len(loader))  
print(db.nameLable)  
# for x, y in loader:  
#     # print(x.shape, y)  
#     viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))  
#     viz.text(str(y.numpy()), win='label', opts=dict(title='batch\_y'))  
#     time.sleep(10)

# # fun2  
# import torchvision  
# tf = transforms.Compose(\[  
#     transforms.Resize((64, 64)),  
#     transforms.RandomRotation(15),  
#     transforms.ToTensor(),  
# \])  
# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)  
# loader = DataLoader(db, batch\_size=32, shuffle=True)  
# print(len(loader))  
# for x, y in loader:  
#     # print(x.shape, y)  
#     viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))  
#     viz.text(str(y.numpy()), win='label', opts=dict(title='batch\_y'))  
#     time.sleep(10)

if __name__ == '__main__':
main()

utils.py

import torch
from torch import nn

class Flatten(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):  
    shape = torch.prod(torch.tensor(x.shape\[1:\])).item()  
    return x.view(-1, shape)

train.py

import os
import sys
base_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(base_path)
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)
import torch
import visdom
from torch import optim, nn
import torchvision

from torch.utils.data import DataLoader

from dogs_train.utils import Flatten
from dogsData import Dogs

from torchvision.models import resnet18

viz = visdom.Visdom()

batchsz = 32
lr = 1e-3
epochs = 20

device = torch.device('cuda')
torch.manual_seed(1234)

train_db = Dogs('Images_Data_Dog', 224, mode='train')
val_db = Dogs('Images_Data_Dog', 224, mode='val')
test_db = Dogs('Images_Data_Dog', 224, mode='test')

train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)

def evalute(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
logist = model(x)
pred = logist.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct/total

def main():

# model = ResNet18(5).to(device)  
trained\_model = resnet18(pretrained=True)  
model = nn.Sequential(\*list(trained\_model.children())\[:-1\],  
                      Flatten(),  # \[b, 512, 1, 1\] => \[b, 512\]  
                      nn.Linear(512, 27)  
                      ).to(device)

x = torch.randn(2, 3, 224, 224).to(device)  
print(model(x).shape)

optimizer = optim.Adam(model.parameters(), lr=lr)  
criteon = nn.CrossEntropyLoss()

best\_acc, best\_epoch = 0, 0  
global\_step = 0  
viz.line(\[0\], \[-1\], win='loss', opts=dict(title='loss'))  
viz.line(\[0\], \[-1\], win='val\_acc', opts=dict(title='val\_acc'))  
for epoch in range(epochs):

    for step, (x, y) in enumerate(train\_loader):  
        x = x.to(device)  
        y = y.to(device)

        logits = model(x)  
        loss = criteon(logits, y)

        optimizer.zero\_grad()  
        loss.backward()  
        optimizer.step()  
        viz.line(\[loss.item()\], \[global\_step\], win='loss', update='append')  
        global\_step += 1  
    if epoch % 2 == 0:  
        val\_acc = evalute(model, val\_loader)  
        if val\_acc > best\_acc:  
            best\_acc = val\_acc  
            best\_epoch = epoch  
            torch.save(model.state\_dict(), 'best.mdl')

            viz.line(\[val\_acc\], \[global\_step\], win='val\_acc', update='append')

print('best acc', best\_acc, 'best epoch', best\_epoch)

model.load\_state\_dict(torch.load('best.mdl'))  
print('loader from ckpt')

test\_acc = evalute(model, test\_loader)  
print(test\_acc)

if __name__ == '__main__':
main()