File size: 3,021 Bytes
4d72a29
0ea4415
4d72a29
 
 
b442155
 
0ea4415
 
4d72a29
091b9da
 
4d72a29
 
091b9da
 
4d72a29
 
 
 
 
 
 
 
 
 
 
0ea4415
 
4d72a29
 
0ea4415
4d72a29
 
0ea4415
 
4d72a29
 
 
 
 
 
0ea4415
b442155
4d72a29
 
0ea4415
 
 
b442155
 
4d72a29
 
 
b442155
0ea4415
 
749bafa
0ea4415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d72a29
0ea4415
4d72a29
0ea4415
 
 
 
 
 
 
749bafa
4c8f95a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import base64
import os
import time
from io import BytesIO
from multiprocessing import Process

import streamlit as st
from PIL import Image

import requests


def start_server():
    os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 2")


def load_models():
    if not is_port_in_use(8080):
        with st.spinner(text="Loading models, please wait..."):
            proc = Process(target=start_server, args=(), daemon=True)
            proc.start()
            while not is_port_in_use(8080):
                time.sleep(1)
            st.success("Model server started.")
    else:
        st.success("Model server already running...")
    st.session_state["models_loaded"] = True


def is_port_in_use(port):
    import socket

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        return s.connect_ex(("0.0.0.0", port)) == 0


def generate(prompt):
    correct_request = f"http://0.0.0.0:8080/correct?prompt={prompt}"
    response = requests.get(correct_request)
    images = response.json()["images"]
    images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
    return images


if "models_loaded" not in st.session_state:
    st.session_state["models_loaded"] = False


st.header("minDALL-E")
st.subheader("Generate images from text")

if not st.session_state["models_loaded"]:
    load_models()

prompt = st.text_input("What do you want to see?")

DEBUG = False
# UI code taken from https://huggingface.co./spaces/flax-community/dalle-mini/blob/main/app/streamlit/app.py
if prompt != "":
    container = st.empty()
    container.markdown(
        f"""
        <style> p {{ margin:0 }} div {{ margin:0 }} </style>
        <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
        <div class="stAlert">
        <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
        <div class="st-b7">
        <div class="css-whx05o e13vu3m50">
        <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
                <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
                Generating predictions for: <b>{prompt}</b>
        </div>
        </div>
        </div>
        </div>
        </div>
        </div>
    """,
        unsafe_allow_html=True,
    )

    print(f"Getting selections: {prompt}")
    selected = generate(prompt)

    margin = 0.1  # for better position of zoom in arrow
    n_columns = 3
    cols = st.columns([1] + [margin, 1] * (n_columns - 1))
    for i, img in enumerate(selected):
        cols[(i % n_columns) * 2].image(img)
    container.markdown(f"**{prompt}**")

    st.button("Again!", key="again_button")
    
    container.markdown("<b><i>UI credits: <a href='https://huggingface.co./spaces/flax-community/dalle-mini'>DALL-E mini Space</a></i></b>")