File size: 9,005 Bytes
726ec90
3d7f69e
97e7d36
 
726ec90
 
c27e5a4
3d7f69e
2ec9baa
 
c27e5a4
 
726ec90
8d626cf
 
726ec90
6831f1f
3d7f69e
c27e5a4
97e7d36
2ec9baa
 
 
 
 
 
 
 
 
 
97e7d36
9254534
c158679
 
 
9254534
97e7d36
c27e5a4
 
 
 
 
 
 
 
 
 
97e7d36
 
 
 
 
 
 
 
 
 
 
 
3d7f69e
97e7d36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a00c4d5
97e7d36
 
ec814ef
 
 
 
 
 
 
 
 
 
 
01423c9
 
 
c27e5a4
ec814ef
 
c27e5a4
 
 
ec814ef
 
 
 
 
 
 
 
 
 
 
 
97e7d36
 
726ec90
6831f1f
01423c9
 
 
6831f1f
 
 
 
 
 
 
 
 
 
726ec90
 
6831f1f
726ec90
 
 
 
 
 
 
6831f1f
726ec90
 
 
 
 
 
 
 
 
 
 
 
c27e5a4
726ec90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d7f69e
c27e5a4
01423c9
 
8419299
726ec90
c27e5a4
 
 
 
726ec90
3d7f69e
97e7d36
 
 
 
 
 
 
 
 
 
c27e5a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6ff669
c27e5a4
 
 
 
 
 
 
 
97e7d36
 
2ec9baa
97e7d36
 
 
8d626cf
 
d7bd042
 
 
 
 
 
 
 
 
726ec90
c27e5a4
8d626cf
d7bd042
 
 
c27e5a4
 
 
 
 
 
 
726ec90
6831f1f
3d7f69e
 
 
 
c27e5a4
 
 
 
 
 
 
 
 
6831f1f
726ec90
97e7d36
 
 
 
 
457d4b2
 
2ec9baa
 
97e7d36
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import io
from flask import Flask, Response, send_from_directory, jsonify, request, abort
import os
from flask_cors import CORS
from multiprocessing import Queue
import base64
from typing import Any, Dict, Tuple
from multiprocessing import Queue
import logging
import sys
from threading import Lock
from multiprocessing import Manager

import torch

from server.AudioTranscriber import AudioTranscriber
from server.ActionProcessor import ActionProcessor
from server.StandaloneApplication import StandaloneApplication
from server.TextFilterer import TextFilterer

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)],
)

# Get a logger for your app
logger = logging.getLogger(__name__)

# Use a directory in the user's home folder for static files
STATIC_DIR = (
    "/app/server/static"
    if os.getenv("DEBUG") != "True"
    else os.path.join(os.getcwd(), "html")
)

# Each packet is a tuple of (data, token)
audio_queue: "Queue[Tuple[io.BytesIO, str]]" = Queue()
text_queue: "Queue[Tuple[str, str]]" = Queue()
filtered_text_queue: "Queue[Tuple[str, str]]" = Queue()
action_queue: "Queue[Tuple[Dict[str, Any], str]]" = Queue()

# Thread-safe storage for actions by session
action_storage_lock = Lock()
manager = Manager()
action_storage = manager.dict()  # Shared dictionary across processes

app = Flask(__name__, static_folder=STATIC_DIR)

_ = CORS(
    app,
    origins=["*"],
    methods=["GET", "POST", "OPTIONS"],
    allow_headers=["Content-Type", "Authorization"],
)


@app.after_request
def add_header(response: Response):
    # Add permissive CORS headers
    response.headers["Access-Control-Allow-Origin"] = "*"
    response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
    response.headers["Access-Control-Allow-Headers"] = "*"  # Allow all headers
    # Cross-origin isolation headers
    response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
    response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
    response.headers["Cross-Origin-Resource-Policy"] = "cross-origin"
    return response


@app.route("/")
def serve_index():
    try:
        response = send_from_directory(app.static_folder, "index.html")
        response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
        response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
        return response
    except FileNotFoundError:
        abort(
            404,
            description=f"Static folder or index.html not found. Static folder: {app.static_folder}",
        )


@app.route("/api/data", methods=["GET"])
def get_data():
    return jsonify({"status": "success"})


@app.route("/api/order", methods=["POST"])
def post_order() -> Tuple[Response, int]:
    try:
        data = request.get_json()
        if not data or "action" not in data:
            return (
                jsonify({"error": "Missing 'action' in request", "status": "error"}),
                400,
            )

        action_text: str = data["action"]
        token = request.args.get("token")
        if not token:
            return jsonify({"error": "Missing token parameter", "status": "error"}), 400

        mid_split = len(action_text) // 2
        # Add the text to the queue
        text_queue.put((action_text[:mid_split], token))
        text_queue.put((action_text, token))
        text_queue.put((action_text[mid_split:], token))

        return jsonify({"status": "success"}), 200

    except Exception as e:
        return (
            jsonify(
                {"error": f"Failed to process request: {str(e)}", "status": "error"}
            ),
            500,
        )


