Fedir Zadniprovskyi commited on
Commit
dc4f25f
·
1 Parent(s): 8ad4ca5

chore: fix ruff errors

Browse files
faster_whisper_server/asr.py CHANGED
@@ -1,6 +1,6 @@
1
  import asyncio
 
2
  import time
3
- from typing import Iterable
4
 
5
  from faster_whisper import transcribe
6
 
@@ -45,7 +45,7 @@ class FasterWhisperASR:
45
  audio: Audio,
46
  prompt: str | None = None,
47
  ) -> tuple[Transcription, transcribe.TranscriptionInfo]:
48
- """Wrapper around _transcribe so it can be used in async context"""
49
  # is this the optimal way to execute a blocking call in an async context?
50
  # TODO: verify performance when running inference on a CPU
51
  return await asyncio.get_running_loop().run_in_executor(
 
1
  import asyncio
2
+ from collections.abc import Iterable
3
  import time
 
4
 
5
  from faster_whisper import transcribe
6
 
 
45
  audio: Audio,
46
  prompt: str | None = None,
47
  ) -> tuple[Transcription, transcribe.TranscriptionInfo]:
48
+ """Wrapper around _transcribe so it can be used in async context."""
49
  # is this the optimal way to execute a blocking call in an async context?
50
  # TODO: verify performance when running inference on a CPU
51
  return await asyncio.get_running_loop().run_in_executor(
faster_whisper_server/audio.py CHANGED
@@ -1,15 +1,19 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
4
- from typing import AsyncGenerator, BinaryIO
5
 
6
  import numpy as np
7
  import soundfile as sf
8
- from numpy.typing import NDArray
9
 
10
  from faster_whisper_server.config import SAMPLES_PER_SECOND
11
  from faster_whisper_server.logger import logger
12
 
 
 
 
 
 
13
 
14
  def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
15
  audio_and_sample_rate = sf.read(
@@ -22,7 +26,7 @@ def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
22
  endian="LITTLE",
23
  )
24
  audio = audio_and_sample_rate[0]
25
- return audio # type: ignore
26
 
27
 
28
  class Audio:
@@ -78,9 +82,7 @@ class AudioStream(Audio):
78
  self.modify_event.set()
79
  logger.info("AudioStream closed")
80
 
81
- async def chunks(
82
- self, min_duration: float
83
- ) -> AsyncGenerator[NDArray[np.float32], None]:
84
  i = 0.0 # end time of last chunk
85
  while True:
86
  await self.modify_event.wait()
 
1
  from __future__ import annotations
2
 
3
  import asyncio
4
+ from typing import TYPE_CHECKING, BinaryIO
5
 
6
  import numpy as np
7
  import soundfile as sf
 
8
 
9
  from faster_whisper_server.config import SAMPLES_PER_SECOND
10
  from faster_whisper_server.logger import logger
11
 
12
+ if TYPE_CHECKING:
13
+ from collections.abc import AsyncGenerator
14
+
15
+ from numpy.typing import NDArray
16
+
17
 
18
  def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
19
  audio_and_sample_rate = sf.read(
 
26
  endian="LITTLE",
27
  )
28
  audio = audio_and_sample_rate[0]
29
+ return audio # pyright: ignore[reportReturnType]
30
 
31
 
32
  class Audio:
 
82
  self.modify_event.set()
83
  logger.info("AudioStream closed")
84
 
85
+ async def chunks(self, min_duration: float) -> AsyncGenerator[NDArray[np.float32], None]:
 
 
86
  i = 0.0 # end time of last chunk
87
  while True:
88
  await self.modify_event.wait()
faster_whisper_server/config.py CHANGED
@@ -15,7 +15,7 @@ class ResponseFormat(enum.StrEnum):
15
  TEXT = "text"
16
  JSON = "json"
17
  VERBOSE_JSON = "verbose_json"
18
- # NOTE: While inspecting outputs of these formats with `curl`, I noticed there's one or two "\n" inserted at the end of the response.
19
 
20
  # VTT = "vtt" # TODO
21
  # 1
@@ -185,8 +185,8 @@ class WhisperConfig(BaseModel):
185
 
186
 
187
  class Config(BaseSettings):
188
- """
189
- Configuration for the application. Values can be set via environment variables.
190
  Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields.
191
  To populate nested, the environment should be prefixed with the nested field name and an underscore. For example,
192
  the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER_MODEL` to `whisper.model`, etc.
@@ -208,7 +208,7 @@ class Config(BaseSettings):
208
  max_inactivity_seconds: float = 5.0
209
  """
210
  Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
211
- """
212
  inactivity_window_seconds: float = 10.0
213
  """
214
  Controls how many latest seconds of audio are being passed through VAD.
 
15
  TEXT = "text"
16
  JSON = "json"
17
  VERBOSE_JSON = "verbose_json"
18
+ # NOTE: While inspecting outputs of these formats with `curl`, I noticed there's one or two "\n" inserted at the end of the response. # noqa: E501
19
 
20
  # VTT = "vtt" # TODO
21
  # 1
 
185
 
186
 
187
  class Config(BaseSettings):
188
+ """Configuration for the application. Values can be set via environment variables.
189
+
190
  Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields.
191
  To populate nested, the environment should be prefixed with the nested field name and an underscore. For example,
192
  the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER_MODEL` to `whisper.model`, etc.
 
208
  max_inactivity_seconds: float = 5.0
209
  """
210
  Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
211
+ """ # noqa: E501
212
  inactivity_window_seconds: float = 10.0
213
  """
214
  Controls how many latest seconds of audio are being passed through VAD.
faster_whisper_server/core.py CHANGED
@@ -1,8 +1,8 @@
1
  # TODO: rename module
2
  from __future__ import annotations
3
 
4
- import re
5
  from dataclasses import dataclass
 
6
 
7
  from faster_whisper_server.config import config
8
 
@@ -18,10 +18,7 @@ class Segment:
18
  def is_eos(self) -> bool:
19
  if self.text.endswith("..."):
20
  return False
21
- for punctuation_symbol in ".?!":
22
- if self.text.endswith(punctuation_symbol):
23
- return True
24
- return False
25
 
26
  def offset(self, seconds: float) -> None:
27
  self.start += seconds
@@ -36,11 +33,7 @@ class Word(Segment):
36
  @classmethod
37
  def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]:
38
  i = 0
39
- while (
40
- i < len(a)
41
- and i < len(b)
42
- and canonicalize_word(a[i].text) == canonicalize_word(b[i].text)
43
- ):
44
  i += 1
45
  return a[:i]
46
 
