Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Fedir Zadniprovskyi
commited on
Commit
·
dc4f25f
1
Parent(s):
8ad4ca5
chore: fix ruff errors
Browse files- faster_whisper_server/asr.py +2 -2
- faster_whisper_server/audio.py +8 -6
- faster_whisper_server/config.py +4 -4
- faster_whisper_server/core.py +15 -35
- faster_whisper_server/gradio_app.py +11 -26
- faster_whisper_server/logger.py +1 -3
- faster_whisper_server/main.py +38 -79
- faster_whisper_server/server_models.py +11 -21
- faster_whisper_server/transcriber.py +6 -2
- pyproject.toml +24 -7
- tests/api_model_test.py +4 -6
- tests/app_test.py +9 -13
- tests/conftest.py +4 -7
- tests/sse_test.py +9 -19
faster_whisper_server/asr.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import asyncio
|
|
|
2 |
import time
|
3 |
-
from typing import Iterable
|
4 |
|
5 |
from faster_whisper import transcribe
|
6 |
|
@@ -45,7 +45,7 @@ class FasterWhisperASR:
|
|
45 |
audio: Audio,
|
46 |
prompt: str | None = None,
|
47 |
) -> tuple[Transcription, transcribe.TranscriptionInfo]:
|
48 |
-
"""Wrapper around _transcribe so it can be used in async context"""
|
49 |
# is this the optimal way to execute a blocking call in an async context?
|
50 |
# TODO: verify performance when running inference on a CPU
|
51 |
return await asyncio.get_running_loop().run_in_executor(
|
|
|
1 |
import asyncio
|
2 |
+
from collections.abc import Iterable
|
3 |
import time
|
|
|
4 |
|
5 |
from faster_whisper import transcribe
|
6 |
|
|
|
45 |
audio: Audio,
|
46 |
prompt: str | None = None,
|
47 |
) -> tuple[Transcription, transcribe.TranscriptionInfo]:
|
48 |
+
"""Wrapper around _transcribe so it can be used in async context."""
|
49 |
# is this the optimal way to execute a blocking call in an async context?
|
50 |
# TODO: verify performance when running inference on a CPU
|
51 |
return await asyncio.get_running_loop().run_in_executor(
|
faster_whisper_server/audio.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
4 |
-
from typing import
|
5 |
|
6 |
import numpy as np
|
7 |
import soundfile as sf
|
8 |
-
from numpy.typing import NDArray
|
9 |
|
10 |
from faster_whisper_server.config import SAMPLES_PER_SECOND
|
11 |
from faster_whisper_server.logger import logger
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
|
15 |
audio_and_sample_rate = sf.read(
|
@@ -22,7 +26,7 @@ def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
|
|
22 |
endian="LITTLE",
|
23 |
)
|
24 |
audio = audio_and_sample_rate[0]
|
25 |
-
return audio #
|
26 |
|
27 |
|
28 |
class Audio:
|
@@ -78,9 +82,7 @@ class AudioStream(Audio):
|
|
78 |
self.modify_event.set()
|
79 |
logger.info("AudioStream closed")
|
80 |
|
81 |
-
async def chunks(
|
82 |
-
self, min_duration: float
|
83 |
-
) -> AsyncGenerator[NDArray[np.float32], None]:
|
84 |
i = 0.0 # end time of last chunk
|
85 |
while True:
|
86 |
await self.modify_event.wait()
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
4 |
+
from typing import TYPE_CHECKING, BinaryIO
|
5 |
|
6 |
import numpy as np
|
7 |
import soundfile as sf
|
|
|
8 |
|
9 |
from faster_whisper_server.config import SAMPLES_PER_SECOND
|
10 |
from faster_whisper_server.logger import logger
|
11 |
|
12 |
+
if TYPE_CHECKING:
|
13 |
+
from collections.abc import AsyncGenerator
|
14 |
+
|
15 |
+
from numpy.typing import NDArray
|
16 |
+
|
17 |
|
18 |
def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
|
19 |
audio_and_sample_rate = sf.read(
|
|
|
26 |
endian="LITTLE",
|
27 |
)
|
28 |
audio = audio_and_sample_rate[0]
|
29 |
+
return audio # pyright: ignore[reportReturnType]
|
30 |
|
31 |
|
32 |
class Audio:
|
|
|
82 |
self.modify_event.set()
|
83 |
logger.info("AudioStream closed")
|
84 |
|
85 |
+
async def chunks(self, min_duration: float) -> AsyncGenerator[NDArray[np.float32], None]:
|
|
|
|
|
86 |
i = 0.0 # end time of last chunk
|
87 |
while True:
|
88 |
await self.modify_event.wait()
|
faster_whisper_server/config.py
CHANGED
@@ -15,7 +15,7 @@ class ResponseFormat(enum.StrEnum):
|
|
15 |
TEXT = "text"
|
16 |
JSON = "json"
|
17 |
VERBOSE_JSON = "verbose_json"
|
18 |
-
# NOTE: While inspecting outputs of these formats with `curl`, I noticed there's one or two "\n" inserted at the end of the response.
|
19 |
|
20 |
# VTT = "vtt" # TODO
|
21 |
# 1
|
@@ -185,8 +185,8 @@ class WhisperConfig(BaseModel):
|
|
185 |
|
186 |
|
187 |
class Config(BaseSettings):
|
188 |
-
"""
|
189 |
-
|
190 |
Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields.
|
191 |
To populate nested, the environment should be prefixed with the nested field name and an underscore. For example,
|
192 |
the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER_MODEL` to `whisper.model`, etc.
|
@@ -208,7 +208,7 @@ class Config(BaseSettings):
|
|
208 |
max_inactivity_seconds: float = 5.0
|
209 |
"""
|
210 |
Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
|
211 |
-
"""
|
212 |
inactivity_window_seconds: float = 10.0
|
213 |
"""
|
214 |
Controls how many latest seconds of audio are being passed through VAD.
|
|
|
15 |
TEXT = "text"
|
16 |
JSON = "json"
|
17 |
VERBOSE_JSON = "verbose_json"
|
18 |
+
# NOTE: While inspecting outputs of these formats with `curl`, I noticed there's one or two "\n" inserted at the end of the response. # noqa: E501
|
19 |
|
20 |
# VTT = "vtt" # TODO
|
21 |
# 1
|
|
|
185 |
|
186 |
|
187 |
class Config(BaseSettings):
|
188 |
+
"""Configuration for the application. Values can be set via environment variables.
|
189 |
+
|
190 |
Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields.
|
191 |
To populate nested, the environment should be prefixed with the nested field name and an underscore. For example,
|
192 |
the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER_MODEL` to `whisper.model`, etc.
|
|
|
208 |
max_inactivity_seconds: float = 5.0
|
209 |
"""
|
210 |
Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
|
211 |
+
""" # noqa: E501
|
212 |
inactivity_window_seconds: float = 10.0
|
213 |
"""
|
214 |
Controls how many latest seconds of audio are being passed through VAD.
|
faster_whisper_server/core.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
# TODO: rename module
|
2 |
from __future__ import annotations
|
3 |
|
4 |
-
import re
|
5 |
from dataclasses import dataclass
|
|
|
6 |
|
7 |
from faster_whisper_server.config import config
|
8 |
|
@@ -18,10 +18,7 @@ class Segment:
|
|
18 |
def is_eos(self) -> bool:
|
19 |
if self.text.endswith("..."):
|
20 |
return False
|
21 |
-
for punctuation_symbol in ".?!"
|
22 |
-
if self.text.endswith(punctuation_symbol):
|
23 |
-
return True
|
24 |
-
return False
|
25 |
|
26 |
def offset(self, seconds: float) -> None:
|
27 |
self.start += seconds
|
@@ -36,11 +33,7 @@ class Word(Segment):
|
|
36 |
@classmethod
|
37 |
def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]:
|
38 |
i = 0
|
39 |
-
while (
|
40 |
-
i < len(a)
|
41 |
-
and i < len(b)
|
42 |
-
and canonicalize_word(a[i].text) == canonicalize_word(b[i].text)
|
43 |
-
):
|
44 |
i += 1
|
45 |
return a[:i]
|
46 |
|
@@ -67,9 +60,7 @@ class Transcription:
|
|
67 |
return self.end - self.start
|
68 |
|
69 |
def after(self, seconds: float) -> Transcription:
|
70 |
-
return Transcription(
|
71 |
-
words=[word for word in self.words if word.start > seconds]
|
72 |
-
)
|
73 |
|
74 |
def extend(self, words: list[Word]) -> None:
|
75 |
self._ensure_no_word_overlap(words)
|
@@ -77,21 +68,16 @@ class Transcription:
|
|
77 |
|
78 |
def _ensure_no_word_overlap(self, words: list[Word]) -> None:
|
79 |
if len(self.words) > 0 and len(words) > 0:
|
80 |
-
if
|
81 |
-
words[0].start + config.word_timestamp_error_margin
|
82 |
-
<= self.words[-1].end
|
83 |
-
):
|
84 |
raise ValueError(
|
85 |
-
f"Words overlap: {self.words[-1]} and {words[0]}. Error margin: {config.word_timestamp_error_margin}"
|
86 |
)
|
87 |
for i in range(1, len(words)):
|
88 |
if words[i].start + config.word_timestamp_error_margin <= words[i - 1].end:
|
89 |
-
raise ValueError(
|
90 |
-
f"Words overlap: {words[i - 1]} and {words[i]}. All words: {words}"
|
91 |
-
)
|
92 |
|
93 |
|
94 |
-
def test_segment_is_eos():
|
95 |
assert not Segment("Hello").is_eos
|
96 |
assert not Segment("Hello...").is_eos
|
97 |
assert Segment("Hello.").is_eos
|
@@ -117,16 +103,14 @@ def to_full_sentences(words: list[Word]) -> list[Segment]:
|
|
117 |
return sentences
|
118 |
|
119 |
|
120 |
-
def tests_to_full_sentences():
|
121 |
assert to_full_sentences([]) == []
|
122 |
assert to_full_sentences([Word(text="Hello")]) == []
|
123 |
assert to_full_sentences([Word(text="Hello..."), Word(" world")]) == []
|
124 |
-
assert to_full_sentences([Word(text="Hello..."), Word(" world.")]) == [
|
|
|
125 |
Segment(text="Hello... world.")
|
126 |
]
|
127 |
-
assert to_full_sentences(
|
128 |
-
[Word(text="Hello..."), Word(" world."), Word(" How")]
|
129 |
-
) == [Segment(text="Hello... world.")]
|
130 |
|
131 |
|
132 |
def to_text(words: list[Word]) -> str:
|
@@ -144,7 +128,7 @@ def canonicalize_word(text: str) -> str:
|
|
144 |
return text.lower().strip().strip(".,?!")
|
145 |
|
146 |
|
147 |
-
def test_canonicalize_word():
|
148 |
assert canonicalize_word("ABC") == "abc"
|
149 |
assert canonicalize_word("...ABC?") == "abc"
|
150 |
assert canonicalize_word("... AbC ...") == "abc"
|
@@ -152,16 +136,12 @@ def test_canonicalize_word():
|
|
152 |
|
153 |
def common_prefix(a: list[Word], b: list[Word]) -> list[Word]:
|
154 |
i = 0
|
155 |
-
while (
|
156 |
-
i < len(a)
|
157 |
-
and i < len(b)
|
158 |
-
and canonicalize_word(a[i].text) == canonicalize_word(b[i].text)
|
159 |
-
):
|
160 |
i += 1
|
161 |
return a[:i]
|
162 |
|
163 |
|
164 |
-
def test_common_prefix():
|
165 |
def word(text: str) -> Word:
|
166 |
return Word(text=text, start=0.0, end=0.0, probability=0.0)
|
167 |
|
@@ -194,7 +174,7 @@ def test_common_prefix():
|
|
194 |
assert common_prefix(a, b) == []
|
195 |
|
196 |
|
197 |
-
def test_common_prefix_and_canonicalization():
|
198 |
def word(text: str) -> Word:
|
199 |
return Word(text=text, start=0.0, end=0.0, probability=0.0)
|
200 |
|
|
|
1 |
# TODO: rename module
|
2 |
from __future__ import annotations
|
3 |
|
|
|
4 |
from dataclasses import dataclass
|
5 |
+
import re
|
6 |
|
7 |
from faster_whisper_server.config import config
|
8 |
|
|
|
18 |
def is_eos(self) -> bool:
|
19 |
if self.text.endswith("..."):
|
20 |
return False
|
21 |
+
return any(self.text.endswith(punctuation_symbol) for punctuation_symbol in ".?!")
|
|
|
|
|
|
|
22 |
|
23 |
def offset(self, seconds: float) -> None:
|
24 |
self.start += seconds
|
|
|
33 |
@classmethod
|
34 |
def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]:
|
35 |
i = 0
|
36 |
+
while i < len(a) and i < len(b) and canonicalize_word(a[i].text) == canonicalize_word(b[i].text):
|
|
|
|
|
|
|
|
|
37 |
i += 1
|
38 |
return a[:i]
|
39 |
|
|
|
60 |
return self.end - self.start
|
61 |
|
62 |
def after(self, seconds: float) -> Transcription:
|
63 |
+
return Transcription(words=[word for word in self.words if word.start > seconds])
|
|
|
|
|
64 |
|
65 |
def extend(self, words: list[Word]) -> None:
|
66 |
self._ensure_no_word_overlap(words)
|
|
|
68 |
|
69 |
def _ensure_no_word_overlap(self, words: list[Word]) -> None:
|
70 |
if len(self.words) > 0 and len(words) > 0:
|
71 |
+
if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
|
|
|
|
|
|
|
72 |
raise ValueError(
|
73 |
+
f"Words overlap: {self.words[-1]} and {words[0]}. Error margin: {config.word_timestamp_error_margin}" # noqa: E501
|
74 |
)
|
75 |
for i in range(1, len(words)):
|
76 |
if words[i].start + config.word_timestamp_error_margin <= words[i - 1].end:
|
77 |
+
raise ValueError(f"Words overlap: {words[i - 1]} and {words[i]}. All words: {words}")
|
|
|
|
|
78 |
|
79 |
|
80 |
+
def test_segment_is_eos() -> None:
|
81 |
assert not Segment("Hello").is_eos
|
82 |
assert not Segment("Hello...").is_eos
|
83 |
assert Segment("Hello.").is_eos
|
|
|
103 |
return sentences
|
104 |
|
105 |
|
106 |
+
def tests_to_full_sentences() -> None:
|
107 |
assert to_full_sentences([]) == []
|
108 |
assert to_full_sentences([Word(text="Hello")]) == []
|
109 |
assert to_full_sentences([Word(text="Hello..."), Word(" world")]) == []
|
110 |
+
assert to_full_sentences([Word(text="Hello..."), Word(" world.")]) == [Segment(text="Hello... world.")]
|
111 |
+
assert to_full_sentences([Word(text="Hello..."), Word(" world."), Word(" How")]) == [
|
112 |
Segment(text="Hello... world.")
|
113 |
]
|
|
|
|
|
|
|
114 |
|
115 |
|
116 |
def to_text(words: list[Word]) -> str:
|
|
|
128 |
return text.lower().strip().strip(".,?!")
|
129 |
|
130 |
|
131 |
+
def test_canonicalize_word() -> None:
|
132 |
assert canonicalize_word("ABC") == "abc"
|
133 |
assert canonicalize_word("...ABC?") == "abc"
|
134 |
assert canonicalize_word("... AbC ...") == "abc"
|
|
|
136 |
|
137 |
def common_prefix(a: list[Word], b: list[Word]) -> list[Word]:
|
138 |
i = 0
|
139 |
+
while i < len(a) and i < len(b) and canonicalize_word(a[i].text) == canonicalize_word(b[i].text):
|
|
|
|
|
|
|
|
|
140 |
i += 1
|
141 |
return a[:i]
|
142 |
|
143 |
|
144 |
+
def test_common_prefix() -> None:
|
145 |
def word(text: str) -> Word:
|
146 |
return Word(text=text, start=0.0, end=0.0, probability=0.0)
|
147 |
|
|
|
174 |
assert common_prefix(a, b) == []
|
175 |
|
176 |
|
177 |
+
def test_common_prefix_and_canonicalization() -> None:
|
178 |
def word(text: str) -> Word:
|
179 |
return Word(text=text, start=0.0, end=0.0, probability=0.0)
|
180 |
|
faster_whisper_server/gradio_app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
|
|
1 |
import os
|
2 |
-
from typing import Generator
|
3 |
|
4 |
import gradio as gr
|
5 |
import httpx
|
@@ -13,26 +13,20 @@ TRANSLATION_ENDPOINT = "/v1/audio/translations"
|
|
13 |
|
14 |
def create_gradio_demo(config: Config) -> gr.Blocks:
|
15 |
host = os.getenv("UVICORN_HOST", "0.0.0.0")
|
16 |
-
port = os.getenv("UVICORN_PORT", 8000)
|
17 |
# NOTE: worth looking into generated clients
|
18 |
http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None)
|
19 |
|
20 |
-
def handler(
|
21 |
-
file_path: str, model: str, task: Task, temperature: float, stream: bool
|
22 |
-
) -> Generator[str, None, None]:
|
23 |
if stream:
|
24 |
previous_transcription = ""
|
25 |
-
for transcription in transcribe_audio_streaming(
|
26 |
-
file_path, task, temperature, model
|
27 |
-
):
|
28 |
previous_transcription += transcription
|
29 |
yield previous_transcription
|
30 |
else:
|
31 |
yield transcribe_audio(file_path, task, temperature, model)
|
32 |
|
33 |
-
def transcribe_audio(
|
34 |
-
file_path: str, task: Task, temperature: float, model: str
|
35 |
-
) -> str:
|
36 |
if task == Task.TRANSCRIBE:
|
37 |
endpoint = TRANSCRIPTION_ENDPOINT
|
38 |
elif task == Task.TRANSLATE:
|
@@ -65,11 +59,7 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
|
|
65 |
"stream": True,
|
66 |
},
|
67 |
}
|
68 |
-
endpoint =
|
69 |
-
TRANSCRIPTION_ENDPOINT
|
70 |
-
if task == Task.TRANSCRIBE
|
71 |
-
else TRANSLATION_ENDPOINT
|
72 |
-
)
|
73 |
with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
|
74 |
for event in event_source.iter_sse():
|
75 |
yield event.data
|
@@ -79,18 +69,15 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
|
|
79 |
res_data = res.json()
|
80 |
models: list[str] = [model["id"] for model in res_data]
|
81 |
assert config.whisper.model in models
|
82 |
-
recommended_models =
|
83 |
-
model for model in models if model.startswith("Systran")
|
84 |
-
)
|
85 |
other_models = [model for model in models if model not in recommended_models]
|
86 |
models = list(recommended_models) + other_models
|
87 |
-
|
88 |
# no idea why it's complaining
|
89 |
-
choices=models, #
|
90 |
label="Model",
|
91 |
value=config.whisper.model,
|
92 |
)
|
93 |
-
return model_dropdown
|
94 |
|
95 |
model_dropdown = gr.Dropdown(
|
96 |
choices=[config.whisper.model],
|
@@ -102,13 +89,11 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
|
|
102 |
label="Task",
|
103 |
value=Task.TRANSCRIBE,
|
104 |
)
|
105 |
-
temperature_slider = gr.Slider(
|
106 |
-
minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0
|
107 |
-
)
|
108 |
stream_checkbox = gr.Checkbox(label="Stream", value=True)
|
109 |
with gr.Interface(
|
110 |
title="Whisper Playground",
|
111 |
-
description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""",
|
112 |
inputs=[
|
113 |
gr.Audio(type="filepath"),
|
114 |
model_dropdown,
|
|
|
1 |
+
from collections.abc import Generator
|
2 |
import os
|
|
|
3 |
|
4 |
import gradio as gr
|
5 |
import httpx
|
|
|
13 |
|
14 |
def create_gradio_demo(config: Config) -> gr.Blocks:
|
15 |
host = os.getenv("UVICORN_HOST", "0.0.0.0")
|
16 |
+
port = int(os.getenv("UVICORN_PORT", "8000"))
|
17 |
# NOTE: worth looking into generated clients
|
18 |
http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None)
|
19 |
|
20 |
+
def handler(file_path: str, model: str, task: Task, temperature: float, stream: bool) -> Generator[str, None, None]:
|
|
|
|
|
21 |
if stream:
|
22 |
previous_transcription = ""
|
23 |
+
for transcription in transcribe_audio_streaming(file_path, task, temperature, model):
|
|
|
|
|
24 |
previous_transcription += transcription
|
25 |
yield previous_transcription
|
26 |
else:
|
27 |
yield transcribe_audio(file_path, task, temperature, model)
|
28 |
|
29 |
+
def transcribe_audio(file_path: str, task: Task, temperature: float, model: str) -> str:
|
|
|
|
|
30 |
if task == Task.TRANSCRIBE:
|
31 |
endpoint = TRANSCRIPTION_ENDPOINT
|
32 |
elif task == Task.TRANSLATE:
|
|
|
59 |
"stream": True,
|
60 |
},
|
61 |
}
|
62 |
+
endpoint = TRANSCRIPTION_ENDPOINT if task == Task.TRANSCRIBE else TRANSLATION_ENDPOINT
|
|
|
|
|
|
|
|
|
63 |
with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
|
64 |
for event in event_source.iter_sse():
|
65 |
yield event.data
|
|
|
69 |
res_data = res.json()
|
70 |
models: list[str] = [model["id"] for model in res_data]
|
71 |
assert config.whisper.model in models
|
72 |
+
recommended_models = {model for model in models if model.startswith("Systran")}
|
|
|
|
|
73 |
other_models = [model for model in models if model not in recommended_models]
|
74 |
models = list(recommended_models) + other_models
|
75 |
+
return gr.Dropdown(
|
76 |
# no idea why it's complaining
|
77 |
+
choices=models, # pyright: ignore[reportArgumentType]
|
78 |
label="Model",
|
79 |
value=config.whisper.model,
|
80 |
)
|
|
|
81 |
|
82 |
model_dropdown = gr.Dropdown(
|
83 |
choices=[config.whisper.model],
|
|
|
89 |
label="Task",
|
90 |
value=Task.TRANSCRIBE,
|
91 |
)
|
92 |
+
temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0)
|
|
|
|
|
93 |
stream_checkbox = gr.Checkbox(label="Stream", value=True)
|
94 |
with gr.Interface(
|
95 |
title="Whisper Playground",
|
96 |
+
description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""", # noqa: E501
|
97 |
inputs=[
|
98 |
gr.Audio(type="filepath"),
|
99 |
model_dropdown,
|
faster_whisper_server/logger.py
CHANGED
@@ -8,6 +8,4 @@ root_logger = logging.getLogger()
|
|
8 |
root_logger.setLevel(logging.CRITICAL)
|
9 |
logger = logging.getLogger(__name__)
|
10 |
logger.setLevel(config.log_level.upper())
|
11 |
-
logging.basicConfig(
|
12 |
-
format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s"
|
13 |
-
)
|
|
|
8 |
root_logger.setLevel(logging.CRITICAL)
|
9 |
logger = logging.getLogger(__name__)
|
10 |
logger.setLevel(config.log_level.upper())
|
11 |
+
logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s")
|
|
|
|
faster_whisper_server/main.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
4 |
-
import
|
5 |
from io import BytesIO
|
6 |
-
|
|
|
7 |
|
8 |
-
import gradio as gr
|
9 |
-
import huggingface_hub
|
10 |
from fastapi import (
|
11 |
FastAPI,
|
12 |
Form,
|
@@ -21,9 +20,9 @@ from fastapi import (
|
|
21 |
from fastapi.responses import StreamingResponse
|
22 |
from fastapi.websockets import WebSocketState
|
23 |
from faster_whisper import WhisperModel
|
24 |
-
from faster_whisper.transcribe import Segment, TranscriptionInfo
|
25 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
26 |
-
|
|
|
27 |
from pydantic import AfterValidator
|
28 |
|
29 |
from faster_whisper_server import utils
|
@@ -45,6 +44,12 @@ from faster_whisper_server.server_models import (
|
|
45 |
)
|
46 |
from faster_whisper_server.transcriber import audio_transcriber
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
|
49 |
|
50 |
|
@@ -54,9 +59,7 @@ def load_model(model_name: str) -> WhisperModel:
|
|
54 |
return loaded_models[model_name]
|
55 |
if len(loaded_models) >= config.max_models:
|
56 |
oldest_model_name = next(iter(loaded_models))
|
57 |
-
logger.info(
|
58 |
-
f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
|
59 |
-
)
|
60 |
del loaded_models[oldest_model_name]
|
61 |
logger.debug(f"Loading {model_name}...")
|
62 |
start = time.perf_counter()
|
@@ -67,7 +70,7 @@ def load_model(model_name: str) -> WhisperModel:
|
|
67 |
compute_type=config.whisper.compute_type,
|
68 |
)
|
69 |
logger.info(
|
70 |
-
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."
|
71 |
)
|
72 |
loaded_models[model_name] = whisper
|
73 |
return whisper
|
@@ -102,9 +105,7 @@ def get_models() -> list[ModelObject]:
|
|
102 |
def get_model(
|
103 |
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
|
104 |
) -> ModelObject:
|
105 |
-
models = list(
|
106 |
-
huggingface_hub.list_models(model_name=model_name, library="ctranslate2")
|
107 |
-
)
|
108 |
if len(models) == 0:
|
109 |
raise HTTPException(status_code=404, detail="Model doesn't exists")
|
110 |
exact_match: ModelInfo | None = None
|
@@ -132,14 +133,12 @@ def segments_to_response(
|
|
132 |
response_format: ResponseFormat,
|
133 |
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
|
134 |
segments = list(segments)
|
135 |
-
if response_format == ResponseFormat.TEXT:
|
136 |
return utils.segments_text(segments)
|
137 |
elif response_format == ResponseFormat.JSON:
|
138 |
return TranscriptionJsonResponse.from_segments(segments)
|
139 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
140 |
-
return TranscriptionVerboseJsonResponse.from_segments(
|
141 |
-
segments, transcription_info
|
142 |
-
)
|
143 |
|
144 |
|
145 |
def format_as_sse(data: str) -> str:
|
@@ -156,26 +155,21 @@ def segments_to_streaming_response(
|
|
156 |
if response_format == ResponseFormat.TEXT:
|
157 |
data = segment.text
|
158 |
elif response_format == ResponseFormat.JSON:
|
159 |
-
data = TranscriptionJsonResponse.from_segments(
|
160 |
-
[segment]
|
161 |
-
).model_dump_json()
|
162 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
163 |
-
data = TranscriptionVerboseJsonResponse.from_segment(
|
164 |
-
segment, transcription_info
|
165 |
-
).model_dump_json()
|
166 |
yield format_as_sse(data)
|
167 |
|
168 |
return StreamingResponse(segment_responses(), media_type="text/event-stream")
|
169 |
|
170 |
|
171 |
def handle_default_openai_model(model_name: str) -> str:
|
172 |
-
"""
|
|
|
173 |
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
|
174 |
"""
|
175 |
if model_name == "whisper-1":
|
176 |
-
logger.info(
|
177 |
-
f"{model_name} is not a valid model name. Using {config.whisper.model} instead."
|
178 |
-
)
|
179 |
return config.whisper.model
|
180 |
return model_name
|
181 |
|
@@ -194,12 +188,7 @@ def translate_file(
|
|
194 |
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
|
195 |
temperature: Annotated[float, Form()] = 0.0,
|
196 |
stream: Annotated[bool, Form()] = False,
|
197 |
-
) ->
|
198 |
-
str
|
199 |
-
| TranscriptionJsonResponse
|
200 |
-
| TranscriptionVerboseJsonResponse
|
201 |
-
| StreamingResponse
|
202 |
-
):
|
203 |
whisper = load_model(model)
|
204 |
segments, transcription_info = whisper.transcribe(
|
205 |
file.file,
|
@@ -210,9 +199,7 @@ def translate_file(
|
|
210 |
)
|
211 |
|
212 |
if stream:
|
213 |
-
return segments_to_streaming_response(
|
214 |
-
segments, transcription_info, response_format
|
215 |
-
)
|
216 |
else:
|
217 |
return segments_to_response(segments, transcription_info, response_format)
|
218 |
|
@@ -231,16 +218,11 @@ def transcribe_file(
|
|
231 |
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
|
232 |
temperature: Annotated[float, Form()] = 0.0,
|
233 |
timestamp_granularities: Annotated[
|
234 |
-
list[Literal["segment"
|
235 |
Form(alias="timestamp_granularities[]"),
|
236 |
] = ["segment"],
|
237 |
stream: Annotated[bool, Form()] = False,
|
238 |
-
) ->
|
239 |
-
str
|
240 |
-
| TranscriptionJsonResponse
|
241 |
-
| TranscriptionVerboseJsonResponse
|
242 |
-
| StreamingResponse
|
243 |
-
):
|
244 |
whisper = load_model(model)
|
245 |
segments, transcription_info = whisper.transcribe(
|
246 |
file.file,
|
@@ -253,9 +235,7 @@ def transcribe_file(
|
|
253 |
)
|
254 |
|
255 |
if stream:
|
256 |
-
return segments_to_streaming_response(
|
257 |
-
segments, transcription_info, response_format
|
258 |
-
)
|
259 |
else:
|
260 |
return segments_to_response(segments, transcription_info, response_format)
|
261 |
|
@@ -263,39 +243,28 @@ def transcribe_file(
|
|
263 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
|
264 |
try:
|
265 |
while True:
|
266 |
-
bytes_ = await asyncio.wait_for(
|
267 |
-
ws.receive_bytes(), timeout=config.max_no_data_seconds
|
268 |
-
)
|
269 |
logger.debug(f"Received {len(bytes_)} bytes of audio data")
|
270 |
audio_samples = audio_samples_from_file(BytesIO(bytes_))
|
271 |
audio_stream.extend(audio_samples)
|
272 |
if audio_stream.duration - config.inactivity_window_seconds >= 0:
|
273 |
-
audio = audio_stream.after(
|
274 |
-
audio_stream.duration - config.inactivity_window_seconds
|
275 |
-
)
|
276 |
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
|
277 |
# NOTE: This is a synchronous operation that runs every time new data is received.
|
278 |
-
# This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato.
|
279 |
timestamps = get_speech_timestamps(audio.data, vad_opts)
|
280 |
if len(timestamps) == 0:
|
281 |
-
logger.info(
|
282 |
-
f"No speech detected in the last {config.inactivity_window_seconds} seconds."
|
283 |
-
)
|
284 |
break
|
285 |
elif (
|
286 |
# last speech end time
|
287 |
-
config.inactivity_window_seconds
|
288 |
-
- timestamps[-1]["end"] / SAMPLES_PER_SECOND
|
289 |
>= config.max_inactivity_seconds
|
290 |
):
|
291 |
-
logger.info(
|
292 |
-
f"Not enough speech in the last {config.inactivity_window_seconds} seconds."
|
293 |
-
)
|
294 |
break
|
295 |
-
except
|
296 |
-
logger.info(
|
297 |
-
f"No data received in {config.max_no_data_seconds} seconds. Closing the connection."
|
298 |
-
)
|
299 |
except WebSocketDisconnect as e:
|
300 |
logger.info(f"Client disconnected: {e}")
|
301 |
audio_stream.close()
|
@@ -306,9 +275,7 @@ async def transcribe_stream(
|
|
306 |
ws: WebSocket,
|
307 |
model: Annotated[ModelName, Query()] = config.whisper.model,
|
308 |
language: Annotated[Language | None, Query()] = config.default_language,
|
309 |
-
response_format: Annotated[
|
310 |
-
ResponseFormat, Query()
|
311 |
-
] = config.default_response_format,
|
312 |
temperature: Annotated[float, Query()] = 0.0,
|
313 |
) -> None:
|
314 |
await ws.accept()
|
@@ -331,19 +298,11 @@ async def transcribe_stream(
|
|
331 |
if response_format == ResponseFormat.TEXT:
|
332 |
await ws.send_text(transcription.text)
|
333 |
elif response_format == ResponseFormat.JSON:
|
334 |
-
await ws.send_json(
|
335 |
-
TranscriptionJsonResponse.from_transcription(
|
336 |
-
transcription
|
337 |
-
).model_dump()
|
338 |
-
)
|
339 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
340 |
-
await ws.send_json(
|
341 |
-
TranscriptionVerboseJsonResponse.from_transcription(
|
342 |
-
transcription
|
343 |
-
).model_dump()
|
344 |
-
)
|
345 |
|
346 |
-
if
|
347 |
logger.info("Closing the connection.")
|
348 |
await ws.close()
|
349 |
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
4 |
+
from collections import OrderedDict
|
5 |
from io import BytesIO
|
6 |
+
import time
|
7 |
+
from typing import TYPE_CHECKING, Annotated, Literal
|
8 |
|
|
|
|
|
9 |
from fastapi import (
|
10 |
FastAPI,
|
11 |
Form,
|
|
|
20 |
from fastapi.responses import StreamingResponse
|
21 |
from fastapi.websockets import WebSocketState
|
22 |
from faster_whisper import WhisperModel
|
|
|
23 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
24 |
+
import gradio as gr
|
25 |
+
import huggingface_hub
|
26 |
from pydantic import AfterValidator
|
27 |
|
28 |
from faster_whisper_server import utils
|
|
|
44 |
)
|
45 |
from faster_whisper_server.transcriber import audio_transcriber
|
46 |
|
47 |
+
if TYPE_CHECKING:
|
48 |
+
from collections.abc import Generator, Iterable
|
49 |
+
|
50 |
+
from faster_whisper.transcribe import Segment, TranscriptionInfo
|
51 |
+
from huggingface_hub.hf_api import ModelInfo
|
52 |
+
|
53 |
loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
|
54 |
|
55 |
|
|
|
59 |
return loaded_models[model_name]
|
60 |
if len(loaded_models) >= config.max_models:
|
61 |
oldest_model_name = next(iter(loaded_models))
|
62 |
+
logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
|
|
|
|
|
63 |
del loaded_models[oldest_model_name]
|
64 |
logger.debug(f"Loading {model_name}...")
|
65 |
start = time.perf_counter()
|
|
|
70 |
compute_type=config.whisper.compute_type,
|
71 |
)
|
72 |
logger.info(
|
73 |
+
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501
|
74 |
)
|
75 |
loaded_models[model_name] = whisper
|
76 |
return whisper
|
|
|
105 |
def get_model(
|
106 |
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
|
107 |
) -> ModelObject:
|
108 |
+
models = list(huggingface_hub.list_models(model_name=model_name, library="ctranslate2"))
|
|
|
|
|
109 |
if len(models) == 0:
|
110 |
raise HTTPException(status_code=404, detail="Model doesn't exists")
|
111 |
exact_match: ModelInfo | None = None
|
|
|
133 |
response_format: ResponseFormat,
|
134 |
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
|
135 |
segments = list(segments)
|
136 |
+
if response_format == ResponseFormat.TEXT: # noqa: RET503
|
137 |
return utils.segments_text(segments)
|
138 |
elif response_format == ResponseFormat.JSON:
|
139 |
return TranscriptionJsonResponse.from_segments(segments)
|
140 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
141 |
+
return TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info)
|
|
|
|
|
142 |
|
143 |
|
144 |
def format_as_sse(data: str) -> str:
|
|
|
155 |
if response_format == ResponseFormat.TEXT:
|
156 |
data = segment.text
|
157 |
elif response_format == ResponseFormat.JSON:
|
158 |
+
data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
|
|
|
|
|
159 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
160 |
+
data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
|
|
|
|
|
161 |
yield format_as_sse(data)
|
162 |
|
163 |
return StreamingResponse(segment_responses(), media_type="text/event-stream")
|
164 |
|
165 |
|
166 |
def handle_default_openai_model(model_name: str) -> str:
|
167 |
+
"""Exists because some callers may not be able override the default("whisper-1") model name.
|
168 |
+
|
169 |
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
|
170 |
"""
|
171 |
if model_name == "whisper-1":
|
172 |
+
logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
|
|
|
|
|
173 |
return config.whisper.model
|
174 |
return model_name
|
175 |
|
|
|
188 |
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
|
189 |
temperature: Annotated[float, Form()] = 0.0,
|
190 |
stream: Annotated[bool, Form()] = False,
|
191 |
+
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
|
|
|
|
|
|
|
|
|
|
|
192 |
whisper = load_model(model)
|
193 |
segments, transcription_info = whisper.transcribe(
|
194 |
file.file,
|
|
|
199 |
)
|
200 |
|
201 |
if stream:
|
202 |
+
return segments_to_streaming_response(segments, transcription_info, response_format)
|
|
|
|
|
203 |
else:
|
204 |
return segments_to_response(segments, transcription_info, response_format)
|
205 |
|
|
|
218 |
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
|
219 |
temperature: Annotated[float, Form()] = 0.0,
|
220 |
timestamp_granularities: Annotated[
|
221 |
+
list[Literal["segment", "word"]],
|
222 |
Form(alias="timestamp_granularities[]"),
|
223 |
] = ["segment"],
|
224 |
stream: Annotated[bool, Form()] = False,
|
225 |
+
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
|
|
|
|
|
|
|
|
|
|
|
226 |
whisper = load_model(model)
|
227 |
segments, transcription_info = whisper.transcribe(
|
228 |
file.file,
|
|
|
235 |
)
|
236 |
|
237 |
if stream:
|
238 |
+
return segments_to_streaming_response(segments, transcription_info, response_format)
|
|
|
|
|
239 |
else:
|
240 |
return segments_to_response(segments, transcription_info, response_format)
|
241 |
|
|
|
243 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
|
244 |
try:
|
245 |
while True:
|
246 |
+
bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
|
|
|
|
|
247 |
logger.debug(f"Received {len(bytes_)} bytes of audio data")
|
248 |
audio_samples = audio_samples_from_file(BytesIO(bytes_))
|
249 |
audio_stream.extend(audio_samples)
|
250 |
if audio_stream.duration - config.inactivity_window_seconds >= 0:
|
251 |
+
audio = audio_stream.after(audio_stream.duration - config.inactivity_window_seconds)
|
|
|
|
|
252 |
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
|
253 |
# NOTE: This is a synchronous operation that runs every time new data is received.
|
254 |
+
# This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. # noqa: E501
|
255 |
timestamps = get_speech_timestamps(audio.data, vad_opts)
|
256 |
if len(timestamps) == 0:
|
257 |
+
logger.info(f"No speech detected in the last {config.inactivity_window_seconds} seconds.")
|
|
|
|
|
258 |
break
|
259 |
elif (
|
260 |
# last speech end time
|
261 |
+
config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND
|
|
|
262 |
>= config.max_inactivity_seconds
|
263 |
):
|
264 |
+
logger.info(f"Not enough speech in the last {config.inactivity_window_seconds} seconds.")
|
|
|
|
|
265 |
break
|
266 |
+
except TimeoutError:
|
267 |
+
logger.info(f"No data received in {config.max_no_data_seconds} seconds. Closing the connection.")
|
|
|
|
|
268 |
except WebSocketDisconnect as e:
|
269 |
logger.info(f"Client disconnected: {e}")
|
270 |
audio_stream.close()
|
|
|
275 |
ws: WebSocket,
|
276 |
model: Annotated[ModelName, Query()] = config.whisper.model,
|
277 |
language: Annotated[Language | None, Query()] = config.default_language,
|
278 |
+
response_format: Annotated[ResponseFormat, Query()] = config.default_response_format,
|
|
|
|
|
279 |
temperature: Annotated[float, Query()] = 0.0,
|
280 |
) -> None:
|
281 |
await ws.accept()
|
|
|
298 |
if response_format == ResponseFormat.TEXT:
|
299 |
await ws.send_text(transcription.text)
|
300 |
elif response_format == ResponseFormat.JSON:
|
301 |
+
await ws.send_json(TranscriptionJsonResponse.from_transcription(transcription).model_dump())
|
|
|
|
|
|
|
|
|
302 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
303 |
+
await ws.send_json(TranscriptionVerboseJsonResponse.from_transcription(transcription).model_dump())
|
|
|
|
|
|
|
|
|
304 |
|
305 |
+
if ws.client_state != WebSocketState.DISCONNECTED:
|
306 |
logger.info("Closing the connection.")
|
307 |
await ws.close()
|
308 |
|
faster_whisper_server/server_models.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
-
from typing import Literal
|
4 |
|
5 |
-
from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
|
6 |
from pydantic import BaseModel, ConfigDict, Field
|
7 |
|
8 |
from faster_whisper_server import utils
|
9 |
-
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
# https://platform.openai.com/docs/api-reference/audio/json-object
|
@@ -18,9 +21,7 @@ class TranscriptionJsonResponse(BaseModel):
|
|
18 |
return cls(text=utils.segments_text(segments))
|
19 |
|
20 |
@classmethod
|
21 |
-
def from_transcription(
|
22 |
-
cls, transcription: Transcription
|
23 |
-
) -> TranscriptionJsonResponse:
|
24 |
return cls(text=transcription.text)
|
25 |
|
26 |
|
@@ -78,18 +79,12 @@ class TranscriptionVerboseJsonResponse(BaseModel):
|
|
78 |
segments: list[SegmentObject]
|
79 |
|
80 |
@classmethod
|
81 |
-
def from_segment(
|
82 |
-
cls, segment: Segment, transcription_info: TranscriptionInfo
|
83 |
-
) -> TranscriptionVerboseJsonResponse:
|
84 |
return cls(
|
85 |
language=transcription_info.language,
|
86 |
duration=segment.end - segment.start,
|
87 |
text=segment.text,
|
88 |
-
words=(
|
89 |
-
[WordObject.from_word(word) for word in segment.words]
|
90 |
-
if isinstance(segment.words, list)
|
91 |
-
else []
|
92 |
-
),
|
93 |
segments=[SegmentObject.from_segment(segment)],
|
94 |
)
|
95 |
|
@@ -102,16 +97,11 @@ class TranscriptionVerboseJsonResponse(BaseModel):
|
|
102 |
duration=transcription_info.duration,
|
103 |
text=utils.segments_text(segments),
|
104 |
segments=[SegmentObject.from_segment(segment) for segment in segments],
|
105 |
-
words=[
|
106 |
-
WordObject.from_word(word)
|
107 |
-
for word in utils.words_from_segments(segments)
|
108 |
-
],
|
109 |
)
|
110 |
|
111 |
@classmethod
|
112 |
-
def from_transcription(
|
113 |
-
cls, transcription: Transcription
|
114 |
-
) -> TranscriptionVerboseJsonResponse:
|
115 |
return cls(
|
116 |
language="english", # FIX: hardcoded
|
117 |
duration=transcription.duration,
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
from typing import TYPE_CHECKING, Literal
|
4 |
|
|
|
5 |
from pydantic import BaseModel, ConfigDict, Field
|
6 |
|
7 |
from faster_whisper_server import utils
|
8 |
+
|
9 |
+
if TYPE_CHECKING:
|
10 |
+
from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
|
11 |
+
|
12 |
+
from faster_whisper_server.core import Transcription
|
13 |
|
14 |
|
15 |
# https://platform.openai.com/docs/api-reference/audio/json-object
|
|
|
21 |
return cls(text=utils.segments_text(segments))
|
22 |
|
23 |
@classmethod
|
24 |
+
def from_transcription(cls, transcription: Transcription) -> TranscriptionJsonResponse:
|
|
|
|
|
25 |
return cls(text=transcription.text)
|
26 |
|
27 |
|
|
|
79 |
segments: list[SegmentObject]
|
80 |
|
81 |
@classmethod
|
82 |
+
def from_segment(cls, segment: Segment, transcription_info: TranscriptionInfo) -> TranscriptionVerboseJsonResponse:
|
|
|
|
|
83 |
return cls(
|
84 |
language=transcription_info.language,
|
85 |
duration=segment.end - segment.start,
|
86 |
text=segment.text,
|
87 |
+
words=([WordObject.from_word(word) for word in segment.words] if isinstance(segment.words, list) else []),
|
|
|
|
|
|
|
|
|
88 |
segments=[SegmentObject.from_segment(segment)],
|
89 |
)
|
90 |
|
|
|
97 |
duration=transcription_info.duration,
|
98 |
text=utils.segments_text(segments),
|
99 |
segments=[SegmentObject.from_segment(segment) for segment in segments],
|
100 |
+
words=[WordObject.from_word(word) for word in utils.words_from_segments(segments)],
|
|
|
|
|
|
|
101 |
)
|
102 |
|
103 |
@classmethod
|
104 |
+
def from_transcription(cls, transcription: Transcription) -> TranscriptionVerboseJsonResponse:
|
|
|
|
|
105 |
return cls(
|
106 |
language="english", # FIX: hardcoded
|
107 |
duration=transcription.duration,
|
faster_whisper_server/transcriber.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
-
from typing import
|
4 |
|
5 |
-
from faster_whisper_server.asr import FasterWhisperASR
|
6 |
from faster_whisper_server.audio import Audio, AudioStream
|
7 |
from faster_whisper_server.config import config
|
8 |
from faster_whisper_server.core import (
|
@@ -13,6 +12,11 @@ from faster_whisper_server.core import (
|
|
13 |
)
|
14 |
from faster_whisper_server.logger import logger
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
class LocalAgreement:
|
18 |
def __init__(self) -> None:
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
from typing import TYPE_CHECKING
|
4 |
|
|
|
5 |
from faster_whisper_server.audio import Audio, AudioStream
|
6 |
from faster_whisper_server.config import config
|
7 |
from faster_whisper_server.core import (
|
|
|
12 |
)
|
13 |
from faster_whisper_server.logger import logger
|
14 |
|
15 |
+
if TYPE_CHECKING:
|
16 |
+
from collections.abc import AsyncGenerator
|
17 |
+
|
18 |
+
from faster_whisper_server.asr import FasterWhisperASR
|
19 |
+
|
20 |
|
21 |
class LocalAgreement:
|
22 |
def __init__(self) -> None:
|
pyproject.toml
CHANGED
@@ -28,18 +28,35 @@ target-version = "py312"
|
|
28 |
[tool.ruff.lint]
|
29 |
select = ["ALL"]
|
30 |
ignore = [
|
31 |
-
"
|
|
|
32 |
"ERA", # allow commented out code
|
33 |
-
"
|
34 |
-
"FIX002", # disable TODO warnings
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
"COM812", # trailing comma
|
37 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
"S101", # allow assert
|
39 |
-
"
|
40 |
"S603", # subprocess untrusted input
|
41 |
-
|
42 |
-
"
|
|
|
|
|
|
|
43 |
]
|
44 |
|
45 |
[tool.ruff.lint.isort]
|
|
|
28 |
[tool.ruff.lint]
|
29 |
select = ["ALL"]
|
30 |
ignore = [
|
31 |
+
"FIX",
|
32 |
+
"TD", # disable todo warnings
|
33 |
"ERA", # allow commented out code
|
34 |
+
"PTH",
|
|
|
35 |
|
36 |
+
"ANN003", # missing kwargs
|
37 |
+
"ANN101", # missing self type
|
38 |
+
"ANN102", # missing cls
|
39 |
+
"B006",
|
40 |
+
"B008",
|
41 |
"COM812", # trailing comma
|
42 |
+
"D10", # disabled required docstrings
|
43 |
+
"D401",
|
44 |
+
"EM102",
|
45 |
+
"FBT001",
|
46 |
+
"FBT002",
|
47 |
+
"PLR0913",
|
48 |
+
"PLR2004", # magic
|
49 |
+
"RET504",
|
50 |
+
"RET505",
|
51 |
+
"RET508",
|
52 |
"S101", # allow assert
|
53 |
+
"S104",
|
54 |
"S603", # subprocess untrusted input
|
55 |
+
"SIM102",
|
56 |
+
"T201", # print
|
57 |
+
"TRY003",
|
58 |
+
"W505",
|
59 |
+
"ISC001" # recommended to disable for formatting
|
60 |
]
|
61 |
|
62 |
[tool.ruff.lint.isort]
|
tests/api_model_test.py
CHANGED
@@ -4,9 +4,7 @@ from faster_whisper_server.server_models import ModelObject
|
|
4 |
|
5 |
MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en"
|
6 |
MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist"
|
7 |
-
MIN_EXPECTED_NUMBER_OF_MODELS =
|
8 |
-
200 # At the time of the test creation there are 228 models
|
9 |
-
)
|
10 |
|
11 |
|
12 |
# HACK: because ModelObject(**data) doesn't work
|
@@ -19,20 +17,20 @@ def model_dict_to_object(model_dict: dict) -> ModelObject:
|
|
19 |
)
|
20 |
|
21 |
|
22 |
-
def test_list_models(client: TestClient):
|
23 |
response = client.get("/v1/models")
|
24 |
data = response.json()
|
25 |
models = [model_dict_to_object(model_dict) for model_dict in data]
|
26 |
assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS
|
27 |
|
28 |
|
29 |
-
def test_model_exists(client: TestClient):
|
30 |
response = client.get(f"/v1/models/{MODEL_THAT_EXISTS}")
|
31 |
data = response.json()
|
32 |
model = model_dict_to_object(data)
|
33 |
assert model.id == MODEL_THAT_EXISTS
|
34 |
|
35 |
|
36 |
-
def test_model_does_not_exist(client: TestClient):
|
37 |
response = client.get(f"/v1/models/{MODEL_THAT_DOES_NOT_EXIST}")
|
38 |
assert response.status_code == 404
|
|
|
4 |
|
5 |
MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en"
|
6 |
MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist"
|
7 |
+
MIN_EXPECTED_NUMBER_OF_MODELS = 200 # At the time of the test creation there are 228 models
|
|
|
|
|
8 |
|
9 |
|
10 |
# HACK: because ModelObject(**data) doesn't work
|
|
|
17 |
)
|
18 |
|
19 |
|
20 |
+
def test_list_models(client: TestClient) -> None:
|
21 |
response = client.get("/v1/models")
|
22 |
data = response.json()
|
23 |
models = [model_dict_to_object(model_dict) for model_dict in data]
|
24 |
assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS
|
25 |
|
26 |
|
27 |
+
def test_model_exists(client: TestClient) -> None:
|
28 |
response = client.get(f"/v1/models/{MODEL_THAT_EXISTS}")
|
29 |
data = response.json()
|
30 |
model = model_dict_to_object(data)
|
31 |
assert model.id == MODEL_THAT_EXISTS
|
32 |
|
33 |
|
34 |
+
def test_model_does_not_exist(client: TestClient) -> None:
|
35 |
response = client.get(f"/v1/models/{MODEL_THAT_DOES_NOT_EXIST}")
|
36 |
assert response.status_code == 404
|
tests/app_test.py
CHANGED
@@ -1,10 +1,10 @@
|
|
|
|
1 |
import json
|
2 |
import os
|
3 |
import time
|
4 |
-
from typing import Generator
|
5 |
|
6 |
-
import pytest
|
7 |
from fastapi.testclient import TestClient
|
|
|
8 |
from starlette.testclient import WebSocketTestSession
|
9 |
|
10 |
from faster_whisper_server.config import BYTES_PER_SECOND
|
@@ -22,35 +22,31 @@ def ws(client: TestClient) -> Generator[WebSocketTestSession, None, None]:
|
|
22 |
yield ws
|
23 |
|
24 |
|
25 |
-
def get_audio_file_paths():
|
26 |
-
file_paths = []
|
27 |
directory = "tests/data"
|
28 |
for filename in sorted(os.listdir(directory)[:AUDIO_FILES_LIMIT]):
|
29 |
-
file_paths.append(os.path.join(directory, filename))
|
30 |
return file_paths
|
31 |
|
32 |
|
33 |
file_paths = get_audio_file_paths()
|
34 |
|
35 |
|
36 |
-
def stream_audio_data(
|
37 |
-
ws: WebSocketTestSession, data: bytes, *, chunk_size: int = 4000, speed: float = 1.0
|
38 |
-
):
|
39 |
for i in range(0, len(data), chunk_size):
|
40 |
ws.send_bytes(data[i : i + chunk_size])
|
41 |
delay = len(data[i : i + chunk_size]) / BYTES_PER_SECOND / speed
|
42 |
time.sleep(delay)
|
43 |
|
44 |
|
45 |
-
def transcribe_audio_data(
|
46 |
-
client: TestClient, data: bytes
|
47 |
-
) -> TranscriptionVerboseJsonResponse:
|
48 |
response = client.post(
|
49 |
TRANSCRIBE_ENDPOINT,
|
50 |
files={"file": ("audio.raw", data, "audio/raw")},
|
51 |
)
|
52 |
data = json.loads(response.json()) # TODO: figure this out
|
53 |
-
return TranscriptionVerboseJsonResponse(**data) #
|
54 |
|
55 |
|
56 |
# @pytest.mark.parametrize("file_path", file_paths)
|
@@ -60,7 +56,7 @@ def transcribe_audio_data(
|
|
60 |
# with open(file_path, "rb") as file:
|
61 |
# data = file.read()
|
62 |
#
|
63 |
-
# streaming_transcription: TranscriptionVerboseJsonResponse = None # type: ignore
|
64 |
# thread = threading.Thread(
|
65 |
# target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0}
|
66 |
# )
|
|
|
1 |
+
from collections.abc import Generator
|
2 |
import json
|
3 |
import os
|
4 |
import time
|
|
|
5 |
|
|
|
6 |
from fastapi.testclient import TestClient
|
7 |
+
import pytest
|
8 |
from starlette.testclient import WebSocketTestSession
|
9 |
|
10 |
from faster_whisper_server.config import BYTES_PER_SECOND
|
|
|
22 |
yield ws
|
23 |
|
24 |
|
25 |
+
def get_audio_file_paths() -> list[str]:
|
26 |
+
file_paths: list[str] = []
|
27 |
directory = "tests/data"
|
28 |
for filename in sorted(os.listdir(directory)[:AUDIO_FILES_LIMIT]):
|
29 |
+
file_paths.append(os.path.join(directory, filename)) # noqa: PERF401
|
30 |
return file_paths
|
31 |
|
32 |
|
33 |
file_paths = get_audio_file_paths()
|
34 |
|
35 |
|
36 |
+
def stream_audio_data(ws: WebSocketTestSession, data: bytes, *, chunk_size: int = 4000, speed: float = 1.0) -> None:
|
|
|
|
|
37 |
for i in range(0, len(data), chunk_size):
|
38 |
ws.send_bytes(data[i : i + chunk_size])
|
39 |
delay = len(data[i : i + chunk_size]) / BYTES_PER_SECOND / speed
|
40 |
time.sleep(delay)
|
41 |
|
42 |
|
43 |
+
def transcribe_audio_data(client: TestClient, data: bytes) -> TranscriptionVerboseJsonResponse:
|
|
|
|
|
44 |
response = client.post(
|
45 |
TRANSCRIBE_ENDPOINT,
|
46 |
files={"file": ("audio.raw", data, "audio/raw")},
|
47 |
)
|
48 |
data = json.loads(response.json()) # TODO: figure this out
|
49 |
+
return TranscriptionVerboseJsonResponse(**data) # pyright: ignore[reportCallIssue]
|
50 |
|
51 |
|
52 |
# @pytest.mark.parametrize("file_path", file_paths)
|
|
|
56 |
# with open(file_path, "rb") as file:
|
57 |
# data = file.read()
|
58 |
#
|
59 |
+
# streaming_transcription: TranscriptionVerboseJsonResponse = None # type: ignore # noqa: PGH003
|
60 |
# thread = threading.Thread(
|
61 |
# target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0}
|
62 |
# )
|
tests/conftest.py
CHANGED
@@ -1,18 +1,15 @@
|
|
|
|
1 |
import logging
|
2 |
-
import os
|
3 |
-
from typing import Generator
|
4 |
|
5 |
-
import pytest
|
6 |
from fastapi.testclient import TestClient
|
|
|
7 |
|
8 |
-
|
9 |
-
os.environ["WHISPER_MODEL"] = "Systran/faster-whisper-tiny.en"
|
10 |
-
from faster_whisper_server.main import app # noqa: E402
|
11 |
|
12 |
disable_loggers = ["multipart.multipart", "faster_whisper"]
|
13 |
|
14 |
|
15 |
-
def pytest_configure():
|
16 |
for logger_name in disable_loggers:
|
17 |
logger = logging.getLogger(logger_name)
|
18 |
logger.disabled = True
|
|
|
1 |
+
from collections.abc import Generator
|
2 |
import logging
|
|
|
|
|
3 |
|
|
|
4 |
from fastapi.testclient import TestClient
|
5 |
+
import pytest
|
6 |
|
7 |
+
from faster_whisper_server.main import app
|
|
|
|
|
8 |
|
9 |
disable_loggers = ["multipart.multipart", "faster_whisper"]
|
10 |
|
11 |
|
12 |
+
def pytest_configure() -> None:
|
13 |
for logger_name in disable_loggers:
|
14 |
logger = logging.getLogger(logger_name)
|
15 |
logger.disabled = True
|
tests/sse_test.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import json
|
2 |
import os
|
3 |
|
4 |
-
import pytest
|
5 |
from fastapi.testclient import TestClient
|
6 |
from httpx_sse import connect_sse
|
|
|
7 |
|
8 |
from faster_whisper_server.server_models import (
|
9 |
TranscriptionJsonResponse,
|
@@ -17,15 +17,11 @@ ENDPOINTS = [
|
|
17 |
]
|
18 |
|
19 |
|
20 |
-
parameters = [
|
21 |
-
(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS
|
22 |
-
]
|
23 |
|
24 |
|
25 |
-
@pytest.mark.parametrize("file_path,endpoint", parameters)
|
26 |
-
def test_streaming_transcription_text(
|
27 |
-
client: TestClient, file_path: str, endpoint: str
|
28 |
-
):
|
29 |
extension = os.path.splitext(file_path)[1]
|
30 |
with open(file_path, "rb") as f:
|
31 |
data = f.read()
|
@@ -36,15 +32,11 @@ def test_streaming_transcription_text(
|
|
36 |
with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
|
37 |
for event in event_source.iter_sse():
|
38 |
print(event)
|
39 |
-
assert (
|
40 |
-
len(event.data) > 1
|
41 |
-
) # HACK: 1 because of the space character that's always prepended
|
42 |
|
43 |
|
44 |
-
@pytest.mark.parametrize("file_path,endpoint", parameters)
|
45 |
-
def test_streaming_transcription_json(
|
46 |
-
client: TestClient, file_path: str, endpoint: str
|
47 |
-
):
|
48 |
extension = os.path.splitext(file_path)[1]
|
49 |
with open(file_path, "rb") as f:
|
50 |
data = f.read()
|
@@ -57,10 +49,8 @@ def test_streaming_transcription_json(
|
|
57 |
TranscriptionJsonResponse(**json.loads(event.data))
|
58 |
|
59 |
|
60 |
-
@pytest.mark.parametrize("file_path,endpoint", parameters)
|
61 |
-
def test_streaming_transcription_verbose_json(
|
62 |
-
client: TestClient, file_path: str, endpoint: str
|
63 |
-
):
|
64 |
extension = os.path.splitext(file_path)[1]
|
65 |
with open(file_path, "rb") as f:
|
66 |
data = f.read()
|
|
|
1 |
import json
|
2 |
import os
|
3 |
|
|
|
4 |
from fastapi.testclient import TestClient
|
5 |
from httpx_sse import connect_sse
|
6 |
+
import pytest
|
7 |
|
8 |
from faster_whisper_server.server_models import (
|
9 |
TranscriptionJsonResponse,
|
|
|
17 |
]
|
18 |
|
19 |
|
20 |
+
parameters = [(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS]
|
|
|
|
|
21 |
|
22 |
|
23 |
+
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
|
24 |
+
def test_streaming_transcription_text(client: TestClient, file_path: str, endpoint: str) -> None:
|
|
|
|
|
25 |
extension = os.path.splitext(file_path)[1]
|
26 |
with open(file_path, "rb") as f:
|
27 |
data = f.read()
|
|
|
32 |
with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
|
33 |
for event in event_source.iter_sse():
|
34 |
print(event)
|
35 |
+
assert len(event.data) > 1 # HACK: 1 because of the space character that's always prepended
|
|
|
|
|
36 |
|
37 |
|
38 |
+
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
|
39 |
+
def test_streaming_transcription_json(client: TestClient, file_path: str, endpoint: str) -> None:
|
|
|
|
|
40 |
extension = os.path.splitext(file_path)[1]
|
41 |
with open(file_path, "rb") as f:
|
42 |
data = f.read()
|
|
|
49 |
TranscriptionJsonResponse(**json.loads(event.data))
|
50 |
|
51 |
|
52 |
+
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
|
53 |
+
def test_streaming_transcription_verbose_json(client: TestClient, file_path: str, endpoint: str) -> None:
|
|
|
|
|
54 |
extension = os.path.splitext(file_path)[1]
|
55 |
with open(file_path, "rb") as f:
|
56 |
data = f.read()
|