Spaces:
Running
Running
Tuchuanhuhuhu
commited on
Commit
·
921af92
1
Parent(s):
70118ca
Added support for Groq, the super fast inference service.
Browse files- config_example.json +1 -0
- modules/config.py +3 -0
- modules/models/Groq.py +55 -0
- modules/models/base_model.py +3 -0
- modules/models/models.py +4 -0
- modules/presets.py +25 -0
- requirements.txt +1 -0
config_example.json
CHANGED
@@ -19,6 +19,7 @@
|
|
19 |
"ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
|
20 |
"ollama_host": "", // 你的 Ollama Host,用于 Ollama 对话模型
|
21 |
"huggingface_auth_token": "", // 你的 Hugging Face API Token,用于访问有限制的模型
|
|
|
22 |
|
23 |
//== Azure ==
|
24 |
"openai_api_type": "openai", // 可选项:azure, openai
|
|
|
19 |
"ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
|
20 |
"ollama_host": "", // 你的 Ollama Host,用于 Ollama 对话模型
|
21 |
"huggingface_auth_token": "", // 你的 Hugging Face API Token,用于访问有限制的模型
|
22 |
+
"groq_api_key": "", // 你的 Groq API Key,用于 Groq 对话模型(https://console.groq.com/)
|
23 |
|
24 |
//== Azure ==
|
25 |
"openai_api_type": "openai", // 可选项:azure, openai
|
modules/config.py
CHANGED
@@ -158,6 +158,9 @@ os.environ["ERNIE_SECRETKEY"] = ernie_secret_key
|
|
158 |
ollama_host = config.get("ollama_host", "")
|
159 |
os.environ["OLLAMA_HOST"] = ollama_host
|
160 |
|
|
|
|
|
|
|
161 |
load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
|
162 |
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
|
163 |
|
|
|
158 |
ollama_host = config.get("ollama_host", "")
|
159 |
os.environ["OLLAMA_HOST"] = ollama_host
|
160 |
|
161 |
+
groq_api_key = config.get("groq_api_key", "")
|
162 |
+
os.environ["GROQ_API_KEY"] = groq_api_key
|
163 |
+
|
164 |
load_config_to_environ(["openai_api_type", "azure_openai_api_key", "azure_openai_api_base_url",
|
165 |
"azure_openai_api_version", "azure_deployment_name", "azure_embedding_deployment_name", "azure_embedding_model_name"])
|
166 |
|
modules/models/Groq.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import textwrap
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
import os
|
7 |
+
from groq import Groq
|
8 |
+
import gradio as gr
|
9 |
+
import PIL
|
10 |
+
import requests
|
11 |
+
|
12 |
+
from modules.presets import i18n
|
13 |
+
|
14 |
+
from ..index_func import construct_index
|
15 |
+
from ..utils import count_token, construct_system
|
16 |
+
from .base_model import BaseLLMModel
|
17 |
+
|
18 |
+
|
19 |
+
class Groq_Client(BaseLLMModel):
|
20 |
+
def __init__(self, model_name, api_key, user_name="") -> None:
|
21 |
+
super().__init__(model_name=model_name, user=user_name)
|
22 |
+
self.api_key = api_key
|
23 |
+
self.client = Groq(
|
24 |
+
api_key=os.environ.get("GROQ_API_KEY"),
|
25 |
+
)
|
26 |
+
|
27 |
+
def _get_groq_style_input(self):
|
28 |
+
messages = [construct_system(self.system_prompt), *self.history]
|
29 |
+
return messages
|
30 |
+
|
31 |
+
def get_answer_at_once(self):
|
32 |
+
messages = self._get_groq_style_input()
|
33 |
+
chat_completion = self.client.chat.completions.create(
|
34 |
+
messages=messages,
|
35 |
+
model=self.model_name,
|
36 |
+
)
|
37 |
+
return chat_completion.choices[0].message.content, chat_completion.usage.total_tokens
|
38 |
+
|
39 |
+
|
40 |
+
def get_answer_stream_iter(self):
|
41 |
+
messages = self._get_groq_style_input()
|
42 |
+
completion = self.client.chat.completions.create(
|
43 |
+
model=self.model_name,
|
44 |
+
messages=messages,
|
45 |
+
temperature=self.temperature,
|
46 |
+
max_tokens=self.max_generation_token,
|
47 |
+
top_p=self.top_p,
|
48 |
+
stream=True,
|
49 |
+
stop=self.stop_sequence,
|
50 |
+
)
|
51 |
+
|
52 |
+
partial_text = ""
|
53 |
+
for chunk in completion:
|
54 |
+
partial_text += chunk.choices[0].delta.content or ""
|
55 |
+
yield partial_text
|
modules/models/base_model.py
CHANGED
@@ -155,6 +155,7 @@ class ModelType(Enum):
|
|
155 |
GoogleGemini = 19
|
156 |
GoogleGemma = 20
|
157 |
Ollama = 21
|
|
|
158 |
|
159 |
@classmethod
|
160 |
def get_type(cls, model_name: str):
|
@@ -173,6 +174,8 @@ class ModelType(Enum):
|
|
173 |
model_type = ModelType.OpenAI
|
174 |
elif "chatglm" in model_name_lower:
|
175 |
model_type = ModelType.ChatGLM
|
|
|
|
|
176 |
elif "ollama" in model_name_lower:
|
177 |
model_type = ModelType.Ollama
|
178 |
elif "llama" in model_name_lower or "alpaca" in model_name_lower:
|
|
|
155 |
GoogleGemini = 19
|
156 |
GoogleGemma = 20
|
157 |
Ollama = 21
|
158 |
+
Groq = 22
|
159 |
|
160 |
@classmethod
|
161 |
def get_type(cls, model_name: str):
|
|
|
174 |
model_type = ModelType.OpenAI
|
175 |
elif "chatglm" in model_name_lower:
|
176 |
model_type = ModelType.ChatGLM
|
177 |
+
elif "groq" in model_name_lower:
|
178 |
+
model_type = ModelType.Groq
|
179 |
elif "ollama" in model_name_lower:
|
180 |
model_type = ModelType.Ollama
|
181 |
elif "llama" in model_name_lower or "alpaca" in model_name_lower:
|
modules/models/models.py
CHANGED
@@ -61,6 +61,10 @@ def get_model(
|
|
61 |
logging.info(f"正在加载ChatGLM模型: {model_name}")
|
62 |
from .ChatGLM import ChatGLM_Client
|
63 |
model = ChatGLM_Client(model_name, user_name=user_name)
|
|
|
|
|
|
|
|
|
64 |
elif model_type == ModelType.LLaMA and lora_model_path == "":
|
65 |
msg = f"现在请为 {model_name} 选择LoRA模型"
|
66 |
logging.info(msg)
|
|
|
61 |
logging.info(f"正在加载ChatGLM模型: {model_name}")
|
62 |
from .ChatGLM import ChatGLM_Client
|
63 |
model = ChatGLM_Client(model_name, user_name=user_name)
|
64 |
+
elif model_type == ModelType.Groq:
|
65 |
+
logging.info(f"正在加载Groq模型: {model_name}")
|
66 |
+
from .Groq import Groq_Client
|
67 |
+
model = Groq_Client(model_name, access_key, user_name=user_name)
|
68 |
elif model_type == ModelType.LLaMA and lora_model_path == "":
|
69 |
msg = f"现在请为 {model_name} 选择LoRA模型"
|
70 |
logging.info(msg)
|
modules/presets.py
CHANGED
@@ -70,6 +70,11 @@ ONLINE_MODELS = [
|
|
70 |
"DALL-E 3",
|
71 |
"Gemini Pro",
|
72 |
"Gemini Pro Vision",
|
|
|
|
|
|
|
|
|
|
|
73 |
"GooglePaLM",
|
74 |
"Gemma 2B",
|
75 |
"Gemma 7B",
|
@@ -218,6 +223,26 @@ MODEL_METADATA = {
|
|
218 |
"repo_id": "google/gemma-7b-it",
|
219 |
"model_name": "gemma-7b-it",
|
220 |
"token_limit": 8192,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
}
|
222 |
}
|
223 |
|
|
|
70 |
"DALL-E 3",
|
71 |
"Gemini Pro",
|
72 |
"Gemini Pro Vision",
|
73 |
+
"Groq LLaMA3 8B",
|
74 |
+
"Groq LLaMA3 70B",
|
75 |
+
"Groq LLaMA2 70B",
|
76 |
+
"Groq Mixtral 8x7B",
|
77 |
+
"Groq Gemma 7B",
|
78 |
"GooglePaLM",
|
79 |
"Gemma 2B",
|
80 |
"Gemma 7B",
|
|
|
223 |
"repo_id": "google/gemma-7b-it",
|
224 |
"model_name": "gemma-7b-it",
|
225 |
"token_limit": 8192,
|
226 |
+
},
|
227 |
+
"Groq LLaMA3 8B": {
|
228 |
+
"model_name": "llama3-8b-8192",
|
229 |
+
"token_limit": 8192,
|
230 |
+
},
|
231 |
+
"Groq LLaMA3 70B": {
|
232 |
+
"model_name": "llama3-70b-8192",
|
233 |
+
"token_limit": 8192,
|
234 |
+
},
|
235 |
+
"Groq LLaMA2 70B": {
|
236 |
+
"model_name": "llama2-70b-4096",
|
237 |
+
"token_limit": 4096,
|
238 |
+
},
|
239 |
+
"Groq Mixtral 8x7B": {
|
240 |
+
"model_name": "mixtral-8x7b-32768",
|
241 |
+
"token_limit": 32768,
|
242 |
+
},
|
243 |
+
"Groq Gemma 7B": {
|
244 |
+
"model_name": "gemma-7b-it",
|
245 |
+
"token_limit": 8192,
|
246 |
}
|
247 |
}
|
248 |
|
requirements.txt
CHANGED
@@ -12,6 +12,7 @@ langchain==0.1.14
|
|
12 |
langchain-openai
|
13 |
langchainhub
|
14 |
langchain_community
|
|
|
15 |
markdown
|
16 |
PyPDF2
|
17 |
pdfplumber
|
|
|
12 |
langchain-openai
|
13 |
langchainhub
|
14 |
langchain_community
|
15 |
+
groq
|
16 |
markdown
|
17 |
PyPDF2
|
18 |
pdfplumber
|