Spaces:
Runtime error
Runtime error
import time | |
import sys | |
import streamlit as st | |
import string | |
import os | |
from io import StringIO | |
import pdb | |
import json | |
import torch | |
import requests | |
import socket | |
from streamlit_image_select import image_select | |
use_case = {"1":"Image background removal","2":"Masking foreground for downstream inpainting task"} | |
mask_types = {"blur - blurs background":"blur","map - makes the foreground white and rest black ":"map","rgba - makes background white":"rgba","green - makes the background green":"green"} | |
APP_NAME = "hf/salient_object_detection" | |
INFO_URL = "https://www.taskswithcode.com/stats/" | |
TMP_DIR="tmp_dir" | |
TMP_SEED = 1 | |
def get_views(action): | |
ret_val = 0 | |
#return "{:,}".format(ret_val) | |
hostname = socket.gethostname() | |
ip_address = socket.gethostbyname(hostname) | |
if ("view_count" not in st.session_state): | |
try: | |
app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address} | |
res = requests.post(INFO_URL, json = app_info).json() | |
print(res) | |
data = res["count"] | |
except: | |
data = 0 | |
ret_val = data | |
st.session_state["view_count"] = data | |
else: | |
ret_val = st.session_state["view_count"] | |
if (action != "init"): | |
app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address} | |
res = requests.post(INFO_URL, json = app_info).json() | |
return "{:,}".format(ret_val) | |
def construct_model_info_for_display(model_names): | |
options_arr = [] | |
#markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/></div>" | |
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Model evaluated </b><br/></div>" | |
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>" | |
for node in model_names: | |
options_arr .append(node["name"]) | |
if (node["mark"] == "True"): | |
markdown_str += f"<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"> • Model: <a href=\'{node['paper_url']}\' target='_blank'>{node['name']}</a><br/> Code released by: <a href=\'{node['orig_author_url']}\' target='_blank'>{node['orig_author']}</a><br/> Model info: <a href=\'{node['sota_info']['sota_link']}\' target='_blank'>{node['sota_info']['task']}</a></div>" | |
if ("Note" in node): | |
markdown_str += f"<div style=\"font-size:16px; color: #a91212; text-align: left\"> {node['Note']}<a href=\'{node['alt_url']}\' target='_blank'>link</a></div>" | |
markdown_str += "<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"><br/></div>" | |
markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><b>Note:</b><br/>• Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not cached</div>" | |
markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><br/><a href=\'https://github.com/taskswithcode/salient_object_detection_app.git\' target='_blank'>Github code</a> for this app</div>" | |
return options_arr,markdown_str | |
def init_page(): | |
st.set_page_config(page_title='TWC - State-of-the-art model salient object detection (visually dominant objects in an image)', page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto', | |
menu_items={ | |
'About': 'This app was created by taskswithcode. http://taskswithcode.com' | |
}) | |
col,pad = st.columns([85,15]) | |
with col: | |
st.image("long_form_logo_with_icon.png") | |
def run_test(config,input_file_name,display_area,uploaded_file,mask_type): | |
global TMP_SEED | |
display_area.text("Processing request...") | |
try: | |
if (uploaded_file is None): | |
file_data = open(input_file_name, "rb") | |
r = requests.post(config["SERVER_ADDRESS"], data={"mask":mask_type}, files={"test":file_data}) | |
else: | |
file_data = uploaded_file.read() | |
file_name = f"{TMP_DIR}/{TMP_SEED}_{str(time.time()).replace('.','_')}_{uploaded_file.name}" | |
TMP_SEED += 1 | |
with open(file_name,"wb") as fp: | |
fp.write(file_data) | |
file_data = open(file_name, "rb") | |
r = requests.post(config["SERVER_ADDRESS"], data={"mask":mask_type}, files={"test":file_data}) | |
os.remove(file_name) | |
print("Servers response:",r.status_code,len(r.content)) | |
if (r.status_code == 200): | |
size = "{:,}".format(len(r.content)) | |
return {"response":r.content,"size":size} | |
else: | |
return {"error":f"API request failed {r.status_code}"} | |
except Exception as e: | |
st.error("Some error occurred during prediction" + str(e)) | |
st.stop() | |
return {"error":f"Exception in performing salient object detection: {str(e)}"} | |
return {} | |
def display_results(results,response_info,mask): | |
main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>" | |
body_sent = [] | |
download_data = {} | |
main_sent = main_sent + "\n" + '\n'.join(body_sent) | |
st.markdown(main_sent,unsafe_allow_html=True) | |
st.image(results["response"], caption=f'Output of salient object detection with mask: {mask}') | |
st.session_state["download_ready"] = results["response"] | |
get_views("submit") | |
def init_session(): | |
print("Init session") | |
init_page() | |
st.session_state["model_name"] = "insprynet" | |
st.session_state["download_ready"] = None | |
st.session_state["model_name"] = "ss_test" | |
st.session_state["file_name"] = "default" | |
st.session_state["mask_type"] = "blur" | |
def app_main(app_mode,example_files,model_name_files,config_file): | |
init_session() | |
with open(example_files) as fp: | |
example_file_names = json.load(fp) | |
with open(model_name_files) as fp: | |
model_names = json.load(fp) | |
with open(config_file) as fp: | |
config = json.load(fp) | |
curr_use_case = use_case[app_mode].split(".")[0] | |
curr_use_case = use_case[app_mode].split(".")[0] | |
st.markdown("<h5 style='text-align: center;'>State-of-the-art model for salient object detection</h5>", unsafe_allow_html=True) | |
st.markdown(f"<div style='color: #4f4f4f; text-align: left'>Use cases for salient object detection<br/> • {use_case['1']}<br/> • {use_case['2']}</div>", unsafe_allow_html=True) | |
st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views: {get_views('init')}</div>", unsafe_allow_html=True) | |
try: | |
with st.form('twc_form'): | |
step1_line = "Upload an image or choose an example image below" | |
uploaded_file = st.file_uploader(step1_line, type=["png","jpg","jpeg"]) | |
selected_file_name = image_select("Select image", ["twc_samples/sample1.jpg", "twc_samples/sample2.jpg", "twc_samples/sample3.jpg", "twc_samples/sample4.jpg"]) | |
st.write("") | |
mask_type = st.selectbox(label=f'Select type of masking', | |
options = list(dict.keys(mask_types)), index=0, key = "twc_mask_types") | |
mask_type = mask_types[mask_type] | |
st.write("") | |
submit_button = st.form_submit_button('Run') | |
options_arr,markdown_str = construct_model_info_for_display(model_names) | |
input_status_area = st.empty() | |
display_area = st.empty() | |
if submit_button: | |
start = time.time() | |
if uploaded_file is not None: | |
st.session_state["file_name"] = uploaded_file.name | |
else: | |
st.session_state["file_name"] = selected_file_name | |
st.session_state["mask_type"] = mask_type | |
display_area.empty() | |
results = run_test(config,st.session_state["file_name"],display_area,uploaded_file,mask_type) | |
with display_area.container(): | |
if ("error" in results): | |
st.error(results["error"]) | |
else: | |
device = 'GPU' if torch.cuda.is_available() else 'CPU' | |
response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for image size: {results['size']} bytes" | |
display_results(results,response_info,mask_type) | |
#st.json(results) | |
st.download_button( | |
label="Download results as png", | |
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "", | |
disabled = False if st.session_state["download_ready"] != None else True, | |
file_name= (st.session_state["model_name"] + "_" + st.session_state["mask_type"] + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".png").replace("/","_"), | |
mime='image/png', | |
key ="download" | |
) | |
except Exception as e: | |
st.error("Some error occurred during loading" + str(e)) | |
st.stop() | |
st.markdown(markdown_str, unsafe_allow_html=True) | |
if __name__ == "__main__": | |
app_main("1","sod_app_examples.json","sod_app_models.json","config.json") | |