|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
|
|
import cv2 |
|
import torch |
|
from torch import nn |
|
|
|
import imgproc |
|
import model |
|
from utils import load_state_dict |
|
|
|
model_names = sorted( |
|
name for name in model.__dict__ if |
|
name.islower() and not name.startswith("__") and callable(model.__dict__[name])) |
|
|
|
|
|
def choice_device(device_type: str) -> torch.device: |
|
|
|
if device_type == "cuda": |
|
device = torch.device("cuda", 0) |
|
else: |
|
device = torch.device("cpu") |
|
return device |
|
|
|
|
|
def build_model(model_arch_name: str, device: torch.device) -> nn.Module: |
|
|
|
sr_model = model.__dict__[model_arch_name](in_channels=3, |
|
out_channels=3, |
|
channels=64, |
|
num_rcb=16) |
|
sr_model = sr_model.to(device=device) |
|
|
|
return sr_model |
|
|
|
|
|
def main(args): |
|
device = choice_device(args.device_type) |
|
|
|
|
|
sr_model = build_model(args.model_arch_name, device) |
|
print(f"Build `{args.model_arch_name}` model successfully.") |
|
|
|
|
|
sr_model = load_state_dict(sr_model, args.model_weights_path) |
|
print(f"Load `{args.model_arch_name}` model weights `{os.path.abspath(args.model_weights_path)}` successfully.") |
|
|
|
|
|
sr_model.eval() |
|
|
|
lr_tensor = imgproc.preprocess_one_image(args.inputs_path, device) |
|
|
|
|
|
with torch.no_grad(): |
|
sr_tensor = sr_model(lr_tensor) |
|
|
|
|
|
sr_image = imgproc.tensor_to_image(sr_tensor, False, False) |
|
sr_image = cv2.cvtColor(sr_image, cv2.COLOR_RGB2BGR) |
|
cv2.imwrite(args.output_path, sr_image) |
|
|
|
print(f"SR image save to `{args.output_path}`") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Using the model generator super-resolution images.") |
|
parser.add_argument("--model_arch_name", |
|
type=str, |
|
default="srresnet_x4") |
|
parser.add_argument("--inputs_path", |
|
type=str, |
|
default="./figure/comic_lr.png", |
|
help="Low-resolution image path.") |
|
parser.add_argument("--output_path", |
|
type=str, |
|
default="./figure/comic_sr.png", |
|
help="Super-resolution image path.") |
|
parser.add_argument("--model_weights_path", |
|
type=str, |
|
default="./results/pretrained_models/SRGAN_x4-ImageNet-8c4a7569.pth.tar", |
|
help="Model weights file path.") |
|
parser.add_argument("--device_type", |
|
type=str, |
|
default="cpu", |
|
choices=["cpu", "cuda"]) |
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|