File size: 6,419 Bytes
85cce87 89dbdbc 85cce87 89dbdbc 85cce87 89dbdbc 85cce87 89dbdbc 85cce87 89dbdbc 85cce87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import base64
import datetime
import os
import sys
from io import BytesIO
from pathlib import Path
import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
import time
import streamlit as st
from demo_config import HUGGING_FACE, WORKER_URL
PACKAGE_PARENT = 'wise'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))
from parameter_optimization.parametric_styletransfer import single_optimize
from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG
from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to
from helpers import torch_to_np, np_to_torch
def retrieve_for_results_from_server():
task_id = st.session_state['current_server_task_id']
vp_res = requests.get(WORKER_URL+"/get_vp", params={"task_id": task_id})
image_res = requests.get(WORKER_URL+"/get_image", params={"task_id": task_id})
if vp_res.status_code != 200 or image_res.status_code != 200:
st.warning("got status for " + WORKER_URL+"/get_vp" + str(vp_res.status_code))
st.warning("got status for " + WORKER_URL+"/image_res" + str(image_res.status_code))
st.session_state['current_server_task_id'] = None
vp_res.raise_for_status()
image_res.raise_for_status()
else:
st.session_state['current_server_task_id'] = None
vp = np.load(BytesIO(vp_res.content))["vp"]
print("received vp from server")
print("got numpy array", vp.shape)
vp = torch.from_numpy(vp).cuda()
image = Image.open(BytesIO(image_res.content))
print("received image from server")
image = np_to_torch(np.asarray(image)).cuda()
st.session_state["effect_input"] = image
st.session_state["result_vp"] = vp
def monitor_task(progress_placeholder):
task_id = st.session_state['current_server_task_id']
started_time = time.time()
retries = 3
with progress_placeholder.container():
st.warning("Do not interact with the app until results are shown - otherwise results might be lost.")
progress_bar = st.empty()
while True:
status = requests.get(WORKER_URL+"/get_status", params={"task_id": task_id})
if status.status_code != 200:
print("get_status got status_code", status.status_code)
st.warning(status.content)
retries -= 1
if retries == 0:
return
else:
time.sleep(2)
continue
status = status.json()
print(status)
if status["status"] != "running" and status["status"] != "queued" :
if status["msg"] != "":
print("got error for task", task_id, ":", status["msg"])
progress_placeholder.error(status["msg"])
st.session_state['current_server_task_id'] = None
st.stop()
if status["status"] == "finished":
retrieve_for_results_from_server()
return
elif status["status"] == "queued":
started_time = time.time()
queue_length = requests.get(WORKER_URL+"/queue_length").json()
progress_bar.write(f"There are {queue_length['length']} tasks in the queue")
elif status["progress"] == 0.0:
progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts
progress_bar.progress(progressed)
else:
progress_bar.progress(min(0.5 + status["progress"] / 2.0, 1.0))
time.sleep(2)
def get_queue_length():
queue_length = requests.get(WORKER_URL+"/queue_length").json()
return queue_length['length']
def optimize_on_server(content, style, result_image_placeholder):
content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
asp_c, asp_s = content.height / content.width, style.height / style.width
if any(a < 0.5 or a > 2.0 for a in (asp_c, asp_s)):
result_image_placeholder.error('aspect ratio must be <= 2')
st.stop()
content = pil_resize_long_edge_to(content, 1024)
content.save(content_path)
style = pil_resize_long_edge_to(style, 1024)
style.save(style_path)
files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")}
print("start-optimizing")
url = WORKER_URL + "/upload"
task_id_res = requests.post(url, files=files)
if task_id_res.status_code != 200:
result_image_placeholder.error(task_id_res.content)
st.stop()
else:
task_id = task_id_res.json()['task_id']
st.session_state['current_server_task_id'] = task_id
monitor_task(result_image_placeholder)
def optimize_params(effect, preset, content, style, result_image_placeholder):
result_image_placeholder.text("Executing NST to create reference image..")
base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}"
os.makedirs(base_dir)
reference = strotss(pil_resize_long_edge_to(content, 1024),
pil_resize_long_edge_to(style, 1024), content_weight=16.0,
device=torch.device("cuda"), space="uniform")
progress_bar = result_image_placeholder.progress(0.0)
ref_save_path = os.path.join(base_dir, "reference.jpg")
content_save_path = os.path.join(base_dir, "content.jpg")
resize_to = 720
reference = pil_resize_long_edge_to(reference, resize_to)
reference.save(ref_save_path)
content.save(content_save_path)
ST_CONFIG["n_iterations"] = 300
vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path),
write_video=False, base_dir=base_dir,
iter_callback=lambda i: progress_bar.progress(
float(i) / ST_CONFIG["n_iterations"]))
st.session_state["effect_input"], st.session_state["result_vp"] = content_img_cuda.detach(), vp.cuda().detach() |