mskov's picture
Duplicate from mskov/whisper_stream
c8eb530
raw
history blame
3.75 kB
import os
from pprint import pprint
os.system("pip install git+https://github.com/openai/whisper.git")
import gradio as gr
import whisper
from transformers import pipeline
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import time
# import streaming.py
# from next_word_prediction import GPT2
### code snippet
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
### /code snippet
# get gpt2 model
generator = pipeline('text-generation', model='gpt2')
# whisper model specification
model = whisper.load_model("tiny")
def inference(audio, state=""):
#time.sleep(2)
#text = p(audio)["text"]
#state += text + " "
# load audio data
audio = whisper.load_audio(audio)
# ensure sample is in correct format for inference
audio = whisper.pad_or_trim(audio)
# generate a log-mel spetrogram of the audio data
mel = whisper.log_mel_spectrogram(audio).to(model.device)
_, probs = model.detect_language(mel)
# decode audio data
options = whisper.DecodingOptions(fp16 = False)
# transcribe speech to text
result = whisper.decode(model, mel, options)
# Added prompt below
input_prompt = "The following is a transcript of someone talking, please predict what they will say next. \n"
### code
input_total = input_prompt + result.text
input_ids = tokenizer(input_total, return_tensors="pt").input_ids
print("inputs ", input_ids)
# prompt length
# prompt_length = len(tokenizer.decode(inputs_ids[0]))
# length penalty for gpt2.generate???
#Prompt
generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True)
print("outputs generated ", generated_outputs[0])
# only use id's that were generated
# gen_sequences has shape [3, 15]
gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:]
print("gen sequences: ", gen_sequences)
# let's stack the logits generated at each step to a tensor and transform
# logits to probs
probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1) # -> shape [3, 15, vocab_size]
# now we need to collect the probability of the generated token
# we need to add a dummy dim in the end to make gather work
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
print("gen probs result: ", gen_probs)
# now we can do all kinds of things with the probs
# 1) the probs that exactly those sequences are generated again
# those are normally going to be very small
# unique_prob_per_sequence = gen_probs.prod(-1)
# 2) normalize the probs over the three sequences
# normed_gen_probs = gen_probs / gen_probs.sum(0)
# assert normed_gen_probs[:, 0].sum() == 1.0, "probs should be normalized"
# 3) compare normalized probs to each other like in 1)
# unique_normed_prob_per_sequence = normed_gen_probs.prod(-1)
### end code
# print audio data as text
# print(result.text)
# prompt
getText = generator(result.text, max_new_tokens=10, num_return_sequences=5)
state = getText
print(state)
gt = [gt['generated_text'] for gt in state]
# result.text
#return getText, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
return result.text, state, gt
# get audio from microphone
gr.Interface(
fn=inference,
inputs=[
gr.inputs.Audio(source="microphone", type="filepath"),
"state"
],
outputs=[
"textbox",
"state",
"textbox"
],
live=True).launch()