File size: 1,863 Bytes
5da8d71
 
 
 
 
 
 
a357a65
 
5da8d71
a357a65
 
5da8d71
 
 
 
 
 
 
 
 
 
a357a65
 
 
5da8d71
a357a65
 
5da8d71
 
 
a357a65
 
 
5da8d71
 
 
a357a65
5da8d71
fa49488
a357a65
5da8d71
 
 
 
a357a65
5da8d71
 
 
 
 
 
 
 
 
 
 
d0bf313
5da8d71
 
 
 
 
7c97c61
5da8d71
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#importing all the necessary packages
import torch
import transformers
import gradio as gr
from torchaudio.sox_effects import apply_effects_file
from termcolor import colored
from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForAudioFrameClassification


device = "cuda" if torch.cuda.is_available() else "cpu"


# Defines the effects to apply to the audio file
EFFECTS = [
    ['remix', '-'],        # merge all the channels
    ["channels", "1"],     #channel-->mono
    ["rate", "16000"],     # resample to 16000 Hz
    ["gain", "-1.0"],      #Attenuation -1 dB
    ["silence", "1", "0.1", "0.1%", "-1", "0.1", "0.1%"],
    #['pad', '0', '1.5'],  # add 1.5 seconds silence at the end
    ['trim', '0', '10'],   # get the first 10 seconds
]



THRESHOLD = 0.85 #depends on dataset


model_name = "microsoft/unispeech-sat-base-sd"
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
model = UniSpeechSatForAudioFrameClassification.from_pretrained(model_name).to(device)



def fn(path):
  #Applying the effects to the audio input file
  wav, _ = apply_effects_file(path, EFFECTS)

  #Extracting features
  input = feature_extractor(wav.squeeze(0), return_tensors="pt", sampling_rate=16000).input_values.to(device)

  with torch.no_grad():
    logits = model(input).logits
  logits = logits.to(device)
  probabilities = torch.sigmoid(logits[0])

  # labels is a one-hot array of shape (num_frames, num_speakers)
  labels = (probabilities > 0.5).long()
  return labels
  
  
  
 
inputs = [
    gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Speaker #1"),
]

output = gr.outputs.Textbox(label="Output Text")

gr.Interface(
    fn=fn,
    inputs=inputs,
    outputs=output,
    theme = "grass",
    title="Speaker diarization using UniSpeech-SAT and X-Vectors").launch(enable_queue=True)