collinbarnwell commited on
Commit
88b4bc4
1 Parent(s): 7c9f58f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +33 -13
handler.py CHANGED
@@ -24,6 +24,8 @@
24
  from pyannote.audio import Pipeline, Audio
25
  import torch
26
  import os
 
 
27
 
28
  class EndpointHandler:
29
  def __init__(self, path=""):
@@ -47,23 +49,41 @@ class EndpointHandler:
47
  # initialize audio reader
48
  self._io = Audio()
49
 
 
50
  def __call__(self, data):
51
  inputs = data.pop("inputs", data)
52
  waveform = torch.tensor(inputs["waveform"])
53
  sample_rate = inputs["sample_rate"]
54
-
55
  parameters = data.pop("parameters", dict())
56
- diarization = self._pipeline(
57
- {"waveform": waveform, "sample_rate": sample_rate}, **parameters
58
- )
59
 
60
- processed_diarization = [
61
- {
62
- "speaker": speaker,
63
- "start": f"{turn.start:.3f}",
64
- "end": f"{turn.end:.3f}",
65
- }
66
- for turn, _, speaker in diarization.itertracks(yield_label=True)
67
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- return {"diarization": processed_diarization}
 
24
  from pyannote.audio import Pipeline, Audio
25
  import torch
26
  import os
27
+ import threading
28
+ import time
29
 
30
  class EndpointHandler:
31
  def __init__(self, path=""):
 
49
  # initialize audio reader
50
  self._io = Audio()
51
 
52
+
53
  def __call__(self, data):
54
  inputs = data.pop("inputs", data)
55
  waveform = torch.tensor(inputs["waveform"])
56
  sample_rate = inputs["sample_rate"]
 
57
  parameters = data.pop("parameters", dict())
 
 
 
58
 
59
+ # Container for storing diarization result
60
+ diarization_result = {}
61
+
62
+ def diarize():
63
+ nonlocal diarization_result
64
+ diarization = self._pipeline(
65
+ {"waveform": waveform, "sample_rate": sample_rate}, **parameters
66
+ )
67
+ diarization_result = [
68
+ {
69
+ "speaker": speaker,
70
+ "start": f"{turn.start:.3f}",
71
+ "end": f"{turn.end:.3f}",
72
+ }
73
+ for turn, _, speaker in diarization.itertracks(yield_label=True)
74
+ ]
75
+
76
+ # Running diarization in a separate thread
77
+ diarization_thread = threading.Thread(target=diarize)
78
+ diarization_thread.start()
79
+
80
+ # Wait for the diarization to complete or timeout
81
+ diarization_thread.join(timeout=298)
82
+
83
+ # Check if the thread is still alive (indicating a timeout occurred)
84
+ if diarization_thread.is_alive():
85
+ print("Diarization timed out")
86
+ # Handle the timeout case, maybe by raising an error or a warning
87
+ raise TimeoutError("Diarization process exceeded time limit.")
88
 
89
+ return {"diarization": diarization_result}