DrishtiSharma commited on
Commit
a357a65
1 Parent(s): 70cb4e9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Untitled29.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Lv3LjRH9bHwMhKsWvFcELMzKqmXd9UIb
8
+ """
9
+
10
+ !pip install -q transformers
11
+ !pip install -q gradio
12
+
13
+ import nltk
14
+ import librosa
15
+ import torch
16
+ import soundfile as sf
17
+ import gradio as gr
18
+ from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
19
+ nltk.download("punkt")
20
+
21
+ input_file = "/content/drive/MyDrive/AAAAUDIO/My Audio.wav"
22
+
23
+ tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
24
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
25
+
26
+ def load_data(input_file):
27
+
28
+ """ Function for resampling to ensure that the speech input is sampled at 16KHz.
29
+ """
30
+ #read the file
31
+ speech, sample_rate = sf.read(input_file)
32
+
33
+ #make it 1-D
34
+ if len(speech.shape) > 1:
35
+ speech = speech[:,0] + speech[:,1]
36
+
37
+ #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
38
+ if sample_rate !=16000:
39
+ speech = librosa.resample(speech, sample_rate,16000)
40
+ return speech
41
+
42
+ def asr_transcript(input_file):
43
+ speech = load_data(input_file)
44
+
45
+ #Tokenize
46
+ input_values = tokenizer(speech, return_tensors="pt").input_values
47
+
48
+ #Take logits
49
+ logits = model(input_values).logits
50
+
51
+ #Take argmax
52
+ predicted_ids = torch.argmax(logits, dim=-1)
53
+
54
+ #Get the words from predicted word ids
55
+ transcription = tokenizer.decode(predicted_ids[0])
56
+
57
+ #Output is all upper case
58
+ transcription = correct_casing(transcription.lower())
59
+
60
+ return transcription
61
+
62
+ gr.Interface(asr_transcript,
63
+ inputs = gr.inputs.Audio(label = "Input Audio", type= "file"),
64
+ outputs = gr.outputs.Textbox(label="Output Text"),
65
+ title="Real-time ASR using Wav2Vec 2.0",
66
+ description = "asdfghnjmk",
67
+ examples = [["/content/drive/MyDrive/AAAAUDIO/My Audio.wav"]]).launch()
68
+