Spaces:
Paused
Paused
Anton Forsman
commited on
Commit
•
585cc65
1
Parent(s):
cdd9a51
fixed small things, also gif and live progress
Browse files- app.py +17 -15
- diffusion.py +4 -1
- inference.py +2 -2
app.py
CHANGED
@@ -3,6 +3,7 @@ from PIL import Image
|
|
3 |
from inference import inference
|
4 |
import torch
|
5 |
import io
|
|
|
6 |
|
7 |
def main():
|
8 |
|
@@ -22,7 +23,7 @@ def main():
|
|
22 |
'Thriller': 13
|
23 |
}
|
24 |
|
25 |
-
st.title("
|
26 |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
27 |
|
28 |
# Add a sidebar for genre selection
|
@@ -31,6 +32,9 @@ def main():
|
|
31 |
|
32 |
selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys()))
|
33 |
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
# Button to trigger image generation
|
@@ -38,20 +42,18 @@ def main():
|
|
38 |
for genre in selected_genres:
|
39 |
code = genres_dict[genre]
|
40 |
cond[code-1] = code
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
# Display the generated image
|
54 |
-
st.image(img_buffer, caption='Generated Image', use_column_width=True)
|
55 |
|
56 |
if __name__ == "__main__":
|
57 |
main()
|
|
|
3 |
from inference import inference
|
4 |
import torch
|
5 |
import io
|
6 |
+
from diffusion import DiffusionImageAPI
|
7 |
|
8 |
def main():
|
9 |
|
|
|
23 |
'Thriller': 13
|
24 |
}
|
25 |
|
26 |
+
st.title("Movie Diffusion")
|
27 |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
28 |
|
29 |
# Add a sidebar for genre selection
|
|
|
32 |
|
33 |
selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys()))
|
34 |
|
35 |
+
progress_placeholder = st.empty()
|
36 |
+
image_placeholder = st.empty()
|
37 |
+
|
38 |
|
39 |
|
40 |
# Button to trigger image generation
|
|
|
42 |
for genre in selected_genres:
|
43 |
code = genres_dict[genre]
|
44 |
cond[code-1] = code
|
45 |
+
|
46 |
+
def callback(image, progress):
|
47 |
+
image = DiffusionImageAPI(None).tensor_to_image(image.squeeze(0))
|
48 |
+
img_buffer = io.BytesIO()
|
49 |
+
image.save(img_buffer, format="PNG")
|
50 |
+
img_buffer.seek(0)
|
51 |
+
|
52 |
+
# Update the content of the placeholders
|
53 |
+
progress_placeholder.write(f"Generating Image...\nProgress: {(progress * 100):.2f}%")
|
54 |
+
image_placeholder.image(img_buffer, caption='Generated Image', width=300)
|
55 |
+
|
56 |
+
inference(cond, callback=callback)
|
|
|
|
|
57 |
|
58 |
if __name__ == "__main__":
|
59 |
main()
|
diffusion.py
CHANGED
@@ -160,7 +160,7 @@ class GaussianDiffusion:
|
|
160 |
|
161 |
return x_t_minus_1
|
162 |
|
163 |
-
def sample(self, num_samples, show_progress=True, cond=None, x0=None):
|
164 |
"""
|
165 |
Sample from the model
|
166 |
"""
|
@@ -193,6 +193,9 @@ class GaussianDiffusion:
|
|
193 |
if show_progress:
|
194 |
it = tqdm(it)
|
195 |
for t in it:
|
|
|
|
|
|
|
196 |
image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
|
197 |
|
198 |
if x0 is not None and t > 80:
|
|
|
160 |
|
161 |
return x_t_minus_1
|
162 |
|
163 |
+
def sample(self, num_samples, show_progress=True, cond=None, x0=None, cb=None):
|
164 |
"""
|
165 |
Sample from the model
|
166 |
"""
|
|
|
193 |
if show_progress:
|
194 |
it = tqdm(it)
|
195 |
for t in it:
|
196 |
+
temp_image = self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0)
|
197 |
+
if cb is not None:
|
198 |
+
cb(temp_image, 1-t/self.noise_steps)
|
199 |
image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
|
200 |
|
201 |
if x0 is not None and t > 80:
|
inference.py
CHANGED
@@ -19,7 +19,7 @@ def inference1():
|
|
19 |
image = requests.get("https://picsum.photos/120/80").content
|
20 |
return Image.open(io.BytesIO(image))
|
21 |
|
22 |
-
def inference(cond, x0=None, gif=False):
|
23 |
model = Unet(
|
24 |
image_channels=3,
|
25 |
dropout=0.1,
|
@@ -48,7 +48,7 @@ def inference(cond, x0=None, gif=False):
|
|
48 |
|
49 |
imageAPI = DiffusionImageAPI(diffusion)
|
50 |
|
51 |
-
new_images, versions = diffusion.sample(1,cond=cond,x0=x0)
|
52 |
if gif:
|
53 |
images = []
|
54 |
for image in versions:
|
|
|
19 |
image = requests.get("https://picsum.photos/120/80").content
|
20 |
return Image.open(io.BytesIO(image))
|
21 |
|
22 |
+
def inference(cond, x0=None, gif=False, callback=None):
|
23 |
model = Unet(
|
24 |
image_channels=3,
|
25 |
dropout=0.1,
|
|
|
48 |
|
49 |
imageAPI = DiffusionImageAPI(diffusion)
|
50 |
|
51 |
+
new_images, versions = diffusion.sample(1,cond=cond,x0=x0, cb=callback)
|
52 |
if gif:
|
53 |
images = []
|
54 |
for image in versions:
|