Spaces:
Paused
Paused
import streamlit as st | |
from PIL import Image | |
from inference import inference | |
import torch | |
import io | |
from diffusion import DiffusionImageAPI | |
import math | |
def main(): | |
genres_dict = { | |
'Action': 1, | |
'Adventure': 2, | |
'Animation': 3, | |
'Comedy': 4, | |
'Drama': 5, | |
'Family': 6, | |
'Horror': 7, | |
'Music': 8, | |
'Romance': 9, | |
'Science Fiction': 10, | |
'Western': 11, | |
'Fantasy': 12, | |
'Thriller': 13 | |
} | |
st.title("Movie Diffusion") | |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | |
# Add a sidebar for genre selection | |
#genre = st.sidebar.selectbox("Select Genre", list(genres_dict.keys())) | |
selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys())) | |
progress_placeholder = st.empty() | |
image_placeholder = st.empty() | |
# Button to trigger image generation | |
if st.button('Generate Image'): | |
for genre in selected_genres: | |
code = genres_dict[genre] | |
cond[code-1] = code | |
if torch.any(cond != 0): | |
random_number = torch.randint(0, 13, (1,)).item() | |
cond[random_number] = random_number + 1 | |
def callback(image, progress): | |
image = DiffusionImageAPI(None).tensor_to_image(image.squeeze(0)) | |
img_buffer = io.BytesIO() | |
image.save(img_buffer, format="PNG") | |
img_buffer.seek(0) | |
# Update the content of the placeholders | |
progress_placeholder.write(f"Generating Image...\nProgress: {min(progress * 110, 100):.2f}%") | |
image_placeholder.image(img_buffer, caption='Generated Image', width=300) | |
inference(cond, callback=callback) | |
if __name__ == "__main__": | |
main() | |