movie-diffusion / app.py
anforsm's picture
Random genre selection if none has been selected. (#2)
ee3757e
raw
history blame contribute delete
No virus
1.72 kB
import streamlit as st
from PIL import Image
from inference import inference
import torch
import io
from diffusion import DiffusionImageAPI
import math
def main():
genres_dict = {
'Action': 1,
'Adventure': 2,
'Animation': 3,
'Comedy': 4,
'Drama': 5,
'Family': 6,
'Horror': 7,
'Music': 8,
'Romance': 9,
'Science Fiction': 10,
'Western': 11,
'Fantasy': 12,
'Thriller': 13
}
st.title("Movie Diffusion")
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
# Add a sidebar for genre selection
#genre = st.sidebar.selectbox("Select Genre", list(genres_dict.keys()))
selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys()))
progress_placeholder = st.empty()
image_placeholder = st.empty()
# Button to trigger image generation
if st.button('Generate Image'):
for genre in selected_genres:
code = genres_dict[genre]
cond[code-1] = code
if torch.any(cond != 0):
random_number = torch.randint(0, 13, (1,)).item()
cond[random_number] = random_number + 1
def callback(image, progress):
image = DiffusionImageAPI(None).tensor_to_image(image.squeeze(0))
img_buffer = io.BytesIO()
image.save(img_buffer, format="PNG")
img_buffer.seek(0)
# Update the content of the placeholders
progress_placeholder.write(f"Generating Image...\nProgress: {min(progress * 110, 100):.2f}%")
image_placeholder.image(img_buffer, caption='Generated Image', width=300)
inference(cond, callback=callback)
if __name__ == "__main__":
main()