rnwang commited on
Commit
c8bce00
·
1 Parent(s): 002720e

infer demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +6 -0
  3. Marvelous_Maisel.jpg +0 -0
  4. README.md +1 -1
  5. app.py +144 -0
  6. data.py +205 -0
  7. data/webcam/input/00000.png +0 -0
  8. data/webcam/input/00001.png +0 -0
  9. data/webcam/input/00002.png +0 -0
  10. data/webcam/input/00003.png +0 -0
  11. data/webcam/input/00004.png +0 -0
  12. data/webcam/input/00005.png +0 -0
  13. data/webcam/input/00006.png +0 -0
  14. data/webcam/input/00007.png +0 -0
  15. data/webcam/input/00008.png +0 -0
  16. data/webcam/input/00009.png +0 -0
  17. data/webcam/input/00010.png +0 -0
  18. data/webcam/input/00011.png +0 -0
  19. data/webcam/input/00012.png +0 -0
  20. data/webcam/input/00013.png +0 -0
  21. data/webcam/input/00014.png +0 -0
  22. data/webcam/input/00015.png +0 -0
  23. data/webcam/input/00016.png +0 -0
  24. data/webcam/input/00017.png +0 -0
  25. data/webcam/input/00018.png +0 -0
  26. data/webcam/input/00019.png +0 -0
  27. data/webcam/input/00020.png +0 -0
  28. data/webcam/input/00021.png +0 -0
  29. data/webcam/input/00022.png +0 -0
  30. data/webcam/input/00023.png +0 -0
  31. data/webcam/input/00024.png +0 -0
  32. data/webcam/input/00025.png +0 -0
  33. data/webcam/input/00026.png +0 -0
  34. data/webcam/input/00027.png +0 -0
  35. data/webcam/input/00028.png +0 -0
  36. data/webcam/input/00029.png +0 -0
  37. data/webcam/input/00030.png +0 -0
  38. data/webcam/input/00031.png +0 -0
  39. data/webcam/input/00032.png +0 -0
  40. data/webcam/input/00033.png +0 -0
  41. data/webcam/input/00034.png +0 -0
  42. data/webcam/input/00035.png +0 -0
  43. data/webcam/input/00036.png +0 -0
  44. data/webcam/input/00037.png +0 -0
  45. data/webcam/input/00038.png +0 -0
  46. data/webcam/input/00039.png +0 -0
  47. data/webcam/input/00040.png +0 -0
  48. data/webcam/input/00041.png +0 -0
  49. data/webcam/input/00042.png +0 -0
  50. data/webcam/input/00043.png +0 -0