@app.route("/api/process", methods=["POST"])
def process_data():
    try:
        content_type = request.headers.get("Content-Type", "")
        token = request.args.get("token")
        if not token:
            return jsonify({"error": "Missing token parameter", "status": "error"}), 400

        # Handle different content types
        if "application/json" in content_type:
            data = request.get_json()
            audio_base64 = data.get("audio_chunk")
        elif "multipart/form-data" in content_type:
            audio_base64 = request.form.get("audio_chunk")
        else:
            # Try to get raw data
            audio_base64 = request.get_data().decode("utf-8")

        # Validate the incoming data
        if not audio_base64:
            return (
                jsonify({"error": "Missing audio_chunk in request", "status": "error"}),
                400,
            )

        # Decode the base64 audio chunk
        try:
            audio_chunk = base64.b64decode(audio_base64)
        except Exception as e:
            return (
                jsonify(
                    {
                        "error": f"Failed to decode audio chunk: {str(e)}",
                        "status": "error",
                    }
                ),
                400,
            )

        # Put the audio chunk in the queue for processing
        audio_queue.put((io.BytesIO(audio_chunk), token))

        return jsonify(
            {
                "status": "success",
            }
        )
    except Exception as e:
        return (
            jsonify(
                {"error": f"Failed to process request: {str(e)}", "status": "error"}
            ),
            500,
        )


@app.route("/api/actions", methods=["GET"])
def get_actions() -> Tuple[Response, int]:
    """Retrieve and clear all pending actions for the current session"""
    token = request.args.get("token")
    if not token:
        return jsonify({"actions": [], "status": "error"}), 400

    with action_storage_lock:
        # Get and clear actions for this session
        actions = action_storage.get(token, [])
        action_storage[token] = []

    return jsonify({"actions": actions, "status": "success"}), 200


@app.route("/<path:path>")
def serve_static(path: str):
    try:
        return send_from_directory(app.static_folder, path)
    except FileNotFoundError:
        abort(404, description=f"File {path} not found in static folder")


class ActionConsumer:
    def __init__(self, action_queue: Queue):
        self.action_queue = action_queue
        self.running = True

    def start(self):
        import threading

        self.thread = threading.Thread(target=self.run, daemon=True)
        self.thread.start()

    def run(self):
        while self.running:
            try:
                action, token = self.action_queue.get()
                with action_storage_lock:
                    if token not in action_storage:
                        logger.info(f"Creating new action storage for token: {token}")
                        action_storage[token] = []
                    current_actions = action_storage[token]
                    current_actions.append(action)
                    action_storage[token] = current_actions
            except Exception as e:
                logger.error(f"Error in ActionConsumer: {e}")


if __name__ == "__main__":
    if os.path.exists(app.static_folder):
        logger.info(f"Static folder contents: {os.listdir(app.static_folder)}")

    os.makedirs(app.static_folder, exist_ok=True)

    num_devices = torch.cuda.device_count()

    device_vram_gb: float = float(
        torch.cuda.get_device_properties(0).total_memory / (1024**3)
    )
    num_3gb_units = int(device_vram_gb) // 3

    logger.info(
        f"Device 0 has {device_vram_gb:.1f}GB VRAM, equivalent to {num_3gb_units} units of Whisper"
    )

    # Start the audio transcriber thread
    transcribers = [
        AudioTranscriber(audio_queue, text_queue, device_index=i % num_devices)
        for i in range(
            4 if os.getenv("DEBUG") == "True" else num_3gb_units * num_devices
        )
    ]
    for transcriber in transcribers:
        transcriber.start()

    # Start the action consumer thread
    action_consumer = ActionConsumer(action_queue)
    action_consumer.start()

    # Start the action processor thread
    MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
    if not MISTRAL_API_KEY:
        raise ValueError("MISTRAL_API_KEY is not set")

    filterer = TextFilterer(text_queue, filtered_text_queue)
    filterer.start()

    actions_processors = [
        ActionProcessor(filtered_text_queue, action_queue, MISTRAL_API_KEY)
        for _ in range(4 if os.getenv("DEBUG") == "True" else 16)
    ]
    for actions_processor in actions_processors:
        actions_processor.start()

    options: Any = {
        "bind": "0.0.0.0:7860",
        "workers": 3,
        "worker_class": "sync",
        "timeout": 120,
        "forwarded_allow_ips": "*",
        "accesslog": None,  # Disable access logging
        "errorlog": "-",  # Keep error logging to stderr
        "capture_output": True,
        "enable_stdio_inheritance": True,
    }

    StandaloneApplication(app, options).run()