csaybar commited on
Commit
689312e
·
verified ·
1 Parent(s): 6c08128

Upload 4 files

Browse files
opensr_model/__pycache__/utils.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
 
opensr_model/run.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import opensr_test
2
+ import matplotlib.pyplot as plt
3
+ from utils import create_opensr_model, run_opensr_model
4
+
5
+ # Load the model
6
+ model = create_opensr_model(device="cpu")
7
+
8
+ # Load the dataset
9
+ dataset = opensr_test.load("naip")
10
+ lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
11
+
12
+ # Run the model
13
+ results = run_opensr_model(
14
+ model=model,
15
+ lr=lr_dataset[7],
16
+ hr=hr_dataset[7],
17
+ device="cpu"
18
+ )
19
+
20
+ # Display the results
21
+ fig, ax = plt.subplots(1, 3, figsize=(10, 5))
22
+ ax[0].imshow(results["lr"].transpose(1, 2, 0)/3000)
23
+ ax[0].set_title("LR")
24
+ ax[0].axis("off")
25
+ ax[1].imshow(results["sr"].transpose(1, 2, 0)/3000)
26
+ ax[1].set_title("SR")
27
+ ax[1].axis("off")
28
+ ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000)
29
+ ax[2].set_title("HR")
30
+ plt.show()
opensr_model/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import opensr_model
4
+ from typing import Union
5
+
6
+ def create_opensr_model(
7
+ device: Union[str, torch.device] = "cpu"
8
+ ) -> opensr_model:
9
+ """ Create the super image model
10
+ Returns:
11
+ HanModel: The super image model
12
+ """
13
+ model = opensr_model.SRLatentDiffusion(device=device)
14
+ model.load_pretrained("./weights/opensr_10m_v4_v5.ckpt")
15
+ model.eval()
16
+ return model
17
+
18
+
19
+ def run_opensr_model(
20
+ model: opensr_model,
21
+ lr: np.ndarray,
22
+ hr: np.ndarray,
23
+ device: Union[str, torch.device] = "cpu"
24
+ ) -> dict:
25
+ # Convert the input to torch tensors
26
+ lr_img = torch.from_numpy(lr[[3, 2, 1, 7]] / 10000).to(device).float()
27
+ hr_img = hr[0:3]
28
+
29
+ if lr_img.shape[1] == 121:
30
+ # add padding
31
+ lr_img = torch.nn.functional.pad(
32
+ lr_img[None],
33
+ pad=(3, 4, 3, 4),
34
+ mode='reflect'
35
+ ).squeeze()
36
+
37
+ # Run the model
38
+ with torch.no_grad():
39
+ sr_img = model(lr_img[None]).squeeze()
40
+
41
+ # take out padding
42
+ lr_img = lr_img[:, 3:-4, 3:-4]
43
+ sr_img = sr_img[:, 3*4:-4*4, 3*4:-4*4]
44
+ else:
45
+ # Run the model
46
+ with torch.no_grad():
47
+ sr_img = model(lr_img[None]).squeeze()
48
+
49
+ # Convert the output to numpy
50
+ lr_img = (lr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16)
51
+ sr_img = (sr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16)
52
+ hr_img = hr_img
53
+
54
+ # Return the results
55
+ return {
56
+ "lr": lr_img,
57
+ "sr": sr_img,
58
+ "hr": hr_img
59
+ }
opensr_model/weights/opensr_10m_v4_v5.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee86e546d7ecb2aa564c4f605d6176d9d31a1cf8e4ea0c6877e6d2e88f0222cd
3
+ size 2109942091