@@ -67,9 +60,7 @@ class Transcription:
67
  return self.end - self.start
68
 
69
  def after(self, seconds: float) -> Transcription:
70
- return Transcription(
71
- words=[word for word in self.words if word.start > seconds]
72
- )
73
 
74
  def extend(self, words: list[Word]) -> None:
75
  self._ensure_no_word_overlap(words)
@@ -77,21 +68,16 @@ class Transcription:
77
 
78
  def _ensure_no_word_overlap(self, words: list[Word]) -> None:
79
  if len(self.words) > 0 and len(words) > 0:
80
- if (
81
- words[0].start + config.word_timestamp_error_margin
82
- <= self.words[-1].end
83
- ):
84
  raise ValueError(
85
- f"Words overlap: {self.words[-1]} and {words[0]}. Error margin: {config.word_timestamp_error_margin}"
86
  )
87
  for i in range(1, len(words)):
88
  if words[i].start + config.word_timestamp_error_margin <= words[i - 1].end:
89
- raise ValueError(
90
- f"Words overlap: {words[i - 1]} and {words[i]}. All words: {words}"
91
- )
92
 
93
 
94
- def test_segment_is_eos():
95
  assert not Segment("Hello").is_eos
96
  assert not Segment("Hello...").is_eos
97
  assert Segment("Hello.").is_eos
@@ -117,16 +103,14 @@ def to_full_sentences(words: list[Word]) -> list[Segment]:
117
  return sentences
118
 
119
 
120
- def tests_to_full_sentences():
121
  assert to_full_sentences([]) == []
122
  assert to_full_sentences([Word(text="Hello")]) == []
123
  assert to_full_sentences([Word(text="Hello..."), Word(" world")]) == []
124
- assert to_full_sentences([Word(text="Hello..."), Word(" world.")]) == [
 
125
  Segment(text="Hello... world.")
126
  ]
127
- assert to_full_sentences(
128
- [Word(text="Hello..."), Word(" world."), Word(" How")]
129
- ) == [Segment(text="Hello... world.")]
130
 
131
 
132
  def to_text(words: list[Word]) -> str:
@@ -144,7 +128,7 @@ def canonicalize_word(text: str) -> str:
144
  return text.lower().strip().strip(".,?!")
145
 
146
 
147
- def test_canonicalize_word():
148
  assert canonicalize_word("ABC") == "abc"
149
  assert canonicalize_word("...ABC?") == "abc"
150
  assert canonicalize_word("... AbC ...") == "abc"
@@ -152,16 +136,12 @@ def test_canonicalize_word():
152
 
153
  def common_prefix(a: list[Word], b: list[Word]) -> list[Word]:
154
  i = 0
155
- while (
156
- i < len(a)
157
- and i < len(b)
158
- and canonicalize_word(a[i].text) == canonicalize_word(b[i].text)
159
- ):
160
  i += 1
161
  return a[:i]
162
 
163
 
164
- def test_common_prefix():
165
  def word(text: str) -> Word:
166
  return Word(text=text, start=0.0, end=0.0, probability=0.0)
167
 
@@ -194,7 +174,7 @@ def test_common_prefix():
194
  assert common_prefix(a, b) == []
195
 
196
 
197
- def test_common_prefix_and_canonicalization():
198
  def word(text: str) -> Word:
199
  return Word(text=text, start=0.0, end=0.0, probability=0.0)
200
 
 
1
  # TODO: rename module
2
  from __future__ import annotations
3
 
 
4
  from dataclasses import dataclass
5
+ import re
6
 
7
  from faster_whisper_server.config import config
8
 
 
18
  def is_eos(self) -> bool:
19
  if self.text.endswith("..."):
20
  return False
21
+ return any(self.text.endswith(punctuation_symbol) for punctuation_symbol in ".?!")
 
 
 
22
 
23
  def offset(self, seconds: float) -> None:
24
  self.start += seconds
 
33
  @classmethod
34
  def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]:
35
  i = 0
36
+ while i < len(a) and i < len(b) and canonicalize_word(a[i].text) == canonicalize_word(b[i].text):
 
 
 
 
37
  i += 1
38
  return a[:i]
39
 
 
60
  return self.end - self.start
61
 
62
  def after(self, seconds: float) -> Transcription:
63
+ return Transcription(words=[word for word in self.words if word.start > seconds])
 
 
64
 
65
  def extend(self, words: list[Word]) -> None:
66
  self._ensure_no_word_overlap(words)
 
68
 
69
  def _ensure_no_word_overlap(self, words: list[Word]) -> None:
70
  if len(self.words) > 0 and len(words) > 0:
71
+ if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
 
 
 
72
  raise ValueError(
73
+ f"Words overlap: {self.words[-1]} and {words[0]}. Error margin: {config.word_timestamp_error_margin}" # noqa: E501
74
  )
75
  for i in range(1, len(words)):
76
  if words[i].start + config.word_timestamp_error_margin <= words[i - 1].end:
77
+ raise ValueError(f"Words overlap: {words[i - 1]} and {words[i]}. All words: {words}")
 
 
78
 
79
 
80
+ def test_segment_is_eos() -> None:
81
  assert not Segment("Hello").is_eos
82
  assert not Segment("Hello...").is_eos
83
  assert Segment("Hello.").is_eos
 
103
  return sentences
104
 
105
 
106
+ def tests_to_full_sentences() -> None:
107
  assert to_full_sentences([]) == []
108
  assert to_full_sentences([Word(text="Hello")]) == []
109
  assert to_full_sentences([Word(text="Hello..."), Word(" world")]) == []
110
+ assert to_full_sentences([Word(text="Hello..."), Word(" world.")]) == [Segment(text="Hello... world.")]
111
+ assert to_full_sentences([Word(text="Hello..."), Word(" world."), Word(" How")]) == [
112
  Segment(text="Hello... world.")
113
  ]
 
 
 
114
 
115
 
116
  def to_text(words: list[Word]) -> str:
 
128
  return text.lower().strip().strip(".,?!")
129
 
130
 
131
+ def test_canonicalize_word() -> None:
132
  assert canonicalize_word("ABC") == "abc"
133
  assert canonicalize_word("...ABC?") == "abc"
134
  assert canonicalize_word("... AbC ...") == "abc"
 
136
 
137
  def common_prefix(a: list[Word], b: list[Word]) -> list[Word]:
138
  i = 0
139
+ while i < len(a) and i < len(b) and canonicalize_word(a[i].text) == canonicalize_word(b[i].text):
 
 
 
 
140
  i += 1
141
  return a[:i]
142
 
143
 
