Feliciano Long commited on
Commit
3ac03d8
·
unverified ·
1 Parent(s): 0f313bc

feat: Add Inspur YuanAI 1.0 (#739)

Browse files

* initial commit and basic chat function of YuanAI

* handles parameters, system_prompt(few shots) as examples

* receives system prompt when changing models

* code refactor: optimize classes and api key

modules/models/base_model.py CHANGED
@@ -33,6 +33,7 @@ class ModelType(Enum):
33
  XMChat = 3
34
  StableLM = 4
35
  MOSS = 5
 
36
 
37
  @classmethod
38
  def get_type(cls, model_name: str):
@@ -50,6 +51,8 @@ class ModelType(Enum):
50
  model_type = ModelType.StableLM
51
  elif "moss" in model_name_lower:
52
  model_type = ModelType.MOSS
 
 
53
  else:
54
  model_type = ModelType.Unknown
55
  return model_type
 
33
  XMChat = 3
34
  StableLM = 4
35
  MOSS = 5
36
+ YuanAI = 6
37
 
38
  @classmethod
39
  def get_type(cls, model_name: str):
 
51
  model_type = ModelType.StableLM
52
  elif "moss" in model_name_lower:
53
  model_type = ModelType.MOSS
54
+ elif "yuanai" in model_name_lower:
55
+ model_type = ModelType.YuanAI
56
  else:
57
  model_type = ModelType.Unknown
58
  return model_type
modules/models/inspurai.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 代码主要来源于 https://github.com/Shawn-Inspur/Yuan-1.0/blob/main/yuan_api/inspurai.py
2
+
3
+ import hashlib
4
+ import json
5
+ import os
6
+ import time
7
+ import uuid
8
+ from datetime import datetime
9
+
10
+ import pytz
11
+ import requests
12
+
13
+ from modules.presets import NO_APIKEY_MSG
14
+ from modules.models.base_model import BaseLLMModel
15
+
16
+
17
+ class Example:
18
+ """ store some examples(input, output pairs and formats) for few-shots to prime the model."""
19
+
20
+ def __init__(self, inp, out):
21
+ self.input = inp
22
+ self.output = out
23
+ self.id = uuid.uuid4().hex
24
+
25
+ def get_input(self):
26
+ """return the input of the example."""
27
+ return self.input
28
+
29
+ def get_output(self):
30
+ """Return the output of the example."""
31
+ return self.output
32
+
33
+ def get_id(self):
34
+ """Returns the unique ID of the example."""
35
+ return self.id
36
+
37
+ def as_dict(self):
38
+ return {
39
+ "input": self.get_input(),
40
+ "output": self.get_output(),
41
+ "id": self.get_id(),
42
+ }
43
+
44
+
45
+ class Yuan:
46
+ """The main class for a user to interface with the Inspur Yuan API.
47
+ A user can set account info and add examples of the API request.
48
+ """
49
+
50
+ def __init__(self,
51
+ engine='base_10B',
52
+ temperature=0.9,
53
+ max_tokens=100,
54
+ input_prefix='',
55
+ input_suffix='\n',
56
+ output_prefix='答:',
57
+ output_suffix='\n\n',
58
+ append_output_prefix_to_query=False,
59
+ topK=1,
60
+ topP=0.9,
61
+ frequencyPenalty=1.2,
62
+ responsePenalty=1.2,
63
+ noRepeatNgramSize=2):
64
+
65
+ self.examples = {}
66
+ self.engine = engine
67
+ self.temperature = temperature
68
+ self.max_tokens = max_tokens
69
+ self.topK = topK
70
+ self.topP = topP
71
+ self.frequencyPenalty = frequencyPenalty
72
+ self.responsePenalty = responsePenalty
73
+ self.noRepeatNgramSize = noRepeatNgramSize
74
+ self.input_prefix = input_prefix
75
+ self.input_suffix = input_suffix
76
+ self.output_prefix = output_prefix
77
+ self.output_suffix = output_suffix
78
+ self.append_output_prefix_to_query = append_output_prefix_to_query
79
+ self.stop = (output_suffix + input_prefix).strip()
80
+ self.api = None
81
+
82
+ # if self.engine not in ['base_10B','translate','dialog']:
83
+ # raise Exception('engine must be one of [\'base_10B\',\'translate\',\'dialog\'] ')
84
+ def set_account(self, api_key):
85
+ account = api_key.split('||')
86
+ self.api = YuanAPI(user=account[0], phone=account[1])
87
+
88
+ def add_example(self, ex):
89
+ """Add an example to the object.
90
+ Example must be an instance of the Example class."""
91
+ assert isinstance(ex, Example), "Please create an Example object."
92
+ self.examples[ex.get_id()] = ex
93
+
94
+ def delete_example(self, id):
95
+ """Delete example with the specific id."""
96
+ if id in self.examples:
97
+ del self.examples[id]
98
+
99
+ def get_example(self, id):
100
+ """Get a single example."""
101
+ return self.examples.get(id, None)
102
+
103
+ def get_all_examples(self):
104
+ """Returns all examples as a list of dicts."""
105
+ return {k: v.as_dict() for k, v in self.examples.items()}
106
+
107
+ def get_prime_text(self):
108
+ """Formats all examples to prime the model."""
109
+ return "".join(
110
+ [self.format_example(ex) for ex in self.examples.values()])
111
+
112
+ def get_engine(self):
113
+ """Returns the engine specified for the API."""
114
+ return self.engine
115
+
116
+ def get_temperature(self):
117
+ """Returns the temperature specified for the API."""
118
+ return self.temperature
119
+
120
+ def get_max_tokens(self):
121
+ """Returns the max tokens specified for the API."""
122
+ return self.max_tokens
123
+
124
+ def craft_query(self, prompt):
125
+ """Creates the query for the API request."""
126
+ q = self.get_prime_text(
127
+ ) + self.input_prefix + prompt + self.input_suffix
128
+ if self.append_output_prefix_to_query:
129
+ q = q + self.output_prefix
130
+
131
+ return q
132
+
133
+ def format_example(self, ex):
134
+ """Formats the input, output pair."""
135
+ return self.input_prefix + ex.get_input(
136
+ ) + self.input_suffix + self.output_prefix + ex.get_output(
137
+ ) + self.output_suffix
138
+
139
+ def response(self,
140
+ query,
141
+ engine='base_10B',
142
+ max_tokens=20,
143
+ temperature=0.9,
144
+ topP=0.1,
145
+ topK=1,
146
+ frequencyPenalty=1.0,
147
+ responsePenalty=1.0,
148
+ noRepeatNgramSize=0):
149
+ """Obtains the original result returned by the API."""
150
+
151
+ if self.api is None:
152
+ return NO_APIKEY_MSG
153
+ try:
154
+ # requestId = submit_request(query,temperature,topP,topK,max_tokens, engine)
155
+ requestId = self.api.submit_request(query, temperature, topP, topK, max_tokens, engine, frequencyPenalty,
156
+ responsePenalty, noRepeatNgramSize)
157
+ response_text = self.api.reply_request(requestId)
158
+ except Exception as e:
159
+ raise e
160
+
161
+ return response_text
162
+
163
+ def del_special_chars(self, msg):
164
+ special_chars = ['<unk>', '<eod>', '#', '▃', '▁', '▂', ' ']
165
+ for char in special_chars:
166
+ msg = msg.replace(char, '')
167
+ return msg
168
+
169
+ def submit_API(self, prompt, trun=[]):
170
+ """Submit prompt to yuan API interface and obtain an pure text reply.
171
+ :prompt: Question or any content a user may input.
172
+ :return: pure text response."""
173
+ query = self.craft_query(prompt)
174
+ res = self.response(query, engine=self.engine,
175
+ max_tokens=self.max_tokens,
176
+ temperature=self.temperature,
177
+ topP=self.topP,
178
+ topK=self.topK,
179
+ frequencyPenalty=self.frequencyPenalty,
180
+ responsePenalty=self.responsePenalty,
181
+ noRepeatNgramSize=self.noRepeatNgramSize)
182
+ if 'resData' in res and res['resData'] != None:
183
+ txt = res['resData']
184
+ else:
185
+ txt = '模型返回为空,请尝试修改输入'
186
+ # 单独针对翻译模型的后处理
187
+ if self.engine == 'translate':
188
+ txt = txt.replace(' ##', '').replace(' "', '"').replace(": ", ":").replace(" ,", ",") \
189
+ .replace('英文:', '').replace('文:', '').replace("( ", "(").replace(" )", ")")
190
+ else:
191
+ txt = txt.replace(' ', '')
192
+ txt = self.del_special_chars(txt)
193
+
194
+ # trun多结束符截断模型输出
195
+ if isinstance(trun, str):
196
+ trun = [trun]
197
+ try:
198
+ if trun != None and isinstance(trun, list) and trun != []:
199
+ for tr in trun:
200
+ if tr in txt and tr != "":
201
+ txt = txt[:txt.index(tr)]
202
+ else:
203
+ continue
204
+ except:
205
+ return txt
206
+ return txt
207
+
208
+
209
+ class YuanAPI:
210
+ ACCOUNT = ''
211
+ PHONE = ''
212
+
213
+ SUBMIT_URL = "http://api.airyuan.cn:32102/v1/interface/api/infer/getRequestId?"
214
+ REPLY_URL = "http://api.airyuan.cn:32102/v1/interface/api/result?"
215
+
216
+ def __init__(self, user, phone):
217
+ self.ACCOUNT = user
218
+ self.PHONE = phone
219
+
220
+ @staticmethod
221
+ def code_md5(str):
222
+ code = str.encode("utf-8")
223
+ m = hashlib.md5()
224
+ m.update(code)
225
+ result = m.hexdigest()
226
+ return result
227
+
228
+ @staticmethod
229
+ def rest_get(url, header, timeout, show_error=False):
230
+ '''Call rest get method'''
231
+ try:
232
+ response = requests.get(url, headers=header, timeout=timeout, verify=False)
233
+ return response
234
+ except Exception as exception:
235
+ if show_error:
236
+ print(exception)
237
+ return None
238
+
239
+ def header_generation(self):
240
+ """Generate header for API request."""
241
+ t = datetime.now(pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d")
242
+ token = self.code_md5(self.ACCOUNT + self.PHONE + t)
243
+ headers = {'token': token}
244
+ return headers
245
+
246
+ def submit_request(self, query, temperature, topP, topK, max_tokens, engine, frequencyPenalty, responsePenalty,
247
+ noRepeatNgramSize):
248
+ """Submit query to the backend server and get requestID."""
249
+ headers = self.header_generation()
250
+ # url=SUBMIT_URL + "account={0}&data={1}&temperature={2}&topP={3}&topK={4}&tokensToGenerate={5}&type={6}".format(ACCOUNT,query,temperature,topP,topK,max_tokens,"api")
251
+ # url=SUBMIT_URL + "engine={0}&account={1}&data={2}&temperature={3}&topP={4}&topK={5}&tokensToGenerate={6}" \
252
+ # "&type={7}".format(engine,ACCOUNT,query,temperature,topP,topK, max_tokens,"api")
253
+ url = self.SUBMIT_URL + "engine={0}&account={1}&data={2}&temperature={3}&topP={4}&topK={5}&tokensToGenerate={6}" \
254
+ "&type={7}&frequencyPenalty={8}&responsePenalty={9}&noRepeatNgramSize={10}". \
255
+ format(engine, self.ACCOUNT, query, temperature, topP, topK, max_tokens, "api", frequencyPenalty,
256
+ responsePenalty, noRepeatNgramSize)
257
+ response = self.rest_get(url, headers, 30)
258
+ response_text = json.loads(response.text)
259
+ if response_text["flag"]:
260
+ requestId = response_text["resData"]
261
+ return requestId
262
+ else:
263
+ raise RuntimeWarning(response_text)
264
+
265
+ def reply_request(self, requestId, cycle_count=5):
266
+ """Check reply API to get the inference response."""
267
+ url = self.REPLY_URL + "account={0}&requestId={1}".format(self.ACCOUNT, requestId)
268
+ headers = self.header_generation()
269
+ response_text = {"flag": True, "resData": None}
270
+ for i in range(cycle_count):
271
+ response = self.rest_get(url, headers, 30, show_error=True)
272
+ response_text = json.loads(response.text)
273
+ if response_text["resData"] is not None:
274
+ return response_text
275
+ if response_text["flag"] is False and i == cycle_count - 1:
276
+ raise RuntimeWarning(response_text)
277
+ time.sleep(3)
278
+ return response_text
279
+
280
+
281
+ class Yuan_Client(BaseLLMModel):
282
+
283
+ def __init__(self, model_name, api_key, user_name="", system_prompt=None):
284
+ super().__init__(model_name=model_name, user=user_name)
285
+ self.history = []
286
+ self.api_key = api_key
287
+ self.system_prompt = system_prompt
288
+
289
+ self.input_prefix = ""
290
+ self.output_prefix = ""
291
+
292
+ def set_text_prefix(self, option, value):
293
+ if option == 'input_prefix':
294
+ self.input_prefix = value
295
+ elif option == 'output_prefix':
296
+ self.output_prefix = value
297
+
298
+ def get_answer_at_once(self):
299
+ # yuan temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert
300
+ temperature = self.temperature if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10
301
+ topP = self.top_p
302
+ topK = self.n_choices
303
+ # max_tokens should be in [1,200]
304
+ max_tokens = self.max_generation_token if self.max_generation_token is not None else 50
305
+ if max_tokens > 200:
306
+ max_tokens = 200
307
+ stop = self.stop_sequence if self.stop_sequence is not None else []
308
+ examples = []
309
+ system_prompt = self.system_prompt
310
+ if system_prompt is not None:
311
+ lines = system_prompt.splitlines()
312
+ # TODO: support prefixes in system prompt or settings
313
+ """
314
+ if lines[0].startswith('-'):
315
+ prefixes = lines.pop()[1:].split('|')
316
+ self.input_prefix = prefixes[0]
317
+ if len(prefixes) > 1:
318
+ self.output_prefix = prefixes[1]
319
+ if len(prefixes) > 2:
320
+ stop = prefixes[2].split(',')
321
+ """
322
+ for i in range(0, len(lines), 2):
323
+ in_line = lines[i]
324
+ out_line = lines[i + 1] if i + 1 < len(lines) else ""
325
+ examples.append((in_line, out_line))
326
+ yuan = Yuan(engine=self.model_name.replace('yuanai-1.0-', ''),
327
+ temperature=temperature,
328
+ max_tokens=max_tokens,
329
+ topK=topK,
330
+ topP=topP,
331
+ input_prefix=self.input_prefix,
332
+ input_suffix="",
333
+ output_prefix=self.output_prefix,
334
+ output_suffix="".join(stop),
335
+ )
336
+ if not self.api_key:
337
+ return NO_APIKEY_MSG, 0
338
+ yuan.set_account(self.api_key)
339
+
340
+ for in_line, out_line in examples:
341
+ yuan.add_example(Example(inp=in_line, out=out_line))
342
+
343
+ prompt = self.history[-1]["content"]
344
+ answer = yuan.submit_API(prompt, trun=stop)
345
+ return answer, len(answer)
modules/models/models.py CHANGED
@@ -599,6 +599,9 @@ def get_model(
599
  elif model_type == ModelType.MOSS:
600
  from .MOSS import MOSS_Client
601
  model = MOSS_Client(model_name, user_name=user_name)
 
 
 
602
  elif model_type == ModelType.Unknown:
603
  raise ValueError(f"未知模型: {model_name}")
604
  logging.info(msg)
 
599
  elif model_type == ModelType.MOSS:
600
  from .MOSS import MOSS_Client
601
  model = MOSS_Client(model_name, user_name=user_name)
602
+ elif model_type == ModelType.YuanAI:
603
+ from .inspurai import Yuan_Client
604
+ model = Yuan_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
605
  elif model_type == ModelType.Unknown:
606
  raise ValueError(f"未知模型: {model_name}")
607
  logging.info(msg)
modules/presets.py CHANGED
@@ -68,6 +68,10 @@ ONLINE_MODELS = [
68
  "gpt-4-32k",
69
  "gpt-4-32k-0314",
70
  "xmchat",
 
 
 
 
71
  ]
72
 
73
  LOCAL_MODELS = [
 
68
  "gpt-4-32k",
69
  "gpt-4-32k-0314",
70
  "xmchat",
71
+ "yuanai-1.0-base_10B",
72
+ "yuanai-1.0-translate",
73
+ "yuanai-1.0-dialog",
74
+ "yuanai-1.0-rhythm_poems",
75
  ]
76
 
77
  LOCAL_MODELS = [