DrishtiSharma commited on
Commit
5da8d71
1 Parent(s): ca5824e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -53
app.py CHANGED
@@ -1,68 +1,69 @@
1
- # -*- coding: utf-8 -*-
2
- """Untitled29.ipynb
 
 
 
 
 
3
 
4
- Automatically generated by Colaboratory.
5
 
6
- Original file is located at
7
- https://colab.research.google.com/drive/1Lv3LjRH9bHwMhKsWvFcELMzKqmXd9UIb
8
- """
9
 
10
- !pip install -q transformers
11
- !pip install -q gradio
12
 
13
- import nltk
14
- import librosa
15
- import torch
16
- import soundfile as sf
17
- import gradio as gr
18
- from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC
19
- nltk.download("punkt")
 
 
 
20
 
21
- input_file = "/content/drive/MyDrive/AAAAUDIO/My Audio.wav"
22
 
23
- tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
24
- model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
25
-
26
- def load_data(input_file):
27
-
28
- """ Function for resampling to ensure that the speech input is sampled at 16KHz.
29
- """
30
- #read the file
31
- speech, sample_rate = sf.read(input_file)
32
-
33
- #make it 1-D
34
- if len(speech.shape) > 1:
35
- speech = speech[:,0] + speech[:,1]
36
-
37
- #Resampling at 16KHz since wav2vec2-base-960h is pretrained and fine-tuned on speech audio sampled at 16 KHz.
38
- if sample_rate !=16000:
39
- speech = librosa.resample(speech, sample_rate,16000)
40
- return speech
41
 
42
- def asr_transcript(input_file):
43
- speech = load_data(input_file)
44
 
45
- #Tokenize
46
- input_values = tokenizer(speech, return_tensors="pt").input_values
47
 
48
- #Take logits
49
- logits = model(input_values).logits
 
50
 
51
- #Take argmax
52
- predicted_ids = torch.argmax(logits, dim=-1)
53
 
54
- #Get the words from predicted word ids
55
- transcription = tokenizer.decode(predicted_ids[0])
56
 
57
- #Output is all upper case
58
- transcription = correct_casing(transcription.lower())
 
59
 
60
- return transcription
 
61
 
62
- gr.Interface(asr_transcript,
63
- inputs = gr.inputs.Audio(label = "Input Audio", type= "file"),
64
- outputs = gr.outputs.Textbox(label="Output Text"),
65
- title="Real-time ASR using Wav2Vec 2.0",
66
- description = "asdfghnjmk",
67
- examples = [["/content/drive/MyDrive/AAAAUDIO/My Audio.wav"]]).launch()
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #importing all the necessary packages
2
+ import torch
3
+ import transformers
4
+ import gradio as gr
5
+ from torchaudio.sox_effects import apply_effects_file
6
+ from termcolor import colored
7
+ from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForAudioFrameClassification
8
 
 
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
11
 
 
 
12
 
13
+ # Defines the effects to apply to the audio file
14
+ EFFECTS = [
15
+ ['remix', '-'], # merge all the channels
16
+ ["channels", "1"], #channel-->mono
17
+ ["rate", "16000"], # resample to 16000 Hz
18
+ ["gain", "-1.0"], #Attenuation -1 dB
19
+ ["silence", "1", "0.1", "0.1%", "-1", "0.1", "0.1%"],
20
+ #['pad', '0', '1.5'], # add 1.5 seconds silence at the end
21
+ ['trim', '0', '10'], # get the first 10 seconds
22
+ ]
23
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ THRESHOLD = 0.85 #depends on dataset
 
27
 
 
 
28
 
29
+ model_name = "microsoft/unispeech-sat-base-sd"
30
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
31
+ model = UniSpeechSatForAudioFrameClassification.from_pretrained(model_name).to(device)
32
 
 
 
33
 
 
 
34
 
35
+ def fn(path):
36
+ #Applying the effects to the audio input file
37
+ wav, _ = apply_effects_file(path, EFFECTS)
38
 
39
+ #Extracting features
40
+ input = feature_extractor(wav.squeez(0), return_tensors="pt", sampling_rate=16000).input_values.to(device)
41
 
42
+ with torch.no_grad():
43
+ logits = model(input).logits
44
+ logits = logits.to(device)
45
+ probabilities = torch.sigmoid(logits[0])
 
 
46
 
47
+ # labels is a one-hot array of shape (num_frames, num_speakers)
48
+ labels = (probabilities > 0.5).long()
49
+ return labels
50
+
51
+
52
+
53
+
54
+ inputs = [
55
+ gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker #1"),
56
+ ]
57
+
58
+ output = gr.outputs.HTML(label="")
59
+
60
+ gr.Interface(
61
+ fn=fn,
62
+ inputs=inputs,
63
+ outputs=output,
64
+ title="Speaker diarization using UniSpeech-SAT and X-Vectors").launch(enable_queue=True)
65
+
66
+
67
+
68
+
69
+