File size: 3,512 Bytes
042a25e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import datetime
import os
os.system('pip install git+https://github.com/openai/whisper.git')
import gradio as gr
import wave
import whisper
import logging
import torchaudio
import torchaudio.functional as F

LOGGING_FORMAT = '%(asctime)s %(message)s'
logging.basicConfig(format=LOGGING_FORMAT,level=logging.INFO)

REC_INTERVAL_IN_SECONDS = 3

# tmp dir to store audio files.
if not os.path.isdir('./tmp/'):
    os.mkdir('./tmp')

class WhisperStreaming():
    def __init__(self, model_name='base', language='en', fp16=False):
        self.model_name = model_name
        self.language = language
        self.fp16 = fp16
        self.whisper_model = whisper.load_model(f'{model_name}.{language}')
        self.decode_option = whisper.DecodingOptions(language=self.language,
                                                     without_timestamps=True,
                                                     fp16=self.fp16)
        self.whisper_sample_rate = 16000

    def transcribe_audio_file(self, wave_file_path):
        waveform, sample_rate = torchaudio.load(wave_file_path)
        resampled_waveform = F.resample(waveform, sample_rate, self.whisper_sample_rate, lowpass_filter_width=6)
        audio_tmp = whisper.pad_or_trim(resampled_waveform[0])
        mel = whisper.log_mel_spectrogram(audio_tmp)
        results = self.whisper_model.decode(mel, self.decode_option)
        return results

def concat_multiple_wav_files(wav_files):
    logging.info(f'Concat {wav_files}')
    concat_audio = []
    for wav_file in wav_files:
        w = wave.open(wav_file, 'rb')
        concat_audio.append([w.getparams(), w.readframes(w.getnframes())])
        w.close()
        logging.info(f'Delete audio file {wav_file}')
        os.remove(wav_file)

    output_file_name = f'{datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f")}.wav'
    output_file_path = os.path.join('./tmp', output_file_name)
    output = wave.open(output_file_path, 'wb')
    output.setparams(concat_audio[0][0])

    for i in range(len(concat_audio)):
        output.writeframes(concat_audio[i][1])
    output.close()
    logging.info(f'Concat past {len(wav_files)} wav files into {output_file_path}')
    return output_file_path


# fp16 indicates whether using Float16 or Float32. Normally, PyTorch does not support fp16 when run on CPU
whisper_model = WhisperStreaming(model_name='base', language='en', fp16=False)


def transcribe(audio, state={}):
    logging.info(f'Transcribe audio file {audio}')
    print('=====================')
    logging.info(state)

    if not state:
        state['concated_audio'] = audio
        state['result_text'] = 'Waitting...'
        state['count'] = 0
    else:
        state['concated_audio'] = concat_multiple_wav_files([state['concated_audio'], audio])
        state['count'] += 1

    if state['count'] % REC_INTERVAL_IN_SECONDS == 0 and state['count'] > 0:
        logging.info('start to transcribe.......')
        result = whisper_model.transcribe_audio_file(state['concated_audio'])
        logging.info('complete transcribe.......')
        state['result_text'] = result.text
        logging.info('The text is:' + state['result_text'])
    else:
        logging.info(f'The count of streaming is {state["count"]}, and skip speech recognition')

    return state['result_text'], state


gr.Interface(fn=transcribe,
             inputs=[gr.Audio(source="microphone", type='filepath', streaming=True), 'state'],
             outputs = ['text', 'state'],
             live=True).launch()