144
+ def test_common_prefix() -> None:
145
  def word(text: str) -> Word:
146
  return Word(text=text, start=0.0, end=0.0, probability=0.0)
147
 
 
174
  assert common_prefix(a, b) == []
175
 
176
 
177
+ def test_common_prefix_and_canonicalization() -> None:
178
  def word(text: str) -> Word:
179
  return Word(text=text, start=0.0, end=0.0, probability=0.0)
180
 
faster_whisper_server/gradio_app.py CHANGED
@@ -1,5 +1,5 @@
 
1
  import os
2
- from typing import Generator
3
 
4
  import gradio as gr
5
  import httpx
@@ -13,26 +13,20 @@ TRANSLATION_ENDPOINT = "/v1/audio/translations"
13
 
14
  def create_gradio_demo(config: Config) -> gr.Blocks:
15
  host = os.getenv("UVICORN_HOST", "0.0.0.0")
16
- port = os.getenv("UVICORN_PORT", 8000)
17
  # NOTE: worth looking into generated clients
18
  http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None)
19
 
20
- def handler(
21
- file_path: str, model: str, task: Task, temperature: float, stream: bool
22
- ) -> Generator[str, None, None]:
23
  if stream:
24
  previous_transcription = ""
25
- for transcription in transcribe_audio_streaming(
26
- file_path, task, temperature, model
27
- ):
28
  previous_transcription += transcription
29
  yield previous_transcription
30
  else:
31
  yield transcribe_audio(file_path, task, temperature, model)
32
 
33
- def transcribe_audio(
34
- file_path: str, task: Task, temperature: float, model: str
35
- ) -> str:
36
  if task == Task.TRANSCRIBE:
37
  endpoint = TRANSCRIPTION_ENDPOINT
38
  elif task == Task.TRANSLATE:
@@ -65,11 +59,7 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
65
  "stream": True,
66
  },
67
  }
68
- endpoint = (
69
- TRANSCRIPTION_ENDPOINT
70
- if task == Task.TRANSCRIBE
71
- else TRANSLATION_ENDPOINT
72
- )
73
  with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
74
  for event in event_source.iter_sse():
75
  yield event.data
@@ -79,18 +69,15 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
79
  res_data = res.json()
80
  models: list[str] = [model["id"] for model in res_data]
81
  assert config.whisper.model in models
82
- recommended_models = set(
83
- model for model in models if model.startswith("Systran")
84
- )
85
  other_models = [model for model in models if model not in recommended_models]
86
  models = list(recommended_models) + other_models
87
- model_dropdown = gr.Dropdown(
88
  # no idea why it's complaining
89
- choices=models, # type: ignore
90
  label="Model",
91
  value=config.whisper.model,
92
  )
93
- return model_dropdown
94
 
95
  model_dropdown = gr.Dropdown(
96
  choices=[config.whisper.model],
@@ -102,13 +89,11 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
102
  label="Task",
103
  value=Task.TRANSCRIBE,
104
  )
105
- temperature_slider = gr.Slider(
106
- minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0
107
- )
108
  stream_checkbox = gr.Checkbox(label="Stream", value=True)
