zhenxuan's picture
create demo
042a25e
raw
history blame
3.51 kB
import datetime
import os
os.system('pip install git+https://github.com/openai/whisper.git')
import gradio as gr
import wave
import whisper
import logging
import torchaudio
import torchaudio.functional as F
LOGGING_FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=LOGGING_FORMAT,level=logging.INFO)
REC_INTERVAL_IN_SECONDS = 3
# tmp dir to store audio files.
if not os.path.isdir('./tmp/'):
os.mkdir('./tmp')
class WhisperStreaming():
def __init__(self, model_name='base', language='en', fp16=False):
self.model_name = model_name
self.language = language
self.fp16 = fp16
self.whisper_model = whisper.load_model(f'{model_name}.{language}')
self.decode_option = whisper.DecodingOptions(language=self.language,
without_timestamps=True,
fp16=self.fp16)
self.whisper_sample_rate = 16000
def transcribe_audio_file(self, wave_file_path):
waveform, sample_rate = torchaudio.load(wave_file_path)
resampled_waveform = F.resample(waveform, sample_rate, self.whisper_sample_rate, lowpass_filter_width=6)
audio_tmp = whisper.pad_or_trim(resampled_waveform[0])
mel = whisper.log_mel_spectrogram(audio_tmp)
results = self.whisper_model.decode(mel, self.decode_option)
return results
def concat_multiple_wav_files(wav_files):
logging.info(f'Concat {wav_files}')
concat_audio = []
for wav_file in wav_files:
w = wave.open(wav_file, 'rb')
concat_audio.append([w.getparams(), w.readframes(w.getnframes())])
w.close()
logging.info(f'Delete audio file {wav_file}')
os.remove(wav_file)
output_file_name = f'{datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")}.wav'
output_file_path = os.path.join('./tmp', output_file_name)
output = wave.open(output_file_path, 'wb')
output.setparams(concat_audio[0][0])
for i in range(len(concat_audio)):
output.writeframes(concat_audio[i][1])
output.close()
logging.info(f'Concat past {len(wav_files)} wav files into {output_file_path}')
return output_file_path
# fp16 indicates whether using Float16 or Float32. Normally, PyTorch does not support fp16 when run on CPU
whisper_model = WhisperStreaming(model_name='base', language='en', fp16=False)
def transcribe(audio, state={}):
logging.info(f'Transcribe audio file {audio}')
print('=====================')
logging.info(state)
if not state:
state['concated_audio'] = audio
state['result_text'] = 'Waitting...'
state['count'] = 0
else:
state['concated_audio'] = concat_multiple_wav_files([state['concated_audio'], audio])
state['count'] += 1
if state['count'] % REC_INTERVAL_IN_SECONDS == 0 and state['count'] > 0:
logging.info('start to transcribe.......')
result = whisper_model.transcribe_audio_file(state['concated_audio'])
logging.info('complete transcribe.......')
state['result_text'] = result.text
logging.info('The text is:' + state['result_text'])
else:
logging.info(f'The count of streaming is {state["count"]}, and skip speech recognition')
return state['result_text'], state
gr.Interface(fn=transcribe,
inputs=[gr.Audio(source="microphone", type='filepath', streaming=True), 'state'],
outputs = ['text', 'state'],
live=True).launch()