Anton Forsman commited on
Commit
585cc65
1 Parent(s): cdd9a51

fixed small things, also gif and live progress

Browse files
Files changed (3) hide show
  1. app.py +17 -15
  2. diffusion.py +4 -1
  3. inference.py +2 -2
app.py CHANGED
@@ -3,6 +3,7 @@ from PIL import Image
3
  from inference import inference
4
  import torch
5
  import io
 
6
 
7
  def main():
8
 
@@ -22,7 +23,7 @@ def main():
22
  'Thriller': 13
23
  }
24
 
25
- st.title("Image Display App")
26
  cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
27
 
28
  # Add a sidebar for genre selection
@@ -31,6 +32,9 @@ def main():
31
 
32
  selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys()))
33
 
 
 
 
34
 
35
 
36
  # Button to trigger image generation
@@ -38,20 +42,18 @@ def main():
38
  for genre in selected_genres:
39
  code = genres_dict[genre]
40
  cond[code-1] = code
41
- # Display loading sign while generating image
42
- with st.spinner('Generating Image...'):
43
- # Call the function from inference.py with selected genre
44
- image = inference(cond)
45
- #image = inference(genre)
46
-
47
- # Convert Pillow image to bytes for display in Streamlit
48
- img_buffer = io.BytesIO()
49
- #"""0,0,0,0,0,0,0,1, 2, 7, 4, 0, 0, 0"""
50
- image.save(img_buffer, format="PNG")
51
- img_buffer.seek(0)
52
-
53
- # Display the generated image
54
- st.image(img_buffer, caption='Generated Image', use_column_width=True)
55
 
56
  if __name__ == "__main__":
57
  main()
 
3
  from inference import inference
4
  import torch
5
  import io
6
+ from diffusion import DiffusionImageAPI
7
 
8
  def main():
9
 
 
23
  'Thriller': 13
24
  }
25
 
26
+ st.title("Movie Diffusion")
27
  cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
28
 
29
  # Add a sidebar for genre selection
 
32
 
33
  selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys()))
34
 
35
+ progress_placeholder = st.empty()
36
+ image_placeholder = st.empty()
37
+
38
 
39
 
40
  # Button to trigger image generation
 
42
  for genre in selected_genres:
43
  code = genres_dict[genre]
44
  cond[code-1] = code
45
+
46
+ def callback(image, progress):
47
+ image = DiffusionImageAPI(None).tensor_to_image(image.squeeze(0))
48
+ img_buffer = io.BytesIO()
49
+ image.save(img_buffer, format="PNG")
50
+ img_buffer.seek(0)
51
+
52
+ # Update the content of the placeholders
53
+ progress_placeholder.write(f"Generating Image...\nProgress: {(progress * 100):.2f}%")
54
+ image_placeholder.image(img_buffer, caption='Generated Image', width=300)
55
+
56
+ inference(cond, callback=callback)
 
 
57
 
58
  if __name__ == "__main__":
59
  main()
diffusion.py CHANGED
@@ -160,7 +160,7 @@ class GaussianDiffusion:
160
 
161
  return x_t_minus_1
162
 
163
- def sample(self, num_samples, show_progress=True, cond=None, x0=None):
164
  """
165
  Sample from the model
166
  """
@@ -193,6 +193,9 @@ class GaussianDiffusion:
193
  if show_progress:
194
  it = tqdm(it)
195
  for t in it:
 
 
 
196
  image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
197
 
198
  if x0 is not None and t > 80:
 
160
 
161
  return x_t_minus_1
162
 
163
+ def sample(self, num_samples, show_progress=True, cond=None, x0=None, cb=None):
164
  """
165
  Sample from the model
166
  """
 
193
  if show_progress:
194
  it = tqdm(it)
195
  for t in it:
196
+ temp_image = self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0)
197
+ if cb is not None:
198
+ cb(temp_image, 1-t/self.noise_steps)
199
  image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
200
 
201
  if x0 is not None and t > 80:
inference.py CHANGED
@@ -19,7 +19,7 @@ def inference1():
19
  image = requests.get("https://picsum.photos/120/80").content
20
  return Image.open(io.BytesIO(image))
21
 
22
- def inference(cond, x0=None, gif=False):
23
  model = Unet(
24
  image_channels=3,
25
  dropout=0.1,
@@ -48,7 +48,7 @@ def inference(cond, x0=None, gif=False):
48
 
49
  imageAPI = DiffusionImageAPI(diffusion)
50
 
51
- new_images, versions = diffusion.sample(1,cond=cond,x0=x0)
52
  if gif:
53
  images = []
54
  for image in versions:
 
19
  image = requests.get("https://picsum.photos/120/80").content
20
  return Image.open(io.BytesIO(image))
21
 
22
+ def inference(cond, x0=None, gif=False, callback=None):
23
  model = Unet(
24
  image_channels=3,
25
  dropout=0.1,
 
48
 
49
  imageAPI = DiffusionImageAPI(diffusion)
50
 
51
+ new_images, versions = diffusion.sample(1,cond=cond,x0=x0, cb=callback)
52
  if gif:
53
  images = []
54
  for image in versions: