Tuchuanhuhuhu commited on
Commit
b8e4532
·
1 Parent(s): c915adf

MOSS支持流式传输

Browse files
Files changed (1) hide show
  1. modules/models/MOSS.py +246 -4
modules/models/MOSS.py CHANGED
@@ -1,11 +1,16 @@
 
 
1
  import os
2
  import torch
3
  import warnings
4
  import platform
 
 
5
 
6
  from huggingface_hub import snapshot_download
7
  from transformers.generation.utils import logger
8
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
9
  try:
10
  from transformers import MossForCausalLM, MossTokenizer
11
  except (ImportError, ModuleNotFoundError):
@@ -25,7 +30,7 @@ class MOSS_Client(BaseLLMModel):
25
  logger.setLevel("ERROR")
26
  warnings.filterwarnings("ignore")
27
  if MOSS_MODEL is None:
28
- model_path = "models/moss-moon-003-sft"
29
  if not os.path.exists(model_path):
30
  model_path = snapshot_download("fnlp/moss-moon-003-sft")
31
 
@@ -57,10 +62,33 @@ class MOSS_Client(BaseLLMModel):
57
  self.text_to_image_switch = '- Text-to-image: disabled.\n'
58
  self.image_edition_switch = '- Image edition: disabled.\n'
59
  self.text_to_speech_switch = '- Text-to-speech: disabled.\n'
60
- self.token_upper_limit = 4096
61
- self.top_p = 0.95
62
- self.top_k = 50
63
  self.temperature = 0.7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def _get_main_instruction(self):
66
  return self.system_prompt + self.web_search_switch + self.calculator_switch + self.equation_solver_switch + self.text_to_image_switch + self.image_edition_switch + self.text_to_speech_switch
@@ -86,6 +114,7 @@ class MOSS_Client(BaseLLMModel):
86
  top_k=self.top_k,
87
  top_p=self.top_p,
88
  temperature=self.temperature,
 
89
  num_return_sequences=1,
90
  eos_token_id=106068,
91
  pad_token_id=MOSS_TOKENIZER.pad_token_id)
@@ -93,6 +122,219 @@ class MOSS_Client(BaseLLMModel):
93
  response = response.lstrip("<|MOSS|>: ")
94
  return response, len(response)
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == "__main__":
98
  model = MOSS_Client("MOSS")
 
1
+ # 代码主要来源于 https://github.com/OpenLMLab/MOSS/blob/main/moss_inference.py
2
+
3
  import os
4
  import torch
5
  import warnings
6
  import platform
7
+ import time
8
+ from typing import Union, List, Tuple, Optional, Dict
9
 
10
  from huggingface_hub import snapshot_download
11
  from transformers.generation.utils import logger
12
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast
14
  try:
15
  from transformers import MossForCausalLM, MossTokenizer
16
  except (ImportError, ModuleNotFoundError):
 
30
  logger.setLevel("ERROR")
31
  warnings.filterwarnings("ignore")
32
  if MOSS_MODEL is None:
33
+ model_path = "/home/guest/llm_models/moss/moss-moon-003-sft"
34
  if not os.path.exists(model_path):
35
  model_path = snapshot_download("fnlp/moss-moon-003-sft")
36
 
 
62
  self.text_to_image_switch = '- Text-to-image: disabled.\n'
63
  self.image_edition_switch = '- Image edition: disabled.\n'
64
  self.text_to_speech_switch = '- Text-to-speech: disabled.\n'
65
+ self.token_upper_limit = 2048
66
+ self.top_p = 0.8
67
+ self.top_k = 40
68
  self.temperature = 0.7
69
+ self.repetition_penalty = 1.1
70
+ self.max_generation_token = 2048
71
+
72
+ self.default_paras = {
73
+ "temperature":0.7,
74
+ "top_k":0,
75
+ "top_p":0.8,
76
+ "length_penalty":1,
77
+ "max_time":60,
78
+ "repetition_penalty":1.1,
79
+ "max_iterations":512,
80
+ "regulation_start":512,
81
+ }
82
+ self.num_layers, self.heads, self.hidden, self.vocab_size = 34, 24, 256, 107008
83
+
84
+ self.moss_startwords = torch.LongTensor([27, 91, 44, 18420, 91, 31175])
85
+ self.tool_startwords = torch.LongTensor([27, 91, 6935, 1746, 91, 31175])
86
+ self.tool_specialwords = torch.LongTensor([6045])
87
+
88
+ self.innerthought_stopwords = torch.LongTensor([MOSS_TOKENIZER.convert_tokens_to_ids("<eot>")])
89
+ self.tool_stopwords = torch.LongTensor([MOSS_TOKENIZER.convert_tokens_to_ids("<eoc>")])
90
+ self.result_stopwords = torch.LongTensor([MOSS_TOKENIZER.convert_tokens_to_ids("<eor>")])
91
+ self.moss_stopwords = torch.LongTensor([MOSS_TOKENIZER.convert_tokens_to_ids("<eom>")])
92
 
