bachvudinh
commited on
Commit
•
e10af0d
1
Parent(s):
c3d86d3
add @spaces.GPU
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
import torchaudio
|
4 |
from encodec import EncodecModel
|
5 |
from whisperspeech.vq_stoks import RQBottleneckTransformer
|
@@ -19,7 +20,7 @@ vq_model = RQBottleneckTransformer.load_model(
|
|
19 |
"whisper-vq-stoks-medium-en+pl-fixed.model"
|
20 |
).to(device)
|
21 |
vq_model.ensure_whisper(device)
|
22 |
-
|
23 |
def audio_to_sound_tokens_whisperspeech(audio_path):
|
24 |
wav, sr = torchaudio.load(audio_path)
|
25 |
if sr != 16000:
|
@@ -30,6 +31,7 @@ def audio_to_sound_tokens_whisperspeech(audio_path):
|
|
30 |
|
31 |
result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
|
32 |
return f'<|sound_start|>{result}<|sound_end|>'
|
|
|
33 |
def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
|
34 |
wav, sr = torchaudio.load(audio_path)
|
35 |
if sr != 16000:
|
@@ -57,7 +59,7 @@ def audio_to_sound_tokens(audio_path, target_bandwidth=1.5, device="cuda"):
|
|
57 |
flatten_tokens = torch.stack((audio_code1, audio_code2), dim=1).flatten().tolist()
|
58 |
result = ''.join(f'<|sound_{num:04d}|>' for num in flatten_tokens)
|
59 |
return f'<|sound_start|>{result}<|sound_end|>'
|
60 |
-
|
61 |
def setup_pipeline(model_path, use_4bit=False, use_8bit=False):
|
62 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
63 |
model_kwargs = {"device_map": "auto"}
|
@@ -79,6 +81,7 @@ tokenizer = pipe.tokenizer
|
|
79 |
model = pipe.model
|
80 |
# print(tokenizer.encode("<|sound_0001|>", add_special_tokens=False))# return the audio tensor
|
81 |
# print(tokenizer.eos_token)
|
|
|
82 |
def text_to_audio_file(text):
|
83 |
# gen a random id for the audio file
|
84 |
id = str(uuid.uuid4())
|
@@ -93,6 +96,7 @@ def text_to_audio_file(text):
|
|
93 |
# torchaudio.save(temp_file, audio.cpu(), sample_rate=24000)
|
94 |
print(f"Saved audio to {temp_file}")
|
95 |
return temp_file
|
|
|
96 |
def process_input(input_type, text_input=None, audio_file=None):
|
97 |
# if input_type == "text":
|
98 |
# audio_file = "temp_audio.wav"
|
@@ -102,6 +106,7 @@ def process_input(input_type, text_input=None, audio_file=None):
|
|
102 |
|
103 |
# if input_type == "text":
|
104 |
# os.remove(audio_file)
|
|
|
105 |
def process_transcribe_input(input_type, text_input=None, audio_file=None):
|
106 |
# if input_type == "text":
|
107 |
# audio_file = "temp_audio.wav"
|
@@ -119,6 +124,7 @@ class StopOnTokens(StoppingCriteria):
|
|
119 |
if input_ids[0][-1] == stop_id:
|
120 |
return True
|
121 |
return False
|
|
|
122 |
def process_audio(audio_file, transcript=False):
|
123 |
if audio_file is None:
|
124 |
raise ValueError("No audio file provided")
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
import spaces
|
4 |
import torchaudio
|
5 |
from encodec import EncodecModel
|
6 |
from whisperspeech.vq_stoks import RQBottleneckTransformer
|
|
|
20 |
"whisper-vq-stoks-medium-en+pl-fixed.model"
|
21 |
).to(device)
|
22 |
vq_model.ensure_whisper(device)
|
23 |
+
@spaces.GPU
|
24 |
def audio_to_sound_tokens_whisperspeech(audio_path):
|
25 |
wav, sr = torchaudio.load(audio_path)
|
26 |
if sr != 16000:
|
|
|
31 |
|
32 |
result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
|
33 |
return f'<|sound_start|>{result}<|sound_end|>'
|
34 |
+
@spaces.GPU
|
35 |
def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
|
36 |
wav, sr = torchaudio.load(audio_path)
|
37 |
if sr != 16000:
|
|
|
59 |
flatten_tokens = torch.stack((audio_code1, audio_code2), dim=1).flatten().tolist()
|
60 |
result = ''.join(f'<|sound_{num:04d}|>' for num in flatten_tokens)
|
61 |
return f'<|sound_start|>{result}<|sound_end|>'
|
62 |
+
@spaces.GPU
|
63 |
def setup_pipeline(model_path, use_4bit=False, use_8bit=False):
|
64 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
65 |
model_kwargs = {"device_map": "auto"}
|
|
|
81 |
model = pipe.model
|
82 |
# print(tokenizer.encode("<|sound_0001|>", add_special_tokens=False))# return the audio tensor
|
83 |
# print(tokenizer.eos_token)
|
84 |
+
@spaces.GPU
|
85 |
def text_to_audio_file(text):
|
86 |
# gen a random id for the audio file
|
87 |
id = str(uuid.uuid4())
|
|
|
96 |
# torchaudio.save(temp_file, audio.cpu(), sample_rate=24000)
|
97 |
print(f"Saved audio to {temp_file}")
|
98 |
return temp_file
|
99 |
+
@spaces.GPU
|
100 |
def process_input(input_type, text_input=None, audio_file=None):
|
101 |
# if input_type == "text":
|
102 |
# audio_file = "temp_audio.wav"
|
|
|
106 |
|
107 |
# if input_type == "text":
|
108 |
# os.remove(audio_file)
|
109 |
+
@spaces.GPU
|
110 |
def process_transcribe_input(input_type, text_input=None, audio_file=None):
|
111 |
# if input_type == "text":
|
112 |
# audio_file = "temp_audio.wav"
|
|
|
124 |
if input_ids[0][-1] == stop_id:
|
125 |
return True
|
126 |
return False
|
127 |
+
@spaces.GPU
|
128 |
def process_audio(audio_file, transcript=False):
|
129 |
if audio_file is None:
|
130 |
raise ValueError("No audio file provided")
|