[个人总结]pytorch中用checkpoint设置恢复,在恢复后的acc上升
阅读原文时间:2023年07月09日阅读:1

原因是因为checkpoint设置好的确是保存了相关字段。但是其中设置的train_dataset却已经走过了epoch轮,当你再继续训练时候,train_dataset是从第一个load_data开始。

# -*- coding:utf-8 -*-
import os
import numpy as np
import torch
import cv2
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from matplotlib import pyplot as plt
import os
from PIL import Image
os.environ ['KMP_DUPLICATE_LIB_OK'] ='True'
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)
fmap_block = list()
import torch.nn.functional as F
grad_block = list()
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

torch.manual_seed(1)  # 设置随机种子
rmb_label = {"1": 0, "100": 1}
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 2)

    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool1(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1
checkpoint_interval=5

# ============================ step 1/5 数据 ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "rmb_split"))
if not os.path.exists(split_dir):
    raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))
train_dir = os.path.join(split_dir, "train")

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
net = Net()
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数
# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器

checkpointdict = torch.load('./checkpoint4.pkl')
net.load_state_dict(checkpointdict["model_state_dict"])
optimizer.load_state_dict(checkpointdict["optimizer_state_dict"])
startepoch = checkpointdict["epoch"]
# ============================ step 5/5 训练 ============================
train_curve = list()
iter_count = 0

for epoch in range(startepoch+1,MAX_EPOCH):
    loss_mean = 0.
    correct = 0.
    total = 0.
    for counti in range(6):
        for i, data in enumerate(train_loader):
            if counti <5:
                continue
            else:
                iter_count += 1
                # forward
                inputs, labels = data
                outputs = net(inputs)
                # backward
                optimizer.zero_grad()
                loss = criterion(outputs, labels)
                loss.backward()
                # update weights
                optimizer.step()
                # 统计分类情况
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).squeeze().sum().numpy()
                # 打印训练信息
                loss_mean += loss.item()
                train_curve.append(loss.item())
                if (i+1) % log_interval == 0:
                    loss_mean = loss_mean / log_interval
                    print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                        epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
                    loss_mean = 0.
            # if ((epoch + 1) % checkpoint_interval == 0):
            #     checkpoint = {"model_state_dict": net.state_dict(),
            #                   "optimizer_state_dict": optimizer.state_dict(),
            #                   "epoch": epoch}
            #     path_checkpoint = './checkpoint{}.pkl'.format(epoch)
            #     torch.save(checkpoint, path_checkpoint)
            # if ((epoch + 1) % 5 == 0):
            #     print("退出")
            #     break