Tuchuanhuhuhu commited on
Commit
921af92
·
1 Parent(s): 70118ca

Added support for Groq, the super fast inference service.

Browse files
config_example.json CHANGED
@@ -19,6 +19,7 @@
19
  "ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
20
  "ollama_host": "", // 你的 Ollama Host,用于 Ollama 对话模型
21
  "huggingface_auth_token": "", // 你的 Hugging Face API Token,用于访问有限制的模型
 
22
 
23
  //== Azure ==
24
  "openai_api_type": "openai", // 可选项:azure, openai
 
19
  "ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
20
  "ollama_host": "", // 你的 Ollama Host,用于 Ollama 对话模型
21
  "huggingface_auth_token": "", // 你的 Hugging Face API Token,用于访问有限制的模型
22
+ "groq_api_key": "", // 你的 Groq API Key,用于 Groq 对话模型(https://console.groq.com/)
23
 
24
  //== Azure ==
25
  "openai_api_type": "openai", // 可选项:azure, openai
modules/config.py CHANGED
@@ -158,6 +158,9 @@ os.environ["ERNIE_SECRETKEY"] = ernie_secret_key
158
  ollama_host = config.get("ollama_host", "")
159
  os.environ["OLLAMA_HOST"] = ollama_host
160
 
 
 
 
161
  load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
162
  "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
163
 
 
158
  ollama_host = config.get("ollama_host", "")
159
  os.environ["OLLAMA_HOST"] = ollama_host
160
 
161
+ groq_api_key = config.get("groq_api_key", "")
162
+ os.environ["GROQ_API_KEY"] = groq_api_key
163
+
164
  load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
165
  "azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
166
 
modules/models/Groq.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import textwrap
4
+ import uuid
5
+
6
+ import os
7
+ from groq import Groq
8
+ import gradio as gr
9
+ import PIL
10
+ import requests
11
+
12
+ from modules.presets import i18n
13
+
14
+ from ..index_func import construct_index
15
+ from ..utils import count_token, construct_system
16
+ from .base_model import BaseLLMModel
17
+
18
+
19
+ class Groq_Client(BaseLLMModel):
20
+ def __init__(self, model_name, api_key, user_name="") -> None:
21
+ super().__init__(model_name=model_name, user=user_name)
22
+ self.api_key = api_key
23
+ self.client = Groq(
24
+ api_key=os.environ.get("GROQ_API_KEY"),
25
+ )
26
+
27
+ def _get_groq_style_input(self):
28
+ messages = [construct_system(self.system_prompt), *self.history]
29
+ return messages
30
+
31
+ def get_answer_at_once(self):
32
+ messages = self._get_groq_style_input()
33
+ chat_completion = self.client.chat.completions.create(
34
+ messages=messages,
35
+ model=self.model_name,
36
+ )
37
+ return chat_completion.choices[0].message.content, chat_completion.usage.total_tokens
38
+
39
+
40
+ def get_answer_stream_iter(self):
41
+ messages = self._get_groq_style_input()
42
+ completion = self.client.chat.completions.create(
43
+ model=self.model_name,
44
+ messages=messages,
45
+ temperature=self.temperature,
46
+ max_tokens=self.max_generation_token,
47
+ top_p=self.top_p,
48
+ stream=True,
49
+ stop=self.stop_sequence,
50
+ )
51
+
52
+ partial_text = ""
53
+ for chunk in completion:
54
+ partial_text += chunk.choices[0].delta.content or ""
55
+ yield partial_text
modules/models/base_model.py CHANGED
@@ -155,6 +155,7 @@ class ModelType(Enum):
155
  GoogleGemini = 19
156
  GoogleGemma = 20
157
  Ollama = 21
 
158
 
159
  @classmethod
160
  def get_type(cls, model_name: str):
@@ -173,6 +174,8 @@ class ModelType(Enum):
173
  model_type = ModelType.OpenAI
174
  elif "chatglm" in model_name_lower:
175
  model_type = ModelType.ChatGLM
 
 
176
  elif "ollama" in model_name_lower:
177
  model_type = ModelType.Ollama
178
  elif "llama" in model_name_lower or "alpaca" in model_name_lower:
 
