zhenxuan commited on
Commit
042a25e
1 Parent(s): b6632e1

create demo

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ os.system('pip install git+https://github.com/openai/whisper.git')
4
+ import gradio as gr
5
+ import wave
6
+ import whisper
7
+ import logging
8
+ import torchaudio
9
+ import torchaudio.functional as F
10
+
11
+ LOGGING_FORMAT = '%(asctime)s %(message)s'
12
+ logging.basicConfig(format=LOGGING_FORMAT,level=logging.INFO)
13
+
14
+ REC_INTERVAL_IN_SECONDS = 3
15
+
16
+ # tmp dir to store audio files.
17
+ if not os.path.isdir('./tmp/'):
18
+ os.mkdir('./tmp')
19
+
20
+ class WhisperStreaming():
21
+ def __init__(self, model_name='base', language='en', fp16=False):
22
+ self.model_name = model_name
23
+ self.language = language
24
+ self.fp16 = fp16
25
+ self.whisper_model = whisper.load_model(f'{model_name}.{language}')
26
+ self.decode_option = whisper.DecodingOptions(language=self.language,
27
+ without_timestamps=True,
28
+ fp16=self.fp16)
29
+ self.whisper_sample_rate = 16000
30
+
31
+ def transcribe_audio_file(self, wave_file_path):
32
+ waveform, sample_rate = torchaudio.load(wave_file_path)
33
+ resampled_waveform = F.resample(waveform, sample_rate, self.whisper_sample_rate, lowpass_filter_width=6)
34
+ audio_tmp = whisper.pad_or_trim(resampled_waveform[0])
35
+ mel = whisper.log_mel_spectrogram(audio_tmp)
36
+ results = self.whisper_model.decode(mel, self.decode_option)
37
+ return results
38
+
39
+ def concat_multiple_wav_files(wav_files):
40
+ logging.info(f'Concat {wav_files}')
41
+ concat_audio = []
42
+ for wav_file in wav_files:
43
+ w = wave.open(wav_file, 'rb')
44
+ concat_audio.append([w.getparams(), w.readframes(w.getnframes())])
45
+ w.close()
46
+ logging.info(f'Delete audio file {wav_file}')
47
+ os.remove(wav_file)
48
+
49
+ output_file_name = f'{datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")}.wav'
50
+ output_file_path = os.path.join('./tmp', output_file_name)
51
+ output = wave.open(output_file_path, 'wb')
52
+ output.setparams(concat_audio[0][0])
53
+
54
+ for i in range(len(concat_audio)):
55
+ output.writeframes(concat_audio[i][1])
56
+ output.close()
57
+ logging.info(f'Concat past {len(wav_files)} wav files into {output_file_path}')
58
+ return output_file_path
59
+
60
+
61
+ # fp16 indicates whether using Float16 or Float32. Normally, PyTorch does not support fp16 when run on CPU
62
+ whisper_model = WhisperStreaming(model_name='base', language='en', fp16=False)
63
+
64
+
65
+ def transcribe(audio, state={}):
66
+ logging.info(f'Transcribe audio file {audio}')
67
+ print('=====================')
68
+ logging.info(state)
69
+
70
+ if not state:
71
+ state['concated_audio'] = audio
72
+ state['result_text'] = 'Waitting...'
73
+ state['count'] = 0
74
+ else:
75
+ state['concated_audio'] = concat_multiple_wav_files([state['concated_audio'], audio])
76
+ state['count'] += 1
77
+
78
+ if state['count'] % REC_INTERVAL_IN_SECONDS == 0 and state['count'] > 0:
79
+ logging.info('start to transcribe.......')
80
+ result = whisper_model.transcribe_audio_file(state['concated_audio'])
81
+ logging.info('complete transcribe.......')
82
+ state['result_text'] = result.text
83
+ logging.info('The text is:' + state['result_text'])
84
+ else:
85
+ logging.info(f'The count of streaming is {state["count"]}, and skip speech recognition')
86
+
87
+ return state['result_text'], state
88
+
89
+
90
+ gr.Interface(fn=transcribe,
91
+ inputs=[gr.Audio(source="microphone", type='filepath', streaming=True), 'state'],
92
+ outputs = ['text', 'state'],
93
+ live=True).launch()