import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.models import vgg19
from torchvision.datasets import ImageFolder
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class UpscaleBlock(nn.Module):
def __init__(self, in_channels, scale_factor):
super(UpscaleBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
self.pixel_shuffle = nn.PixelShuffle(scale_factor)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.relu(x)
return x
class SuperResolutionModel(nn.Module):
def __init__(self, upscale_factor):
super(SuperResolutionModel, self).__init__()
self.conv1 = ConvBlock(3, 64, kernel_size=9, stride=1, padding=4)
self.conv2 = ConvBlock(64, 32, kernel_size=1, stride=1, padding=0)
self.upscale = UpscaleBlock(32, upscale_factor)
self.conv3 = nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.upscale(x)
x = self.conv3(x)
return x
class SuperResolutionDataset(torch.utils.data.Dataset):
def __init__(self, image_folder, input_transform, target_transform):
self.dataset = ImageFolder(image_folder)
self.input_transform = input_transform
self.target_transform = target_transform
def __getitem__(self, index):
img, _ = self.dataset[index]
target = self.target_transform(img)
input = self.input_transform(target)
return input, target
def __len__(self):
return len(self.dataset)
upscale_factor = 2
model = SuperResolutionModel(upscale_factor).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
input_transform = transforms.Compose([
transforms.Resize((256 // upscale_factor, 256 // upscale_factor), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
target_transform = transforms.Compose([
transforms.Resize((256, 256), interpolation=TF.InterpolationMode.BICUBIC),
transforms.ToTensor()
])
train_dataset = SuperResolutionDataset("path/to/train_data", input_transform, target_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_dataset = SuperResolutionDataset("path/to/val_data", input_transform, target_transform)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
val_loss += loss.item()
val_loss /= len(val_loader)
print(f"Validation Loss: {val_loss:.4f}")
手机扫一扫
移动阅读更方便
你可能感兴趣的文章