bachvudinh commited on
Commit
e10af0d
1 Parent(s): c3d86d3

add @spaces.GPU

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  import torchaudio
4
  from encodec import EncodecModel
5
  from whisperspeech.vq_stoks import RQBottleneckTransformer
@@ -19,7 +20,7 @@ vq_model = RQBottleneckTransformer.load_model(
19
  "whisper-vq-stoks-medium-en+pl-fixed.model"
20
  ).to(device)
21
  vq_model.ensure_whisper(device)
22
-
23
  def audio_to_sound_tokens_whisperspeech(audio_path):
24
  wav, sr = torchaudio.load(audio_path)
25
  if sr != 16000:
@@ -30,6 +31,7 @@ def audio_to_sound_tokens_whisperspeech(audio_path):
30
 
31
  result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
32
  return f'<|sound_start|>{result}<|sound_end|>'
 
33
  def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
34
  wav, sr = torchaudio.load(audio_path)
35
  if sr != 16000:
@@ -57,7 +59,7 @@ def audio_to_sound_tokens(audio_path, target_bandwidth=1.5, device="cuda"):
57
  flatten_tokens = torch.stack((audio_code1, audio_code2), dim=1).flatten().tolist()
58
  result = ''.join(f'<|sound_{num:04d}|>' for num in flatten_tokens)
59
  return f'<|sound_start|>{result}<|sound_end|>'
60
-
61
  def setup_pipeline(model_path, use_4bit=False, use_8bit=False):
62
  tokenizer = AutoTokenizer.from_pretrained(model_path)
63
  model_kwargs = {"device_map": "auto"}
@@ -79,6 +81,7 @@ tokenizer = pipe.tokenizer
79
  model = pipe.model
80
  # print(tokenizer.encode("<|sound_0001|>", add_special_tokens=False))# return the audio tensor
81
  # print(tokenizer.eos_token)
 
82
  def text_to_audio_file(text):
83
  # gen a random id for the audio file
84
  id = str(uuid.uuid4())
@@ -93,6 +96,7 @@ def text_to_audio_file(text):
93
  # torchaudio.save(temp_file, audio.cpu(), sample_rate=24000)
94
  print(f"Saved audio to {temp_file}")
95
  return temp_file
 
96
  def process_input(input_type, text_input=None, audio_file=None):
97
  # if input_type == "text":
98
  # audio_file = "temp_audio.wav"
@@ -102,6 +106,7 @@ def process_input(input_type, text_input=None, audio_file=None):
102
 
103
  # if input_type == "text":
104
  # os.remove(audio_file)
 
105
  def process_transcribe_input(input_type, text_input=None, audio_file=None):
106
  # if input_type == "text":
107
  # audio_file = "temp_audio.wav"
@@ -119,6 +124,7 @@ class StopOnTokens(StoppingCriteria):
119
  if input_ids[0][-1] == stop_id:
120
  return True
121
  return False
 
122
  def process_audio(audio_file, transcript=False):
123
  if audio_file is None:
124
  raise ValueError("No audio file provided")
 
1
  import gradio as gr
2
  import torch
3
+ import spaces
4
  import torchaudio
5
  from encodec import EncodecModel
6
  from whisperspeech.vq_stoks import RQBottleneckTransformer
 
20
  "whisper-vq-stoks-medium-en+pl-fixed.model"
21
  ).to(device)
22
  vq_model.ensure_whisper(device)
23
+ @spaces.GPU
24
  def audio_to_sound_tokens_whisperspeech(audio_path):
25
  wav, sr = torchaudio.load(audio_path)
26
  if sr != 16000:
 
31
 
32
  result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
33
  return f'<|sound_start|>{result}<|sound_end|>'
34
+ @spaces.GPU
35
  def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
36
  wav, sr = torchaudio.load(audio_path)
37
  if sr != 16000:
 
59
  flatten_tokens = torch.stack((audio_code1, audio_code2), dim=1).flatten().tolist()
60
  result = ''.join(f'<|sound_{num:04d}|>' for num in flatten_tokens)
61
  return f'<|sound_start|>{result}<|sound_end|>'
62
+ @spaces.GPU
63
  def setup_pipeline(model_path, use_4bit=False, use_8bit=False):
64
  tokenizer = AutoTokenizer.from_pretrained(model_path)
65
  model_kwargs = {"device_map": "auto"}
 
81
  model = pipe.model
82
  # print(tokenizer.encode("<|sound_0001|>", add_special_tokens=False))# return the audio tensor
83
  # print(tokenizer.eos_token)
84
+ @spaces.GPU
85
  def text_to_audio_file(text):
86
  # gen a random id for the audio file
87
  id = str(uuid.uuid4())
 
96
  # torchaudio.save(temp_file, audio.cpu(), sample_rate=24000)
97
  print(f"Saved audio to {temp_file}")
98
  return temp_file
99
+ @spaces.GPU
100
  def process_input(input_type, text_input=None, audio_file=None):
101
  # if input_type == "text":
102
  # audio_file = "temp_audio.wav"
 
106
 
107
  # if input_type == "text":
108
  # os.remove(audio_file)
109
+ @spaces.GPU
110
  def process_transcribe_input(input_type, text_input=None, audio_file=None):
111
  # if input_type == "text":
112
  # audio_file = "temp_audio.wav"
 
124
  if input_ids[0][-1] == stop_id:
125
  return True
126
  return False
127
+ @spaces.GPU
128
  def process_audio(audio_file, transcript=False):
129
  if audio_file is None:
130
  raise ValueError("No audio file provided")