.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ result/*
2
+ input/*
3
+ output/*
4
+ *.gif
5
+ nc_workspace/*
6
+ flagged/*
Marvelous_Maisel.jpg ADDED
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: BigDL-Nano Inference
3
  emoji: 🌖
4
- colorFrom: pink
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.0.13
 
1
  ---
2
  title: BigDL-Nano Inference
3
  emoji: 🌖
4
+ colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.0.13
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import time
4
+ from data import write_image_tensor, PatchDataModule, prepare_data, image2tensor, tensor2image
5
+ import torch
6
+ from tqdm import tqdm
7
+ from bigdl.nano.pytorch.trainer import Trainer
8
+ from torch.utils.data import DataLoader
9
+ from pathlib import Path
10
+ from torch.utils.data import Dataset
11
+ import datetime
12
+
13
+
14
+ device = 'cpu'
15
+ dtype = torch.float32
16
+ generator = torch.load("models/generator.pt")
17
+ generator.eval()
18
+ generator.to(device, dtype)
19
+ params = {'batch_size': 1,
20
+ 'num_workers': 0}
21
+
22
+
23
+ class ImageDataset(Dataset):
24
+ def __init__(self, img):
25
+ self.imgs = [image2tensor(img)]
26
+ def __getitem__(self, idx: int) -> dict:
27
+ return self.imgs[idx]
28
+
29
+ def __len__(self) -> int:
30
+ return len(self.imgs)
31
+
32
+
33
+ # quantize model
34
+ data_path = Path('data/webcam')
35
+ train_image_dd = prepare_data(data_path)
36
+ dm = PatchDataModule(train_image_dd, patch_size=2**6,
37
+ batch_size=2**3, patch_num=2**6)
38
+ train_loader = dm.train_dataloader()
39
+ train_loader_iter = iter(train_loader)
40
+ quantized_model = Trainer.quantize(generator, accelerator=None,
41
+ calib_dataloader=train_loader)
42
+
43
+
44
+ def original_transfer(input_img):
45
+ w, h, _ = input_img.shape
46
+ print(datetime.datetime.now())
47
+ print("input size: ", w, h)
48
+ # resize too large image
49
+ if w > 3000 or h > 3000:
50
+ ratio = min(3000 / w, 3000 / h)
51
+ w = int(w * ratio)
52
+ h = int(h * ratio)
53
+ if w % 4 != 0 or h % 4 != 0:
54
+ NW = int((w // 4) * 4)
55
+ NH = int((h // 4) * 4)
56
+ input_img = np.resize(input_img,(NW,NH,3))
57
+ st = time.perf_counter()
58
+ dataset = ImageDataset(input_img)
59
+ loader = DataLoader(dataset, **params)
60
+ with torch.no_grad():
61
+ for inputs in tqdm(loader):
62
+ inputs = inputs.to(device, dtype)
63
+ st = time.perf_counter()
64
+ outputs = generator(inputs)
65
+ ori_time = time.perf_counter() - st
66
+ ori_time = "{:.3f}s".format(ori_time)
67
+ ori_image = np.array(tensor2image(outputs[0]))
68
+ del inputs
69
+ del outputs
70
+ return ori_image, ori_time
71
+
72
+ def nano_transfer(input_img):
73
+ w, h, _ = input_img.shape
74
+ print(datetime.datetime.now())
75
+ print("input size: ", w, h)
76
+ # resize too large image
77
+ if w > 3000 or h > 3000:
78
+ ratio = min(3000 / w, 3000 / h)
79
+ w = int(w * ratio)
80
+ h = int(h * ratio)
81
+ if w % 4 != 0 or h % 4 != 0:
82
+ NW = int((w // 4) * 4)
83
+ NH = int((h // 4) * 4)
84
+ input_img = np.resize(input_img,(NW,NH,3))
85
+ st = time.perf_counter()
86
+ dataset = ImageDataset(input_img)
87
+ loader = DataLoader(dataset, **params)
88
+ with torch.no_grad():
89
+ for inputs in tqdm(loader):
90
+ inputs = inputs.to(device, dtype)
91
+ st = time.perf_counter()
92
+ outputs = quantized_model(inputs)
93
+ nano_time = time.perf_counter() - st
94
+ nano_time = "{:.3f}s".format(nano_time)
95
+ nano_image = np.array(tensor2image(outputs[0]))
96
+ del inputs
97
+ del outputs
98
+ return nano_image, nano_time
99
+
100
+
101
+ def clear():
102
+ return None, None, None, None
103
+
104
+
105
+ demo = gr.Blocks()
106
+
107
+ with demo:
108
+ gr.Markdown("<h1><center>BigDL-Nano inference demo</center></h1>")
109
+ with gr.Row().style(equal_height=False):
110
+ with gr.Column():
111
+ gr.Markdown('''
112
+ <h2>Overview</h2>
113
+
114
+ BigDL-Nano is a library in [BigDL 2.0](https://github.com/intel-analytics/BigDL) that allows the users to transparently accelerate their deep learning pipelines (including data processing, training and inference) by automatically integrating optimized libraries, best-known configurations, and software optimizations. </p>
115
+
116
+ The video on the right shows how the user can easily enable quantization using BigDL-Nano (with just a couple of lines of code); you may refer to our [CVPR 2022 demo paper](https://arxiv.org/abs/2204.01715) for more details.
117
+ ''')
118
+ with gr.Column():
119
+ gr.Video(value="nano_quantize_api.mp4")
120
+ gr.Markdown('''
121
+ <h2>Demo</h2>
122
+
123
+ This section uses an image stylization example to demostrate the speedup of the above code when using quantization in BigDL-Nano (about 2~3x inference time speedup). The demo is adapted from the original [FSPBT-Image-Translation code](https://github.com/rnwzd/FSPBT-Image-Translation/blob/master/eval.py).
124
+ ''')
125
+ with gr.Row().style(equal_height=False):
126
+ input_img = gr.Image(label="input image", value="Marvelous_Maisel.jpg", source="upload")
127
+ with gr.Column():
128
+ ori_but = gr.Button("Standard PyTorch Lightning")
129
+ nano_but = gr.Button("BigDL-Nano")
130
+ clear_but = gr.Button("Clear Output")
131
+ with gr.Row().style(equal_height=False):
132
+ with gr.Column():
133
+ ori_time = gr.Text(label="Standard PyTorch Lightning latency")
134
+ ori_image = gr.Image(label="Standard PyTorch Lightning output image")
135
+ with gr.Column():
136
+ nano_time = gr.Text(label="BigDL-Nano latency")
137
+ nano_image = gr.Image(label="BigDL-Nano output image")
138
+
139
+ ori_but.click(original_transfer, inputs=input_img, outputs=[ori_image, ori_time])
140
+ nano_but.click(nano_transfer, inputs=input_img, outputs=[nano_image, nano_time])
141
+ clear_but.click(clear, inputs=None, outputs=[ori_image, ori_time, nano_image, nano_time])
142
+
143
+
144
+ demo.launch(share=True, enable_queue=True)
data.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ import torchvision.transforms.functional as F
5
+ from bigdl.nano.pytorch.vision.transforms import transforms
6
+ import pytorch_lightning as pl
7
+ from collections.abc import Iterable
8
+ # image reader writer
9
+ from pathlib import Path
10
+ from PIL import Image
11
+ from typing import Tuple
12
+
13
+
14
+ def read_image(filepath: Path, mode: str = None) -> Image:
15
+ with open(filepath, 'rb') as file:
16
+ image = Image.open(file)
17
+ return image.convert(mode)
18
+
19
+
20
+ image2tensor = transforms.ToTensor()
21
+ tensor2image = transforms.ToPILImage()
22
+
23
+
24
+ def write_image(image: Image, filepath: Path):
25
+ filepath.parent.mkdir(parents=True, exist_ok=True)
26
+ image.save(str(filepath))
27
+
28
+
29
+ def read_image_tensor(filepath: Path, mode: str = 'RGB') -> torch.Tensor:
30
+ return image2tensor(read_image(filepath, mode))
31
+
32
+
33
+ def write_image_tensor(input: torch.Tensor, filepath: Path):
34
+ write_image(tensor2image(input), filepath)
35
+
36
+
37
+ def get_valid_indices(H: int, W: int, patch_size: int, random_overlap: int = 0):
38
+
39
+ vih = torch.arange(random_overlap, H-patch_size -
40
+ random_overlap+1, patch_size)
41
+ viw = torch.arange(random_overlap, W-patch_size -
42
+ random_overlap+1, patch_size)
43
+ if random_overlap > 0:
44
+ rih = torch.randint_like(vih, -random_overlap, random_overlap)
45
+ riw = torch.randint_like(viw, -random_overlap, random_overlap)
46
+ vih += rih
47
+ viw += riw
48
+ vi = torch.stack(torch.meshgrid(vih, viw)).view(2, -1).t()
49
+ return vi
50
+
51
+
52
+ def cut_patches(input: torch.Tensor, indices: Tuple[Tuple[int, int]], patch_size: int, padding: int = 0):
53
+ # TODO use slices to get all patches at the same time ?
54
+
55
+ patches_l = []
56
+ for n in range(len(indices)):
57
+
58
+ patch = F.crop(input, *(indices[n]-padding),
59
+ *(patch_size+padding*2,)*2)
60
+ patches_l.append(patch)
61
+ patches = torch.cat(patches_l, dim=0)
62
+
63
+ return patches
64
+
65
+
66
+ def prepare_data(data_path: Path, read_func: Callable = read_image_tensor) -> Dict:
67
+ """
68
+ Takes a data_path of a folder which contains subfolders with input, target, etc.
69
+ lablelled by the same names.
70
+
71
+ :param data_path: Path of the folder containing data
72
+ :param read_func: function that reads data and returns a tensor
73
+ """
74
+ data_dict = {}
75
+
76
+ subdir_names = ["target", "input", "mask"] # ,"helper"
77
+
78
+ # checks only files for which there is an target
79
+ # TODO check for images
80
+ name_ls = [file.name for file in (
81
+ data_path / "target").iterdir() if file.is_file()] # 数据集大小=3
82
+ subdirs = [data_path / sdn for sdn in subdir_names]
83
+ for sd in subdirs:
84
+ if sd.is_dir():
85
+ data_ls = []
86
+ files = [sd / name for name in name_ls]
87
+ for file in files:
88
+ tensor = read_func(file)
89
+ H, W = tensor.shape[-2:]
90
+ data_ls.append(tensor)
91
+ # TODO check that all sizes match
92
+ data_dict[sd.name] = torch.stack(data_ls, dim=0)
93
+
94
+ data_dict['name'] = name_ls
95
+ data_dict['len'] = len(data_dict['name'])
96
+ data_dict['H'] = H
97
+ data_dict['W'] = W
98
+ return data_dict
99
+
100
+
101
+ # TODO an image is loaded whenever a patch is needed, this may be a bottleneck
102
+ class DataDictLoader():
103
+ def __init__(self, data_dict: Dict,
104
+ batch_size: int = 16,
105
+ max_length: int = 128,
106
+ shuffle: bool = False):
107
+ """
108
+
109
+ """
110
+
111
+ self.batch_size = batch_size
112
+ self.shuffle = shuffle
113
+
114
+ self.batch_size = batch_size
115
+
116
+ self.data_dict = data_dict
117
+ self.dataset_len = data_dict['len'] # train: 93
118
+ self.len = self.dataset_len if max_length is None else min(
119
+ self.dataset_len, max_length)
120
+ # Calculate # batches
121
+ num_batches, remainder = divmod(self.len, self.batch_size)
122
+ if remainder > 0:
123
+ num_batches += 1
124
+ self.num_batches = num_batches
125
+
126
+
127
+ def __iter__(self):
128
+ if self.shuffle:
129
+ r = torch.randperm(self.dataset_len)
130
+ self.data_dict = {k: v[r] if isinstance(
131
+ v, Iterable) else v for k, v in self.data_dict.items()}
132
+ self.i = 0
133
+ return self
134
+
135
+ def __next__(self):
136
+ if self.i >= self.len:
137
+ raise StopIteration
138
+ batch = {k: v[self.i:self.i+self.batch_size]
139
+ if isinstance(v, Iterable) else v for k, v in self.data_dict.items()}
140
+
141
+ self.i += self.batch_size
142
+ return batch
143
+
144
+ def __len__(self):
145
+ return self.num_batches
146
+
147
+
148
+ class PatchDataModule(pl.LightningDataModule):
149
+
150
+ def __init__(self, data_dict,
151
+ patch_size: int = 2**5,
152
+ batch_size: int = 2**4,
153
+ patch_num: int = 2**6):
154
+ super().__init__()
155
+ self.data_dict = data_dict
156
+ self.H, self.W = data_dict['H'], data_dict['W']
157
+ self.len = data_dict['len']
158
+
159
+ self.batch_size = batch_size
160
+ self.patch_size = patch_size # 32
161
+ self.patch_num = patch_num # 64
162
+
163
+ def dataloader(self, data_dict, **kwargs):
164
+ return DataDictLoader(data_dict, **kwargs)
165
+
166
+ def train_dataloader(self):
167
+ patches = self.cut_patches()
168
+ return self.dataloader(patches, batch_size=self.batch_size, shuffle=True,
169
+ max_length=self.patch_num) # patch num = 64
170
+
171
+ def val_dataloader(self):
172
+ return self.dataloader(self.data_dict, batch_size=1)
173
+
174
+ def test_dataloader(self):
175
+ return self.dataloader(self.data_dict) # TODO batch size
176
+
177
+ def cut_patches(self):
178
+ # TODO cycle once
179
+ patch_indices = get_valid_indices(
180
+ self.H, self.W, self.patch_size, self.patch_size//4)
181
+ dd = {k: cut_patches(
182
+ v, patch_indices, self.patch_size) for k, v in self.data_dict.items()
183
+ if isinstance(v, torch.Tensor)
184
+ }
185
+ threshold = 0.1
186
+ mask_p = torch.mean(
187
+ dd.get('mask', torch.ones_like(dd['input'])), dim=(-1, -2, -3))
188
+ masked_idx = (mask_p > threshold).nonzero(as_tuple=True)[0]
189
+ dd = {k: v[masked_idx] for k, v in dd.items()}
190
+ dd['len'] = len(masked_idx)
191
+ dd['H'], dd['W'] = (self.patch_size,)*2
192
+
193
+ return dd
194
+
195
+
196
+ class ImageDataset(Dataset):
197
+ def __init__(self, file_paths: Iterable, read_func: Callable = read_image_tensor):
198
+ self.file_paths = file_paths
199
+
200
+ def __getitem__(self, idx: int) -> dict:
201
+ file = self.file_paths[idx]
202
+ return read_image_tensor(file), file.name
203
+
204
+ def __len__(self) -> int:
205
+ return len(self.file_paths)
data/webcam/input/00000.png ADDED
data/webcam/input/00001.png ADDED
data/webcam/input/00002.png ADDED
data/webcam/input/00003.png ADDED
data/webcam/input/00004.png ADDED
data/webcam/input/00005.png ADDED
data/webcam/input/00006.png ADDED
data/webcam/input/00007.png ADDED
data/webcam/input/00008.png ADDED
data/webcam/input/00009.png ADDED
data/webcam/input/00010.png ADDED
data/webcam/input/00011.png ADDED
data/webcam/input/00012.png ADDED
data/webcam/input/00013.png ADDED
data/webcam/input/00014.png ADDED
data/webcam/input/00015.png ADDED
data/webcam/input/00016.png ADDED
data/webcam/input/00017.png ADDED
data/webcam/input/00018.png ADDED
data/webcam/input/00019.png ADDED
data/webcam/input/00020.png ADDED
data/webcam/input/00021.png ADDED
data/webcam/input/00022.png ADDED
data/webcam/input/00023.png ADDED
data/webcam/input/00024.png ADDED
data/webcam/input/00025.png ADDED
data/webcam/input/00026.png ADDED
data/webcam/input/00027.png ADDED
data/webcam/input/00028.png ADDED
data/webcam/input/00029.png ADDED
data/webcam/input/00030.png ADDED
data/webcam/input/00031.png ADDED
data/webcam/input/00032.png ADDED
data/webcam/input/00033.png ADDED
data/webcam/input/00034.png ADDED
data/webcam/input/00035.png ADDED
data/webcam/input/00036.png ADDED
data/webcam/input/00037.png ADDED
data/webcam/input/00038.png ADDED
data/webcam/input/00039.png ADDED
data/webcam/input/00040.png ADDED
data/webcam/input/00041.png ADDED
data/webcam/input/00042.png ADDED
data/webcam/input/00043.png ADDED