cydxg commited on
Commit
480f3c3
·
verified ·
1 Parent(s): 54f2893

Upload 5 files

Browse files
Files changed (5) hide show
  1. flow_inference.py +142 -0
  2. model_server.py +116 -0
  3. quantifization.py +27 -0
  4. requirements.txt +36 -0
  5. web_demo.py +264 -0
flow_inference.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import numpy as np
4
+ import re
5
+ from hyperpyyaml import load_hyperpyyaml
6
+ import uuid
7
+ from collections import defaultdict
8
+
9
+
10
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
11
+ device = fade_in_mel.device
12
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
13
+ mel_overlap_len = int(window.shape[0] / 2)
14
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
15
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
16
+ return fade_in_mel.to(device)
17
+
18
+
19
+ class AudioDecoder:
20
+ def __init__(self, config_path, flow_ckpt_path, hift_ckpt_path, device="cuda"):
21
+ self.device = device
22
+
23
+ with open(config_path, 'r') as f:
24
+ self.scratch_configs = load_hyperpyyaml(f)
25
+
26
+ # Load models
27
+ self.flow = self.scratch_configs['flow']
28
+ self.flow.load_state_dict(torch.load(flow_ckpt_path, map_location=self.device))
29
+ self.hift = self.scratch_configs['hift']
30
+ self.hift.load_state_dict(torch.load(hift_ckpt_path, map_location=self.device))
31
+
32
+ # Move models to the appropriate device
33
+ self.flow.to(self.device)
34
+ self.hift.to(self.device)
35
+ self.mel_overlap_dict = defaultdict(lambda: None)
36
+ self.hift_cache_dict = defaultdict(lambda: None)
37
+ self.token_min_hop_len = 2 * self.flow.input_frame_rate
38
+ self.token_max_hop_len = 4 * self.flow.input_frame_rate
39
+ self.token_overlap_len = 5
40
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
41
+ self.mel_window = np.hamming(2 * self.mel_overlap_len)
42
+ # hift cache
43
+ self.mel_cache_len = 1
44
+ self.source_cache_len = int(self.mel_cache_len * 256)
45
+ # speech fade in out
46
+ self.speech_window = np.hamming(2 * self.source_cache_len)
47
+
48
+ def token2wav(self, token, uuid, prompt_token=torch.zeros(1, 0, dtype=torch.int32),
49
+ prompt_feat=torch.zeros(1, 0, 80), embedding=torch.zeros(1, 192), finalize=False):
50
+ tts_mel = self.flow.inference(token=token.to(self.device),
51
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
52
+ prompt_token=prompt_token.to(self.device),
53
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(
54
+ self.device),
55
+ prompt_feat=prompt_feat.to(self.device),
56
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(
57
+ self.device),
58
+ embedding=embedding.to(self.device))
59
+
60
+ # mel overlap fade in out
61
+ if self.mel_overlap_dict[uuid] is not None:
62
+ tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
63
+ # append hift cache
64
+ if self.hift_cache_dict[uuid] is not None:
65
+ hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
66
+ tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
67
+
68
+ else:
69
+ hift_cache_source = torch.zeros(1, 1, 0)
70
+ # _tts_mel=tts_mel.contiguous()
71
+ # keep overlap mel and hift cache
72
+ if finalize is False:
73
+ self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
74
+ tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
75
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
76
+
77
+ self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
78
+ 'source': tts_source[:, :, -self.source_cache_len:],
79
+ 'speech': tts_speech[:, -self.source_cache_len:]}
80
+ # if self.hift_cache_dict[uuid] is not None:
81
+ # tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
82
+ tts_speech = tts_speech[:, :-self.source_cache_len]
83
+
84
+ else:
85
+ tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
86
+ del self.hift_cache_dict[uuid]
87
+ del self.mel_overlap_dict[uuid]
88
+ # if uuid in self.hift_cache_dict.keys() and self.hift_cache_dict[uuid] is not None:
89
+ # tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
90
+ return tts_speech, tts_mel
91
+
92
+ def offline_inference(self, token):
93
+ this_uuid = str(uuid.uuid1())
94
+ tts_speech, tts_mel = self.token2wav(token, uuid=this_uuid, finalize=True)
95
+ return tts_speech.cpu()
96
+
97
+ def stream_inference(self, token):
98
+ token.to(self.device)
99
+ this_uuid = str(uuid.uuid1())
100
+
101
+ # Prepare other necessary input tensors
102
+ llm_embedding = torch.zeros(1, 192).to(self.device)
103
+ prompt_speech_feat = torch.zeros(1, 0, 80).to(self.device)
104
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int32).to(self.device)
105
+
106
+ tts_speechs = []
107
+ tts_mels = []
108
+
109
+ block_size = self.flow.encoder.block_size
110
+ prev_mel = None
111
+
112
+ for idx in range(0, token.size(1), block_size):
113
+ # if idx>block_size: break
114
+ tts_token = token[:, idx:idx + block_size]
115
+
116
+ print(tts_token.size())
117
+
118
+ if prev_mel is not None:
119
+ prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
120
+ flow_prompt_speech_token = token[:, :idx]
121
+
122
+ if idx + block_size >= token.size(-1):
123
+ is_finalize = True
124
+ else:
125
+ is_finalize = False
126
+
127
+ tts_speech, tts_mel = self.token2wav(tts_token, uuid=this_uuid,
128
+ prompt_token=flow_prompt_speech_token.to(self.device),
129
+ prompt_feat=prompt_speech_feat.to(self.device), finalize=is_finalize)
130
+
131
+ prev_mel = tts_mel
132
+ prev_speech = tts_speech
133
+ print(tts_mel.size())
134
+
135
+ tts_speechs.append(tts_speech)
136
+ tts_mels.append(tts_mel)
137
+
138
+ # Convert Mel spectrogram to audio using HiFi-GAN
139
+ tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
140
+
141
+ return tts_speech.cpu()
142
+
model_server.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import json
6
+ import uuid
7
+
8
+ from fastapi import FastAPI, Request
9
+ from fastapi.responses import StreamingResponse
10
+ from transformers import AutoModel, AutoTokenizer
11
+ import torch
12
+ import uvicorn
13
+
14
+ from transformers.generation.streamers import BaseStreamer
15
+ from threading import Thread
16
+ from queue import Queue
17
+
18
+
19
+ class TokenStreamer(BaseStreamer):
20
+ def __init__(self, skip_prompt: bool = False, timeout=None):
21
+ self.skip_prompt = skip_prompt
22
+
23
+ # variables used in the streaming process
24
+ self.token_queue = Queue()
25
+ self.stop_signal = None
26
+ self.next_tokens_are_prompt = True
27
+ self.timeout = timeout
28
+
29
+ def put(self, value):
30
+ if len(value.shape) > 1 and value.shape[0] > 1:
31
+ raise ValueError("TextStreamer only supports batch size 1")
32
+ elif len(value.shape) > 1:
33
+ value = value[0]
34
+
35
+ if self.skip_prompt and self.next_tokens_are_prompt:
36
+ self.next_tokens_are_prompt = False
37
+ return
38
+
39
+ for token in value.tolist():
40
+ self.token_queue.put(token)
41
+
42
+ def end(self):
43
+ self.token_queue.put(self.stop_signal)
44
+
45
+ def __iter__(self):
46
+ return self
47
+
48
+ def __next__(self):
49
+ value = self.token_queue.get(timeout=self.timeout)
50
+ if value == self.stop_signal:
51
+ raise StopIteration()
52
+ else:
53
+ return value
54
+
55
+
56
+ class ModelWorker:
57
+ def __init__(self, model_path, device='cuda'):
58
+ self.device = device
59
+ self.glm_model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
60
+ device_map=device,low_cpu_mem_usage=True,load_in_8bit=True).eval()
61
+ self.glm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
62
+
63
+ @torch.inference_mode()
64
+ def generate_stream(self, params):
65
+ tokenizer, model = self.glm_tokenizer, self.glm_model
66
+
67
+ prompt = params["prompt"]
68
+
69
+ temperature = float(params.get("temperature", 1.0))
70
+ top_p = float(params.get("top_p", 1.0))
71
+ max_new_tokens = int(params.get("max_new_tokens", 256))
72
+
73
+ inputs = tokenizer([prompt], return_tensors="pt")
74
+ inputs = inputs.to(self.device)
75
+ streamer = TokenStreamer(skip_prompt=True)
76
+ thread = Thread(target=model.generate,
77
+ kwargs=dict(**inputs, max_new_tokens=int(max_new_tokens),
78
+ temperature=float(temperature), top_p=float(top_p),
79
+ streamer=streamer))
80
+ thread.start()
81
+ for token_id in streamer:
82
+ yield (json.dumps({"token_id": token_id, "error_code": 0}) + "\n").encode()
83
+
84
+ def generate_stream_gate(self, params):
85
+ try:
86
+ for x in self.generate_stream(params):
87
+ yield x
88
+ except Exception as e:
89
+ print("Caught Unknown Error", e)
90
+ ret = {
91
+ "text": "Server Error",
92
+ "error_code": 1,
93
+ }
94
+ yield (json.dumps(ret)+ "\n").encode()
95
+
96
+
97
+ app = FastAPI()
98
+
99
+
100
+ @app.post("/generate_stream")
101
+ async def generate_stream(request: Request):
102
+ params = await request.json()
103
+
104
+ generator = worker.generate_stream_gate(params)
105
+ return StreamingResponse(generator)
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument("--host", type=str, default="localhost")
111
+ parser.add_argument("--port", type=int, default=10000)
112
+ parser.add_argument("--model-path", type=str, default="glm-4-voice-9b-int8")
113
+ args = parser.parse_args()
114
+
115
+ worker = ModelWorker(args.model_path)
116
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
quantifization.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+
4
+ device = "cuda:0"
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("glm-4-voice-9b", trust_remote_code=True)
7
+
8
+ tokenizer.chat_template = "{{role}}: {{content}}"
9
+
10
+ query = "你好"
11
+
12
+ inputs = tokenizer.apply_chat_template([{"role": "user", "content": query}],
13
+ add_generation_prompt=True,
14
+ tokenize=True,
15
+ return_tensors="pt",
16
+ return_dict=True
17
+ )
18
+
19
+ inputs = inputs.to(device)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ "glm-4-voice-9b",
22
+ low_cpu_mem_usage=True,
23
+ trust_remote_code=True,
24
+ load_in_8bit=True
25
+ ).eval()
26
+ model.save_pretrained("glm-4-voice-9b-int8")
27
+ tokenizer.save_pretrained("glm-4-voice-9b-int8")
requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ conformer==0.3.2
2
+ deepspeed==0.14.2; sys_platform == 'linux'
3
+ diffusers==0.27.2
4
+ fastapi==0.115.3
5
+ fastapi-cli==0.0.4
6
+ gdown==5.1.0
7
+ gradio==5.3.0
8
+ grpcio==1.57.0
9
+ grpcio-tools==1.57.0
10
+ huggingface_hub==0.25.2
11
+ hydra-core==1.3.2
12
+ HyperPyYAML==1.2.2
13
+ inflect==7.3.1
14
+ librosa==0.10.2
15
+ lightning==2.2.4
16
+ matplotlib==3.7.5
17
+ modelscope==1.15.0
18
+ networkx==3.1
19
+ numpy==1.24.4
20
+ omegaconf==2.3.0
21
+ onnxruntime-gpu==1.16.0; sys_platform == 'linux'
22
+ onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
23
+ openai-whisper==20231117
24
+ protobuf==4.25
25
+ pydantic==2.7.0
26
+ rich==13.7.1
27
+ Requests==2.32.3
28
+ safetensors==0.4.5
29
+ soundfile==0.12.1
30
+ tensorboard==2.14.0
31
+ transformers==4.44.1
32
+ uvicorn==0.32.0
33
+ wget==3.2
34
+ WeTextProcessing==1.0.3
35
+ torch==2.3.0
36
+ torchaudio==2.3.0
web_demo.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path
3
+ import tempfile
4
+ import sys
5
+ import re
6
+ import uuid
7
+ import requests
8
+ from argparse import ArgumentParser
9
+
10
+ import torchaudio
11
+ from transformers import WhisperFeatureExtractor, AutoTokenizer, AutoModel
12
+ from speech_tokenizer.modeling_whisper import WhisperVQEncoder
13
+
14
+ # import gc
15
+
16
+
17
+ sys.path.insert(0, "./cosyvoice")
18
+ sys.path.insert(0, "./third_party/Matcha-TTS")
19
+
20
+ from speech_tokenizer.utils import extract_speech_token
21
+
22
+ import gradio as gr
23
+ import torch
24
+
25
+ audio_token_pattern = re.compile(r"<\|audio_(\d+)\|>")
26
+
27
+ from flow_inference import AudioDecoder
28
+
29
+ if __name__ == "__main__":
30
+ parser = ArgumentParser()
31
+ parser.add_argument("--host", type=str, default="localhost")
32
+ parser.add_argument("--port", type=int, default="8888")
33
+ parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder")
34
+ parser.add_argument("--model-path", type=str, default="./glm-4-voice-9b-int8")
35
+ parser.add_argument("--tokenizer-path", type=str, default="./glm-4-voice-tokenizer")
36
+ args = parser.parse_args()
37
+
38
+ flow_config = os.path.join(args.flow_path, "config.yaml")
39
+ flow_checkpoint = os.path.join(args.flow_path, 'flow.pt')
40
+ hift_checkpoint = os.path.join(args.flow_path, 'hift.pt')
41
+ glm_tokenizer = None
42
+ device = "cuda"
43
+ audio_decoder: AudioDecoder = None
44
+ whisper_model, feature_extractor = None, None
45
+
46
+
47
+ def initialize_fn():
48
+ global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer
49
+ if audio_decoder is not None:
50
+ return
51
+
52
+ # GLM
53
+ glm_tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
54
+
55
+ # Flow & Hift
56
+ audio_decoder = AudioDecoder(config_path=flow_config, flow_ckpt_path=flow_checkpoint,
57
+ hift_ckpt_path=hift_checkpoint,
58
+ device=device)
59
+
60
+ # Speech tokenizer
61
+ whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device)
62
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path)
63
+
64
+
65
+ def clear_fn():
66
+ return [], [], '', '', '', None, None
67
+
68
+
69
+ def inference_fn(
70
+ temperature: float,
71
+ top_p: float,
72
+ max_new_token: int,
73
+ input_mode,
74
+ audio_path: str | None,
75
+ input_text: str | None,
76
+ history: list[dict],
77
+ previous_input_tokens: str,
78
+ previous_completion_tokens: str,
79
+ ):
80
+
81
+ # gc.collect()
82
+ # torch.cuda.empty_cache()
83
+ # torch.cuda.ipc_collect()
84
+
85
+ if input_mode == "audio":
86
+ assert audio_path is not None
87
+ history.append({"role": "user", "content": {"path": audio_path}})
88
+ audio_tokens = extract_speech_token(
89
+ whisper_model, feature_extractor, [audio_path]
90
+ )[0]
91
+ if len(audio_tokens) == 0:
92
+ raise gr.Error("No audio tokens extracted")
93
+ audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens])
94
+ audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>"
95
+ user_input = audio_tokens
96
+ system_prompt = "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens. "
97
+
98
+ else:
99
+ assert input_text is not None
100
+ history.append({"role": "user", "content": input_text})
101
+ user_input = input_text
102
+ system_prompt = "User will provide you with a text instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens."
103
+
104
+
105
+ # Gather history
106
+ inputs = previous_input_tokens + previous_completion_tokens
107
+ inputs = inputs.strip()
108
+ if "<|system|>" not in inputs:
109
+ inputs += f"<|system|>\n{system_prompt}"
110
+ inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n"
111
+
112
+ with torch.no_grad():
113
+ response = requests.post(
114
+ "http://localhost:10000/generate_stream",
115
+ data=json.dumps({
116
+ "prompt": inputs,
117
+ "temperature": temperature,
118
+ "top_p": top_p,
119
+ "max_new_tokens": max_new_token,
120
+ }),
121
+ stream=True
122
+ )
123
+ text_tokens, audio_tokens = [], []
124
+ audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>')
125
+ end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>')
126
+ complete_tokens = []
127
+ prompt_speech_feat = torch.zeros(1, 0, 80).to(device)
128
+ flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device)
129
+ this_uuid = str(uuid.uuid4())
130
+ tts_speechs = []
131
+ tts_mels = []
132
+ prev_mel = None
133
+ is_finalize = False
134
+ block_size = 10
135
+ for chunk in response.iter_lines():
136
+ token_id = json.loads(chunk)["token_id"]
137
+ if token_id == end_token_id:
138
+ is_finalize = True
139
+ if len(audio_tokens) >= block_size or (is_finalize and audio_tokens):
140
+ block_size = 20
141
+ tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0)
142
+
143
+ if prev_mel is not None:
144
+ prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2)
145
+
146
+ tts_speech, tts_mel = audio_decoder.token2wav(tts_token, uuid=this_uuid,
147
+ prompt_token=flow_prompt_speech_token.to(device),
148
+ prompt_feat=prompt_speech_feat.to(device),
149
+ finalize=is_finalize)
150
+ prev_mel = tts_mel
151
+
152
+ tts_speechs.append(tts_speech.squeeze())
153
+ tts_mels.append(tts_mel)
154
+ yield history, inputs, '', '', (22050, tts_speech.squeeze().cpu().numpy()), None
155
+ flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1)
156
+ audio_tokens = []
157
+ if not is_finalize:
158
+ complete_tokens.append(token_id)
159
+ if token_id >= audio_offset:
160
+ audio_tokens.append(token_id - audio_offset)
161
+ else:
162
+ text_tokens.append(token_id)
163
+ tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
164
+ complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False)
165
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
166
+ torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav")
167
+ history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}})
168
+ history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)})
169
+ yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy())
170
+
171
+
172
+ def update_input_interface(input_mode):
173
+ if input_mode == "audio":
174
+ return [gr.update(visible=True), gr.update(visible=False)]
175
+ else:
176
+ return [gr.update(visible=False), gr.update(visible=True)]
177
+
178
+
179
+ # Create the Gradio interface
180
+ with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo:
181
+ with gr.Row():
182
+ temperature = gr.Number(
183
+ label="Temperature",
184
+ value=0.2
185
+ )
186
+
187
+ top_p = gr.Number(
188
+ label="Top p",
189
+ value=0.8
190
+ )
191
+
192
+ max_new_token = gr.Number(
193
+ label="Max new tokens",
194
+ value=2000,
195
+ )
196
+
197
+ chatbot = gr.Chatbot(
198
+ elem_id="chatbot",
199
+ bubble_full_width=False,
200
+ type="messages",
201
+ scale=1,
202
+ )
203
+
204
+ with gr.Row():
205
+ with gr.Column():
206
+ input_mode = gr.Radio(["audio", "text"], label="Input Mode", value="audio")
207
+ audio = gr.Audio(label="Input audio", type='filepath', show_download_button=True, visible=True)
208
+ text_input = gr.Textbox(label="Input text", placeholder="Enter your text here...", lines=2, visible=False)
209
+
210
+ with gr.Column():
211
+ submit_btn = gr.Button("Submit")
212
+ reset_btn = gr.Button("Clear")
213
+ output_audio = gr.Audio(label="Play", streaming=True,
214
+ autoplay=True, show_download_button=False)
215
+ complete_audio = gr.Audio(label="Last Output Audio (If Any)", show_download_button=True)
216
+
217
+
218
+
219
+ gr.Markdown("""## Debug Info""")
220
+ with gr.Row():
221
+ input_tokens = gr.Textbox(
222
+ label=f"Input Tokens",
223
+ interactive=False,
224
+ )
225
+
226
+ completion_tokens = gr.Textbox(
227
+ label=f"Completion Tokens",
228
+ interactive=False,
229
+ )
230
+
231
+ detailed_error = gr.Textbox(
232
+ label=f"Detailed Error",
233
+ interactive=False,
234
+ )
235
+
236
+ history_state = gr.State([])
237
+
238
+ respond = submit_btn.click(
239
+ inference_fn,
240
+ inputs=[
241
+ temperature,
242
+ top_p,
243
+ max_new_token,
244
+ input_mode,
245
+ audio,
246
+ text_input,
247
+ history_state,
248
+ input_tokens,
249
+ completion_tokens,
250
+ ],
251
+ outputs=[history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]
252
+ )
253
+
254
+ respond.then(lambda s: s, [history_state], chatbot)
255
+
256
+ reset_btn.click(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio])
257
+ input_mode.input(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]).then(update_input_interface, inputs=[input_mode], outputs=[audio, text_input])
258
+
259
+ initialize_fn()
260
+ # Launch the interface
261
+ demo.launch(
262
+ server_port=args.port,
263
+ server_name=args.host
264
+ )