109
  with gr.Interface(
110
  title="Whisper Playground",
111
- description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""",
112
  inputs=[
113
  gr.Audio(type="filepath"),
114
  model_dropdown,
 
1
+ from collections.abc import Generator
2
  import os
 
3
 
4
  import gradio as gr
5
  import httpx
 
13
 
14
  def create_gradio_demo(config: Config) -> gr.Blocks:
15
  host = os.getenv("UVICORN_HOST", "0.0.0.0")
16
+ port = int(os.getenv("UVICORN_PORT", "8000"))
17
  # NOTE: worth looking into generated clients
18
  http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None)
19
 
20
+ def handler(file_path: str, model: str, task: Task, temperature: float, stream: bool) -> Generator[str, None, None]:
 
 
21
  if stream:
22
  previous_transcription = ""
23
+ for transcription in transcribe_audio_streaming(file_path, task, temperature, model):
 
 
24
  previous_transcription += transcription
25
  yield previous_transcription
26
  else:
27
  yield transcribe_audio(file_path, task, temperature, model)
28
 
29
+ def transcribe_audio(file_path: str, task: Task, temperature: float, model: str) -> str:
 
 
30
  if task == Task.TRANSCRIBE:
31
  endpoint = TRANSCRIPTION_ENDPOINT
32
  elif task == Task.TRANSLATE:
 
59
  "stream": True,
60
  },
61
  }
62
+ endpoint = TRANSCRIPTION_ENDPOINT if task == Task.TRANSCRIBE else TRANSLATION_ENDPOINT
 
 
 
 
63
  with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
64
  for event in event_source.iter_sse():
65
  yield event.data
 
69
  res_data = res.json()
70
  models: list[str] = [model["id"] for model in res_data]
71
  assert config.whisper.model in models
72
+ recommended_models = {model for model in models if model.startswith("Systran")}
 
 
73
  other_models = [model for model in models if model not in recommended_models]
74
  models = list(recommended_models) + other_models
75
+ return gr.Dropdown(
76
  # no idea why it's complaining
77
+ choices=models, # pyright: ignore[reportArgumentType]
78
  label="Model",
79
  value=config.whisper.model,
80
  )
 
81
 
82
  model_dropdown = gr.Dropdown(
83
  choices=[config.whisper.model],
 
89
  label="Task",
90
  value=Task.TRANSCRIBE,
91
  )
92
+ temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0)
 
 
93
  stream_checkbox = gr.Checkbox(label="Stream", value=True)
94
  with gr.Interface(
95
  title="Whisper Playground",
96
+ description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""", # noqa: E501
97
  inputs=[
98
  gr.Audio(type="filepath"),
99
  model_dropdown,
faster_whisper_server/logger.py CHANGED
@@ -8,6 +8,4 @@ root_logger = logging.getLogger()
8
  root_logger.setLevel(logging.CRITICAL)
9
  logger = logging.getLogger(__name__)
10
  logger.setLevel(config.log_level.upper())
11
- logging.basicConfig(
12
- format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s"
13
- )
 
8
  root_logger.setLevel(logging.CRITICAL)
9
  logger = logging.getLogger(__name__)
10
  logger.setLevel(config.log_level.upper())
11
+ logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s")
 
 
faster_whisper_server/main.py CHANGED
@@ -1,12 +1,11 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
4
- import time
5
  from io import BytesIO
6
- from typing import Annotated, Generator, Iterable, Literal, OrderedDict
 
7
 
8
- import gradio as gr
9
- import huggingface_hub
10
  from fastapi import (
11
  FastAPI,
12
  Form,
@@ -21,9 +20,9 @@ from fastapi import (
21
  from fastapi.responses import StreamingResponse
22
  from fastapi.websockets import WebSocketState
23
  from faster_whisper import WhisperModel
24
- from faster_whisper.transcribe import Segment, TranscriptionInfo
25
  from faster_whisper.vad import VadOptions, get_speech_timestamps
26
- from huggingface_hub.hf_api import ModelInfo
 
27
  from pydantic import AfterValidator
28
 
29
  from faster_whisper_server import utils
@@ -45,6 +44,12 @@ from faster_whisper_server.server_models import (
45
  )
46
  from faster_whisper_server.transcriber import audio_transcriber
47
 
 
 
 
 
 
 
48
  loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
49
 
50
 
@@ -54,9 +59,7 @@ def load_model(model_name: str) -> WhisperModel:
54
  return loaded_models[model_name]
55
  if len(loaded_models) >= config.max_models:
56
  oldest_model_name = next(iter(loaded_models))
57
- logger.info(
58
- f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
59
- )
60
  del loaded_models[oldest_model_name]
61
  logger.debug(f"Loading {model_name}...")
62
  start = time.perf_counter()
@@ -67,7 +70,7 @@ def load_model(model_name: str) -> WhisperModel:
67
  compute_type=config.whisper.compute_type,
68
  )
69
  logger.info(
70
- f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."
71
  )
72
  loaded_models[model_name] = whisper
73
  return whisper
@@ -102,9 +105,7 @@ def get_models() -> list[ModelObject]:
102
  def get_model(
103
  model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
104
  ) -> ModelObject:
105
- models = list(
106
- huggingface_hub.list_models(model_name=model_name, library="ctranslate2")
107
- )
108
  if len(models) == 0:
109
  raise HTTPException(status_code=404, detail="Model doesn't exists")
110
  exact_match: ModelInfo | None = None
@@ -132,14 +133,12 @@ def segments_to_response(
132
  response_format: ResponseFormat,
133
  ) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
134
  segments = list(segments)
135
- if response_format == ResponseFormat.TEXT:
136
  return utils.segments_text(segments)
137
  elif response_format == ResponseFormat.JSON:
138
  return TranscriptionJsonResponse.from_segments(segments)
139
  elif response_format == ResponseFormat.VERBOSE_JSON:
140
- return TranscriptionVerboseJsonResponse.from_segments(
141
- segments, transcription_info
142
- )
143
 
144
 
145
  def format_as_sse(data: str) -> str:
@@ -156,26 +155,21 @@ def segments_to_streaming_response(
156
  if response_format == ResponseFormat.TEXT:
157
  data = segment.text
158
  elif response_format == ResponseFormat.JSON:
159
- data = TranscriptionJsonResponse.from_segments(
160
- [segment]
161
- ).model_dump_json()
162
  elif response_format == ResponseFormat.VERBOSE_JSON:
163
- data = TranscriptionVerboseJsonResponse.from_segment(
164
- segment, transcription_info
165
- ).model_dump_json()
166
  yield format_as_sse(data)
167
 
168
  return StreamingResponse(segment_responses(), media_type="text/event-stream")
169
 
170
 
171
  def handle_default_openai_model(model_name: str) -> str:
172
- """This exists because some callers may not be able override the default("whisper-1") model name.
 
173
  For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
174
  """
175
  if model_name == "whisper-1":
176
- logger.info(
177
- f"{model_name} is not a valid model name. Using {config.whisper.model} instead."
178
- )
179
  return config.whisper.model
180
  return model_name
181
 
@@ -194,12 +188,7 @@ def translate_file(
194
  response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
195
  temperature: Annotated[float, Form()] = 0.0,
196
  stream: Annotated[bool, Form()] = False,
197
- ) -> (
198
- str
199
- | TranscriptionJsonResponse
200
- | TranscriptionVerboseJsonResponse
201
- | StreamingResponse
202
- ):
203
  whisper = load_model(model)
204
  segments, transcription_info = whisper.transcribe(
205
  file.file,
@@ -210,9 +199,7 @@ def translate_file(
210
  )
211
 
212
  if stream:
213
- return segments_to_streaming_response(
214
- segments, transcription_info, response_format
215
- )
216
  else:
217
  return segments_to_response(segments, transcription_info, response_format)
218
 
@@ -231,16 +218,11 @@ def transcribe_file(
231
  response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
232
  temperature: Annotated[float, Form()] = 0.0,
233
  timestamp_granularities: Annotated[
234
- list[Literal["segment"] | Literal["word"]],
235
  Form(alias="timestamp_granularities[]"),
236
  ] = ["segment"],
237
  stream: Annotated[bool, Form()] = False,
238
- ) -> (
239
- str
240
- | TranscriptionJsonResponse
241
- | TranscriptionVerboseJsonResponse
242
- | StreamingResponse
243
- ):
244
  whisper = load_model(model)
245
  segments, transcription_info = whisper.transcribe(
246
  file.file,
@@ -253,9 +235,7 @@ def transcribe_file(
253
  )
254
 
255
  if stream:
256
- return segments_to_streaming_response(
257
- segments, transcription_info, response_format
258
- )
259
  else:
260
  return segments_to_response(segments, transcription_info, response_format)
261
 
@@ -263,39 +243,28 @@ def transcribe_file(
263
  async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
264
  try:
265
  while True:
266
- bytes_ = await asyncio.wait_for(
267
- ws.receive_bytes(), timeout=config.max_no_data_seconds
268
- )
269
  logger.debug(f"Received {len(bytes_)} bytes of audio data")
270
  audio_samples = audio_samples_from_file(BytesIO(bytes_))
271
  audio_stream.extend(audio_samples)
272
  if audio_stream.duration - config.inactivity_window_seconds >= 0:
273
- audio = audio_stream.after(
274
- audio_stream.duration - config.inactivity_window_seconds
275
- )
276
  vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
277
  # NOTE: This is a synchronous operation that runs every time new data is received.
278
- # This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato.
279
  timestamps = get_speech_timestamps(audio.data, vad_opts)
280
  if len(timestamps) == 0:
281
- logger.info(
282
- f"No speech detected in the last {config.inactivity_window_seconds} seconds."
283
- )
284
  break
285
  elif (
286
  # last speech end time
287
- config.inactivity_window_seconds
288
- - timestamps[-1]["end"] / SAMPLES_PER_SECOND
289
  >= config.max_inactivity_seconds
290
  ):
291
- logger.info(
292
- f"Not enough speech in the last {config.inactivity_window_seconds} seconds."
293
- )
294
  break
295
- except asyncio.TimeoutError:
296
- logger.info(
297
- f"No data received in {config.max_no_data_seconds} seconds. Closing the connection."
298
- )
299
  except WebSocketDisconnect as e:
300
  logger.info(f"Client disconnected: {e}")
301
  audio_stream.close()
@@ -306,9 +275,7 @@ async def transcribe_stream(
306
  ws: WebSocket,
307
  model: Annotated[ModelName, Query()] = config.whisper.model,
308
  language: Annotated[Language | None, Query()] = config.default_language,
309
- response_format: Annotated[
310
- ResponseFormat, Query()
311
- ] = config.default_response_format,
312
  temperature: Annotated[float, Query()] = 0.0,
313
  ) -> None:
314
  await ws.accept()
@@ -331,19 +298,11 @@ async def transcribe_stream(
331
  if response_format == ResponseFormat.TEXT:
332
  await ws.send_text(transcription.text)
333
  elif response_format == ResponseFormat.JSON:
334
- await ws.send_json(
335
- TranscriptionJsonResponse.from_transcription(
336
- transcription
337
- ).model_dump()
338
- )
339
  elif response_format == ResponseFormat.VERBOSE_JSON:
340
- await ws.send_json(
341
- TranscriptionVerboseJsonResponse.from_transcription(
342
- transcription
343
- ).model_dump()
344
- )
345
 
346
- if not ws.client_state == WebSocketState.DISCONNECTED:
347
  logger.info("Closing the connection.")
348
  await ws.close()
349
 
 
1
  from __future__ import annotations
2
 
3
  import asyncio
4
+ from collections import OrderedDict
5
  from io import BytesIO
6
+ import time
7
+ from typing import TYPE_CHECKING, Annotated, Literal
8
 
 
 
9
  from fastapi import (
10
  FastAPI,
11
  Form,
 
20
  from fastapi.responses import StreamingResponse
21
  from fastapi.websockets import WebSocketState
22
  from faster_whisper import WhisperModel
 
23
  from faster_whisper.vad import VadOptions, get_speech_timestamps
24
+ import gradio as gr
25
+ import huggingface_hub
26
  from pydantic import AfterValidator
27
 
28
  from faster_whisper_server import utils
 
44
  )
45
  from faster_whisper_server.transcriber import audio_transcriber
46
 
47
+ if TYPE_CHECKING:
48
+ from collections.abc import Generator, Iterable
49
+
50
+ from faster_whisper.transcribe import Segment, TranscriptionInfo
51
+ from huggingface_hub.hf_api import ModelInfo
52
+
53
  loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
54
 
55
 
 
59
  return loaded_models[model_name]
60
  if len(loaded_models) >= config.max_models:
61
  oldest_model_name = next(iter(loaded_models))
62
+ logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
 
 
63
  del loaded_models[oldest_model_name]
64
  logger.debug(f"Loading {model_name}...")
65
  start = time.perf_counter()
 
70
  compute_type=config.whisper.compute_type,
71
  )
72
  logger.info(
73
+ f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501
74
  )
75
  loaded_models[model_name] = whisper
76
  return whisper
 
105
  def get_model(
106
  model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
107
  ) -> ModelObject:
108
+ models = list(huggingface_hub.list_models(model_name=model_name, library="ctranslate2"))
 
 
109
  if len(models) == 0:
110
  raise HTTPException(status_code=404, detail="Model doesn't exists")
111
  exact_match: ModelInfo | None = None
 
133
  response_format: ResponseFormat,
134
  ) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
135
  segments = list(segments)
136
+ if response_format == ResponseFormat.TEXT: # noqa: RET503
137
  return utils.segments_text(segments)
138
  elif response_format == ResponseFormat.JSON:
139
  return TranscriptionJsonResponse.from_segments(segments)
140
  elif response_format == ResponseFormat.VERBOSE_JSON:
141
+ return TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info)
 
 
142
 
143
 
144
  def format_as_sse(data: str) -> str:
 
155
  if response_format == ResponseFormat.TEXT:
156
  data = segment.text
157
  elif response_format == ResponseFormat.JSON:
158
+ data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
 
 
159
  elif response_format == ResponseFormat.VERBOSE_JSON:
160
+ data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
 
 
161
  yield format_as_sse(data)
162
 
163
  return StreamingResponse(segment_responses(), media_type="text/event-stream")
164
 
165
 
166
  def handle_default_openai_model(model_name: str) -> str:
167
+ """Exists because some callers may not be able override the default("whisper-1") model name.
168
+
169
  For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
170
  """
171
  if model_name == "whisper-1":
172
+ logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
 
 
173
  return config.whisper.model
174
  return model_name
175
 
 
188
  response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
189
  temperature: Annotated[float, Form()] = 0.0,
190
  stream: Annotated[bool, Form()] = False,
191
+ ) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
 
 
 
 
 
192
  whisper = load_model(model)
193
  segments, transcription_info = whisper.transcribe(
194
  file.file,
 
199
  )
200
 
201
  if stream:
202
+ return segments_to_streaming_response(segments, transcription_info, response_format)
 
 
203
  else:
204
  return segments_to_response(segments, transcription_info, response_format)
205
 
 
218
  response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
219
  temperature: Annotated[float, Form()] = 0.0,
220
  timestamp_granularities: Annotated[
221
+ list[Literal["segment", "word"]],
222
  Form(alias="timestamp_granularities[]"),
223
  ] = ["segment"],
224
  stream: Annotated[bool, Form()] = False,
225
+ ) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
 
 
 
 
 
226
  whisper = load_model(model)
227
  segments, transcription_info = whisper.transcribe(
228
  file.file,
 
235
  )
236
 
237
  if stream:
238
+ return segments_to_streaming_response(segments, transcription_info, response_format)
 
 
239
  else:
240
  return segments_to_response(segments, transcription_info, response_format)
241
 
 
243
  async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
244
  try:
245
  while True:
246
+ bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
 
 
247
  logger.debug(f"Received {len(bytes_)} bytes of audio data")
248
  audio_samples = audio_samples_from_file(BytesIO(bytes_))
249
  audio_stream.extend(audio_samples)
250
  if audio_stream.duration - config.inactivity_window_seconds >= 0:
251
+ audio = audio_stream.after(audio_stream.duration - config.inactivity_window_seconds)
 
 
252
  vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
253
  # NOTE: This is a synchronous operation that runs every time new data is received.
254
+ # This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. # noqa: E501
255
  timestamps = get_speech_timestamps(audio.data, vad_opts)
256
  if len(timestamps) == 0:
257
+ logger.info(f"No speech detected in the last {config.inactivity_window_seconds} seconds.")
 
 
258
  break
259
  elif (
260
  # last speech end time
261
+ config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND
 
262
  >= config.max_inactivity_seconds
263
  ):
264
+ logger.info(f"Not enough speech in the last {config.inactivity_window_seconds} seconds.")
 
 
265
  break
266
+ except TimeoutError:
267
+ logger.info(f"No data received in {config.max_no_data_seconds} seconds. Closing the connection.")
 
 
268
  except WebSocketDisconnect as e:
269
  logger.info(f"Client disconnected: {e}")
270
  audio_stream.close()
 
275
  ws: WebSocket,
276
  model: Annotated[ModelName, Query()] = config.whisper.model,
277
  language: Annotated[Language | None, Query()] = config.default_language,
278
+ response_format: Annotated[ResponseFormat, Query()] = config.default_response_format,
 
 
279
  temperature: Annotated[float, Query()] = 0.0,
280
  ) -> None:
281
  await ws.accept()
 
298
  if response_format == ResponseFormat.TEXT:
299
  await ws.send_text(transcription.text)
300
  elif response_format == ResponseFormat.JSON:
301
+ await ws.send_json(TranscriptionJsonResponse.from_transcription(transcription).model_dump())
 
 
 
 
302
  elif response_format == ResponseFormat.VERBOSE_JSON:
303
+ await ws.send_json(TranscriptionVerboseJsonResponse.from_transcription(transcription).model_dump())
 
 
 
 
304
 
305
+ if ws.client_state != WebSocketState.DISCONNECTED:
306
  logger.info("Closing the connection.")
307
  await ws.close()
308
 
faster_whisper_server/server_models.py CHANGED
@@ -1,12 +1,15 @@
1
  from __future__ import annotations
2
 
3
- from typing import Literal
4
 
5
- from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
6
  from pydantic import BaseModel, ConfigDict, Field
7
 
8
  from faster_whisper_server import utils
9
- from faster_whisper_server.core import Transcription
 
 
 
 
10
 
11
 
12
  # https://platform.openai.com/docs/api-reference/audio/json-object
@@ -18,9 +21,7 @@ class TranscriptionJsonResponse(BaseModel):
18
  return cls(text=utils.segments_text(segments))
19
 
20
  @classmethod
21
- def from_transcription(
22
- cls, transcription: Transcription
23
- ) -> TranscriptionJsonResponse:
24
  return cls(text=transcription.text)
25
 
26
 
@@ -78,18 +79,12 @@ class TranscriptionVerboseJsonResponse(BaseModel):
78
  segments: list[SegmentObject]
79
 
80
  @classmethod
81
- def from_segment(
82
- cls, segment: Segment, transcription_info: TranscriptionInfo
83
- ) -> TranscriptionVerboseJsonResponse:
84
  return cls(
85
  language=transcription_info.language,
86
  duration=segment.end - segment.start,
87
  text=segment.text,
88
- words=(
89
- [WordObject.from_word(word) for word in segment.words]
90
- if isinstance(segment.words, list)
91
- else []
92
- ),
93
  segments=[SegmentObject.from_segment(segment)],
94
  )
95
 
@@ -102,16 +97,11 @@ class TranscriptionVerboseJsonResponse(BaseModel):
102
  duration=transcription_info.duration,
103
  text=utils.segments_text(segments),
104
  segments=[SegmentObject.from_segment(segment) for segment in segments],
105
- words=[
106
- WordObject.from_word(word)
107
- for word in utils.words_from_segments(segments)
108
- ],
109
  )
110
 
111
  @classmethod
112
- def from_transcription(
113
- cls, transcription: Transcription
114
- ) -> TranscriptionVerboseJsonResponse:
115
  return cls(
116
  language="english", # FIX: hardcoded
117
  duration=transcription.duration,
 
1
  from __future__ import annotations
2
 
3
+ from typing import TYPE_CHECKING, Literal
4
 
 
5
  from pydantic import BaseModel, ConfigDict, Field
6
 
7
  from faster_whisper_server import utils
8
+
9
+ if TYPE_CHECKING:
10
+ from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
11
+
12
+ from faster_whisper_server.core import Transcription
13
 
14
 
15
  # https://platform.openai.com/docs/api-reference/audio/json-object
 
21
  return cls(text=utils.segments_text(segments))
22
 
23
  @classmethod
24
+ def from_transcription(cls, transcription: Transcription) -> TranscriptionJsonResponse:
 
 
25
  return cls(text=transcription.text)
26
 
27
 
 
79
  segments: list[SegmentObject]
80
 
81
  @classmethod
82
+ def from_segment(cls, segment: Segment, transcription_info: TranscriptionInfo) -> TranscriptionVerboseJsonResponse:
 
 
83
  return cls(
84
  language=transcription_info.language,
85
  duration=segment.end - segment.start,
86
  text=segment.text,
87
+ words=([WordObject.from_word(word) for word in segment.words] if isinstance(segment.words, list) else []),
 
 
 
 
88
  segments=[SegmentObject.from_segment(segment)],
89
  )
90
 
 
97
  duration=transcription_info.duration,
98
  text=utils.segments_text(segments),
99
  segments=[SegmentObject.from_segment(segment) for segment in segments],
100
+ words=[WordObject.from_word(word) for word in utils.words_from_segments(segments)],
 
 
 
101
  )
102
 
103
  @classmethod
104
+ def from_transcription(cls, transcription: Transcription) -> TranscriptionVerboseJsonResponse:
 
 
105
  return cls(
106
  language="english", # FIX: hardcoded
107
  duration=transcription.duration,
faster_whisper_server/transcriber.py CHANGED
@@ -1,8 +1,7 @@
1
  from __future__ import annotations
2
 
3
- from typing import AsyncGenerator
4
 
5
- from faster_whisper_server.asr import FasterWhisperASR
6
  from faster_whisper_server.audio import Audio, AudioStream
7
  from faster_whisper_server.config import config
8
  from faster_whisper_server.core import (
@@ -13,6 +12,11 @@ from faster_whisper_server.core import (
13
  )
14
  from faster_whisper_server.logger import logger
15
 
 
 
 
 
 
16
 
17
  class LocalAgreement:
18
  def __init__(self) -> None:
 
1
  from __future__ import annotations
2
 
3
+ from typing import TYPE_CHECKING
4
 
 
5
  from faster_whisper_server.audio import Audio, AudioStream
6
  from faster_whisper_server.config import config
7
  from faster_whisper_server.core import (
 
12
  )
13
  from faster_whisper_server.logger import logger
14
 
15
+ if TYPE_CHECKING:
16
+ from collections.abc import AsyncGenerator
17
+
18
+ from faster_whisper_server.asr import FasterWhisperASR
19
+
20
 
21
  class LocalAgreement:
22
  def __init__(self) -> None:
pyproject.toml CHANGED
@@ -28,18 +28,35 @@ target-version = "py312"
28
  [tool.ruff.lint]
29
  select = ["ALL"]
30
  ignore = [
31
- "D10", # disabled required docstrings
 
32
  "ERA", # allow commented out code
33
- "TD", # disable TODO warnings
34
- "FIX002", # disable TODO warnings
35
 
 
 
 
 
 
36
  "COM812", # trailing comma
37
- "T201", # print
 
 
 
 
 
 
 
 
 
38
  "S101", # allow assert
39
- "PTH123", # Path.open
40
  "S603", # subprocess untrusted input
41
-
42
- "ANN101", # missing self type
 
 
 
43
  ]
44
 
45
  [tool.ruff.lint.isort]
 
28
  [tool.ruff.lint]
29
  select = ["ALL"]
30
  ignore = [
31
+ "FIX",
32
+ "TD", # disable todo warnings
33
  "ERA", # allow commented out code
34
+ "PTH",
 
35
 
36
+ "ANN003", # missing kwargs
37
+ "ANN101", # missing self type
38
+ "ANN102", # missing cls
39
+ "B006",
40
+ "B008",
41
  "COM812", # trailing comma
42
+ "D10", # disabled required docstrings
43
+ "D401",
44
+ "EM102",
45
+ "FBT001",
46
+ "FBT002",
47
+ "PLR0913",
48
+ "PLR2004", # magic
49
+ "RET504",
50
+ "RET505",
51
+ "RET508",
52
  "S101", # allow assert
53
+ "S104",
54
  "S603", # subprocess untrusted input
55
+ "SIM102",
56
+ "T201", # print
57
+ "TRY003",
58
+ "W505",
59
+ "ISC001" # recommended to disable for formatting
60
  ]
61
 
62
  [tool.ruff.lint.isort]
tests/api_model_test.py CHANGED
@@ -4,9 +4,7 @@ from faster_whisper_server.server_models import ModelObject
4
 
5
  MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en"
6
  MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist"
7
- MIN_EXPECTED_NUMBER_OF_MODELS = (
8
- 200 # At the time of the test creation there are 228 models
9
- )
10
 
11
 
12
  # HACK: because ModelObject(**data) doesn't work
@@ -19,20 +17,20 @@ def model_dict_to_object(model_dict: dict) -> ModelObject:
19
  )
20
 
21
 
22
- def test_list_models(client: TestClient):
23
  response = client.get("/v1/models")
24
  data = response.json()
25
  models = [model_dict_to_object(model_dict) for model_dict in data]
26
  assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS
27
 
28
 
29
- def test_model_exists(client: TestClient):
30
  response = client.get(f"/v1/models/{MODEL_THAT_EXISTS}")
31
  data = response.json()
32
  model = model_dict_to_object(data)
33
  assert model.id == MODEL_THAT_EXISTS
34
 
35
 
36
- def test_model_does_not_exist(client: TestClient):
37
  response = client.get(f"/v1/models/{MODEL_THAT_DOES_NOT_EXIST}")
38
  assert response.status_code == 404
 
4
 
5
  MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en"
6
  MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist"
7
+ MIN_EXPECTED_NUMBER_OF_MODELS = 200 # At the time of the test creation there are 228 models
 
 
8
 
9
 
10
  # HACK: because ModelObject(**data) doesn't work
 
17
  )
18
 
19
 
20
+ def test_list_models(client: TestClient) -> None:
21
  response = client.get("/v1/models")
22
  data = response.json()
23
  models = [model_dict_to_object(model_dict) for model_dict in data]
24
  assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS
25
 
26
 
27
+ def test_model_exists(client: TestClient) -> None:
28
  response = client.get(f"/v1/models/{MODEL_THAT_EXISTS}")
29
  data = response.json()
30
  model = model_dict_to_object(data)
31
  assert model.id == MODEL_THAT_EXISTS
32
 
33
 
34
+ def test_model_does_not_exist(client: TestClient) -> None:
35
  response = client.get(f"/v1/models/{MODEL_THAT_DOES_NOT_EXIST}")
36
  assert response.status_code == 404
tests/app_test.py CHANGED
@@ -1,10 +1,10 @@
 
1
  import json
2
  import os
3
  import time
4
- from typing import Generator
5
 
6
- import pytest
7
  from fastapi.testclient import TestClient
 
8
  from starlette.testclient import WebSocketTestSession
9
 
10
  from faster_whisper_server.config import BYTES_PER_SECOND
@@ -22,35 +22,31 @@ def ws(client: TestClient) -> Generator[WebSocketTestSession, None, None]:
22
  yield ws
23
 
24
 
25
- def get_audio_file_paths():
26
- file_paths = []
27
  directory = "tests/data"
28
  for filename in sorted(os.listdir(directory)[:AUDIO_FILES_LIMIT]):
29
- file_paths.append(os.path.join(directory, filename))
30
  return file_paths
31
 
32
 
33
  file_paths = get_audio_file_paths()
34
 
35
 
36
- def stream_audio_data(
37
- ws: WebSocketTestSession, data: bytes, *, chunk_size: int = 4000, speed: float = 1.0
38
- ):
39
  for i in range(0, len(data), chunk_size):
40
  ws.send_bytes(data[i : i + chunk_size])
41
  delay = len(data[i : i + chunk_size]) / BYTES_PER_SECOND / speed
42
  time.sleep(delay)
43
 
44
 
45
- def transcribe_audio_data(
46
- client: TestClient, data: bytes
47
- ) -> TranscriptionVerboseJsonResponse:
48
  response = client.post(
49
  TRANSCRIBE_ENDPOINT,
50
  files={"file": ("audio.raw", data, "audio/raw")},
51
  )
52
  data = json.loads(response.json()) # TODO: figure this out
53
- return TranscriptionVerboseJsonResponse(**data) # type: ignore
54
 
55
 
56
  # @pytest.mark.parametrize("file_path", file_paths)
@@ -60,7 +56,7 @@ def transcribe_audio_data(
60
  # with open(file_path, "rb") as file:
61
  # data = file.read()
62
  #
63
- # streaming_transcription: TranscriptionVerboseJsonResponse = None # type: ignore
64
  # thread = threading.Thread(
65
  # target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0}
66
  # )
 
1
+ from collections.abc import Generator
2
  import json
3
  import os
4
  import time
 
5
 
 
6
  from fastapi.testclient import TestClient
7
+ import pytest
8
  from starlette.testclient import WebSocketTestSession
9
 
10
  from faster_whisper_server.config import BYTES_PER_SECOND
 
22
  yield ws
23
 
24
 
25
+ def get_audio_file_paths() -> list[str]:
26
+ file_paths: list[str] = []
27
  directory = "tests/data"
28
  for filename in sorted(os.listdir(directory)[:AUDIO_FILES_LIMIT]):
29
+ file_paths.append(os.path.join(directory, filename)) # noqa: PERF401
30
  return file_paths
31
 
32
 
33
  file_paths = get_audio_file_paths()
34
 
35
 
36
+ def stream_audio_data(ws: WebSocketTestSession, data: bytes, *, chunk_size: int = 4000, speed: float = 1.0) -> None:
 
 
37
  for i in range(0, len(data), chunk_size):
38
  ws.send_bytes(data[i : i + chunk_size])
39
  delay = len(data[i : i + chunk_size]) / BYTES_PER_SECOND / speed
40
  time.sleep(delay)
41
 
42
 
43
+ def transcribe_audio_data(client: TestClient, data: bytes) -> TranscriptionVerboseJsonResponse:
 
 
44
  response = client.post(
45
  TRANSCRIBE_ENDPOINT,
46
  files={"file": ("audio.raw", data, "audio/raw")},
47
  )
48
  data = json.loads(response.json()) # TODO: figure this out
49
+ return TranscriptionVerboseJsonResponse(**data) # pyright: ignore[reportCallIssue]
50
 
51
 
52
  # @pytest.mark.parametrize("file_path", file_paths)
 
56
  # with open(file_path, "rb") as file:
57
  # data = file.read()
58
  #
59
+ # streaming_transcription: TranscriptionVerboseJsonResponse = None # type: ignore # noqa: PGH003
60
  # thread = threading.Thread(
61
  # target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0}
62
  # )
tests/conftest.py CHANGED
@@ -1,18 +1,15 @@
 
1
  import logging
2
- import os
3
- from typing import Generator
4
 
5
- import pytest
6
  from fastapi.testclient import TestClient
 
7
 
8
- # HACK
9
- os.environ["WHISPER_MODEL"] = "Systran/faster-whisper-tiny.en"
10
- from faster_whisper_server.main import app # noqa: E402
11
 
12
  disable_loggers = ["multipart.multipart", "faster_whisper"]
13
 
14
 
15
- def pytest_configure():
16
  for logger_name in disable_loggers:
17
  logger = logging.getLogger(logger_name)
18
  logger.disabled = True
 
1
+ from collections.abc import Generator
2
  import logging
 
 
3
 
 
4
  from fastapi.testclient import TestClient
5
+ import pytest
6
 
7
+ from faster_whisper_server.main import app
 
 
8
 
9
  disable_loggers = ["multipart.multipart", "faster_whisper"]
10
 
11
 
12
+ def pytest_configure() -> None:
13
  for logger_name in disable_loggers:
14
  logger = logging.getLogger(logger_name)
15
  logger.disabled = True
tests/sse_test.py CHANGED
@@ -1,9 +1,9 @@
1
  import json
2
  import os
3
 
4
- import pytest
5
  from fastapi.testclient import TestClient
6
  from httpx_sse import connect_sse
 
7
 
8
  from faster_whisper_server.server_models import (
9
  TranscriptionJsonResponse,
@@ -17,15 +17,11 @@ ENDPOINTS = [
17
  ]
18
 
19
 
20
- parameters = [
21
- (file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS
22
- ]
23
 
24
 
25
- @pytest.mark.parametrize("file_path,endpoint", parameters)
26
- def test_streaming_transcription_text(
27
- client: TestClient, file_path: str, endpoint: str
28
- ):
29
  extension = os.path.splitext(file_path)[1]
30
  with open(file_path, "rb") as f:
31
  data = f.read()
@@ -36,15 +32,11 @@ def test_streaming_transcription_text(
36
  with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
37
  for event in event_source.iter_sse():
38
  print(event)
39
- assert (
40
- len(event.data) > 1
41
- ) # HACK: 1 because of the space character that's always prepended
42
 
43
 
44
- @pytest.mark.parametrize("file_path,endpoint", parameters)
45
- def test_streaming_transcription_json(
46
- client: TestClient, file_path: str, endpoint: str
47
- ):
48
  extension = os.path.splitext(file_path)[1]
49
  with open(file_path, "rb") as f:
50
  data = f.read()
@@ -57,10 +49,8 @@ def test_streaming_transcription_json(
57
  TranscriptionJsonResponse(**json.loads(event.data))
58
 
59
 
60
- @pytest.mark.parametrize("file_path,endpoint", parameters)
61
- def test_streaming_transcription_verbose_json(
62
- client: TestClient, file_path: str, endpoint: str
63
- ):
64
  extension = os.path.splitext(file_path)[1]
65
  with open(file_path, "rb") as f:
66
  data = f.read()
 
1
  import json
2
  import os
3
 
 
4
  from fastapi.testclient import TestClient
5
  from httpx_sse import connect_sse
6
+ import pytest
7
 
8
  from faster_whisper_server.server_models import (
9
  TranscriptionJsonResponse,
 
17
  ]
18
 
19
 
20
+ parameters = [(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS]
 
 
21
 
22
 
23
+ @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
24
+ def test_streaming_transcription_text(client: TestClient, file_path: str, endpoint: str) -> None:
 
 
25
  extension = os.path.splitext(file_path)[1]
26
  with open(file_path, "rb") as f:
27
  data = f.read()
 
32
  with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
33
  for event in event_source.iter_sse():
34
  print(event)
35
+ assert len(event.data) > 1 # HACK: 1 because of the space character that's always prepended
 
 
36
 
37
 
38
+ @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
39
+ def test_streaming_transcription_json(client: TestClient, file_path: str, endpoint: str) -> None:
 
 
40
  extension = os.path.splitext(file_path)[1]
41
  with open(file_path, "rb") as f:
42
  data = f.read()
 
49
  TranscriptionJsonResponse(**json.loads(event.data))
50
 
51
 
52
+ @pytest.mark.parametrize(("file_path", "endpoint"), parameters)
53
+ def test_streaming_transcription_verbose_json(client: TestClient, file_path: str, endpoint: str) -> None:
 
 
54
  extension = os.path.splitext(file_path)[1]
55
  with open(file_path, "rb") as f:
56
  data = f.read()