English
music
music-captioning
Inference Endpoints
ivillar commited on
Commit
d99dc9a
·
1 Parent(s): 3ac37de

Update handler and requirements

Browse files
Files changed (2) hide show
  1. handler.py +67 -35
  2. requirements.txt +19 -0
handler.py CHANGED
@@ -1,17 +1,15 @@
1
  import torch
2
  from model.bart import BartCaptionModel
3
  from utils.audio_utils import load_audio, STR_CH_FIRST
 
4
  import numpy as np
 
5
 
6
-
7
- def get_audio(audio_path, duration=10, target_sr=16000):
8
  n_samples = int(duration * target_sr)
9
- audio, sr = load_audio(
10
- path= audio_path,
11
- ch_format= STR_CH_FIRST,
12
- sample_rate= target_sr,
13
- downmix_to_mono= True,
14
- )
15
  if len(audio.shape) == 2:
16
  audio = audio.mean(0, False) # to mono
17
  input_size = int(n_samples)
@@ -23,31 +21,65 @@ def get_audio(audio_path, duration=10, target_sr=16000):
23
  audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
24
  return audio
25
 
26
- def captioning(audio_path):
27
- audio_tensor = get_audio(audio_path = audio_path)
28
- if device is not None:
29
- audio_tensor = audio_tensor.to(device)
30
- with torch.no_grad():
31
- output = model.generate(
32
- samples=audio_tensor,
33
- num_beams=5,
34
- )
35
- inference = ""
36
- number_of_chunks = range(audio_tensor.shape[0])
37
- for chunk, text in zip(number_of_chunks, output):
38
- time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]"
39
- inference += f"{time}\n{text} \n \n"
40
- return inference
41
-
42
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
43
-
44
- example_list = ['electronic.mp3', 'orchestra.wav']
45
- model = BartCaptionModel(max_length = 128)
46
- pretrained_object = torch.load('./transfer.pth', map_location='cpu')
47
- state_dict = pretrained_object['state_dict']
48
- model.load_state_dict(state_dict)
49
- if torch.cuda.is_available():
50
- torch.cuda.set_device(device)
51
- model = model.cuda(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- print(captioning("electronic.mp3"))
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from model.bart import BartCaptionModel
3
  from utils.audio_utils import load_audio, STR_CH_FIRST
4
+ from typing import Dict, List, Any
5
  import numpy as np
6
+ import librosa
7
 
8
+ def preprocess_audio(audio_signal, sr, duration=10, target_sr=16000):
 
9
  n_samples = int(duration * target_sr)
10
+ audio = librosa.to_mono(audio_signal)
11
+ audio = librosa.resample(audio, orig_sr = sr, target_sr = target_sr)
12
+
 
 
 
13
  if len(audio.shape) == 2:
14
  audio = audio.mean(0, False) # to mono
15
  input_size = int(n_samples)
 
21
  audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
22
  return audio
23
 
24
+ class EndpointHandler:
25
+ def __init__(self, path=""):
26
+ self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
27
+ self.model = BartCaptionModel(max_length = 128)
28
+ pretrained_object = torch.load('./transfer.pth', map_location='cpu')
29
+ state_dict = pretrained_object['state_dict']
30
+ self.model.load_state_dict(state_dict)
31
+ if torch.cuda.is_available():
32
+ torch.cuda.set_device(self.device)
33
+ self.model = self.model.cuda(self.device)
34
+
35
+ def _captioning(self, audio_tensor):
36
+ if self.device is not None:
37
+ audio_tensor = audio_tensor.to(self.device)
38
+
39
+ with torch.no_grad():
40
+ output = self.model.generate(
41
+ samples=audio_tensor,
42
+ num_beams=5,
43
+ )
44
+ inference = ""
45
+ number_of_chunks = range(audio_tensor.shape[0])
46
+ for chunk, text in zip(number_of_chunks, output):
47
+ time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]"
48
+ inference += f"{time}\n{text} \n \n"
49
+ return inference
50
+
51
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
52
+ audio_bytes = data["audio_bytes"]
53
+ audio_shape = tuple([int(x) for x in data["audio_shape"].split(', ')])
54
+ audio_dtype = data["audio_dtype"]
55
+ sr = data["sampling_rate"]
56
+
57
+ input_audio = np.frombuffer(audio_bytes, dtype=audio_dtype).reshape(audio_shape)
58
+
59
+ preprocessed_audio = preprocess_audio(input_audio, sr)
60
+
61
+ return self._captioning(preprocessed_audio)
62
+ """
63
+ if __name__ == "__main__":
64
+ import numpy as np
65
+ from scipy.io.wavfile import write as wav_write
66
+ from huggingface_hub import InferenceApi
67
+
68
+ handler = EndpointHandler()
69
+ audio_path = "folk.wav"
70
+ np_audio, sr = librosa.load(audio_path, sr=44100)
71
+
72
+ np_bytes = np_audio.tobytes()
73
+ np_shape = np_audio.shape
74
+ np_dtype = np_audio.dtype.name
75
 
76
+ request = {
77
+ "audio_bytes": np_bytes,
78
+ "audio_shape": ', '.join(map(str, np_shape)),
79
+ "audio_dtype": np_dtype,
80
+ "sampling_rate": sr
81
+ }
82
+
83
+ print(f"Loaded {audio_path} with sample rate {sr}")
84
+ print(handler.__call__(request))
85
+ """
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets==2.18.0
2
+ huggingface-hub==0.21.4
3
+ julius==0.2.7
4
+ librosa==0.10.1
5
+ multidict==6.0.5
6
+ multiprocess==0.70.16
7
+ numpy==1.26.4
8
+ packaging==23.2
9
+ pandas==2.2.1
10
+ pydub==0.25.1
11
+ scikit-learn==1.4.1.post1
12
+ scipy==1.12.0
13
+ tokenizers==0.13.3
14
+ torch==1.13.1
15
+ torchaudio==0.13.1
16
+ torchaudio-augmentations==0.2.1
17
+ tqdm==4.66.2
18
+ transformers==4.26.1
19
+ wavaugment==0.2