movie-diffusion / app.py
anforsm's picture
Random genre selection if none has been selected. (#2)
ee3757e
raw
history blame
1.72 kB
import streamlit as st
from PIL import Image
from inference import inference
import torch
import io
from diffusion import DiffusionImageAPI
import math
def main():
genres_dict = {
'Action': 1,
'Adventure': 2,
'Animation': 3,
'Comedy': 4,
'Drama': 5,
'Family': 6,
'Horror': 7,
'Music': 8,
'Romance': 9,
'Science Fiction': 10,
'Western': 11,
'Fantasy': 12,
'Thriller': 13
}
st.title("Movie Diffusion")
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
# Add a sidebar for genre selection
#genre = st.sidebar.selectbox("Select Genre", list(genres_dict.keys()))
selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys()))
progress_placeholder = st.empty()
image_placeholder = st.empty()
# Button to trigger image generation
if st.button('Generate Image'):
for genre in selected_genres:
code = genres_dict[genre]
cond[code-1] = code
if torch.any(cond != 0):
random_number = torch.randint(0, 13, (1,)).item()
cond[random_number] = random_number + 1
def callback(image, progress):
image = DiffusionImageAPI(None).tensor_to_image(image.squeeze(0))
img_buffer = io.BytesIO()
image.save(img_buffer, format="PNG")
img_buffer.seek(0)
# Update the content of the placeholders
progress_placeholder.write(f"Generating Image...\nProgress: {min(progress * 110, 100):.2f}%")
image_placeholder.image(img_buffer, caption='Generated Image', width=300)
inference(cond, callback=callback)
if __name__ == "__main__":
main()