Spaces:
Runtime error
Runtime error
Javi
commited on
Commit
•
ed1918f
1
Parent(s):
47ed623
First version working
Browse files- .gitignore +1 -0
- requirements.txt +1 -0
- session_state.py +86 -0
- streamlit_app.py +76 -0
.gitignore
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
# Byte-compiled / optimized / DLL files
|
2 |
__pycache__/
|
3 |
*.py[cod]
|
|
|
1 |
+
.idea
|
2 |
# Byte-compiled / optimized / DLL files
|
3 |
__pycache__/
|
4 |
*.py[cod]
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
streamlit~=0.76.0
|
session_state.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662
|
2 |
+
from streamlit.hashing import _CodeHasher
|
3 |
+
|
4 |
+
try:
|
5 |
+
# Before Streamlit 0.65
|
6 |
+
from streamlit.ReportThread import get_report_ctx
|
7 |
+
from streamlit.server.Server import Server
|
8 |
+
except ModuleNotFoundError:
|
9 |
+
# After Streamlit 0.65
|
10 |
+
from streamlit.report_thread import get_report_ctx
|
11 |
+
from streamlit.server.server import Server
|
12 |
+
|
13 |
+
|
14 |
+
class SessionState:
|
15 |
+
def __init__(self, session, hash_funcs):
|
16 |
+
"""Initialize SessionState instance."""
|
17 |
+
self.__dict__["_state"] = {
|
18 |
+
"data": {},
|
19 |
+
"hash": None,
|
20 |
+
"hasher": _CodeHasher(hash_funcs),
|
21 |
+
"is_rerun": False,
|
22 |
+
"session": session,
|
23 |
+
}
|
24 |
+
|
25 |
+
def __call__(self, **kwargs):
|
26 |
+
"""Initialize state data once."""
|
27 |
+
for item, value in kwargs.items():
|
28 |
+
if item not in self._state["data"]:
|
29 |
+
self._state["data"][item] = value
|
30 |
+
|
31 |
+
def __getitem__(self, item):
|
32 |
+
"""Return a saved state value, None if item is undefined."""
|
33 |
+
return self._state["data"].get(item, None)
|
34 |
+
|
35 |
+
def __getattr__(self, item):
|
36 |
+
"""Return a saved state value, None if item is undefined."""
|
37 |
+
return self._state["data"].get(item, None)
|
38 |
+
|
39 |
+
def __setitem__(self, item, value):
|
40 |
+
"""Set state value."""
|
41 |
+
self._state["data"][item] = value
|
42 |
+
|
43 |
+
def __setattr__(self, item, value):
|
44 |
+
"""Set state value."""
|
45 |
+
self._state["data"][item] = value
|
46 |
+
|
47 |
+
def clear(self):
|
48 |
+
"""Clear session state and request a rerun."""
|
49 |
+
self._state["data"].clear()
|
50 |
+
self._state["session"].request_rerun()
|
51 |
+
|
52 |
+
def sync(self):
|
53 |
+
"""Rerun the app with all state values up to date from the beginning to fix rollbacks."""
|
54 |
+
|
55 |
+
# Ensure to rerun only once to avoid infinite loops
|
56 |
+
# caused by a constantly changing state value at each run.
|
57 |
+
#
|
58 |
+
# Example: state.value += 1
|
59 |
+
if self._state["is_rerun"]:
|
60 |
+
self._state["is_rerun"] = False
|
61 |
+
|
62 |
+
elif self._state["hash"] is not None:
|
63 |
+
if self._state["hash"] != self._state["hasher"].to_bytes(self._state["data"], None):
|
64 |
+
self._state["is_rerun"] = True
|
65 |
+
self._state["session"].request_rerun()
|
66 |
+
|
67 |
+
self._state["hash"] = self._state["hasher"].to_bytes(self._state["data"], None)
|
68 |
+
|
69 |
+
|
70 |
+
def get_session():
|
71 |
+
session_id = get_report_ctx().session_id
|
72 |
+
session_info = Server.get_current()._get_session_info(session_id)
|
73 |
+
|
74 |
+
if session_info is None:
|
75 |
+
raise RuntimeError("Couldn't get your Streamlit Session object.")
|
76 |
+
|
77 |
+
return session_info.session
|
78 |
+
|
79 |
+
|
80 |
+
def get_state(hash_funcs=None):
|
81 |
+
session = get_session()
|
82 |
+
|
83 |
+
if not hasattr(session, "_custom_session_state"):
|
84 |
+
session._custom_session_state = SessionState(session, hash_funcs)
|
85 |
+
|
86 |
+
return session._custom_session_state
|
streamlit_app.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import streamlit as st
|
3 |
+
import booste
|
4 |
+
|
5 |
+
from session_state import SessionState, get_state
|
6 |
+
|
7 |
+
# Unfortunately Streamlit sharing does not allow to hide enviroment variables yet.
|
8 |
+
# Do not copy this API key, go to https://www.booste.io/ and get your own, it is free!
|
9 |
+
BOOSTE_API_KEY = "3818ba84-3526-4029-9dc8-ef3038697ea2"
|
10 |
+
|
11 |
+
|
12 |
+
task_name: str = st.sidebar.radio("Task", options=["Image classification", "Image ranking", "Prompt ranking"])
|
13 |
+
|
14 |
+
st.markdown("# CLIP playground")
|
15 |
+
st.markdown("### Try OpenAI's CLIP model in your browser")
|
16 |
+
st.markdown(" "); st.markdown(" ")
|
17 |
+
with st.beta_expander("What is CLIP?"):
|
18 |
+
st.markdown("Nice CLIP explaination")
|
19 |
+
st.markdown(" "); st.markdown(" ")
|
20 |
+
if task_name == "Image classification":
|
21 |
+
session_state = get_state()
|
22 |
+
uploaded_image = st.file_uploader("Upload image", type=[".png", ".jpg", ".jpeg"],
|
23 |
+
accept_multiple_files=False)
|
24 |
+
st.markdown("or choose one from")
|
25 |
+
col1, col2, col3 = st.beta_columns(3)
|
26 |
+
with col1:
|
27 |
+
default_image_1 = "https://cdn.pixabay.com/photo/2014/10/13/21/34/clipper-487503_960_720.jpg"
|
28 |
+
st.image(default_image_1, use_column_width=True)
|
29 |
+
if st.button("Select image 1"):
|
30 |
+
session_state.image = default_image_1
|
31 |
+
with col2:
|
32 |
+
default_image_2 = "https://cdn.pixabay.com/photo/2019/12/17/18/20/peacock-4702197_960_720.jpg"
|
33 |
+
st.image(default_image_2, use_column_width=True)
|
34 |
+
if st.button("Select image 2"):
|
35 |
+
session_state.image = default_image_2
|
36 |
+
with col3:
|
37 |
+
default_image_3 = "https://cdn.pixabay.com/photo/2016/11/15/16/24/banana-1826760_960_720.jpg"
|
38 |
+
st.image(default_image_3, use_column_width=True)
|
39 |
+
if st.button("Select image 3"):
|
40 |
+
session_state.image = default_image_3
|
41 |
+
raw_classes = st.text_input("Enter the classes to chose from separated by a comma."
|
42 |
+
" (f.x. `banana, sailing boat, honesty, apple`)")
|
43 |
+
if raw_classes:
|
44 |
+
session_state.processed_classes = raw_classes.split(",")
|
45 |
+
input_prompts = ["A picture of a " + class_name for class_name in session_state.processed_classes]
|
46 |
+
|
47 |
+
col1, col2 = st.beta_columns([2, 1])
|
48 |
+
with col1:
|
49 |
+
st.markdown("Image to classify")
|
50 |
+
if session_state.image is not None:
|
51 |
+
st.image(session_state.image, use_column_width=True)
|
52 |
+
else:
|
53 |
+
st.warning("Select an image")
|
54 |
+
|
55 |
+
with col2:
|
56 |
+
st.markdown("Classes to choose from")
|
57 |
+
if session_state.processed_classes is not None:
|
58 |
+
for class_name in session_state.processed_classes:
|
59 |
+
st.write(class_name)
|
60 |
+
else:
|
61 |
+
st.warning("Enter the classes to classify from")
|
62 |
+
|
63 |
+
# Possible way of customize this https://discuss.streamlit.io/t/st-button-in-a-custom-layout/2187/2
|
64 |
+
if st.button("Predict"):
|
65 |
+
with st.spinner("Predicting..."):
|
66 |
+
clip_response = booste.clip(BOOSTE_API_KEY,
|
67 |
+
prompts=input_prompts,
|
68 |
+
images=[session_state.image],
|
69 |
+
pretty_print=True)
|
70 |
+
st.write(clip_response)
|
71 |
+
|
72 |
+
|
73 |
+
session_state.sync()
|
74 |
+
|
75 |
+
|
76 |
+
|