movie-diffusion / app.py
CarlosMN's picture
Random genre selection if none has been selected.
5140c81
raw
history blame
1.72 kB
import streamlit as st
from PIL import Image
from inference import inference
import torch
import io
from diffusion import DiffusionImageAPI
import math
def main():
genres_dict = {
'Action': 1,
'Adventure': 2,
'Animation': 3,
'Comedy': 4,
'Drama': 5,
'Family': 6,
'Horror': 7,
'Music': 8,
'Romance': 9,
'Science Fiction': 10,
'Western': 11,
'Fantasy': 12,
'Thriller': 13
}
st.title("Movie Diffusion")
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
# Add a sidebar for genre selection
#genre = st.sidebar.selectbox("Select Genre", list(genres_dict.keys()))
selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys()))
progress_placeholder = st.empty()
image_placeholder = st.empty()
# Button to trigger image generation
if st.button('Generate Image'):
for genre in selected_genres:
code = genres_dict[genre]
cond[code-1] = code
if torch.any(cond != 0):
random_number = torch.randint(0, 13, (1,)).item()
cond[random_number] = random_number + 1
def callback(image, progress):
image = DiffusionImageAPI(None).tensor_to_image(image.squeeze(0))
img_buffer = io.BytesIO()
image.save(img_buffer, format="PNG")
img_buffer.seek(0)
# Update the content of the placeholders
progress_placeholder.write(f"Generating Image...\nProgress: {min(progress * 110, 100):.2f}%")
image_placeholder.image(img_buffer, caption='Generated Image', width=300)
inference(cond, callback=callback)
if __name__ == "__main__":
main()