|
from torch import optim |
|
from torch.utils.data import DataLoader |
|
from torchvision.utils import save_image |
|
from tqdm import trange |
|
|
|
from Dataloader import * |
|
from .utils import image_quality |
|
from .utils.cls import CyclicLR |
|
from .utils.prepare_images import * |
|
|
|
train_folder = "./dataset/train" |
|
test_folder = "./dataset/test" |
|
|
|
img_dataset = ImageDBData( |
|
db_file="dataset/images.db", |
|
db_table="train_images_size_128_noise_1_rgb", |
|
max_images=24, |
|
) |
|
img_data = DataLoader(img_dataset, batch_size=6, shuffle=True, num_workers=6) |
|
|
|
total_batch = len(img_data) |
|
print(len(img_dataset)) |
|
|
|
test_dataset = ImageDBData( |
|
db_file="dataset/test2.db", |
|
db_table="test_images_size_128_noise_1_rgb", |
|
max_images=None, |
|
) |
|
num_test = len(test_dataset) |
|
test_data = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1) |
|
|
|
criteria = nn.L1Loss() |
|
|
|
model = CARN_V2( |
|
color_channels=3, |
|
mid_channels=64, |
|
conv=nn.Conv2d, |
|
single_conv_size=3, |
|
single_conv_group=1, |
|
scale=2, |
|
activation=nn.LeakyReLU(0.1), |
|
SEBlock=True, |
|
repeat_blocks=3, |
|
atrous=(1, 1, 1), |
|
) |
|
|
|
model.total_parameters() |
|
|
|
|
|
|
|
|
|
|
|
model = network_to_half(model) |
|
model = model.cuda() |
|
model.load_state_dict(torch.load("CARN_model_checkpoint.pt")) |
|
|
|
learning_rate = 1e-4 |
|
weight_decay = 1e-6 |
|
optimizer = optim.Adam( |
|
model.parameters(), lr=learning_rate, weight_decay=weight_decay, amsgrad=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
last_iter = -1 |
|
scheduler = CyclicLR( |
|
optimizer, |
|
base_lr=1e-4, |
|
max_lr=1e-4, |
|
step_size=3 * total_batch, |
|
mode="triangular", |
|
last_batch_iteration=last_iter, |
|
) |
|
train_loss = [] |
|
train_ssim = [] |
|
train_psnr = [] |
|
|
|
test_loss = [] |
|
test_ssim = [] |
|
test_psnr = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
counter = 0 |
|
iteration = 2 |
|
ibar = trange( |
|
iteration, |
|
ascii=True, |
|
maxinterval=1, |
|
postfix={"avg_loss": 0, "train_ssim": 0, "test_ssim": 0}, |
|
) |
|
for i in ibar: |
|
|
|
|
|
|
|
for index, batch in enumerate(img_data): |
|
scheduler.batch_step() |
|
lr_img, hr_img = batch |
|
lr_img = lr_img.cuda().half() |
|
hr_img = hr_img.cuda() |
|
|
|
|
|
optimizer.zero_grad() |
|
outputs = model.forward(lr_img) |
|
outputs = outputs.float() |
|
loss = criteria(outputs, hr_img) |
|
|
|
optimizer.backward(loss) |
|
|
|
optimizer.step() |
|
|
|
counter += 1 |
|
|
|
|
|
ssim = image_quality.msssim(outputs, hr_img).item() |
|
psnr = image_quality.psnr(outputs, hr_img).item() |
|
|
|
ibar.set_postfix( |
|
ratio=index / total_batch, |
|
loss=loss.item(), |
|
ssim=ssim, |
|
batch=index, |
|
psnr=psnr, |
|
lr=scheduler.current_lr, |
|
) |
|
train_loss.append(loss.item()) |
|
train_ssim.append(ssim) |
|
train_psnr.append(psnr) |
|
|
|
|
|
|
|
|
|
|
|
if (counter + 1) % 500 == 0: |
|
torch.save(model.state_dict(), "CARN_model_checkpoint.pt") |
|
torch.save(optimizer.state_dict(), "CARN_adam_checkpoint.pt") |
|
torch.save(train_loss, "train_loss.pt") |
|
torch.save(train_ssim, "train_ssim.pt") |
|
torch.save(train_psnr, "train_psnr.pt") |
|
torch.save(scheduler.last_batch_iteration, "CARN_scheduler_last_iter.pt") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), "CARN_model_checkpoint.pt") |
|
|
|
torch.save(optimizer.state_dict(), "CARN_adam_checkpoint.pt") |
|
torch.save(train_loss, "train_loss.pt") |
|
torch.save(train_ssim, "train_ssim.pt") |
|
torch.save(train_psnr, "train_psnr.pt") |
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
ssim = [] |
|
batch_loss = [] |
|
psnr = [] |
|
for index, test_batch in enumerate(test_data): |
|
lr_img, hr_img = test_batch |
|
lr_img = lr_img.cuda() |
|
hr_img = hr_img.cuda() |
|
|
|
lr_img_up = model(lr_img) |
|
lr_img_up = lr_img_up.float() |
|
loss = criteria(lr_img_up, hr_img) |
|
|
|
save_image([lr_img_up[0], hr_img[0]], f"check_test_imgs/{index}.png") |
|
batch_loss.append(loss.item()) |
|
ssim.append(image_quality.msssim(lr_img_up, hr_img).item()) |
|
psnr.append(image_quality.psnr(lr_img_up, hr_img).item()) |
|
|
|
test_ssim.append(np.mean(ssim)) |
|
test_loss.append(np.mean(batch_loss)) |
|
test_psnr.append(np.mean(psnr)) |
|
|
|
torch.save(test_loss, "test_loss.pt") |
|
torch.save(test_ssim, "test_ssim.pt") |
|
torch.save(test_psnr, "test_psnr.pt") |
|
|
|
|
|
|
|
|
|
|