#!/usr/bin/env python3 # coding=utf-8 # # Copyright 2020 Institute of Formal and Applied Linguistics, Faculty of # Mathematics and Physics, Charles University, Czech Republic. # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. """Word embeddings server class.""" import email.parser import http.server import json import random import socketserver import sys import threading import time import urllib.error import urllib.parse import urllib.request class FrontendRESTServer(socketserver.TCPServer): class Backend(): def __init__(self, server): self._server = server with self.request("/models") as response: data = json.loads(response.read()) assert "models" in data and isinstance(data["models"], dict) self.models = data["models"] assert "default_model" in data and isinstance(data["default_model"], str) self.default_model = data["default_model"] def request(self, url, data=None, data_content_type=None): return urllib.request.urlopen(urllib.request.Request( url="http://{}{}".format(self._server, url), data=data, headers={} if data is None else {"Content-Type": data_content_type}, )) class FrontendRESTServer(http.server.BaseHTTPRequestHandler): protocol_version = "HTTP/1.1" format_for_log_table = str.maketrans("\n", "\r", "\r") def format_for_log(request, data, limit=None): if limit is not None: if limit <= 0: data = "[{}B]".format(len(data)) elif len(data) > limit: data = data[:limit // 2] + " ... " + data[min(-1, -limit // 2):] return data.translate(request.format_for_log_table) def respond(request, content_type, code=200, additional_headers={}): request.close_connection = True request.send_response(code) request.send_header("Connection", "close") request.send_header("Content-Type", content_type) request.send_header("Access-Control-Allow-Origin", "*") for key, value in additional_headers.items(): request.send_header(key, value) request.end_headers() def respond_error(request, message, code=400): request.respond("text/plain", code) request.wfile.write(message.encode("utf-8")) def handle_expect_100(request): try: request_too_long = int(request.headers["Content-Length"]) > request.server._args.max_request_size except: request_too_long = False if request_too_long: request.respond_error("The payload size is too large.") return False return super().handle_expect_100() def do_GET(request): # Parse the model from URL/body params, body, body_content_type = {}, None, None try: encoded_path = request.path.encode("iso-8859-1").decode("utf-8") url = urllib.parse.urlparse(encoded_path) for name, value in urllib.parse.parse_qsl(url.query, encoding="utf-8", keep_blank_values=True, errors="strict"): params[name] = value except: return request.respond_error("Cannot parse request URL.") # Parse the body of a POST request if request.command == "POST": if request.headers.get("Transfer-Encoding", "identity").lower() != "identity": return request.respond_error("Only 'identity' Transfer-Encoding of payload is supported for now.") try: content_length = int(request.headers["Content-Length"]) except: return request.respond_error("The Content-Length of payload is required.") if content_length > request.server._args.max_request_size: while content_length: read = request.rfile.read(min(content_length, 65536)) content_length -= len(read) if read else content_length return request.respond_error("The payload size is too large.") body = request.rfile.read(content_length) body_content_type = request.headers.get("Content-Type", "") # multipart/form-data if request.headers.get("Content-Type", "").startswith("multipart/form-data"): try: parser = email.parser.BytesFeedParser() parser.feed(b"Content-Type: " + request.headers["Content-Type"].encode("ascii") + b"\r\n\r\n") parser.feed(body) for part in parser.close().get_payload(): name = part.get_param("name", header="Content-Disposition") if name: params[name] = part.get_payload(decode=True).decode("utf-8") except: return request.respond_error("Cannot parse the multipart/form-data payload.") # x-www-form-urlencoded elif request.headers.get("Content-Type", "").startswith("application/x-www-form-urlencoded"): try: for name, value in urllib.parse.parse_qsl( body.decode("utf-8"), encoding="utf-8", keep_blank_values=True, errors="strict"): params[name] = value except: return request.respond_error("Cannot parse the application/x-www-form-urlencoded payload.") # Log if required if request.server._args.log_data: print(url.path, " ".join(request.headers.get_all("X-Forwarded-For", [])), *["{}:{}".format(key, request.format_for_log(value)) for key, value in params.items() if key != "data"], "data:" + request.format_for_log(params.get("data", ""), request.server._args.log_data), sep="\t", file=sys.stderr, flush=True) # Handle /models if url.path == "/models": response = { "models": {name: value for backend in request.server.backends for name, value in backend.models.items()}, "default_model": request.server.backends[0].default_model, } request.respond("application/json") request.wfile.write(json.dumps(response, indent=1).encode("utf-8")) # Handle everything else else: # Start by finding appropriate backends backends = request.server.backends.copy() model = params.get("model", request.server.backends[0].default_model) if model in request.server.aliases: resolved_model = request.server.aliases[model] backends = [backend for backend in request.server.backends if resolved_model in backend.models] or backends # Forward the request to the backend started_responding = False try: assert backends, "No backends found!" while backends: backend = random.choice(backends) if len(backends) > 1 else backends[0] backends.remove(backend) try: with backend.request(request.path, body, body_content_type) as response: while True: data = response.read(32768) if not started_responding: started_responding = True billing_infclen = response.getheader("X-Billing-Input-NFC-Len", None) headers = {"X-Billing-Input-NFC-Len": billing_infclen} if billing_infclen is not None else {} request.respond(response.getheader("Content-Type", "application/json"), code=response.code, additional_headers=headers) if len(data) == 0: break request.wfile.write(data) except urllib.error.HTTPError as error: if not started_responding: started_responding = True request.respond(error.headers.get("Content-Type", "text/plain"), code=error.code) request.wfile.write(error.file.read()) break raise except: if backends and not started_responding: import traceback traceback.print_exc(file=sys.stderr) print("The above error occurred during request processing on '{}',".format(backend._server), "but more backends are available, retrying.", file=sys.stderr, flush=True) continue raise break except: import traceback traceback.print_exc(file=sys.stderr) sys.stderr.flush() if not started_responding: request.respond_error("An internal error occurred during processing.") else: request.wfile.write(b'",\n"An internal error occurred during processing, producing incorrect JSON!"') def do_POST(request): return request.do_GET() def __init__(self, args): self._args = args # Initialize all backends self.backends = [self.Backend(backend) for backend in args.backends] # Initialize the aliases self.aliases = {} if args.aliases is not None: with open(args.aliases, "r", encoding="utf-8") as aliases_file: for line in aliases_file: line = line.rstrip("\r\n") if not line or line.startswith("#"): continue parts = line.split() assert len(parts) in [3, 4], "Expected 3-4 columns in the aliases file: line '{}'".format(line) names = parts[0].split(":") for name in names: parts = name.split("-") for prefix in ("-".join(parts[:None if not i else -i]) for i in range(len(parts))): self.aliases.setdefault(prefix, names[0]) # Initialize the server self._threads = [] super().__init__(("", self._args.port), self.FrontendRESTServer) def server_bind(self): import socket self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) super().server_bind() def server_activate(self): self.socket.listen(256) def process_request_thread(self, request, client_address): try: self.finish_request(request, client_address) except Exception: self.handle_error(request, client_address) finally: self.shutdown_request(request) def process_request(self, request, client_address): thread = threading.Thread(target=self.process_request_thread, args=(request, client_address), daemon=False) self._threads.append(thread) thread.start() def service_actions(self): if len(self._threads) >= self._args.max_concurrency: self._threads = [thread for thread in self._threads if thread.is_alive()] while len(self._threads) >= self._args.max_concurrency: time.sleep(0.1) self._threads = [thread for thread in self._threads if thread.is_alive()] def server_close(self): super().server_close() for thread in self._threads: thread.join() if __name__ == "__main__": import argparse import signal # Parse server arguments parser = argparse.ArgumentParser() parser.add_argument("port", type=int, help="Port to use") parser.add_argument("backends", type=str, nargs="+", help="Backends to use") parser.add_argument("--aliases", default=None, type=str, help="Path to model aliases") parser.add_argument("--logfile", default=None, type=str, help="Log path") parser.add_argument("--log_data", default=None, type=int, help="Log that much bytes of every request data") parser.add_argument("--max_concurrency", default=256, type=int, help="Maximum concurrency") parser.add_argument("--max_request_size", default=4096*1024, type=int, help="Maximum request size") args = parser.parse_args() # Log stderr to logfile if given if args.logfile is not None: sys.stderr = open(args.logfile, "a", encoding="utf-8") # Create the server server = FrontendRESTServer(args) server_thread = threading.Thread(target=server.serve_forever, daemon=True) server_thread.start() print("Started Frontend REST server on port {}.".format(args.port), file=sys.stderr) print("To stop it gracefully, either send SIGINT (Ctrl+C) or SIGUSR1.", file=sys.stderr, flush=True) # Wait until the server should be closed signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT, signal.SIGUSR1]) signal.sigwait([signal.SIGINT, signal.SIGUSR1]) print("Initiating shutdown of the Frontend REST server.", file=sys.stderr, flush=True) server.shutdown() print("Stopped handling new requests, processing all current ones.", file=sys.stderr, flush=True) server.server_close() print("Finished shutdown of the Frontend REST server.", file=sys.stderr, flush=True)