目录结构
dogsData.py
import json
import torch
import os, glob
import random, csv
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
class Dogs(Dataset):
def \_\_init\_\_(self, root, resize, mode):
super().\_\_init\_\_()
self.root = root
self.resize = resize
self.nameLable = {}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.nameLable\[name\] = len(self.nameLable.keys())
if not os.path.exists(os.path.join(self.root, 'label.txt')):
with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f:
f.write(json.dumps(self.nameLable, ensure\_ascii=False))
# print(self.nameLable)
self.images, self.labels = self.load\_csv('images.csv')
# print(self.labels)
if mode == 'train':
self.images = self.images\[:int(0.8\*len(self.images))\]
self.labels = self.labels\[:int(0.8\*len(self.labels))\]
elif mode == 'val':
self.images = self.images\[int(0.8\*len(self.images)):int(0.9\*len(self.images))\]
self.labels = self.labels\[int(0.8\*len(self.labels)):int(0.9\*len(self.labels))\]
else:
self.images = self.images\[int(0.9\*len(self.images)):\]
self.labels = self.labels\[int(0.9\*len(self.labels)):\]
def load\_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
images = \[\]
for name in self.nameLable.keys():
images += glob.glob(os.path.join(self.root, name, '\*.png'))
images += glob.glob(os.path.join(self.root, name, '\*.jpg'))
images += glob.glob(os.path.join(self.root, name, '\*.jpeg'))
# print(len(images))
random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)\[-2\]
label = self.nameLable\[name\]
writer.writerow(\[img, label\])
print('csv write succesful')
images, labels = \[\], \[\]
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
def denormalize(self, x\_hat):
mean = \[0.485, 0.456, 0.406\]
std = \[0.229, 0.224, 0.225\]
# x\_hot = (x-mean)/std
# x = x\_hat \* std = mean
# x : \[c, w, h\]
# mean \[3\] => \[3, 1, 1\]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
x = x\_hat \* std + mean
return x
def \_\_len\_\_(self):
return len(self.images)
def \_\_getitem\_\_(self, idx):
# print(idx, len(self.images), len(self.labels))
img, label = self.images\[idx\], self.labels\[idx\]
# 将字符串路径转换为tensor数据
# print(self.resize, type(self.resize))
tf = transforms.Compose(\[
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize\*1.25), int(self.resize\*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=\[0.485, 0.456, 0.406\], std=\[0.229, 0.224, 0.225\])
\])
img = tf(img)
label = torch.tensor(label)
return img, label
def main():
import visdom
import time
viz = visdom.Visdom()
# func1 通用
db = Dogs('Images\_Data\_Dog', 224, 'train')
# 取一张
# x,y = next(iter(db))
# print(x.shape, y)
# viz.image(db.denormalize(x), win='sample\_x', opts=dict(title='sample\_x'))
# 取一个batch
loader = DataLoader(db, batch\_size=32, shuffle=True, num\_workers=8)
print(len(loader))
print(db.nameLable)
# for x, y in loader:
# # print(x.shape, y)
# viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch\_y'))
# time.sleep(10)
# # fun2
# import torchvision
# tf = transforms.Compose(\[
# transforms.Resize((64, 64)),
# transforms.RandomRotation(15),
# transforms.ToTensor(),
# \])
# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
# loader = DataLoader(db, batch\_size=32, shuffle=True)
# print(len(loader))
# for x, y in loader:
# # print(x.shape, y)
# viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch\_y'))
# time.sleep(10)
if __name__ == '__main__':
main()
utils.py
import torch
from torch import nn
class Flatten(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape\[1:\])).item()
return x.view(-1, shape)
train.py
import os
import sys
base_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(base_path)
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)
import torch
import visdom
from torch import optim, nn
import torchvision
from torch.utils.data import DataLoader
from dogs_train.utils import Flatten
from dogsData import Dogs
from torchvision.models import resnet18
viz = visdom.Visdom()
batchsz = 32
lr = 1e-3
epochs = 20
device = torch.device('cuda')
torch.manual_seed(1234)
train_db = Dogs('Images_Data_Dog', 224, mode='train')
val_db = Dogs('Images_Data_Dog', 224, mode='val')
test_db = Dogs('Images_Data_Dog', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
def evalute(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
logist = model(x)
pred = logist.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct/total
def main():
# model = ResNet18(5).to(device)
trained\_model = resnet18(pretrained=True)
model = nn.Sequential(\*list(trained\_model.children())\[:-1\],
Flatten(), # \[b, 512, 1, 1\] => \[b, 512\]
nn.Linear(512, 27)
).to(device)
x = torch.randn(2, 3, 224, 224).to(device)
print(model(x).shape)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
best\_acc, best\_epoch = 0, 0
global\_step = 0
viz.line(\[0\], \[-1\], win='loss', opts=dict(title='loss'))
viz.line(\[0\], \[-1\], win='val\_acc', opts=dict(title='val\_acc'))
for epoch in range(epochs):
for step, (x, y) in enumerate(train\_loader):
x = x.to(device)
y = y.to(device)
logits = model(x)
loss = criteon(logits, y)
optimizer.zero\_grad()
loss.backward()
optimizer.step()
viz.line(\[loss.item()\], \[global\_step\], win='loss', update='append')
global\_step += 1
if epoch % 2 == 0:
val\_acc = evalute(model, val\_loader)
if val\_acc > best\_acc:
best\_acc = val\_acc
best\_epoch = epoch
torch.save(model.state\_dict(), 'best.mdl')
viz.line(\[val\_acc\], \[global\_step\], win='val\_acc', update='append')
print('best acc', best\_acc, 'best epoch', best\_epoch)
model.load\_state\_dict(torch.load('best.mdl'))
print('loader from ckpt')
test\_acc = evalute(model, test\_loader)
print(test\_acc)
if __name__ == '__main__':
main()
手机扫一扫
移动阅读更方便
你可能感兴趣的文章