File size: 6,001 Bytes
85cce87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import base64
import datetime
import os
import sys
from io import BytesIO
from pathlib import Path
import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
import time
import streamlit as st
from demo_config import HUGGING_FACE, WORKER_URL



PACKAGE_PARENT = 'wise'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

from parameter_optimization.parametric_styletransfer import single_optimize
from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG
from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to
from helpers import torch_to_np, np_to_torch

def retrieve_for_results_from_server():
    task_id = st.session_state['current_server_task_id']
    vp_res = requests.get(WORKER_URL+"/get_vp", params={"task_id": task_id})
    image_res = requests.get(WORKER_URL+"/get_image", params={"task_id": task_id})
    if vp_res.status_code != 200 or image_res.status_code != 200:
        st.warning("got status for " + WORKER_URL+"/get_vp" + str(vp_res.status_code))
        st.warning("got status for " + WORKER_URL+"/image_res" + str(image_res.status_code))
        st.session_state['current_server_task_id'] = None
        vp_res.raise_for_status()
        image_res.raise_for_status()
    else:
        st.session_state['current_server_task_id'] = None
        vp = np.load(BytesIO(vp_res.content))["vp"] 
        print("received vp from server")
        print("got numpy array", vp.shape)
        vp = torch.from_numpy(vp).cuda()
        image = Image.open(BytesIO(image_res.content))
        print("received image from server")
        image = np_to_torch(np.asarray(image)).cuda()

        st.session_state["effect_input"] = image
        st.session_state["result_vp"] = vp


def monitor_task(progress_placeholder):
    task_id = st.session_state['current_server_task_id']

    started_time = time.time()
    retries = 3
    while True:
        status = requests.get(WORKER_URL+"/get_status", params={"task_id": task_id})
        if status.status_code != 200:
            print("get_status got status_code", status.status_code)
            st.warning(status.content)
            retries -= 1
            if retries == 0:
                return
            else:
                time.sleep(2)
                continue
        status = status.json()
        print(status)
        if status["status"] != "running" and status["status"] != "queued" :
            if status["msg"] != "":
                print("got error for task", task_id, ":", status["msg"])
                progress_placeholder.error(status["msg"])
                st.session_state['current_server_task_id'] = None
                st.stop()
            if status["status"] == "finished":
                retrieve_for_results_from_server()
            return
        elif status["status"] == "queued":
            started_time = time.time()
            queue_length = requests.get(WORKER_URL+"/queue_length").json()
            progress_placeholder.write(f"There are {queue_length['length']} tasks in the queue")
        elif status["progress"] == 0.0:
            progressed = min(0.5 * (time.time() - started_time) / 80.0, 0.5) #estimate 80s for strotts
            progress_placeholder.progress(progressed)
        else:
            progress_placeholder.progress(min(0.5 + status["progress"] / 2.0, 1.0))

        time.sleep(2)


def optimize_on_server(content, style, result_image_placeholder):
    url = WORKER_URL + "/upload"
    content_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
    style_path=f"/tmp/content-wise-uploaded{str(datetime.datetime.timestamp(datetime.datetime.now()))}.jpg"
    asp_c, asp_s =  content.height / content.width, style.height / style.width
    if any(a < 0.5 or a > 2.0 for a in (asp_c, asp_s)):
        result_image_placeholder.error('aspect ratio must be <= 2')
        st.stop()
    content = pil_resize_long_edge_to(content, 1024)
    content.save(content_path)
    style = pil_resize_long_edge_to(style, 1024)
    style.save(style_path)
    files = {'style-image': open(style_path, "rb"), "content-image": open(content_path, "rb")}
    print("start-optimizing")
    task_id_res = requests.post(url, files=files)
    if task_id_res.status_code != 200:
        result_image_placeholder.error(task_id_res.content)
        st.stop()
    else:
        task_id = task_id_res.json()['task_id']
        st.session_state['current_server_task_id'] = task_id

    monitor_task(result_image_placeholder)

def optimize_params(effect, preset, content, style, result_image_placeholder):
    result_image_placeholder.text("Executing NST to create reference image..")
    base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}"
    os.makedirs(base_dir)
    reference = strotss(pil_resize_long_edge_to(content, 1024),
                        pil_resize_long_edge_to(style, 1024), content_weight=16.0,
                        device=torch.device("cuda"), space="uniform")
    progress_bar = result_image_placeholder.progress(0.0)
    ref_save_path = os.path.join(base_dir, "reference.jpg")
    content_save_path = os.path.join(base_dir, "content.jpg")
    resize_to = 720
    reference = pil_resize_long_edge_to(reference, resize_to)
    reference.save(ref_save_path)
    content.save(content_save_path)
    ST_CONFIG["n_iterations"] = 300
    
    vp, content_img_cuda = single_optimize(effect, preset, "l1", content_save_path, str(ref_save_path),
                                        write_video=False, base_dir=base_dir,
                                        iter_callback=lambda i: progress_bar.progress(
                                            float(i) / ST_CONFIG["n_iterations"]))
    st.session_state["effect_input"], st.session_state["result_vp"]  = content_img_cuda.detach(), vp.cuda().detach()