|
import benchmark |
|
import matplotlib.pyplot as plt |
|
import opensr_test |
|
|
|
from swin2_mose.utils import load_swin2_mose, load_config, run_swin2_mose |
|
|
|
device = "cuda" |
|
path = 'swin2_mose/weights/config-70.yml' |
|
model_weights = "swin2_mose/weights/model-70.pt" |
|
|
|
|
|
cfg = load_config(path) |
|
|
|
|
|
model = load_swin2_mose(model_weights, cfg) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
index = 2 |
|
dataset = opensr_test.load("venus") |
|
lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"] |
|
results = run_swin2_mose(model, lr_dataset[index], hr_dataset[index], device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark.create_geotiff(model, run_swin2_mose, "all", "swin2_mose/") |
|
|