Fedir Zadniprovskyi commited on
Commit
323aa51
·
1 Parent(s): 2a79f48

feat: handle srt and vtt response formats

Browse files
faster_whisper_server/config.py CHANGED
@@ -15,35 +15,8 @@ class ResponseFormat(enum.StrEnum):
15
  TEXT = "text"
16
  JSON = "json"
17
  VERBOSE_JSON = "verbose_json"
18
- # NOTE: While inspecting outputs of these formats with `curl`, I noticed there's one or two "\n" inserted at the end of the response. # noqa: E501
19
-
20
- # VTT = "vtt" # TODO
21
- # 1
22
- # 00:00:00,000 --> 00:00:09,220
23
- # In his video on Large Language Models or LLMs, OpenAI co-founder and YouTuber Andrej Karpathy
24
- #
25
- # 2
26
- # 00:00:09,220 --> 00:00:12,280
27
- # likened LLMs to operating systems.
28
- #
29
- # 3
30
- # 00:00:12,280 --> 00:00:13,280
31
- # Karpathy said,
32
- #
33
- # SRT = "srt" # TODO
34
- # WEBVTT
35
- #
36
- # 00:00:00.000 --> 00:00:09.220
37
- # In his video on Large Language Models or LLMs, OpenAI co-founder and YouTuber Andrej Karpathy
38
- #
39
- # 00:00:09.220 --> 00:00:12.280
40
- # likened LLMs to operating systems.
41
- #
42
- # 00:00:12.280 --> 00:00:13.280
43
- # Karpathy said,
44
- #
45
- # 00:00:13.280 --> 00:00:19.799
46
- # I see a lot of equivalence between this new LLM OS and operating systems of today.
47
 
48
 
49
  class Device(enum.StrEnum):
 
15
  TEXT = "text"
16
  JSON = "json"
17
  VERBOSE_JSON = "verbose_json"
18
+ SRT = "srt"
19
+ VTT = "vtt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  class Device(enum.StrEnum):
faster_whisper_server/core.py CHANGED
@@ -172,6 +172,62 @@ def segments_to_text(segments: Iterable[Segment]) -> str:
172
  return "".join(segment.text for segment in segments).strip()
173
 
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def canonicalize_word(text: str) -> str:
176
  text = text.lower()
177
  # Remove non-alphabetic characters using regular expression
 
172
  return "".join(segment.text for segment in segments).strip()
173
 
174
 
175
+ def srt_format_timestamp(ts: float) -> str:
176
+ hours = ts // 3600
177
+ minutes = (ts % 3600) // 60
178
+ seconds = ts % 60
179
+ milliseconds = (ts * 1000) % 1000
180
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
181
+
182
+
183
+ def test_srt_format_timestamp() -> None:
184
+ assert srt_format_timestamp(0.0) == "00:00:00,000"
185
+ assert srt_format_timestamp(1.0) == "00:00:01,000"
186
+ assert srt_format_timestamp(1.234) == "00:00:01,234"
187
+ assert srt_format_timestamp(60.0) == "00:01:00,000"
188
+ assert srt_format_timestamp(61.0) == "00:01:01,000"
189
+ assert srt_format_timestamp(61.234) == "00:01:01,234"
190
+ assert srt_format_timestamp(3600.0) == "01:00:00,000"
191
+ assert srt_format_timestamp(3601.0) == "01:00:01,000"
192
+ assert srt_format_timestamp(3601.234) == "01:00:01,234"
193
+ assert srt_format_timestamp(23423.4234) == "06:30:23,423"
194
+
195
+
196
+ def vtt_format_timestamp(ts: float) -> str:
197
+ hours = ts // 3600
198
+ minutes = (ts % 3600) // 60
199
+ seconds = ts % 60
200
+ milliseconds = (ts * 1000) % 1000
201
+ return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
202
+
203
+
204
+ def test_vtt_format_timestamp() -> None:
205
+ assert vtt_format_timestamp(0.0) == "00:00:00.000"
206
+ assert vtt_format_timestamp(1.0) == "00:00:01.000"
207
+ assert vtt_format_timestamp(1.234) == "00:00:01.234"
208
+ assert vtt_format_timestamp(60.0) == "00:01:00.000"
209
+ assert vtt_format_timestamp(61.0) == "00:01:01.000"
210
+ assert vtt_format_timestamp(61.234) == "00:01:01.234"
211
+ assert vtt_format_timestamp(3600.0) == "01:00:00.000"
212
+ assert vtt_format_timestamp(3601.0) == "01:00:01.000"
213
+ assert vtt_format_timestamp(3601.234) == "01:00:01.234"
214
+ assert vtt_format_timestamp(23423.4234) == "06:30:23.423"
215
+
216
+
217
+ def segments_to_vtt(segment: Segment, i: int) -> str:
218
+ start = segment.start if i > 0 else 0.0
219
+ result = f"{vtt_format_timestamp(start)} --> {vtt_format_timestamp(segment.end)}\n{segment.text}\n\n"
220
+
221
+ if i == 0:
222
+ return f"WEBVTT\n\n{result}"
223
+ else:
224
+ return result
225
+
226
+
227
+ def segments_to_srt(segment: Segment, i: int) -> str:
228
+ return f"{i + 1}\n{srt_format_timestamp(segment.start)} --> {srt_format_timestamp(segment.end)}\n{segment.text}\n\n"
229
+
230
+
231
  def canonicalize_word(text: str) -> str:
