ayaanzaveri commited on
Commit
52f8ce8
1 Parent(s): 4f0c4d6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -25
main.py CHANGED
@@ -1,24 +1,34 @@
 
1
  from faster_whisper import WhisperModel
2
- from fastapi import FastAPI
3
- from video import download_convert_video_to_audio
4
  import yt_dlp
5
  import uuid
6
  import os
7
- from fastapi.middleware.cors import CORSMiddleware
8
 
9
- app = FastAPI()
10
 
11
- app.add_middleware(
12
- CORSMiddleware,
13
- allow_origins=["*"],
14
- allow_credentials=True,
15
- allow_methods=["*"],
16
- allow_headers=["*"],
17
- )
18
-
19
- # or run on GPU with INT8
20
- # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
21
- # or run on CPU with INT8
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def segment_to_dict(segment):
24
  segment = segment._asdict()
@@ -26,28 +36,30 @@ def segment_to_dict(segment):
26
  segment["words"] = [word._asdict() for word in segment["words"]]
27
  return segment
28
 
29
- @app.post("/video")
30
- async def download_video(video_url: str):
31
- download_convert_video_to_audio(yt_dlp, video_url, f"/home/user/{uuid.uuid4().hex}")
32
 
33
- @app.post("/transcribe")
34
- async def transcribe_video(video_url: str, beam_size: int = 5, model_size: str = "tiny", word_timestamps: bool = True):
35
  print("loading model")
36
  model = WhisperModel(model_size, device="cpu", compute_type="int8")
37
  print("getting hex")
38
  rand_id = uuid.uuid4().hex
39
  print("doing download")
40
- download_convert_video_to_audio(yt_dlp, video_url, f"/home/user/{rand_id}")
41
- segments, info = model.transcribe(f"/home/user/{rand_id}.mp3", beam_size=beam_size, word_timestamps=word_timestamps)
42
  segments = [segment_to_dict(segment) for segment in segments]
43
  total_duration = round(info.duration, 2) # Same precision as the Whisper timestamps.
44
  print(info)
45
- os.remove(f"/home/user/{rand_id}.mp3")
46
  print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
47
-
48
  return segments
49
 
50
  # print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
51
 
52
  # for segment in segments:
53
- # print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
 
 
 
 
 
1
+ import pathlib
2
  from faster_whisper import WhisperModel
 
 
3
  import yt_dlp
4
  import uuid
5
  import os
6
+ import gradio as gr
7
 
 
8
 
9
+ # List of all supported video sites here https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md
10
+ def download_convert_video_to_audio(
11
+ yt_dlp,
12
+ video_url: str,
13
+ destination_path: pathlib.Path,
14
+ ) -> None:
15
+ ydl_opts = {
16
+ "format": "bestaudio/best",
17
+ "postprocessors": [
18
+ { # Extract audio using ffmpeg
19
+ "key": "FFmpegExtractAudio",
20
+ "preferredcodec": "mp3",
21
+ }
22
+ ],
23
+ "outtmpl": f"{destination_path}.%(ext)s",
24
+ }
25
+ try:
26
+ print(f"Downloading video from {video_url}")
27
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
28
+ ydl.download(video_url)
29
+ print(f"Downloaded video from {video_url} to {destination_path}")
30
+ except Exception as e:
31
+ raise (e)
32
 
33
  def segment_to_dict(segment):
34
  segment = segment._asdict()
 
36
  segment["words"] = [word._asdict() for word in segment["words"]]
37
  return segment
38
 
39
+ def download_video(video_url: str):
40
+ download_convert_video_to_audio(yt_dlp, video_url, f"/content/{uuid.uuid4().hex}")
 
41
 
42
+ def transcribe_video(video_url: str, beam_size: int = 5, model_size: str = "tiny", word_timestamps: bool = True):
 
43
  print("loading model")
44
  model = WhisperModel(model_size, device="cpu", compute_type="int8")
45
  print("getting hex")
46
  rand_id = uuid.uuid4().hex
47
  print("doing download")
48
+ download_convert_video_to_audio(yt_dlp, video_url, f"/content/{rand_id}")
49
+ segments, info = model.transcribe(f"/content/{rand_id}.mp3", beam_size=beam_size, word_timestamps=word_timestamps)
50
  segments = [segment_to_dict(segment) for segment in segments]
51
  total_duration = round(info.duration, 2) # Same precision as the Whisper timestamps.
52
  print(info)
53
+ os.remove(f"/content/{rand_id}.mp3")
54
  print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
55
+ print(segments)
56
  return segments
57
 
58
  # print("Detected language '%s' with probability %f" % (info.language, info.language_probability))
59
 
60
  # for segment in segments:
61
+ # print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
62
+
63
+ demo = gr.Interface(fn=transcribe_video, inputs="text", outputs="text")
64
+
65
+ demo.launch()