File size: 14,385 Bytes
0e5da39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
#!/usr/bin/env python3
# coding=utf-8
#
# Copyright 2020 Institute of Formal and Applied Linguistics, Faculty of
# Mathematics and Physics, Charles University, Czech Republic.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

"""Word embeddings server class."""

import email.parser
import http.server
import json
import random
import socketserver
import sys
import threading
import time
import urllib.error
import urllib.parse
import urllib.request


class FrontendRESTServer(socketserver.TCPServer):
    class Backend():
        def __init__(self, server):
            self._server = server

            with self.request("/models") as response:
                data = json.loads(response.read())

            assert "models" in data and isinstance(data["models"], dict)
            self.models = data["models"]

            assert "default_model" in data and isinstance(data["default_model"], str)
            self.default_model = data["default_model"]

        def request(self, url, data=None, data_content_type=None):
            return urllib.request.urlopen(urllib.request.Request(
                url="http://{}{}".format(self._server, url),
                data=data,
                headers={} if data is None else {"Content-Type": data_content_type},
            ))


    class FrontendRESTServer(http.server.BaseHTTPRequestHandler):
        protocol_version = "HTTP/1.1"

        format_for_log_table = str.maketrans("\n", "\r", "\r")
        def format_for_log(request, data, limit=None):
            if limit is not None:
                if limit <= 0:
                    data = "[{}B]".format(len(data))
                elif len(data) > limit:
                    data = data[:limit // 2] + " ... " + data[min(-1, -limit // 2):]
            return data.translate(request.format_for_log_table)

        def respond(request, content_type, code=200, additional_headers={}):
            request.close_connection = True
            request.send_response(code)
            request.send_header("Connection", "close")
            request.send_header("Content-Type", content_type)
            request.send_header("Access-Control-Allow-Origin", "*")
            for key, value in additional_headers.items():
                request.send_header(key, value)
            request.end_headers()

        def respond_error(request, message, code=400):
            request.respond("text/plain", code)
            request.wfile.write(message.encode("utf-8"))

        def handle_expect_100(request):
            try:
                request_too_long = int(request.headers["Content-Length"]) > request.server._args.max_request_size
            except:
                request_too_long = False

            if request_too_long:
                request.respond_error("The payload size is too large.")
                return False
            return super().handle_expect_100()

        def do_GET(request):
            # Parse the model from URL/body
            params, body, body_content_type = {}, None, None
            try:
                encoded_path = request.path.encode("iso-8859-1").decode("utf-8")
                url = urllib.parse.urlparse(encoded_path)
                for name, value in urllib.parse.parse_qsl(url.query, encoding="utf-8", keep_blank_values=True, errors="strict"):
                    params[name] = value
            except:
                return request.respond_error("Cannot parse request URL.")

            # Parse the body of a POST request
            if request.command == "POST":
                if request.headers.get("Transfer-Encoding", "identity").lower() != "identity":
                    return request.respond_error("Only 'identity' Transfer-Encoding of payload is supported for now.")

                try:
                    content_length = int(request.headers["Content-Length"])
                except:
                    return request.respond_error("The Content-Length of payload is required.")

                if content_length > request.server._args.max_request_size:
                    while content_length:
                        read = request.rfile.read(min(content_length, 65536))
                        content_length -= len(read) if read else content_length
                    return request.respond_error("The payload size is too large.")

                body = request.rfile.read(content_length)
                body_content_type = request.headers.get("Content-Type", "")

                # multipart/form-data
                if request.headers.get("Content-Type", "").startswith("multipart/form-data"):
                    try:
                        parser = email.parser.BytesFeedParser()
                        parser.feed(b"Content-Type: " + request.headers["Content-Type"].encode("ascii") + b"\r\n\r\n")
                        parser.feed(body)
                        for part in parser.close().get_payload():
                            name = part.get_param("name", header="Content-Disposition")
                            if name:
                                params[name] = part.get_payload(decode=True).decode("utf-8")
                    except:
                        return request.respond_error("Cannot parse the multipart/form-data payload.")
                # x-www-form-urlencoded
                elif request.headers.get("Content-Type", "").startswith("application/x-www-form-urlencoded"):
                    try:
                        for name, value in urllib.parse.parse_qsl(
                                body.decode("utf-8"), encoding="utf-8", keep_blank_values=True, errors="strict"):
                            params[name] = value
                    except:
                        return request.respond_error("Cannot parse the application/x-www-form-urlencoded payload.")

            # Log if required
            if request.server._args.log_data:
                print(url.path, " ".join(request.headers.get_all("X-Forwarded-For", [])),
                      *["{}:{}".format(key, request.format_for_log(value)) for key, value in params.items() if key != "data"],
                      "data:" + request.format_for_log(params.get("data", ""), request.server._args.log_data),
                      sep="\t", file=sys.stderr, flush=True)

            # Handle /models
            if url.path == "/models":
                response = {
                    "models": {name: value for backend in request.server.backends for name, value in backend.models.items()},
                    "default_model": request.server.backends[0].default_model,
                }
                request.respond("application/json")
                request.wfile.write(json.dumps(response, indent=1).encode("utf-8"))
            # Handle everything else
            else:
                # Start by finding appropriate backends
                backends = request.server.backends.copy()
                model = params.get("model", request.server.backends[0].default_model)
                if model in request.server.aliases:
                    resolved_model = request.server.aliases[model]
                    backends = [backend for backend in request.server.backends if resolved_model in backend.models] or backends

                # Forward the request to the backend
                started_responding = False
                try:
                    assert backends, "No backends found!"
                    while backends:
                        backend = random.choice(backends) if len(backends) > 1 else backends[0]
                        backends.remove(backend)
                        try:
                            with backend.request(request.path, body, body_content_type) as response:
                                while True:
                                    data = response.read(32768)
                                    if not started_responding:
                                        started_responding = True
                                        billing_infclen = response.getheader("X-Billing-Input-NFC-Len", None)
                                        headers = {"X-Billing-Input-NFC-Len": billing_infclen} if billing_infclen is not None else {}
                                        request.respond(response.getheader("Content-Type", "application/json"), code=response.code,
                                                        additional_headers=headers)
                                    if len(data) == 0: break
                                    request.wfile.write(data)
                        except urllib.error.HTTPError as error:
                            if not started_responding:
                                started_responding = True
                                request.respond(error.headers.get("Content-Type", "text/plain"), code=error.code)
                                request.wfile.write(error.file.read())
                                break
                            raise
                        except:
                            if backends and not started_responding:
                                import traceback
                                traceback.print_exc(file=sys.stderr)
                                print("The above error occurred during request processing on '{}',".format(backend._server),
                                      "but more backends are available, retrying.", file=sys.stderr, flush=True)
                                continue
                            raise
                        break
                except:
                    import traceback
                    traceback.print_exc(file=sys.stderr)
                    sys.stderr.flush()

                    if not started_responding:
                        request.respond_error("An internal error occurred during processing.")
                    else:
                        request.wfile.write(b'",\n"An internal error occurred during processing, producing incorrect JSON!"')

        def do_POST(request):
            return request.do_GET()

    def __init__(self, args):
        self._args = args

        # Initialize all backends
        self.backends = [self.Backend(backend) for backend in args.backends]

        # Initialize the aliases
        self.aliases = {}
        if args.aliases is not None:
            with open(args.aliases, "r", encoding="utf-8") as aliases_file:
                for line in aliases_file:
                    line = line.rstrip("\r\n")
                    if not line or line.startswith("#"):
                        continue
                    parts = line.split()
                    assert len(parts) in [3, 4], "Expected 3-4 columns in the aliases file: line '{}'".format(line)
                    names = parts[0].split(":")
                    for name in names:
                        parts = name.split("-")
                        for prefix in ("-".join(parts[:None if not i else -i]) for i in range(len(parts))):
                            self.aliases.setdefault(prefix, names[0])

        # Initialize the server
        self._threads = []
        super().__init__(("", self._args.port), self.FrontendRESTServer)

    def server_bind(self):
        import socket
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
        super().server_bind()

    def server_activate(self):
        self.socket.listen(256)

    def process_request_thread(self, request, client_address):
        try:
            self.finish_request(request, client_address)
        except Exception:
            self.handle_error(request, client_address)
        finally:
            self.shutdown_request(request)

    def process_request(self, request, client_address):
        thread = threading.Thread(target=self.process_request_thread, args=(request, client_address), daemon=False)
        self._threads.append(thread)
        thread.start()

    def service_actions(self):
        if len(self._threads) >= self._args.max_concurrency:
            self._threads = [thread for thread in self._threads if thread.is_alive()]

        while len(self._threads) >= self._args.max_concurrency:
            time.sleep(0.1)
            self._threads = [thread for thread in self._threads if thread.is_alive()]

    def server_close(self):
        super().server_close()
        for thread in self._threads:
            thread.join()


if __name__ == "__main__":
    import argparse
    import signal

    # Parse server arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("port", type=int, help="Port to use")
    parser.add_argument("backends", type=str, nargs="+", help="Backends to use")
    parser.add_argument("--aliases", default=None, type=str, help="Path to model aliases")
    parser.add_argument("--logfile", default=None, type=str, help="Log path")
    parser.add_argument("--log_data", default=None, type=int, help="Log that much bytes of every request data")
    parser.add_argument("--max_concurrency", default=256, type=int, help="Maximum concurrency")
    parser.add_argument("--max_request_size", default=4096*1024, type=int, help="Maximum request size")
    args = parser.parse_args()

    # Log stderr to logfile if given
    if args.logfile is not None:
        sys.stderr = open(args.logfile, "a", encoding="utf-8")

    # Create the server
    server = FrontendRESTServer(args)
    server_thread = threading.Thread(target=server.serve_forever, daemon=True)
    server_thread.start()

    print("Started Frontend REST server on port {}.".format(args.port), file=sys.stderr)
    print("To stop it gracefully, either send SIGINT (Ctrl+C) or SIGUSR1.", file=sys.stderr, flush=True)

    # Wait until the server should be closed
    signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT, signal.SIGUSR1])
    signal.sigwait([signal.SIGINT, signal.SIGUSR1])
    print("Initiating shutdown of the Frontend REST server.", file=sys.stderr, flush=True)
    server.shutdown()
    print("Stopped handling new requests, processing all current ones.", file=sys.stderr, flush=True)
    server.server_close()
    print("Finished shutdown of the Frontend REST server.", file=sys.stderr, flush=True)