import datetime import os from pathlib import Path import sys from flask import Flask, jsonify, request, send_file, abort from flask_uploads import UploadSet, configure_uploads, IMAGES from werkzeug.exceptions import default_exceptions from werkzeug.exceptions import HTTPException, NotFound import json import torch import time import threading import traceback from PIL import Image import numpy as np PACKAGE_PARENT = '..' WISE_DIR = '../wise/' SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__)))) sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT))) sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, WISE_DIR))) from parameter_optimization.parametric_styletransfer import single_optimize from parameter_optimization.parametric_styletransfer import CONFIG as ST_CONFIG from parameter_optimization.strotss_org import strotss, pil_resize_long_edge_to from helpers import torch_to_np, np_to_torch from effects import get_default_settings, MinimalPipelineEffect class JSONExceptionHandler(object): def __init__(self, app=None): if app: self.init_app(app) def std_handler(self, error): response = jsonify(message=error.message) response.status_code = error.code if isinstance(error, HTTPException) else 500 return response def init_app(self, app): self.app = app self.register(HTTPException) for code, v in default_exceptions.items(): self.register(code) def register(self, exception_or_code, handler=None): self.app.errorhandler(exception_or_code)(handler or self.std_handler) app = Flask(__name__) handler = JSONExceptionHandler(app) image_folder = 'img_received' photos = UploadSet('photos', IMAGES) app.config['UPLOADED_PHOTOS_DEST'] = image_folder configure_uploads(app, photos) class Args(object): def __init__(self, initial_data): for key in initial_data: setattr(self, key, initial_data[key]) def set_attributes(self, val_dict): for key in val_dict: setattr(self, key, val_dict[key]) default_args = { "output_image" : "output.jpg", ## values always set by request ## "content_image": "", "style_image": "", "output_vp": "", "iters": 500 } total_task_count = 0 class NeuralOptimizer(): def __init__(self, args) -> None: self.cur_iteration = 0 self.args = args def optimize(self): base_dir = f"result/{datetime.datetime.now().strftime(r'%Y-%m-%d %H.%Mh %Ss')}" os.makedirs(base_dir) content = Image.open(self.args.content_image) style = Image.open(self.args.style_image) def set_iter(iter): self.cur_iteration=iter effect, preset, _ = get_default_settings("minimal_pipeline") effect.enable_checkpoints() reference = strotss(pil_resize_long_edge_to(content, 1024), pil_resize_long_edge_to(style, 1024), content_weight=16.0, device=torch.device("cuda"), space="uniform") ref_save_path = os.path.join(base_dir, "reference.jpg") resize_to = 720 reference = pil_resize_long_edge_to(reference, resize_to) reference.save(ref_save_path) ST_CONFIG["n_iterations"] = self.args.iters vp, content_img_cuda = single_optimize(effect, preset, "l1", self.args.content_image, str(ref_save_path), write_video=False, base_dir=base_dir, iter_callback=set_iter) output = Image.fromarray(torch_to_np(content_img_cuda.detach().cpu() * 255.0).astype(np.uint8)) output.save(self.args.output_image) # torch.save (vp.detach().clone(), self.args.output_vp) # preset_tensor = effect.vpd.preset_tensor(preset, np_to_torch(np.array(content)).cuda(), add_local_dims=True) np.savez_compressed(self.args.output_vp, vp=vp.detach().cpu().numpy()) class StyleTask: def __init__(self, task_id, style_filename, content_filename): self.content_filename=content_filename self.style_filename=style_filename self.status = "queued" self.task_id = task_id self.error_msg = "" self.output_filename = content_filename.split(".")[0] + "_output.jpg" self.vp_output_filename = content_filename.split(".")[0] + "_output.npz" # global neural_optimizer # if neural_optimizer is None: # neural_optimizer = NeuralOptimizer(Args(default_args)) self.neural_optimizer = NeuralOptimizer(Args(default_args)) def start(self): self.neural_optimizer.args.set_attributes(default_args) self.neural_optimizer.args.style_image = os.path.join(image_folder, self.style_filename) self.neural_optimizer.args.content_image = os.path.join(image_folder, self.content_filename) self.neural_optimizer.args.output_image = os.path.join(image_folder, self.output_filename) self.neural_optimizer.args.output_vp = os.path.join(image_folder, self.vp_output_filename) thread = threading.Thread(target=self.run, args=()) thread.daemon = True # Daemonize thread thread.start() # Start the execution def run(self): self.status = "running" try: self.neural_optimizer.optimize() except Exception as e: print("Error in task %d :"%(self.task_id), str(e)) traceback.print_exc() self.status = "error" self.error_msg = str(e) return self.status = "finished" print("finished styling task: " + str(self.task_id)) class StylerQueue: queued_tasks = [] finished_tasks = [] running_task = None def __init__(self): thread = threading.Thread(target=self.status_checker, args=()) thread.daemon = True # Daemonize thread thread.start() # Start the execution def queue_task(self, *args): global total_task_count total_task_count += 1 task_id = abs(hash(str(time.time()))) print("queued task num. ", total_task_count, "with ID", task_id) task = StyleTask(task_id, *args) self.queued_tasks.append(task) return task_id def get_task(self, task_id): if self.running_task is not None and self.running_task.task_id == task_id: return self.running_task task = next((task for task in self.queued_tasks + self.finished_tasks if task.task_id == task_id), None) return task def status_checker(self): while True: time.sleep(0.3) if self.running_task is None: if len(self.queued_tasks) > 0: self.running_task = self.queued_tasks[0] self.running_task.start() self.queued_tasks = self.queued_tasks[1:] elif self.running_task.status == "finished" or self.running_task.status == "error": self.finished_tasks.append(self.running_task) if len(self.queued_tasks) > 0: self.running_task = self.queued_tasks[0] self.running_task.start() self.queued_tasks = self.queued_tasks[1:] else: self.running_task = None styler_queue = StylerQueue() @app.route('/upload', methods=['POST']) def upload(): if 'style-image' in request.files and \ 'content-image' in request.files: style_filename = photos.save(request.files['style-image']) content_filename = photos.save(request.files['content-image']) job_id = styler_queue.queue_task(style_filename, content_filename) print('added new stylization task', style_filename, content_filename) return jsonify({"task_id": job_id}) abort(jsonify(message="request needs style, content image"), 400) @app.route('/get_status') def get_status(): task_id = int(request.args.get("task_id")) task = styler_queue.get_task(task_id) if task is None: abort(jsonify(message="task with id %d not found"%task_id), 400) status = { "status": task.status, "msg": task.error_msg } if task.status == "running": if isinstance(task, StyleTask): status["progress"] = float(task.neural_optimizer.cur_iteration) / float(default_args["iters"]) return jsonify(status) @app.route('/queue_length') def get_queue_length(): tasks = len(styler_queue.queued_tasks) if styler_queue.running_task is not None: tasks += 1 status = { "length": tasks } return jsonify(status) @app.route('/get_image') def get_image(): task_id = int(request.args.get("task_id")) task = styler_queue.get_task(task_id) if task is None: abort(jsonify(message="task with id %d not found"%task_id), 400) if task.status != "finished": abort(jsonify(message="task with id %d not in finished state"%task_id), 400) return send_file(os.path.join(image_folder, task.output_filename), mimetype='image/jpg') @app.route('/get_vp') def get_vp(): task_id = int(request.args.get("task_id")) task = styler_queue.get_task(task_id) if task is None: abort(jsonify(message="task with id %d not found"%task_id), 400) if task.status != "finished": abort(jsonify(message="task with id %d not in finished state"%task_id), 400) return send_file(os.path.join(image_folder, task.vp_output_filename), mimetype='application/zip') if __name__ == '__main__': app.run(debug=False, host="0.0.0.0",port=8600)