Spaces:
Sleeping
Sleeping
Tuchuanhuhuhu
commited on
Commit
·
b8e4532
1
Parent(s):
c915adf
MOSS支持流式传输
Browse files- modules/models/MOSS.py +246 -4
modules/models/MOSS.py
CHANGED
@@ -1,11 +1,16 @@
|
|
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import warnings
|
4 |
import platform
|
|
|
|
|
5 |
|
6 |
from huggingface_hub import snapshot_download
|
7 |
from transformers.generation.utils import logger
|
8 |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
|
|
9 |
try:
|
10 |
from transformers import MossForCausalLM, MossTokenizer
|
11 |
except (ImportError, ModuleNotFoundError):
|
@@ -25,7 +30,7 @@ class MOSS_Client(BaseLLMModel):
|
|
25 |
logger.setLevel("ERROR")
|
26 |
warnings.filterwarnings("ignore")
|
27 |
if MOSS_MODEL is None:
|
28 |
-
model_path = "
|
29 |
if not os.path.exists(model_path):
|
30 |
model_path = snapshot_download("fnlp/moss-moon-003-sft")
|
31 |
|
@@ -57,10 +62,33 @@ class MOSS_Client(BaseLLMModel):
|
|
57 |
self.text_to_image_switch = '- Text-to-image: disabled.\n'
|
58 |
self.image_edition_switch = '- Image edition: disabled.\n'
|
59 |
self.text_to_speech_switch = '- Text-to-speech: disabled.\n'
|
60 |
-
self.token_upper_limit =
|
61 |
-
self.top_p = 0.
|
62 |
-
self.top_k =
|
63 |
self.temperature = 0.7
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def _get_main_instruction(self):
|
66 |
return self.system_prompt + self.web_search_switch + self.calculator_switch + self.equation_solver_switch + self.text_to_image_switch + self.image_edition_switch + self.text_to_speech_switch
|
@@ -86,6 +114,7 @@ class MOSS_Client(BaseLLMModel):
|
|
86 |
top_k=self.top_k,
|
87 |
top_p=self.top_p,
|
88 |
temperature=self.temperature,
|
|
|
89 |
num_return_sequences=1,
|
90 |
eos_token_id=106068,
|
91 |
pad_token_id=MOSS_TOKENIZER.pad_token_id)
|
@@ -93,6 +122,219 @@ class MOSS_Client(BaseLLMModel):
|
|
93 |
response = response.lstrip("<|MOSS|>: ")
|
94 |
return response, len(response)
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
if __name__ == "__main__":
|
98 |
model = MOSS_Client("MOSS")
|
|
|
1 |
+
# 代码主要来源于 https://github.com/OpenLMLab/MOSS/blob/main/moss_inference.py
|
2 |
+
|
3 |
import os
|
4 |
import torch
|
5 |
import warnings
|
6 |
import platform
|
7 |
+
import time
|
8 |
+
from typing import Union, List, Tuple, Optional, Dict
|
9 |
|
10 |
from huggingface_hub import snapshot_download
|
11 |
from transformers.generation.utils import logger
|
12 |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
14 |
try:
|
15 |
from transformers import MossForCausalLM, MossTokenizer
|
16 |
except (ImportError, ModuleNotFoundError):
|
|
|
30 |
logger.setLevel("ERROR")
|
31 |
warnings.filterwarnings("ignore")
|
32 |
if MOSS_MODEL is None:
|
33 |
+
model_path = "/home/guest/llm_models/moss/moss-moon-003-sft"
|
34 |
if not os.path.exists(model_path):
|
35 |
model_path = snapshot_download("fnlp/moss-moon-003-sft")
|
36 |
|
|
|
62 |
self.text_to_image_switch = '- Text-to-image: disabled.\n'
|
63 |
self.image_edition_switch = '- Image edition: disabled.\n'
|
64 |
self.text_to_speech_switch = '- Text-to-speech: disabled.\n'
|
65 |
+
self.token_upper_limit = 2048
|
66 |
+
self.top_p = 0.8
|
67 |
+
self.top_k = 40
|
68 |
self.temperature = 0.7
|
69 |
+
self.repetition_penalty = 1.1
|
70 |
+
self.max_generation_token = 2048
|
71 |
+
|
72 |
+
self.default_paras = {
|
73 |
+
"temperature":0.7,
|
74 |
+
"top_k":0,
|
75 |
+
"top_p":0.8,
|
76 |
+
"length_penalty":1,
|
77 |
+
"max_time":60,
|
78 |
+
"repetition_penalty":1.1,
|
79 |
+
"max_iterations":512,
|
80 |
+
"regulation_start":512,
|
81 |
+
}
|
82 |
+
self.num_layers, self.heads, self.hidden, self.vocab_size = 34, 24, 256, 107008
|
83 |
+
|
84 |
+
self.moss_startwords = torch.LongTensor([27, 91, 44, 18420, 91, 31175])
|
85 |
+
self.tool_startwords = torch.LongTensor([27, 91, 6935, 1746, 91, 31175])
|
86 |
+
self.tool_specialwords = torch.LongTensor([6045])
|
87 |
+
|
88 |
+
self.innerthought_stopwords = torch.LongTensor([MOSS_TOKENIZER.convert_tokens_to_ids("<eot>")])
|
89 |
+
self.tool_stopwords = torch.LongTensor([MOSS_TOKENIZER.convert_tokens_to_ids("<eoc>")])
|
90 |
+
self.result_stopwords = torch.LongTensor([MOSS_TOKENIZER.convert_tokens_to_ids("<eor>")])
|
91 |
+
self.moss_stopwords = torch.LongTensor([MOSS_TOKENIZER.convert_tokens_to_ids("<eom>")])
|
92 |
|
93 |
def _get_main_instruction(self):
|
94 |
return self.system_prompt + self.web_search_switch + self.calculator_switch + self.equation_solver_switch + self.text_to_image_switch + self.image_edition_switch + self.text_to_speech_switch
|
|
|
114 |
top_k=self.top_k,
|
115 |
top_p=self.top_p,
|
116 |
temperature=self.temperature,
|
117 |
+
repetition_penalty=self.repetition_penalty,
|
118 |
num_return_sequences=1,
|
119 |
eos_token_id=106068,
|
120 |
pad_token_id=MOSS_TOKENIZER.pad_token_id)
|
|
|
122 |
response = response.lstrip("<|MOSS|>: ")
|
123 |
return response, len(response)
|
124 |
|
125 |
+
def get_answer_stream_iter(self):
|
126 |
+
prompt = self._get_moss_style_inputs()
|
127 |
+
it = self.forward(prompt)
|
128 |
+
for i in it:
|
129 |
+
yield i
|
130 |
+
|
131 |
+
def preprocess(self, raw_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
132 |
+
"""
|
133 |
+
Preprocesses the raw input text by adding the prefix and tokenizing it.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
raw_text (str): The raw input text.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the tokenized input IDs and attention mask.
|
140 |
+
"""
|
141 |
+
|
142 |
+
tokens = MOSS_TOKENIZER.batch_encode_plus([raw_text], return_tensors="pt")
|
143 |
+
input_ids, attention_mask = tokens['input_ids'], tokens['attention_mask']
|
144 |
+
|
145 |
+
return input_ids, attention_mask
|
146 |
+
|
147 |
+
def forward(
|
148 |
+
self, data: str, paras: Optional[Dict[str, float]] = None
|
149 |
+
) -> List[str]:
|
150 |
+
"""
|
151 |
+
Generates text using the model, given the input data and generation parameters.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
data (str): The input text for generation.
|
155 |
+
paras (Optional[Dict[str, float]], optional): A dictionary of generation parameters. Defaults to None.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
List[str]: The list of generated texts.
|
159 |
+
"""
|
160 |
+
input_ids, attention_mask = self.preprocess(data)
|
161 |
+
|
162 |
+
if not paras:
|
163 |
+
paras = self.default_paras
|
164 |
+
|
165 |
+
streaming_iter = self.streaming_topk_search(
|
166 |
+
input_ids,
|
167 |
+
attention_mask,
|
168 |
+
temperature=self.temperature,
|
169 |
+
repetition_penalty=self.repetition_penalty,
|
170 |
+
top_k=self.top_k,
|
171 |
+
top_p=self.top_p,
|
172 |
+
max_iterations=self.max_generation_token,
|
173 |
+
regulation_start=paras["regulation_start"],
|
174 |
+
length_penalty=paras["length_penalty"],
|
175 |
+
max_time=paras["max_time"],
|
176 |
+
)
|
177 |
+
|
178 |
+
for outputs in streaming_iter:
|
179 |
+
|
180 |
+
preds = MOSS_TOKENIZER.batch_decode(outputs)
|
181 |
+
|
182 |
+
res = [pred.lstrip(data) for pred in preds]
|
183 |
+
|
184 |
+
yield res[0]
|
185 |
+
|
186 |
+
def streaming_topk_search(
|
187 |
+
self,
|
188 |
+
input_ids: torch.Tensor,
|
189 |
+
attention_mask: torch.Tensor,
|
190 |
+
temperature: float = 0.7,
|
191 |
+
repetition_penalty: float = 1.1,
|
192 |
+
top_k: int = 0,
|
193 |
+
top_p: float = 0.92,
|
194 |
+
max_iterations: int = 1024,
|
195 |
+
regulation_start: int = 512,
|
196 |
+
length_penalty: float = 1,
|
197 |
+
max_time: int = 60,
|
198 |
+
) -> torch.Tensor:
|
199 |
+
"""
|
200 |
+
Performs a streaming top-k search using the given parameters.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
input_ids (torch.Tensor): The input IDs tensor.
|
204 |
+
attention_mask (torch.Tensor): The attention mask tensor.
|
205 |
+
temperature (float, optional): The temperature for logits. Defaults to 0.7.
|
206 |
+
repetition_penalty (float, optional): The repetition penalty factor. Defaults to 1.1.
|
207 |
+
top_k (int, optional): The top-k value for filtering. Defaults to 0.
|
208 |
+
top_p (float, optional): The top-p value for filtering. Defaults to 0.92.
|
209 |
+
max_iterations (int, optional): The maximum number of iterations. Defaults to 1024.
|
210 |
+
regulation_start (int, optional): The number of iterations after which regulation starts. Defaults to 512.
|
211 |
+
length_penalty (float, optional): The length penalty factor. Defaults to 1.
|
212 |
+
max_time (int, optional): The maximum allowed time in seconds. Defaults to 60.
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
torch.Tensor: The generated output IDs tensor.
|
216 |
+
"""
|
217 |
+
assert input_ids.dtype == torch.int64 and attention_mask.dtype == torch.int64
|
218 |
+
|
219 |
+
self.bsz, self.seqlen = input_ids.shape
|
220 |
+
|
221 |
+
input_ids, attention_mask = input_ids.to('cuda'), attention_mask.to('cuda')
|
222 |
+
last_token_indices = attention_mask.sum(1) - 1
|
223 |
+
|
224 |
+
moss_stopwords = self.moss_stopwords.to(input_ids.device)
|
225 |
+
queue_for_moss_stopwords = torch.empty(size=(self.bsz, len(self.moss_stopwords)), device=input_ids.device, dtype=input_ids.dtype)
|
226 |
+
all_shall_stop = torch.tensor([False] * self.bsz, device=input_ids.device)
|
227 |
+
moss_stop = torch.tensor([False] * self.bsz, device=input_ids.device)
|
228 |
+
|
229 |
+
generations, start_time = torch.ones(self.bsz, 1, dtype=torch.int64), time.time()
|
230 |
+
|
231 |
+
past_key_values = None
|
232 |
+
for i in range(int(max_iterations)):
|
233 |
+
logits, past_key_values = self.infer_(input_ids if i == 0 else new_generated_id, attention_mask, past_key_values)
|
234 |
+
|
235 |
+
if i == 0:
|
236 |
+
logits = logits.gather(1, last_token_indices.view(self.bsz, 1, 1).repeat(1, 1, self.vocab_size)).squeeze(1)
|
237 |
+
else:
|
238 |
+
logits = logits[:, -1, :]
|
239 |
+
|
240 |
+
|
241 |
+
if repetition_penalty > 1:
|
242 |
+
score = logits.gather(1, input_ids)
|
243 |
+
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
244 |
+
# just gather the histroy token from input_ids, preprocess then scatter back
|
245 |
+
# here we apply extra work to exclude special token
|
246 |
+
|
247 |
+
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
248 |
+
|
249 |
+
logits.scatter_(1, input_ids, score)
|
250 |
+
|
251 |
+
logits = logits / temperature
|
252 |
+
|
253 |
+
filtered_logits = self.top_k_top_p_filtering(logits, top_k, top_p)
|
254 |
+
probabilities = torch.softmax(filtered_logits, dim=-1)
|
255 |
+
|
256 |
+
cur_len = i
|
257 |
+
if cur_len > int(regulation_start):
|
258 |
+
for i in self.moss_stopwords:
|
259 |
+
probabilities[:, i] = probabilities[:, i] * pow(length_penalty, cur_len - regulation_start)
|
260 |
+
|
261 |
+
new_generated_id = torch.multinomial(probabilities, 1)
|
262 |
+
|
263 |
+
# update extra_ignored_tokens
|
264 |
+
new_generated_id_cpu = new_generated_id.cpu()
|
265 |
+
|
266 |
+
input_ids, attention_mask = torch.cat([input_ids, new_generated_id], dim=1), torch.cat([attention_mask, torch.ones((self.bsz, 1), device=attention_mask.device, dtype=attention_mask.dtype)], dim=1)
|
267 |
+
|
268 |
+
generations = torch.cat([generations, new_generated_id.cpu()], dim=1)
|
269 |
+
|
270 |
+
# stop words components
|
271 |
+
queue_for_moss_stopwords = torch.cat([queue_for_moss_stopwords[:, 1:], new_generated_id], dim=1)
|
272 |
+
|
273 |
+
moss_stop |= (queue_for_moss_stopwords == moss_stopwords).all(1)
|
274 |
+
|
275 |
+
all_shall_stop |= moss_stop
|
276 |
+
|
277 |
+
if all_shall_stop.all().item():
|
278 |
+
break
|
279 |
+
elif time.time() - start_time > max_time:
|
280 |
+
break
|
281 |
+
|
282 |
+
yield input_ids
|
283 |
+
|
284 |
+
def top_k_top_p_filtering(self, logits, top_k, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1, ):
|
285 |
+
if top_k > 0:
|
286 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
287 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
288 |
+
logits[indices_to_remove] = filter_value
|
289 |
+
|
290 |
+
if top_p < 1.0:
|
291 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
292 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
293 |
+
|
294 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
295 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
296 |
+
if min_tokens_to_keep > 1:
|
297 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
298 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
299 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
300 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
301 |
+
sorted_indices_to_remove[..., 0] = 0
|
302 |
+
# scatter sorted tensors to original indexing
|
303 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
304 |
+
logits[indices_to_remove] = filter_value
|
305 |
+
|
306 |
+
return logits
|
307 |
+
|
308 |
+
def infer_(
|
309 |
+
self,
|
310 |
+
input_ids: torch.Tensor,
|
311 |
+
attention_mask: torch.Tensor,
|
312 |
+
past_key_values: Optional[Tuple[torch.Tensor]],
|
313 |
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
314 |
+
"""
|
315 |
+
Inference method that computes logits and past key values.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
input_ids (torch.Tensor): The input IDs tensor.
|
319 |
+
attention_mask (torch.Tensor): The attention mask tensor.
|
320 |
+
past_key_values (Optional[Tuple[torch.Tensor]]): The past key values tuple.
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
Tuple[torch.Tensor, Tuple[torch.Tensor]]: A tuple containing the logits and past key values.
|
324 |
+
"""
|
325 |
+
inputs = {
|
326 |
+
"input_ids": input_ids,
|
327 |
+
"attention_mask": attention_mask,
|
328 |
+
"past_key_values": past_key_values,
|
329 |
+
}
|
330 |
+
with torch.no_grad():
|
331 |
+
outputs: BaseModelOutputWithPast = MOSS_MODEL(**inputs)
|
332 |
+
|
333 |
+
return outputs.logits, outputs.past_key_values
|
334 |
+
|
335 |
+
def __call__(self, input):
|
336 |
+
return self.forward(input)
|
337 |
+
|
338 |
|
339 |
if __name__ == "__main__":
|
340 |
model = MOSS_Client("MOSS")
|