Pytorch构建超分辨率模型——常用模块
阅读原文时间:2023年08月27日阅读:2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class UpscaleBlock(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super(UpscaleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.relu(x)
        return x


class SuperResolutionModel(nn.Module):
    def __init__(self, upscale_factor):
        super(SuperResolutionModel, self).__init__()
        self.conv1 = ConvBlock(3, 64, kernel_size=9, stride=1, padding=4)
        self.conv2 = ConvBlock(64, 32, kernel_size=1, stride=1, padding=0)
        self.upscale = UpscaleBlock(32, upscale_factor)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upscale(x)
        x = self.conv3(x)
        return x


class SuperResolutionDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder, input_transform, target_transform):
        self.dataset = ImageFolder(image_folder)
        self.input_transform = input_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        img, _ = self.dataset[index]
        target = self.target_transform(img)
        input = self.input_transform(target)
        return input, target

    def __len__(self):
        return len(self.dataset)


upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)


input_transform = transforms.Compose([
    transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),
    transforms.ToTensor()
])

target_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])


train_dataset = SuperResolutionDataset("path/to/train_data", input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)

val_dataset = SuperResolutionDataset("path/to/val_data", input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)


model.eval()
val_loss = 0.0

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.item()

val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")


model.eval()
val_loss = 0.0

with torch.no_grad():
    for inputs, targets in val_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.item()

val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器