import streamlit as st from PIL import Image from inference import inference import torch import io from diffusion import DiffusionImageAPI import math def main(): genres_dict = { 'Action': 1, 'Adventure': 2, 'Animation': 3, 'Comedy': 4, 'Drama': 5, 'Family': 6, 'Horror': 7, 'Music': 8, 'Romance': 9, 'Science Fiction': 10, 'Western': 11, 'Fantasy': 12, 'Thriller': 13 } st.title("Movie Diffusion") cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # Add a sidebar for genre selection #genre = st.sidebar.selectbox("Select Genre", list(genres_dict.keys())) selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys())) progress_placeholder = st.empty() image_placeholder = st.empty() # Button to trigger image generation if st.button('Generate Image'): for genre in selected_genres: code = genres_dict[genre] cond[code-1] = code if torch.any(cond != 0): random_number = torch.randint(0, 13, (1,)).item() cond[random_number] = random_number + 1 def callback(image, progress): image = DiffusionImageAPI(None).tensor_to_image(image.squeeze(0)) img_buffer = io.BytesIO() image.save(img_buffer, format="PNG") img_buffer.seek(0) # Update the content of the placeholders progress_placeholder.write(f"Generating Image...\nProgress: {min(progress * 110, 100):.2f}%") image_placeholder.image(img_buffer, caption='Generated Image', width=300) inference(cond, callback=callback) if __name__ == "__main__": main()