Kr08 commited on
Commit
e564472
·
verified ·
1 Parent(s): 353faef

Implemented Chunking

Browse files
Files changed (1) hide show
  1. app.py +42 -21
app.py CHANGED
@@ -3,6 +3,7 @@ import pickle
3
  import whisper
4
  import streamlit as st
5
  import torchaudio as ta
 
6
 
7
  from io import BytesIO
8
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
@@ -12,10 +13,11 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float16 if device == "cuda:0" else torch.float32
13
 
14
  SAMPLING_RATE = 16000
 
15
 
16
  # Load Whisper model and processor
17
  processor = WhisperProcessor.from_pretrained("openai/whisper-small")
18
- model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
19
 
20
  # Title of the app
21
  st.title("Audio Player with Live Transcription")
@@ -36,18 +38,42 @@ if 'audio_files' not in st.session_state:
36
 
37
  def detect_language(audio_file):
38
  whisper_model = whisper.load_model("small")
39
- trimmed_audio = whisper.pad_or_trim(audio_file)
40
  mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device)
41
- _, probs = whisper_model.detect_language(mel[0])
42
- detected_lang = max(probs, key=probs.get)
43
  print(f"Detected language: {detected_lang}")
44
  return detected_lang
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Process uploaded files
48
  if submit_button and uploaded_files is not None:
49
  st.session_state.audio_files = uploaded_files
50
  st.session_state.detected_languages = []
 
51
 
52
  for uploaded_file in uploaded_files:
53
  waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read()))
@@ -69,30 +95,25 @@ if 'audio_files' in st.session_state and st.session_state.audio_files:
69
  st.write(f"**Detected Language**: {st.session_state.detected_languages[i]}")
70
 
71
  with col2:
72
- # import pdb;pdb.set_trace()
73
- input_features = processor(st.session_state.waveforms[i][0], sampling_rate=SAMPLING_RATE, return_tensors='pt').input_features
74
-
75
  if st.button(f"Transcribe {uploaded_file.name}"):
76
- predicted_ids = model.generate(input_features)
77
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
78
- st.session_state.transcriptions[i] = transcription
79
 
80
  if st.session_state.transcriptions.get(i):
81
  st.write("**Transcription**:")
82
- for line in st.session_state.transcriptions[i]:
83
- st.write(line)
84
 
85
  if st.button(f"Translate {uploaded_file.name}"):
86
- with open('languages.pkl', 'rb') as f:
87
- lang_dict = pickle.load(f)
88
- detected_language_name = lang_dict[st.session_state.detected_languages[i]]
 
89
 
90
- forced_decoder_ids = processor.get_decoder_prompt_ids(language=detected_language_name, task="translate")
91
- predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
92
- translation = processor.batch_decode(predicted_ids, skip_special_tokens=True)
93
- st.session_state.translations[i] = translation
94
 
95
  if st.session_state.translations.get(i):
96
  st.write("**Translation**:")
97
- for line in st.session_state.translations[i]:
98
- st.write(line)
 
3
  import whisper
4
  import streamlit as st
5
  import torchaudio as ta
6
+ import numpy as np
7
 
8
  from io import BytesIO
9
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
 
13
  torch_dtype = torch.float16 if device == "cuda:0" else torch.float32
14
 
15
  SAMPLING_RATE = 16000
16
+ CHUNK_LENGTH_S = 20 # 30 seconds per chunk
17
 
18
  # Load Whisper model and processor
19
  processor = WhisperProcessor.from_pretrained("openai/whisper-small")
20
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
21
 
22
  # Title of the app
23
  st.title("Audio Player with Live Transcription")
 
38
 
39
  def detect_language(audio_file):
40
  whisper_model = whisper.load_model("small")
41
+ trimmed_audio = whisper.pad_or_trim(audio_file.squeeze())
42
  mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device)
43
+ _, probs = whisper_model.detect_language(mel)
44
+ detected_lang = max(probs[0], key=probs[0].get)
45
  print(f"Detected language: {detected_lang}")
46
  return detected_lang
47
 
48
 
49
+ def process_long_audio(waveform, sampling_rate, task="transcribe", language=None):
50
+ input_length = waveform.shape[1]
51
+ chunk_length = int(CHUNK_LENGTH_S * sampling_rate)
52
+ chunks = [waveform[:, i:i + chunk_length] for i in range(0, input_length, chunk_length)]
53
+
54
+ results = []
55
+ for chunk in chunks:
56
+ # import pdb;pdb.set_trace()
57
+ input_features = processor(chunk[0], sampling_rate=sampling_rate, return_tensors="pt").input_features.to(device)
58
+
59
+ with torch.no_grad():
60
+ if task == "translate":
61
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate")
62
+ generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
63
+ else:
64
+ generated_ids = model.generate(input_features)
65
+
66
+ transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
67
+ results.extend(transcription)
68
+
69
+ return " ".join(results)
70
+
71
+
72
  # Process uploaded files
73
  if submit_button and uploaded_files is not None:
74
  st.session_state.audio_files = uploaded_files
75
  st.session_state.detected_languages = []
76
+ st.session_state.waveforms = []
77
 
78
  for uploaded_file in uploaded_files:
79
  waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read()))
 
95
  st.write(f"**Detected Language**: {st.session_state.detected_languages[i]}")
96
 
97
  with col2:
 
 
 
98
  if st.button(f"Transcribe {uploaded_file.name}"):
99
+ with st.spinner("Transcribing..."):
100
+ transcription = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE)
101
+ st.session_state.transcriptions[i] = transcription
102
 
103
  if st.session_state.transcriptions.get(i):
104
  st.write("**Transcription**:")
105
+ st.write(st.session_state.transcriptions[i])
 
106
 
107
  if st.button(f"Translate {uploaded_file.name}"):
108
+ with st.spinner("Translating..."):
109
+ with open('languages.pkl', 'rb') as f:
110
+ lang_dict = pickle.load(f)
111
+ detected_language_name = lang_dict[st.session_state.detected_languages[i]]
112
 
113
+ translation = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE, task="translate",
114
+ language=detected_language_name)
115
+ st.session_state.translations[i] = translation
 
116
 
117
  if st.session_state.translations.get(i):
118
  st.write("**Translation**:")
119
+ st.write(st.session_state.translations[i])