232
  text = text.lower()
233
  # Remove non-alphabetic characters using regular expression
faster_whisper_server/main.py CHANGED
@@ -33,7 +33,7 @@ from faster_whisper_server.config import (
33
  Task,
34
  config,
35
  )
36
- from faster_whisper_server.core import Segment, segments_to_text
37
  from faster_whisper_server.logger import logger
38
  from faster_whisper_server.server_models import (
39
  ModelListResponse,
@@ -154,14 +154,28 @@ def segments_to_response(
154
  segments: Iterable[Segment],
155
  transcription_info: TranscriptionInfo,
156
  response_format: ResponseFormat,
157
- ) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
158
  segments = list(segments)
159
  if response_format == ResponseFormat.TEXT: # noqa: RET503
160
- return segments_to_text(segments)
161
  elif response_format == ResponseFormat.JSON:
162
- return TranscriptionJsonResponse.from_segments(segments)
 
 
 
163
  elif response_format == ResponseFormat.VERBOSE_JSON:
164
- return TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info)
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  def format_as_sse(data: str) -> str:
@@ -174,13 +188,17 @@ def segments_to_streaming_response(
174
  response_format: ResponseFormat,
175
  ) -> StreamingResponse:
176
  def segment_responses() -> Generator[str, None, None]:
177
- for segment in segments:
178
  if response_format == ResponseFormat.TEXT:
179
  data = segment.text
180
  elif response_format == ResponseFormat.JSON:
181
  data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
182
  elif response_format == ResponseFormat.VERBOSE_JSON:
183
  data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
 
 
 
 
184
  yield format_as_sse(data)
185
 
186
  return StreamingResponse(segment_responses(), media_type="text/event-stream")
@@ -211,7 +229,7 @@ def translate_file(
211
  response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
212
  temperature: Annotated[float, Form()] = 0.0,
213
  stream: Annotated[bool, Form()] = False,
214
- ) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
215
  whisper = load_model(model)
216
  segments, transcription_info = whisper.transcribe(
217
  file.file,
@@ -247,7 +265,7 @@ def transcribe_file(
247
  ] = ["segment"],
248
  stream: Annotated[bool, Form()] = False,
249
  hotwords: Annotated[str | None, Form()] = None,
250
- ) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
251
  whisper = load_model(model)
252
  segments, transcription_info = whisper.transcribe(
253
  file.file,
 
33
  Task,
34
  config,
35
  )
36
+ from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
37
  from faster_whisper_server.logger import logger
38
  from faster_whisper_server.server_models import (
39
  ModelListResponse,
 
154
  segments: Iterable[Segment],
155
  transcription_info: TranscriptionInfo,
156
  response_format: ResponseFormat,
157
+ ) -> Response:
158
  segments = list(segments)
