File size: 6,453 Bytes
85cce87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89dbdbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85cce87
89dbdbc
 
 
 
 
 
 
85cce87
89dbdbc
 
 
 
 
 
 
85cce87
 
 
 
 
 
 
 
 
89dbdbc
85cce87
 
 
 
 
34a420d
89dbdbc
85cce87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34a420d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import base64
import datetime
import os
import sys
from io import BytesIO
from pathlib import Path
import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
import time
import streamlit as st
from demo_config import HUGGING_FACE, WORKER_URL



PACKAGE_PARENT = 'wise'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

from parameter_optimization.parametric_styletransfer import single_optimize
from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG
from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to
from helpers import torch_to_np, np_to_torch

def retrieve_for_results_from_server():
    task_id = st.session_state['current_server_task_id']
    vp_res = requests.get(WORKER_URL+"/get_vp", params={"task_id": task_id})
    image_res = requests.get(WORKER_URL+"/get_image", params={"task_id": task_id})
    if vp_res.status_code != 200 or image_res.status_code != 200:
        st.warning("got status for " + WORKER_URL+"/get_vp" + str(vp_res.status_code))
        st.warning("got status for " + WORKER_URL+"/image_res" + str(image_res.status_code))
        st.session_state['current_server_task_id'] = None
        vp_res.raise_for_status()
        image_res.raise_for_status()
    else:
        st.session_state['current_server_task_id'] = None
        vp = np.load(BytesIO(vp_res.content))["vp"] 
        print("received vp from server")
        print("got numpy array", vp.shape)
        vp = torch.from_numpy(vp).cuda()
        image = Image.open(BytesIO(image_res.content))
        print("received image from server")
        image = np_to_torch(np.asarray(image)).cuda()

        st.session_state["effect_input"] = image
        st.session_state["result_vp"] = vp


def monitor_task(progress_placeholder):
    task_id = st.session_state['current_server_task_id']

    started_time = time.time()
    retries = 3
    with progress_placeholder.container():
        st.warning("Do not interact with the app until results are shown - otherwise results might be lost.")
        progress_bar = st.empty()
        while True:
            status = requests.get(WORKER_URL+"/get_status", params={"task_id": task_id})
            if status.status_code != 200:
                print("get_status got status_code", status.status_code)
                st.warning(status.content)
                retries -= 1
                if retries == 0:
                    return
                else:
                    time.sleep(2)
                    continue
            status = status.json()
            print(status)
            if status["status"] != "running" and status["status"] != "queued" :
                if status["msg"] != "":
                    print("got error for task", task_id, ":", status["msg"])
                    progress_placeholder.error(status["msg"])
                    st.session_state['current_server_task_id'] = None
                    st.stop()
                if status["status"] == "finished":
                    retrieve_for_results_from_server()
                return
            elif status["status"] == "queued":
                started_time = time.time()
                queue_length = requests.get(WORKER_URL+"/queue_length").json()
                progress_bar.write(f"There are {queue_length['length']} tasks in the queue")
            elif status["progress"] == 0.0:
                progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts
                progress_bar.progress(progressed)
            else:
                progress_bar.progress(min(0.5 + status["progress"] / 2.0, 1.0))

            time.sleep(2)

def get_queue_length():
    queue_length = requests.get(WORKER_URL+"/queue_length").json()
    return queue_length['length']


def optimize_on_server(content, style, result_image_placeholder):
    content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
    style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
    asp_c, asp_s =  content.height / content.width, style.height / style.width
    if any(a < 0.5 or a > 2.0 for a in (asp_c, asp_s)):
        result_image_placeholder.error('aspect ratio must be <= 2')
        st.stop()

    content = pil_resize_long_edge_to(content, 1024)
    content.save(content_path)
    style = pil_resize_long_edge_to(style, 1024)
    style.save(style_path)
    files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")}
    print("start-optimizing. Time: ", datetime.datetime.now())
    url = WORKER_URL + "/upload"
    task_id_res = requests.post(url, files=files)
    if task_id_res.status_code != 200:
        result_image_placeholder.error(task_id_res.content)
        st.stop()
    else:
        task_id = task_id_res.json()['task_id']
        st.session_state['current_server_task_id'] = task_id

    monitor_task(result_image_placeholder)

def optimize_params(effect, preset, content, style, result_image_placeholder):
    result_image_placeholder.text("Executing NST to create reference image..")
    base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}"
    os.makedirs(base_dir)
    reference = strotss(pil_resize_long_edge_to(content, 1024),
                        pil_resize_long_edge_to(style, 1024), content_weight=16.0,
                        device=torch.device("cuda"), space="uniform")
    progress_bar = result_image_placeholder.progress(0.0)
    ref_save_path = os.path.join(base_dir, "reference.jpg")
    content_save_path = os.path.join(base_dir, "content.jpg")
    resize_to = 720
    reference = pil_resize_long_edge_to(reference, resize_to)
    reference.save(ref_save_path)
    content.save(content_save_path)
    ST_CONFIG["n_iterations"] = 300
    
    vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path),
                                        write_video=False, base_dir=base_dir,
                                        iter_callback=lambda i: progress_bar.progress(
                                            float(i) / ST_CONFIG["n_iterations"]))
    st.session_state["effect_input"], st.session_state["result_vp"]  = content_img_cuda.detach(), vp.cuda().detach()