Spaces:
Paused
Paused
Added genre selection + inpainting via notebook (#1)
Browse files- Added genre selection + inpainting via notebook (d56e77f57a166ee70088a054b4df03ec76461393)
Co-authored-by: Carlos Marí Noguera <[email protected]>
- app.py +38 -4
- diffusion.py +25 -3
- inference.py +22 -16
- inpainting.ipynb +160 -0
app.py
CHANGED
@@ -1,22 +1,56 @@
|
|
1 |
import streamlit as st
|
2 |
from PIL import Image
|
3 |
-
from inference import inference
|
|
|
4 |
import io
|
5 |
|
6 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
st.title("Image Display App")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
# Button to trigger image generation
|
10 |
if st.button('Generate Image'):
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Convert Pillow image to bytes for display in Streamlit
|
15 |
img_buffer = io.BytesIO()
|
|
|
16 |
image.save(img_buffer, format="PNG")
|
17 |
img_buffer.seek(0)
|
18 |
|
19 |
-
# Display the image
|
20 |
st.image(img_buffer, caption='Generated Image', use_column_width=True)
|
21 |
|
22 |
if __name__ == "__main__":
|
|
|
1 |
import streamlit as st
|
2 |
from PIL import Image
|
3 |
+
from inference import inference
|
4 |
+
import torch
|
5 |
import io
|
6 |
|
7 |
def main():
|
8 |
+
|
9 |
+
genres_dict = {
|
10 |
+
'Action': 1,
|
11 |
+
'Adventure': 2,
|
12 |
+
'Animation': 3,
|
13 |
+
'Comedy': 4,
|
14 |
+
'Drama': 5,
|
15 |
+
'Family': 6,
|
16 |
+
'Horror': 7,
|
17 |
+
'Music': 8,
|
18 |
+
'Romance': 9,
|
19 |
+
'Science Fiction': 10,
|
20 |
+
'Western': 11,
|
21 |
+
'Fantasy': 12,
|
22 |
+
'Thriller': 13
|
23 |
+
}
|
24 |
+
|
25 |
st.title("Image Display App")
|
26 |
+
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
27 |
+
|
28 |
+
# Add a sidebar for genre selection
|
29 |
+
#genre = st.sidebar.selectbox("Select Genre", list(genres_dict.keys()))
|
30 |
+
|
31 |
+
|
32 |
+
selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys()))
|
33 |
+
|
34 |
+
|
35 |
|
36 |
# Button to trigger image generation
|
37 |
if st.button('Generate Image'):
|
38 |
+
for genre in selected_genres:
|
39 |
+
code = genres_dict[genre]
|
40 |
+
cond[code-1] = code
|
41 |
+
# Display loading sign while generating image
|
42 |
+
with st.spinner('Generating Image...'):
|
43 |
+
# Call the function from inference.py with selected genre
|
44 |
+
image = inference(cond)
|
45 |
+
#image = inference(genre)
|
46 |
|
47 |
# Convert Pillow image to bytes for display in Streamlit
|
48 |
img_buffer = io.BytesIO()
|
49 |
+
#"""0,0,0,0,0,0,0,1, 2, 7, 4, 0, 0, 0"""
|
50 |
image.save(img_buffer, format="PNG")
|
51 |
img_buffer.seek(0)
|
52 |
|
53 |
+
# Display the generated image
|
54 |
st.image(img_buffer, caption='Generated Image', use_column_width=True)
|
55 |
|
56 |
if __name__ == "__main__":
|
diffusion.py
CHANGED
@@ -160,26 +160,48 @@ class GaussianDiffusion:
|
|
160 |
|
161 |
return x_t_minus_1
|
162 |
|
163 |
-
def sample(self, num_samples, show_progress=True):
|
164 |
"""
|
165 |
Sample from the model
|
166 |
"""
|
167 |
-
cond = None
|
168 |
-
if
|
169 |
# cond is arange()
|
170 |
assert num_samples <= self.model.num_classes, "num_samples must be less than or equal to the number of classes"
|
171 |
cond = torch.arange(self.model.num_classes)[:num_samples].to(self.device)
|
172 |
cond = rearrange(cond, 'i -> i ()')
|
173 |
|
|
|
|
|
174 |
self.model.eval()
|
175 |
image_versions = []
|
176 |
with torch.no_grad():
|
177 |
x = torch.randn(num_samples, self.channels, *self.image_size).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
it = reversed(range(1, self.noise_steps))
|
179 |
if show_progress:
|
180 |
it = tqdm(it)
|
181 |
for t in it:
|
182 |
image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
x = self.sample_step(x, t, cond)
|
184 |
self.model.train()
|
185 |
x = torch.clip(x, -1.0, 1.0)
|
|
|
160 |
|
161 |
return x_t_minus_1
|
162 |
|
163 |
+
def sample(self, num_samples, show_progress=True, cond=None, x0=None):
|
164 |
"""
|
165 |
Sample from the model
|
166 |
"""
|
167 |
+
#cond = None
|
168 |
+
if cond == None:
|
169 |
# cond is arange()
|
170 |
assert num_samples <= self.model.num_classes, "num_samples must be less than or equal to the number of classes"
|
171 |
cond = torch.arange(self.model.num_classes)[:num_samples].to(self.device)
|
172 |
cond = rearrange(cond, 'i -> i ()')
|
173 |
|
174 |
+
|
175 |
+
# Inpainting
|
176 |
self.model.eval()
|
177 |
image_versions = []
|
178 |
with torch.no_grad():
|
179 |
x = torch.randn(num_samples, self.channels, *self.image_size).to(self.device)
|
180 |
+
|
181 |
+
|
182 |
+
if x0 is not None:
|
183 |
+
x0 = x0.to(self.device)
|
184 |
+
mask = x0 != -1
|
185 |
+
x_noised = self.apply_noise(x0,self.noise_steps -1)[0].to(self.device)
|
186 |
+
new_x = x
|
187 |
+
new_x[mask] = x_noised[mask]
|
188 |
+
|
189 |
+
x = new_x
|
190 |
+
|
191 |
+
|
192 |
it = reversed(range(1, self.noise_steps))
|
193 |
if show_progress:
|
194 |
it = tqdm(it)
|
195 |
for t in it:
|
196 |
image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
|
197 |
+
|
198 |
+
if x0 is not None and t > 80:
|
199 |
+
x_noised = self.apply_noise(x0,t)[0]
|
200 |
+
new_x = x
|
201 |
+
new_x[mask] = x_noised[mask]
|
202 |
+
|
203 |
+
x = new_x
|
204 |
+
|
205 |
x = self.sample_step(x, t, cond)
|
206 |
self.model.train()
|
207 |
x = torch.clip(x, -1.0, 1.0)
|
inference.py
CHANGED
@@ -13,12 +13,13 @@ from diffusion import GaussianDiffusion, DiffusionImageAPI
|
|
13 |
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
|
|
16 |
def inference1():
|
17 |
# new image from web page
|
18 |
image = requests.get("https://picsum.photos/120/80").content
|
19 |
return Image.open(io.BytesIO(image))
|
20 |
|
21 |
-
def inference():
|
22 |
model = Unet(
|
23 |
image_channels=3,
|
24 |
dropout=0.1,
|
@@ -37,26 +38,31 @@ def inference():
|
|
37 |
image_size=(192, 128),
|
38 |
)
|
39 |
|
|
|
|
|
|
|
|
|
|
|
40 |
model.to(device)
|
41 |
diffusion.to(device)
|
42 |
|
43 |
imageAPI = DiffusionImageAPI(diffusion)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
return imageAPI.tensor_to_image(
|
60 |
|
61 |
if __name__ == "__main__":
|
62 |
inference().show()
|
|
|
13 |
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
+
|
17 |
def inference1():
|
18 |
# new image from web page
|
19 |
image = requests.get("https://picsum.photos/120/80").content
|
20 |
return Image.open(io.BytesIO(image))
|
21 |
|
22 |
+
def inference(cond, x0=None, gif=False):
|
23 |
model = Unet(
|
24 |
image_channels=3,
|
25 |
dropout=0.1,
|
|
|
38 |
image_size=(192, 128),
|
39 |
)
|
40 |
|
41 |
+
if x0 is not None:
|
42 |
+
x0 = diffusion.normalize_image(x0)
|
43 |
+
x0 = x0.permute(2, 0, 1)
|
44 |
+
x0 = x0.unsqueeze(0)
|
45 |
+
|
46 |
model.to(device)
|
47 |
diffusion.to(device)
|
48 |
|
49 |
imageAPI = DiffusionImageAPI(diffusion)
|
50 |
|
51 |
+
new_images, versions = diffusion.sample(1,cond=cond,x0=x0)
|
52 |
+
if gif:
|
53 |
+
images = []
|
54 |
+
for image in versions:
|
55 |
+
images.append(imageAPI.tensor_to_image(image.squeeze(0)))
|
56 |
+
|
57 |
+
print(len(images))
|
58 |
+
print(images[0])
|
59 |
+
# make gif out of pillow images
|
60 |
+
images[0].save('./gif_output/versions.gif',
|
61 |
+
save_all=True,
|
62 |
+
append_images=images[1:],
|
63 |
+
duration=100,
|
64 |
+
loop=0)
|
65 |
+
return imageAPI.tensor_to_image(new_images.squeeze(0))
|
66 |
|
67 |
if __name__ == "__main__":
|
68 |
inference().show()
|
inpainting.ipynb
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 29,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"The autoreload extension is already loaded. To reload it, use:\n",
|
13 |
+
" %reload_ext autoreload\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"%load_ext autoreload\n",
|
19 |
+
"%autoreload 2\n",
|
20 |
+
"from PIL import Image\n",
|
21 |
+
"import torch \n",
|
22 |
+
"from diffusion import GaussianDiffusion, DiffusionImageAPI\n",
|
23 |
+
"from unet import Unet\n",
|
24 |
+
"from inference import inference\n",
|
25 |
+
"import numpy as np"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 30,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [
|
33 |
+
{
|
34 |
+
"data": {
|
35 |
+
"text/plain": [
|
36 |
+
"True"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
"execution_count": 30,
|
40 |
+
"metadata": {},
|
41 |
+
"output_type": "execute_result"
|
42 |
+
}
|
43 |
+
],
|
44 |
+
"source": [
|
45 |
+
"torch.cuda.is_available()"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 31,
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [],
|
53 |
+
"source": [
|
54 |
+
"cond = torch.tensor([2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) \n",
|
55 |
+
"genres_dict = {\n",
|
56 |
+
" 'Action': 1,\n",
|
57 |
+
" 'Adventure': 2,\n",
|
58 |
+
" 'Animation': 3,\n",
|
59 |
+
" 'Comedy': 4,\n",
|
60 |
+
" 'Drama': 5,\n",
|
61 |
+
" 'Family': 6,\n",
|
62 |
+
" 'Horror': 7,\n",
|
63 |
+
" 'Music': 8,\n",
|
64 |
+
" 'Romance': 9,\n",
|
65 |
+
" 'Science Fiction': 10,\n",
|
66 |
+
" 'Western': 11,\n",
|
67 |
+
" 'Fantasy': 12,\n",
|
68 |
+
" 'Thriller': 13\n",
|
69 |
+
"}"
|
70 |
+
]
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"cell_type": "code",
|
74 |
+
"execution_count": 45,
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [
|
77 |
+
{
|
78 |
+
"name": "stderr",
|
79 |
+
"output_type": "stream",
|
80 |
+
"text": [
|
81 |
+
"999it [01:18, 12.69it/s]\n"
|
82 |
+
]
|
83 |
+
}
|
84 |
+
],
|
85 |
+
"source": [
|
86 |
+
"pic = 'IndianaBovik'\n",
|
87 |
+
"image_np = np.array(Image.open(f\"InferenceTests/{pic}.png\").convert('RGB'))\n",
|
88 |
+
"\n",
|
89 |
+
"# Convert the NumPy array to a PyTorch tensor with explicitly specifying the data type\n",
|
90 |
+
"x0 = torch.tensor(image_np, dtype=torch.float32)\n",
|
91 |
+
"\n",
|
92 |
+
"\n",
|
93 |
+
"\n",
|
94 |
+
"image = inference(cond,x0)\n"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "code",
|
99 |
+
"execution_count": 46,
|
100 |
+
"metadata": {},
|
101 |
+
"outputs": [
|
102 |
+
{
|
103 |
+
"data": {
|
104 |
+
"image/png": "",
|
105 |
+
"text/plain": [
|
106 |
+
"<PIL.Image.Image image mode=RGB size=128x192>"
|
107 |
+
]
|
108 |
+
},
|
109 |
+
"execution_count": 46,
|
110 |
+
"metadata": {},
|
111 |
+
"output_type": "execute_result"
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"source": [
|
115 |
+
"image"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": 47,
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": [
|
124 |
+
"i = 0"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"cell_type": "code",
|
129 |
+
"execution_count": 48,
|
130 |
+
"metadata": {},
|
131 |
+
"outputs": [],
|
132 |
+
"source": [
|
133 |
+
"route = f'./InferenceOutputs/{pic}{i}.png'\n",
|
134 |
+
"i+=1\n",
|
135 |
+
"image.save(route)"
|
136 |
+
]
|
137 |
+
}
|
138 |
+
],
|
139 |
+
"metadata": {
|
140 |
+
"kernelspec": {
|
141 |
+
"display_name": "DIP_DEMO",
|
142 |
+
"language": "python",
|
143 |
+
"name": "python3"
|
144 |
+
},
|
145 |
+
"language_info": {
|
146 |
+
"codemirror_mode": {
|
147 |
+
"name": "ipython",
|
148 |
+
"version": 3
|
149 |
+
},
|
150 |
+
"file_extension": ".py",
|
151 |
+
"mimetype": "text/x-python",
|
152 |
+
"name": "python",
|
153 |
+
"nbconvert_exporter": "python",
|
154 |
+
"pygments_lexer": "ipython3",
|
155 |
+
"version": "3.10.13"
|
156 |
+
}
|
157 |
+
},
|
158 |
+
"nbformat": 4,
|
159 |
+
"nbformat_minor": 2
|
160 |
+
}
|