159
  if response_format == ResponseFormat.TEXT: # noqa: RET503
160
+ return Response(segments_to_text(segments), media_type="text/plain")
161
  elif response_format == ResponseFormat.JSON:
162
+ return Response(
163
+ TranscriptionJsonResponse.from_segments(segments).model_dump_json(),
164
+ media_type="application/json",
165
+ )
166
  elif response_format == ResponseFormat.VERBOSE_JSON:
167
+ return Response(
168
+ TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info).model_dump_json(),
169
+ media_type="application/json",
170
+ )
171
+ elif response_format == ResponseFormat.VTT:
172
+ return Response(
173
+ "".join(segments_to_vtt(segment, i) for i, segment in enumerate(segments)), media_type="text/vtt"
174
+ )
175
+ elif response_format == ResponseFormat.SRT:
176
+ return Response(
177
+ "".join(segments_to_srt(segment, i) for i, segment in enumerate(segments)), media_type="text/plain"
178
+ )
179
 
180
 
181
  def format_as_sse(data: str) -> str:
 
188
  response_format: ResponseFormat,
189
  ) -> StreamingResponse:
190
  def segment_responses() -> Generator[str, None, None]:
191
+ for i, segment in enumerate(segments):
192
  if response_format == ResponseFormat.TEXT:
193
  data = segment.text
194
  elif response_format == ResponseFormat.JSON:
195
  data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
196
  elif response_format == ResponseFormat.VERBOSE_JSON:
197
  data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
198
+ elif response_format == ResponseFormat.VTT:
199
+ data = segments_to_vtt(segment, i)
200
+ elif response_format == ResponseFormat.SRT:
201
+ data = segments_to_srt(segment, i)
202
  yield format_as_sse(data)
203
 
204
  return StreamingResponse(segment_responses(), media_type="text/event-stream")
 
229
  response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
230
  temperature: Annotated[float, Form()] = 0.0,
231
  stream: Annotated[bool, Form()] = False,
232
+ ) -> Response | StreamingResponse:
233
  whisper = load_model(model)
234
  segments, transcription_info = whisper.transcribe(
235
  file.file,
 
265
  ] = ["segment"],
266
  stream: Annotated[bool, Form()] = False,
267
  hotwords: Annotated[str | None, Form()] = None,
268
+ ) -> Response | StreamingResponse:
269
  whisper = load_model(model)
270
  segments, transcription_info = whisper.transcribe(
271
  file.file,
pyproject.toml CHANGED
@@ -18,7 +18,7 @@ dependencies = [
18
  ]
19
 
20
  [project.optional-dependencies]
21
- dev = ["ruff==0.5.3", "pytest", "basedpyright==1.13.0", "pytest-xdist"]
22
 
23
  other = ["youtube-dl @ git+https://github.com/ytdl-org/youtube-dl.git@37cea84f775129ad715b9bcd617251c831fcc980", "aider-chat==0.39.0"]
24
 
 
18
  ]
19
 
20
  [project.optional-dependencies]
21
+ dev = ["ruff==0.5.3", "pytest", "webvtt-py", "srt", "basedpyright==1.13.0", "pytest-xdist"]
22
 
23
  other = ["youtube-dl @ git+https://github.com/ytdl-org/youtube-dl.git@37cea84f775129ad715b9bcd617251c831fcc980", "aider-chat==0.39.0"]
24
 
requirements-all.txt CHANGED
@@ -496,7 +496,7 @@ scipy==1.13.1
496
  # via aider-chat
497
  semantic-version==2.10.0
498
  # via gradio
499
- setuptools==71.0.3
500
  # via ctranslate2
501
  shellingham==1.5.4
502
  # via typer
@@ -524,11 +524,13 @@ soupsieve==2.5
524
  # via
525
  # aider-chat
526
  # beautifulsoup4
 
 
527
  starlette==0.37.2
528
  # via fastapi
529
  streamlit==1.35.0
530
  # via aider-chat
531
- sympy==1.13.0
532
  # via onnxruntime
533
  tenacity==8.3.0
534
  # via
