Max Reimann
Integrate servertask into demo app, add dockerfile
85cce87
raw
history blame
6 kB
import base64
import datetime
import os
import sys
from io import BytesIO
from pathlib import Path
import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
import time
import streamlit as st
from demo_config import HUGGING_FACE, WORKER_URL
PACKAGE_PARENT = 'wise'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
from parameter_optimization.parametric_styletransfer import single_optimize
from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG
from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to
from helpers import torch_to_np, np_to_torch
def retrieve_for_results_from_server():
task_id = st.session_state['current_server_task_id']
vp_res = requests.get(WORKER_URL+"/get_vp", params={"task_id": task_id})
image_res = requests.get(WORKER_URL+"/get_image", params={"task_id": task_id})
if vp_res.status_code != 200 or image_res.status_code != 200:
st.warning("got status for " + WORKER_URL+"/get_vp" + str(vp_res.status_code))
st.warning("got status for " + WORKER_URL+"/image_res" + str(image_res.status_code))
st.session_state['current_server_task_id'] = None
vp_res.raise_for_status()
image_res.raise_for_status()
else:
st.session_state['current_server_task_id'] = None
vp = np.load(BytesIO(vp_res.content))["vp"]
print("received vp from server")
print("got numpy array", vp.shape)
vp = torch.from_numpy(vp).cuda()
image = Image.open(BytesIO(image_res.content))
print("received image from server")
image = np_to_torch(np.asarray(image)).cuda()
st.session_state["effect_input"] = image
st.session_state["result_vp"] = vp
def monitor_task(progress_placeholder):
task_id = st.session_state['current_server_task_id']
started_time = time.time()
retries = 3
while True:
status = requests.get(WORKER_URL+"/get_status", params={"task_id": task_id})
if status.status_code != 200:
print("get_status got status_code", status.status_code)
st.warning(status.content)
retries -= 1
if retries == 0:
return
else:
time.sleep(2)
continue
status = status.json()
print(status)
if status["status"] != "running" and status["status"] != "queued" :
if status["msg"] != "":
print("got error for task", task_id, ":", status["msg"])
progress_placeholder.error(status["msg"])
st.session_state['current_server_task_id'] = None
st.stop()
if status["status"] == "finished":
retrieve_for_results_from_server()
return
elif status["status"] == "queued":
started_time = time.time()
queue_length = requests.get(WORKER_URL+"/queue_length").json()
progress_placeholder.write(f"There are {queue_length['length']} tasks in the queue")
elif status["progress"] == 0.0:
progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts
progress_placeholder.progress(progressed)
else:
progress_placeholder.progress(min(0.5 + status["progress"] / 2.0, 1.0))
time.sleep(2)
def optimize_on_server(content, style, result_image_placeholder):
url = WORKER_URL + "/upload"
content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
asp_c, asp_s = content.height / content.width, style.height / style.width
if any(a < 0.5 or a > 2.0 for a in (asp_c, asp_s)):
result_image_placeholder.error('aspect ratio must be <= 2')
st.stop()
content = pil_resize_long_edge_to(content, 1024)
content.save(content_path)
style = pil_resize_long_edge_to(style, 1024)
style.save(style_path)
files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")}
print("start-optimizing")
task_id_res = requests.post(url, files=files)
if task_id_res.status_code != 200:
result_image_placeholder.error(task_id_res.content)
st.stop()
else:
task_id = task_id_res.json()['task_id']
st.session_state['current_server_task_id'] = task_id
monitor_task(result_image_placeholder)
def optimize_params(effect, preset, content, style, result_image_placeholder):
result_image_placeholder.text("Executing NST to create reference image..")
base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}"
os.makedirs(base_dir)
reference = strotss(pil_resize_long_edge_to(content, 1024),
pil_resize_long_edge_to(style, 1024), content_weight=16.0,
device=torch.device("cuda"), space="uniform")
progress_bar = result_image_placeholder.progress(0.0)
ref_save_path = os.path.join(base_dir, "reference.jpg")
content_save_path = os.path.join(base_dir, "content.jpg")
resize_to = 720
reference = pil_resize_long_edge_to(reference, resize_to)
reference.save(ref_save_path)
content.save(content_save_path)
ST_CONFIG["n_iterations"] = 300
vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path),
write_video=False, base_dir=base_dir,
iter_callback=lambda i: progress_bar.progress(
float(i) / ST_CONFIG["n_iterations"]))
st.session_state["effect_input"], st.session_state["result_vp"] = content_img_cuda.detach(), vp.cuda().detach()