File size: 3,746 Bytes
c8eb530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from pprint import pprint
os.system("pip install git+https://github.com/openai/whisper.git")
import gradio as gr
import whisper
from transformers import pipeline
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import time
# import streaming.py
# from next_word_prediction import GPT2



### code snippet
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

### /code snippet


# get gpt2 model
generator = pipeline('text-generation', model='gpt2')

# whisper model specification 
model = whisper.load_model("tiny")


        
def inference(audio, state=""):

    #time.sleep(2)
    #text = p(audio)["text"]
    #state += text + " "
    # load audio data
    audio = whisper.load_audio(audio)
    # ensure sample is in correct format for inference
    audio = whisper.pad_or_trim(audio)

    # generate a log-mel spetrogram of the audio data
    mel = whisper.log_mel_spectrogram(audio).to(model.device)
    
    _, probs = model.detect_language(mel)

    # decode audio data
    options = whisper.DecodingOptions(fp16 = False)
    # transcribe speech to text
    result = whisper.decode(model, mel, options)
    
    # Added prompt below
    input_prompt = "The following is a transcript of someone talking, please predict what they will say next. \n"
    ### code 
    input_total = input_prompt + result.text
    input_ids = tokenizer(input_total, return_tensors="pt").input_ids
    print("inputs ", input_ids)

    # prompt length
    # prompt_length = len(tokenizer.decode(inputs_ids[0]))
    
    # length penalty for gpt2.generate??? 
    #Prompt
    generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True)
    print("outputs generated ", generated_outputs[0])
    # only use id's that were generated
    # gen_sequences has shape [3, 15]
    gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:]
    print("gen sequences: ", gen_sequences)
    
    # let's stack the logits generated at each step to a tensor and transform
    # logits to probs
    probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1)  # -> shape [3, 15, vocab_size]
    
    # now we need to collect the probability of the generated token
    # we need to add a dummy dim in the end to make gather work
    gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
    print("gen probs result: ", gen_probs)
    # now we can do all kinds of things with the probs
    
    # 1) the probs that exactly those sequences are generated again
    # those are normally going to be very small
    # unique_prob_per_sequence = gen_probs.prod(-1)
    
    # 2) normalize the probs over the three sequences
    # normed_gen_probs = gen_probs / gen_probs.sum(0)
    # assert normed_gen_probs[:, 0].sum() == 1.0, "probs should be normalized"
    
    # 3) compare normalized probs to each other like in 1)
    # unique_normed_prob_per_sequence = normed_gen_probs.prod(-1)
    
    ### end code
    # print audio data as text
    # print(result.text)
    # prompt
    getText = generator(result.text, max_new_tokens=10, num_return_sequences=5)
    state = getText
    print(state)
    gt = [gt['generated_text'] for gt in state]
        
    # result.text
    #return getText, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
    return result.text, state, gt



# get audio from microphone 

gr.Interface(
        fn=inference, 
    inputs=[
        gr.inputs.Audio(source="microphone", type="filepath"), 
        "state"
    ],
    outputs=[
        "textbox",
        "state",
        "textbox"
    ],
    live=True).launch()