155
  GoogleGemini = 19
156
  GoogleGemma = 20
157
  Ollama = 21
158
+ Groq = 22
159
 
160
  @classmethod
161
  def get_type(cls, model_name: str):
 
174
  model_type = ModelType.OpenAI
175
  elif "chatglm" in model_name_lower:
176
  model_type = ModelType.ChatGLM
177
+ elif "groq" in model_name_lower:
178
+ model_type = ModelType.Groq
179
  elif "ollama" in model_name_lower:
180
  model_type = ModelType.Ollama
181
  elif "llama" in model_name_lower or "alpaca" in model_name_lower:
modules/models/models.py CHANGED
@@ -61,6 +61,10 @@ def get_model(
61
  logging.info(f"正在加载ChatGLM模型: {model_name}")
62
  from .ChatGLM import ChatGLM_Client
63
  model = ChatGLM_Client(model_name, user_name=user_name)
 
 
 
 
64
  elif model_type == ModelType.LLaMA and lora_model_path == "":
65
  msg = f"现在请为 {model_name} 选择LoRA模型"
66
  logging.info(msg)
 
61
  logging.info(f"正在加载ChatGLM模型: {model_name}")
62
  from .ChatGLM import ChatGLM_Client
63
  model = ChatGLM_Client(model_name, user_name=user_name)
64
+ elif model_type == ModelType.Groq:
65
+ logging.info(f"正在加载Groq模型: {model_name}")
66
+ from .Groq import Groq_Client
67
+ model = Groq_Client(model_name, access_key, user_name=user_name)
68
  elif model_type == ModelType.LLaMA and lora_model_path == "":
69
  msg = f"现在请为 {model_name} 选择LoRA模型"
70
  logging.info(msg)
modules/presets.py CHANGED
@@ -70,6 +70,11 @@ ONLINE_MODELS = [
70
  "DALL-E 3",
71
  "Gemini Pro",
72
  "Gemini Pro Vision",
 
 
 
 
 
73
  "GooglePaLM",
74
  "Gemma 2B",
75
  "Gemma 7B",
@@ -218,6 +223,26 @@ MODEL_METADATA = {
218
  "repo_id": "google/gemma-7b-it",
219
  "model_name": "gemma-7b-it",
220
  "token_limit": 8192,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  }
222
  }
223
 
 
70
  "DALL-E 3",
71
  "Gemini Pro",
72
  "Gemini Pro Vision",
73
+ "Groq LLaMA3 8B",
74
+ "Groq LLaMA3 70B",
75
+ "Groq LLaMA2 70B",
76
+ "Groq Mixtral 8x7B",
77
+ "Groq Gemma 7B",
78
  "GooglePaLM",
79
  "Gemma 2B",
80
  "Gemma 7B",
 
223
  "repo_id": "google/gemma-7b-it",
224
  "model_name": "gemma-7b-it",
225
  "token_limit": 8192,
226
+ },
227
+ "Groq LLaMA3 8B": {
228
+ "model_name": "llama3-8b-8192",
229
+ "token_limit": 8192,
230
+ },
231
+ "Groq LLaMA3 70B": {
232
+ "model_name": "llama3-70b-8192",
233
+ "token_limit": 8192,
234
+ },
235
+ "Groq LLaMA2 70B": {
236
+ "model_name": "llama2-70b-4096",
237
+ "token_limit": 4096,
238
+ },
239
+ "Groq Mixtral 8x7B": {
240
+ "model_name": "mixtral-8x7b-32768",
241
+ "token_limit": 32768,
242
+ },
243
+ "Groq Gemma 7B": {
244
+ "model_name": "gemma-7b-it",
245
+ "token_limit": 8192,
246
  }
247
  }
248
 
requirements.txt CHANGED
@@ -12,6 +12,7 @@ langchain==0.1.14
12
  langchain-openai
13
  langchainhub
14
  langchain_community
 
15
  markdown
16
  PyPDF2
17
  pdfplumber
 
12
  langchain-openai
13
  langchainhub
14
  langchain_community
15
+ groq
16
  markdown
17
  PyPDF2
18
  pdfplumber