@@ -623,6 +625,8 @@ websockets==11.0.3
623
  # via
624
  # gradio-client
625
  # uvicorn
 
 
626
  yarl==1.9.4
627
  # via
628
  # aider-chat
 
496
  # via aider-chat
497
  semantic-version==2.10.0
498
  # via gradio
499
+ setuptools==71.0.4
500
  # via ctranslate2
501
  shellingham==1.5.4
502
  # via typer
 
524
  # via
525
  # aider-chat
526
  # beautifulsoup4
527
+ srt==3.5.3
528
+ # via faster-whisper-server (pyproject.toml)
529
  starlette==0.37.2
530
  # via fastapi
531
  streamlit==1.35.0
532
  # via aider-chat
533
+ sympy==1.13.1
534
  # via onnxruntime
535
  tenacity==8.3.0
536
  # via
 
625
  # via
626
  # gradio-client
627
  # uvicorn
628
+ webvtt-py==0.5.1
629
+ # via faster-whisper-server (pyproject.toml)
630
  yarl==1.9.4
631
  # via
632
  # aider-chat
requirements-dev.txt CHANGED
@@ -146,7 +146,7 @@ numpy==1.26.4
146
  # pandas
147
  onnxruntime==1.18.1
148
  # via faster-whisper
149
- openai==1.35.15
150
  # via faster-whisper-server (pyproject.toml)
151
  orjson==3.10.6
152
  # via gradio
@@ -235,7 +235,7 @@ ruff==0.5.3
235
  # gradio
236
  semantic-version==2.10.0
237
  # via gradio
238
- setuptools==71.0.3
239
  # via ctranslate2
240
  shellingham==1.5.4
241
  # via typer
@@ -248,9 +248,11 @@ sniffio==1.3.1
248
  # openai
249
  soundfile==0.12.1
250
  # via faster-whisper-server (pyproject.toml)
 
 
251
  starlette==0.37.2
252
  # via fastapi
253
- sympy==1.13.0
254
  # via onnxruntime
255
  tokenizers==0.19.1
256
  # via faster-whisper
@@ -295,3 +297,5 @@ websockets==11.0.3
295
  # via
296
  # gradio-client
297
  # uvicorn
 
 
 
146
  # pandas
147
  onnxruntime==1.18.1
148
  # via faster-whisper
149
+ openai==1.36.0
150
  # via faster-whisper-server (pyproject.toml)
151
  orjson==3.10.6
152
  # via gradio
 
235
  # gradio
236
  semantic-version==2.10.0
237
  # via gradio
238
+ setuptools==71.0.4
239
  # via ctranslate2
240
  shellingham==1.5.4
241
  # via typer
 
248
  # openai
249
  soundfile==0.12.1
250
  # via faster-whisper-server (pyproject.toml)
251
+ srt==3.5.3
252
+ # via faster-whisper-server (pyproject.toml)
253
  starlette==0.37.2
254
  # via fastapi
255
+ sympy==1.13.1
256
  # via onnxruntime
257
  tokenizers==0.19.1
258
  # via faster-whisper
 
297
  # via
298
  # gradio-client
299
  # uvicorn
300
+ webvtt-py==0.5.1
301
+ # via faster-whisper-server (pyproject.toml)
requirements.txt CHANGED
@@ -138,7 +138,7 @@ numpy==1.26.4
138
  # pandas
139
  onnxruntime==1.18.1
140
  # via faster-whisper
141
- openai==1.35.15
142
  # via faster-whisper-server (pyproject.toml)
143
  orjson==3.10.6
144
  # via gradio
@@ -216,7 +216,7 @@ ruff==0.5.3
216
  # via gradio
217
  semantic-version==2.10.0
218
  # via gradio
219
- setuptools==71.0.3
220
  # via ctranslate2
221
  shellingham==1.5.4
222
  # via typer
@@ -231,7 +231,7 @@ soundfile==0.12.1
231
  # via faster-whisper-server (pyproject.toml)
232
  starlette==0.37.2
233
  # via fastapi
234
- sympy==1.13.0
235
  # via onnxruntime
