File size: 3,749 Bytes
1b5d3b4
 
 
 
 
 
645c5d6
72733aa
1b5d3b4
f9d356e
87448d5
72733aa
1b5d3b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da5b65e
 
 
 
 
 
c442756
 
 
b9b4130
f946ebf
b9b4130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c442756
1b5d3b4
645c5d6
1b5d3b4
 
 
f946ebf
 
f9d356e
 
 
 
f946ebf
da5b65e
 
 
 
f946ebf
 
1b5d3b4
f946ebf
1b5d3b4
 
f9d356e
1b5d3b4
 
 
 
f9d356e
645c5d6
 
67a28b1
645c5d6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import soundfile as sf
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import gradio as gr
import sox
import subprocess
from fuzzywuzzy import fuzz
from data import get_data


DATASET = get_data()

def read_file_and_process(wav_file):
    filename = wav_file.split('.')[0]
    filename_16k = filename + "16k.wav"
    resampler(wav_file, filename_16k)
    speech, _ = sf.read(filename_16k)
    inputs = processor(speech, sampling_rate=16_000, return_tensors="pt", padding=True)
    
    return inputs


def resampler(input_file_path, output_file_path):
    command = (
        f"ffmpeg -hide_banner -loglevel panic -i {input_file_path} -ar 16000 -ac 1 -bits_per_raw_sample 16 -vn "
        f"{output_file_path}"
    )
    subprocess.call(command, shell=True)


def parse_transcription(logits):
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
    return transcription


def parse(wav_file):
    input_values = read_file_and_process(wav_file)
    with torch.no_grad():
        logits = model(**input_values).logits
    user_question = parse_transcription(logits)
    return user_question


# Function to retrieve an answer based on a question (using fuzzy matching)
def get_answer(wav_file=None):
    
    input_values = read_file_and_process(wav_file)
    
    with torch.no_grad():
        logits = model(**input_values).logits
    user_question = parse_transcription(logits)
    
    highest_score = 0
    best_answer = None

    for item in DATASET:
        similarity_score = fuzz.token_set_ratio(user_question, item["question"])
        if similarity_score > highest_score:
            highest_score = similarity_score
            best_answer = item["answer"]

    if highest_score >= 80:  # Adjust the similarity threshold as needed
        return best_answer
    else:
        return "I don't have an answer to that question."


model_id = "jonatasgrosman/wav2vec2-large-xlsr-53-persian"
processor = Wav2Vec2Processor.from_pretrained(model_id)
model = Wav2Vec2ForCTC.from_pretrained(model_id)

input_ = [
        gr.Audio(source="microphone",
                  type="filepath",
                  label="لطفا دکمه ضبط صدا را بزنید و شروع به صحبت کنید و بعذ از اتمام صحبت دوباره دکمه ضبط را فشار دهید.",
                  show_download_button=True,
                  show_edit_button=True,
                 ), 
        # gr.Textbox(label="سوال خود را بنویسید.",
        #            lines=3,
        #            text_align="right",
        #            show_label=True,)
         ]

txtbox = gr.Textbox(
            label="پاسخ شما: ",
            lines=5,
            text_align="right",
            show_label=True,
            show_copy_button=True,
        )

title = "Speech-to-Text (persian)"
description = "، توجه داشته باشید که هرچه گفتار شما شمرده تر باشد خروجی با کیفیت تری دارید.روی دکمه ضبط صدا کلیک کنید و سپس دسترسی مرورگر خود را به میکروفون دستگاه بدهید، سپس شروع به صحبت کنید و برای اتمام ضبط دوباره روی دکمه کلیک کنید"
article = "<p style='text-align: center'><a href='https://github.com/nimaprgrmr'>Large-Scale Self- and Semi-Supervised Learning for Speech Translation</a></p>"

demo = gr.Interface(fn=get_answer, inputs = input_,  outputs=txtbox, title=title, description=description, article = article,
             streaming=True, interactive=True,
             analytics_enabled=False, show_tips=False, enable_queue=True)
demo.launch(share=True)