93
  def _get_main_instruction(self):
94
  return self.system_prompt + self.web_search_switch + self.calculator_switch + self.equation_solver_switch + self.text_to_image_switch + self.image_edition_switch + self.text_to_speech_switch
 
114
  top_k=self.top_k,
115
  top_p=self.top_p,
116
  temperature=self.temperature,
117
+ repetition_penalty=self.repetition_penalty,
118
  num_return_sequences=1,
119
  eos_token_id=106068,
120
  pad_token_id=MOSS_TOKENIZER.pad_token_id)
 
122
  response = response.lstrip("<|MOSS|>: ")
123
  return response, len(response)
124
 
125
+ def get_answer_stream_iter(self):
126
+ prompt = self._get_moss_style_inputs()
127
+ it = self.forward(prompt)
128
+ for i in it:
129
+ yield i
130
+
131
+ def preprocess(self, raw_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
132
+ """
133
+ Preprocesses the raw input text by adding the prefix and tokenizing it.
134
+
135
+ Args:
136
+ raw_text (str): The raw input text.
137
+
138
+ Returns:
139
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing the tokenized input IDs and attention mask.
140
+ """
141
+
142
+ tokens = MOSS_TOKENIZER.batch_encode_plus([raw_text], return_tensors="pt")
143
+ input_ids, attention_mask = tokens['input_ids'], tokens['attention_mask']
144
+
145
+ return input_ids, attention_mask
146
+
147
+ def forward(
148
+ self, data: str, paras: Optional[Dict[str, float]] = None
149
+ ) -> List[str]:
150
+ """
151
+ Generates text using the model, given the input data and generation parameters.
152
+
153
+ Args:
154
+ data (str): The input text for generation.
155
+ paras (Optional[Dict[str, float]], optional): A dictionary of generation parameters. Defaults to None.
156
+
157
+ Returns:
158
+ List[str]: The list of generated texts.
159
+ """
160
+ input_ids, attention_mask = self.preprocess(data)
161
+
162
+ if not paras:
163
+ paras = self.default_paras
164
+
165
+ streaming_iter = self.streaming_topk_search(
166
+ input_ids,
167
+ attention_mask,
168
+ temperature=self.temperature,
169
+ repetition_penalty=self.repetition_penalty,
170
+ top_k=self.top_k,
171
+ top_p=self.top_p,
172
+ max_iterations=self.max_generation_token,
173
+ regulation_start=paras["regulation_start"],
174
+ length_penalty=paras["length_penalty"],
175
+ max_time=paras["max_time"],
176
+ )
177
+
178
+ for outputs in streaming_iter:
179
+
180
+ preds = MOSS_TOKENIZER.batch_decode(outputs)
181
+
182
+ res = [pred.lstrip(data) for pred in preds]
183
+
184
+ yield res[0]
185
+
186
+ def streaming_topk_search(
187
+ self,
188
+ input_ids: torch.Tensor,
189
+ attention_mask: torch.Tensor,
190
+ temperature: float = 0.7,
191
+ repetition_penalty: float = 1.1,
192
+ top_k: int = 0,
193
+ top_p: float = 0.92,
194
+ max_iterations: int = 1024,
195
+ regulation_start: int = 512,
196
+ length_penalty: float = 1,
197
+ max_time: int = 60,
198
+ ) -> torch.Tensor:
199
+ """
200
+ Performs a streaming top-k search using the given parameters.
201
+
202
+ Args:
203
+ input_ids (torch.Tensor): The input IDs tensor.
204
+ attention_mask (torch.Tensor): The attention mask tensor.
205
+ temperature (float, optional): The temperature for logits. Defaults to 0.7.
206
+ repetition_penalty (float, optional): The repetition penalty factor. Defaults to 1.1.
207
+ top_k (int, optional): The top-k value for filtering. Defaults to 0.
208
+ top_p (float, optional): The top-p value for filtering. Defaults to 0.92.
209
+ max_iterations (int, optional): The maximum number of iterations. Defaults to 1024.
210
+ regulation_start (int, optional): The number of iterations after which regulation starts. Defaults to 512.
211
+ length_penalty (float, optional): The length penalty factor. Defaults to 1.
212
+ max_time (int, optional): The maximum allowed time in seconds. Defaults to 60.
213
+
214
+ Returns:
215
+ torch.Tensor: The generated output IDs tensor.
216
+ """
217
+ assert input_ids.dtype == torch.int64 and attention_mask.dtype == torch.int64
218
+
219
+ self.bsz, self.seqlen = input_ids.shape
220
+
221
+ input_ids, attention_mask = input_ids.to('cuda'), attention_mask.to('cuda')
222
+ last_token_indices = attention_mask.sum(1) - 1
223
+
224
+ moss_stopwords = self.moss_stopwords.to(input_ids.device)
225
+ queue_for_moss_stopwords = torch.empty(size=(self.bsz, len(self.moss_stopwords)), device=input_ids.device, dtype=input_ids.dtype)
226
+ all_shall_stop = torch.tensor([False] * self.bsz, device=input_ids.device)
227
+ moss_stop = torch.tensor([False] * self.bsz, device=input_ids.device)
228
+
229
+ generations, start_time = torch.ones(self.bsz, 1, dtype=torch.int64), time.time()
230
+
231
+ past_key_values = None
232
+ for i in range(int(max_iterations)):
233
+ logits, past_key_values = self.infer_(input_ids if i == 0 else new_generated_id, attention_mask, past_key_values)
234
+
235
+ if i == 0:
236
+ logits = logits.gather(1, last_token_indices.view(self.bsz, 1, 1).repeat(1, 1, self.vocab_size)).squeeze(1)
237
+ else:
238
+ logits = logits[:, -1, :]
239
+
240
+
241
+ if repetition_penalty > 1:
242
+ score = logits.gather(1, input_ids)
243
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
244
+ # just gather the histroy token from input_ids, preprocess then scatter back
245
+ # here we apply extra work to exclude special token
246
+
247
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
248
+
249
+ logits.scatter_(1, input_ids, score)
250
+
251
+ logits = logits / temperature
252
+
253
+ filtered_logits = self.top_k_top_p_filtering(logits, top_k, top_p)
254
+ probabilities = torch.softmax(filtered_logits, dim=-1)
255
+
256
+ cur_len = i
257
+ if cur_len > int(regulation_start):
258
+ for i in self.moss_stopwords:
259
+ probabilities[:, i] = probabilities[:, i] * pow(length_penalty, cur_len - regulation_start)
260
+
261
+ new_generated_id = torch.multinomial(probabilities, 1)
262
+
263
+ # update extra_ignored_tokens
264
+ new_generated_id_cpu = new_generated_id.cpu()
265
+
266
+ input_ids, attention_mask = torch.cat([input_ids, new_generated_id], dim=1), torch.cat([attention_mask, torch.ones((self.bsz, 1), device=attention_mask.device, dtype=attention_mask.dtype)], dim=1)
267
+
268
+ generations = torch.cat([generations, new_generated_id.cpu()], dim=1)
269
+
270
+ # stop words components
271
+ queue_for_moss_stopwords = torch.cat([queue_for_moss_stopwords[:, 1:], new_generated_id], dim=1)
272
+
273
+ moss_stop |= (queue_for_moss_stopwords == moss_stopwords).all(1)
274
+
275
+ all_shall_stop |= moss_stop
276
+
277
+ if all_shall_stop.all().item():
278
+ break
279
+ elif time.time() - start_time > max_time:
280
+ break
281
+
282
+ yield input_ids
283
+
284
+ def top_k_top_p_filtering(self, logits, top_k, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1, ):
285
+ if top_k > 0:
286
+ # Remove all tokens with a probability less than the last token of the top-k
287
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
288
+ logits[indices_to_remove] = filter_value
289
+
290
+ if top_p < 1.0:
291
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
292
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
293
+
294
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
295
+ sorted_indices_to_remove = cumulative_probs > top_p
296
+ if min_tokens_to_keep > 1:
297
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
298
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
299
+ # Shift the indices to the right to keep also the first token above the threshold
300
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
301
+ sorted_indices_to_remove[..., 0] = 0
302
+ # scatter sorted tensors to original indexing
303
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
304
+ logits[indices_to_remove] = filter_value
305
+
306
+ return logits
307
+
308
+ def infer_(
309
+ self,
310
+ input_ids: torch.Tensor,
311
+ attention_mask: torch.Tensor,
312
+ past_key_values: Optional[Tuple[torch.Tensor]],
313
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
314
+ """
315
+ Inference method that computes logits and past key values.
316
+
317
+ Args:
318
+ input_ids (torch.Tensor): The input IDs tensor.
319
+ attention_mask (torch.Tensor): The attention mask tensor.
320
+ past_key_values (Optional[Tuple[torch.Tensor]]): The past key values tuple.
321
+
322
+ Returns:
323
+ Tuple[torch.Tensor, Tuple[torch.Tensor]]: A tuple containing the logits and past key values.
324
+ """
325
+ inputs = {
326
+ "input_ids": input_ids,
327
+ "attention_mask": attention_mask,
328
+ "past_key_values": past_key_values,
329
+ }
330
+ with torch.no_grad():
331
+ outputs: BaseModelOutputWithPast = MOSS_MODEL(**inputs)
332
+
333
+ return outputs.logits, outputs.past_key_values
334
+
335
+ def __call__(self, input):
336
+ return self.forward(input)
337
+
338
 
339
  if __name__ == "__main__":
340
  model = MOSS_Client("MOSS")