236
  tokenizers==0.19.1
237
  # via faster-whisper
 
138
  # pandas
139
  onnxruntime==1.18.1
140
  # via faster-whisper
141
+ openai==1.36.0
142
  # via faster-whisper-server (pyproject.toml)
143
  orjson==3.10.6
144
  # via gradio
 
216
  # via gradio
217
  semantic-version==2.10.0
218
  # via gradio
219
+ setuptools==71.0.4
220
  # via ctranslate2
221
  shellingham==1.5.4
222
  # via typer
 
231
  # via faster-whisper-server (pyproject.toml)
232
  starlette==0.37.2
233
  # via fastapi
234
+ sympy==1.13.1
235
  # via onnxruntime
236
  tokenizers==0.19.1
237
  # via faster-whisper
tests/conftest.py CHANGED
@@ -1,10 +1,12 @@
1
  from collections.abc import Generator
2
  import logging
 
3
 
4
  from fastapi.testclient import TestClient
5
  from openai import OpenAI
6
  import pytest
7
 
 
8
  from faster_whisper_server.main import app
9
 
10
  disable_loggers = ["multipart.multipart", "faster_whisper"]
 
1
  from collections.abc import Generator
2
  import logging
3
+ import os
4
 
5
  from fastapi.testclient import TestClient
6
  from openai import OpenAI
7
  import pytest
8
 
9
+ os.environ["WHISPER__MODEL"] = "Systran/faster-whisper-tiny.en"
10
  from faster_whisper_server.main import app
11
 
12
  disable_loggers = ["multipart.multipart", "faster_whisper"]
tests/sse_test.py CHANGED
@@ -4,6 +4,9 @@ import os
4
  from fastapi.testclient import TestClient
5
  from httpx_sse import connect_sse
6
  import pytest
 
 
 
7
 
8
  from faster_whisper_server.server_models import (
9
  TranscriptionJsonResponse,
@@ -61,3 +64,38 @@ def test_streaming_transcription_verbose_json(client: TestClient, file_path: str
61
  with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
62
  for event in event_source.iter_sse():
63
  TranscriptionVerboseJsonResponse(**json.loads(event.data))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from fastapi.testclient import TestClient
5
  from httpx_sse import connect_sse
6
  import pytest
7
+ import srt
8
+ import webvtt
9
+ import webvtt.vtt
10
 
11
  from faster_whisper_server.server_models import (
12
  TranscriptionJsonResponse,
 
64
  with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
65
  for event in event_source.iter_sse():
66
  TranscriptionVerboseJsonResponse(**json.loads(event.data))
67
+
68
+
69
+ def test_transcription_vtt(client: TestClient) -> None:
70
+ with open("audio.wav", "rb") as f:
71
+ data = f.read()
72
+ kwargs = {
73
+ "files": {"file": ("audio.wav", data, "audio/wav")},
74
+ "data": {"response_format": "vtt", "stream": False},
75
+ }
76
+ response = client.post("/v1/audio/transcriptions", **kwargs)
77
+ assert response.status_code == 200
78
+ assert response.headers["content-type"] == "text/vtt; charset=utf-8"
79
+ text = response.text
80
+ webvtt.from_string(text)
81
+ text = text.replace("WEBVTT", "YO")
82
+ with pytest.raises(webvtt.vtt.MalformedFileError):
83
+ webvtt.from_string(text)
84
+
85
+
86
+ def test_transcription_srt(client: TestClient) -> None:
87
+ with open("audio.wav", "rb") as f:
88
+ data = f.read()
89
+ kwargs = {
90
+ "files": {"file": ("audio.wav", data, "audio/wav")},
91
+ "data": {"response_format": "srt", "stream": False},
92
+ }
93
+ response = client.post("/v1/audio/transcriptions", **kwargs)
94
+ assert response.status_code == 200
95
+ assert "text/plain" in response.headers["content-type"]
96
+
97
+ text = response.text
98
+ list(srt.parse(text))
99
+ text = text.replace("1", "YO")
100
+ with pytest.raises(srt.SRTParseError):
101
+ list(srt.parse(text))