Spaces:
Sleeping
Sleeping
johnsmith253325
commited on
Commit
·
d3fb4a3
1
Parent(s):
06281ff
feat: Added Gemma support (needs hf auth token)
Browse files- config_example.json +1 -0
- modules/config.py +4 -0
- modules/models/GoogleGemma.py +101 -0
- modules/models/LLaMA.py +2 -32
- modules/models/base_model.py +61 -28
- modules/models/models.py +7 -1
- modules/presets.py +20 -1
config_example.json
CHANGED
@@ -17,6 +17,7 @@
|
|
17 |
"claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
|
18 |
"ernie_api_key": "",// 你的文心一言在百度云中的API Key,用于文心一言对话模型
|
19 |
"ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
|
|
|
20 |
|
21 |
|
22 |
//== Azure ==
|
|
|
17 |
"claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
|
18 |
"ernie_api_key": "",// 你的文心一言在百度云中的API Key,用于文心一言对话模型
|
19 |
"ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
|
20 |
+
"huggingface_auth_token": "", // 你的 Hugging Face API Token,用于访问有限制的模型
|
21 |
|
22 |
|
23 |
//== Azure ==
|
modules/config.py
CHANGED
@@ -116,6 +116,10 @@ google_genai_api_key = config.get("google_palm_api_key", google_genai_api_key)
|
|
116 |
google_genai_api_key = config.get("google_genai_api_key", google_genai_api_key)
|
117 |
os.environ["GOOGLE_GENAI_API_KEY"] = google_genai_api_key
|
118 |
|
|
|
|
|
|
|
|
|
119 |
xmchat_api_key = config.get("xmchat_api_key", "")
|
120 |
os.environ["XMCHAT_API_KEY"] = xmchat_api_key
|
121 |
|
|
|
116 |
google_genai_api_key = config.get("google_genai_api_key", google_genai_api_key)
|
117 |
os.environ["GOOGLE_GENAI_API_KEY"] = google_genai_api_key
|
118 |
|
119 |
+
huggingface_auth_token = os.environ.get("HF_AUTH_TOKEN", "")
|
120 |
+
huggingface_auth_token = config.get("hf_auth_token", huggingface_auth_token)
|
121 |
+
os.environ["HF_AUTH_TOKEN"] = huggingface_auth_token
|
122 |
+
|
123 |
xmchat_api_key = config.get("xmchat_api_key", "")
|
124 |
os.environ["XMCHAT_API_KEY"] = xmchat_api_key
|
125 |
|
modules/models/GoogleGemma.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from threading import Thread
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
6 |
+
|
7 |
+
from ..presets import *
|
8 |
+
from .base_model import BaseLLMModel
|
9 |
+
|
10 |
+
|
11 |
+
class GoogleGemmaClient(BaseLLMModel):
|
12 |
+
def __init__(self, model_name, api_key, user_name="") -> None:
|
13 |
+
super().__init__(model_name=model_name, user=user_name)
|
14 |
+
|
15 |
+
global GEMMA_TOKENIZER, GEMMA_MODEL
|
16 |
+
# self.deinitialize()
|
17 |
+
self.max_generation_token = self.token_upper_limit
|
18 |
+
if GEMMA_TOKENIZER is None or GEMMA_MODEL is None:
|
19 |
+
model_path = None
|
20 |
+
if os.path.exists("models"):
|
21 |
+
model_dirs = os.listdir("models")
|
22 |
+
if model_name in model_dirs:
|
23 |
+
model_path = f"models/{model_name}"
|
24 |
+
if model_path is not None:
|
25 |
+
model_source = model_path
|
26 |
+
else:
|
27 |
+
if os.path.exists(
|
28 |
+
os.path.join("models", MODEL_METADATA[model_name]["model_name"])
|
29 |
+
):
|
30 |
+
model_source = os.path.join(
|
31 |
+
"models", MODEL_METADATA[model_name]["model_name"]
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
try:
|
35 |
+
model_source = MODEL_METADATA[model_name]["repo_id"]
|
36 |
+
except:
|
37 |
+
model_source = model_name
|
38 |
+
dtype = torch.bfloat16
|
39 |
+
GEMMA_TOKENIZER = AutoTokenizer.from_pretrained(
|
40 |
+
model_source, use_auth_token=os.environ["HF_AUTH_TOKEN"]
|
41 |
+
)
|
42 |
+
GEMMA_MODEL = AutoModelForCausalLM.from_pretrained(
|
43 |
+
model_source,
|
44 |
+
device_map="auto",
|
45 |
+
torch_dtype=dtype,
|
46 |
+
trust_remote_code=True,
|
47 |
+
resume_download=True,
|
48 |
+
use_auth_token=os.environ["HF_AUTH_TOKEN"],
|
49 |
+
)
|
50 |
+
|
51 |
+
def deinitialize(self):
|
52 |
+
global GEMMA_TOKENIZER, GEMMA_MODEL
|
53 |
+
GEMMA_TOKENIZER = None
|
54 |
+
GEMMA_MODEL = None
|
55 |
+
self.clear_cuda_cache()
|
56 |
+
logging.info("GEMMA deinitialized")
|
57 |
+
|
58 |
+
def _get_gemma_style_input(self):
|
59 |
+
global GEMMA_TOKENIZER
|
60 |
+
# messages = [{"role": "system", "content": self.system_prompt}, *self.history] # system prompt is not supported
|
61 |
+
messages = self.history
|
62 |
+
prompt = GEMMA_TOKENIZER.apply_chat_template(
|
63 |
+
messages, tokenize=False, add_generation_prompt=True
|
64 |
+
)
|
65 |
+
inputs = GEMMA_TOKENIZER.encode(
|
66 |
+
prompt, add_special_tokens=True, return_tensors="pt"
|
67 |
+
)
|
68 |
+
return inputs
|
69 |
+
|
70 |
+
def get_answer_at_once(self):
|
71 |
+
global GEMMA_TOKENIZER, GEMMA_MODEL
|
72 |
+
inputs = self._get_gemma_style_input()
|
73 |
+
outputs = GEMMA_MODEL.generate(
|
74 |
+
input_ids=inputs.to(GEMMA_MODEL.device),
|
75 |
+
max_new_tokens=self.max_generation_token,
|
76 |
+
)
|
77 |
+
generated_token_count = outputs.shape[1] - inputs.shape[1]
|
78 |
+
outputs = GEMMA_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
|
79 |
+
outputs = outputs.split("<start_of_turn>model\n")[-1][:-5]
|
80 |
+
self.clear_cuda_cache()
|
81 |
+
return outputs, generated_token_count
|
82 |
+
|
83 |
+
def get_answer_stream_iter(self):
|
84 |
+
global GEMMA_TOKENIZER, GEMMA_MODEL
|
85 |
+
inputs = self._get_gemma_style_input()
|
86 |
+
streamer = TextIteratorStreamer(
|
87 |
+
GEMMA_TOKENIZER, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
88 |
+
)
|
89 |
+
input_kwargs = dict(
|
90 |
+
input_ids=inputs.to(GEMMA_MODEL.device),
|
91 |
+
max_new_tokens=self.max_generation_token,
|
92 |
+
streamer=streamer,
|
93 |
+
)
|
94 |
+
t = Thread(target=GEMMA_MODEL.generate, kwargs=input_kwargs)
|
95 |
+
t.start()
|
96 |
+
|
97 |
+
partial_text = ""
|
98 |
+
for new_text in streamer:
|
99 |
+
partial_text += new_text
|
100 |
+
yield partial_text
|
101 |
+
self.clear_cuda_cache()
|
modules/models/LLaMA.py
CHANGED
@@ -2,14 +2,12 @@ from __future__ import annotations
|
|
2 |
|
3 |
import json
|
4 |
import os
|
5 |
-
|
6 |
-
from huggingface_hub import hf_hub_download
|
7 |
from llama_cpp import Llama
|
8 |
|
9 |
from ..index_func import *
|
10 |
from ..presets import *
|
11 |
from ..utils import *
|
12 |
-
from .base_model import BaseLLMModel
|
13 |
|
14 |
SYS_PREFIX = "<<SYS>>\n"
|
15 |
SYS_POSTFIX = "\n<</SYS>>\n\n"
|
@@ -19,34 +17,6 @@ OUTPUT_PREFIX = "[/INST] "
|
|
19 |
OUTPUT_POSTFIX = "</s>"
|
20 |
|
21 |
|
22 |
-
def download(repo_id, filename, retry=10):
|
23 |
-
if os.path.exists("./models/downloaded_models.json"):
|
24 |
-
with open("./models/downloaded_models.json", "r") as f:
|
25 |
-
downloaded_models = json.load(f)
|
26 |
-
if repo_id in downloaded_models:
|
27 |
-
return downloaded_models[repo_id]["path"]
|
28 |
-
else:
|
29 |
-
downloaded_models = {}
|
30 |
-
while retry > 0:
|
31 |
-
try:
|
32 |
-
model_path = hf_hub_download(
|
33 |
-
repo_id=repo_id,
|
34 |
-
filename=filename,
|
35 |
-
cache_dir="models",
|
36 |
-
resume_download=True,
|
37 |
-
)
|
38 |
-
downloaded_models[repo_id] = {"path": model_path}
|
39 |
-
with open("./models/downloaded_models.json", "w") as f:
|
40 |
-
json.dump(downloaded_models, f)
|
41 |
-
break
|
42 |
-
except:
|
43 |
-
print("Error downloading model, retrying...")
|
44 |
-
retry -= 1
|
45 |
-
if retry == 0:
|
46 |
-
raise Exception("Error downloading model, please try again later.")
|
47 |
-
return model_path
|
48 |
-
|
49 |
-
|
50 |
class LLaMA_Client(BaseLLMModel):
|
51 |
def __init__(self, model_name, lora_path=None, user_name="") -> None:
|
52 |
super().__init__(model_name=model_name, user=user_name)
|
@@ -115,7 +85,7 @@ class LLaMA_Client(BaseLLMModel):
|
|
115 |
iter = self.model(
|
116 |
context,
|
117 |
max_tokens=self.max_generation_token,
|
118 |
-
stop=[SYS_PREFIX, SYS_POSTFIX, INST_PREFIX, OUTPUT_PREFIX,OUTPUT_POSTFIX],
|
119 |
echo=False,
|
120 |
stream=True,
|
121 |
)
|
|
|
2 |
|
3 |
import json
|
4 |
import os
|
|
|
|
|
5 |
from llama_cpp import Llama
|
6 |
|
7 |
from ..index_func import *
|
8 |
from ..presets import *
|
9 |
from ..utils import *
|
10 |
+
from .base_model import BaseLLMModel, download
|
11 |
|
12 |
SYS_PREFIX = "<<SYS>>\n"
|
13 |
SYS_POSTFIX = "\n<</SYS>>\n\n"
|
|
|
17 |
OUTPUT_POSTFIX = "</s>"
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
class LLaMA_Client(BaseLLMModel):
|
21 |
def __init__(self, model_name, lora_path=None, user_name="") -> None:
|
22 |
super().__init__(model_name=model_name, user=user_name)
|
|
|
85 |
iter = self.model(
|
86 |
context,
|
87 |
max_tokens=self.max_generation_token,
|
88 |
+
stop=[SYS_PREFIX, SYS_POSTFIX, INST_PREFIX, OUTPUT_PREFIX, OUTPUT_POSTFIX],
|
89 |
echo=False,
|
90 |
stream=True,
|
91 |
)
|
modules/models/base_model.py
CHANGED
@@ -1,43 +1,41 @@
|
|
1 |
from __future__ import annotations
|
2 |
-
from typing import TYPE_CHECKING, List
|
3 |
|
4 |
-
import
|
|
|
5 |
import json
|
6 |
-
import
|
7 |
import os
|
8 |
-
import sys
|
9 |
-
import requests
|
10 |
-
import urllib3
|
11 |
-
import traceback
|
12 |
import pathlib
|
13 |
import shutil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
import colorama
|
|
|
|
|
|
|
|
|
17 |
from duckduckgo_search import DDGS
|
18 |
-
from
|
19 |
-
import
|
20 |
-
import aiohttp
|
21 |
-
from enum import Enum
|
22 |
-
|
23 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
24 |
-
from langchain.callbacks.base import BaseCallbackManager
|
25 |
-
|
26 |
-
from typing import Any, Dict, List, Optional, Union
|
27 |
-
|
28 |
-
from langchain.callbacks.base import BaseCallbackHandler
|
29 |
-
from langchain.input import print_text
|
30 |
-
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
31 |
-
from threading import Thread, Condition
|
32 |
-
from collections import deque
|
33 |
from langchain.chat_models.base import BaseChatModel
|
34 |
-
from langchain.
|
|
|
|
|
|
|
35 |
|
36 |
-
from ..presets import *
|
37 |
-
from ..index_func import *
|
38 |
-
from ..utils import *
|
39 |
from .. import shared
|
40 |
from ..config import retrieve_proxy
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
class CallbackToIterator:
|
@@ -155,6 +153,7 @@ class ModelType(Enum):
|
|
155 |
ERNIE = 17
|
156 |
DALLE3 = 18
|
157 |
GoogleGemini = 19
|
|
|
158 |
|
159 |
@classmethod
|
160 |
def get_type(cls, model_name: str):
|
@@ -201,11 +200,41 @@ class ModelType(Enum):
|
|
201 |
model_type = ModelType.ERNIE
|
202 |
elif "dall" in model_name_lower:
|
203 |
model_type = ModelType.DALLE3
|
|
|
|
|
204 |
else:
|
205 |
model_type = ModelType.LLaMA
|
206 |
return model_type
|
207 |
|
208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
class BaseLLMModel:
|
210 |
def __init__(
|
211 |
self,
|
@@ -371,10 +400,10 @@ class BaseLLMModel:
|
|
371 |
status = i18n("总结完成")
|
372 |
logging.info(i18n("生成内容总结中……"))
|
373 |
os.environ["OPENAI_API_KEY"] = self.api_key
|
|
|
374 |
from langchain.chains.summarize import load_summarize_chain
|
375 |
-
from langchain.prompts import PromptTemplate
|
376 |
from langchain.chat_models import ChatOpenAI
|
377 |
-
from langchain.
|
378 |
|
379 |
prompt_template = (
|
380 |
"Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN "
|
@@ -1055,6 +1084,10 @@ class BaseLLMModel:
|
|
1055 |
"""deinitialize the model, implement if needed"""
|
1056 |
pass
|
1057 |
|
|
|
|
|
|
|
|
|
1058 |
|
1059 |
class Base_Chat_Langchain_Client(BaseLLMModel):
|
1060 |
def __init__(self, model_name, user_name=""):
|
|
|
1 |
from __future__ import annotations
|
|
|
2 |
|
3 |
+
import asyncio
|
4 |
+
import gc
|
5 |
import json
|
6 |
+
import logging
|
7 |
import os
|
|
|
|
|
|
|
|
|
8 |
import pathlib
|
9 |
import shutil
|
10 |
+
import sys
|
11 |
+
import traceback
|
12 |
+
from collections import deque
|
13 |
+
from enum import Enum
|
14 |
+
from itertools import islice
|
15 |
+
from threading import Condition, Thread
|
16 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
17 |
|
18 |
+
import aiohttp
|
19 |
import colorama
|
20 |
+
import commentjson as cjson
|
21 |
+
import requests
|
22 |
+
import torch
|
23 |
+
import urllib3
|
24 |
from duckduckgo_search import DDGS
|
25 |
+
from huggingface_hub import hf_hub_download
|
26 |
+
from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
|
|
|
|
|
|
|
27 |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
from langchain.chat_models.base import BaseChatModel
|
29 |
+
from langchain.input import print_text
|
30 |
+
from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage,
|
31 |
+
HumanMessage, LLMResult, SystemMessage)
|
32 |
+
from tqdm import tqdm
|
33 |
|
|
|
|
|
|
|
34 |
from .. import shared
|
35 |
from ..config import retrieve_proxy
|
36 |
+
from ..index_func import *
|
37 |
+
from ..presets import *
|
38 |
+
from ..utils import *
|
39 |
|
40 |
|
41 |
class CallbackToIterator:
|
|
|
153 |
ERNIE = 17
|
154 |
DALLE3 = 18
|
155 |
GoogleGemini = 19
|
156 |
+
GoogleGemma = 20
|
157 |
|
158 |
@classmethod
|
159 |
def get_type(cls, model_name: str):
|
|
|
200 |
model_type = ModelType.ERNIE
|
201 |
elif "dall" in model_name_lower:
|
202 |
model_type = ModelType.DALLE3
|
203 |
+
elif "gemma" in model_name_lower:
|
204 |
+
model_type = ModelType.GoogleGemma
|
205 |
else:
|
206 |
model_type = ModelType.LLaMA
|
207 |
return model_type
|
208 |
|
209 |
|
210 |
+
def download(repo_id, filename, retry=10):
|
211 |
+
if os.path.exists("./models/downloaded_models.json"):
|
212 |
+
with open("./models/downloaded_models.json", "r") as f:
|
213 |
+
downloaded_models = json.load(f)
|
214 |
+
if repo_id in downloaded_models:
|
215 |
+
return downloaded_models[repo_id]["path"]
|
216 |
+
else:
|
217 |
+
downloaded_models = {}
|
218 |
+
while retry > 0:
|
219 |
+
try:
|
220 |
+
model_path = hf_hub_download(
|
221 |
+
repo_id=repo_id,
|
222 |
+
filename=filename,
|
223 |
+
cache_dir="models",
|
224 |
+
resume_download=True,
|
225 |
+
)
|
226 |
+
downloaded_models[repo_id] = {"path": model_path}
|
227 |
+
with open("./models/downloaded_models.json", "w") as f:
|
228 |
+
json.dump(downloaded_models, f)
|
229 |
+
break
|
230 |
+
except:
|
231 |
+
print("Error downloading model, retrying...")
|
232 |
+
retry -= 1
|
233 |
+
if retry == 0:
|
234 |
+
raise Exception("Error downloading model, please try again later.")
|
235 |
+
return model_path
|
236 |
+
|
237 |
+
|
238 |
class BaseLLMModel:
|
239 |
def __init__(
|
240 |
self,
|
|
|
400 |
status = i18n("总结完成")
|
401 |
logging.info(i18n("生成内容总结中……"))
|
402 |
os.environ["OPENAI_API_KEY"] = self.api_key
|
403 |
+
from langchain.callbacks import StdOutCallbackHandler
|
404 |
from langchain.chains.summarize import load_summarize_chain
|
|
|
405 |
from langchain.chat_models import ChatOpenAI
|
406 |
+
from langchain.prompts import PromptTemplate
|
407 |
|
408 |
prompt_template = (
|
409 |
"Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN "
|
|
|
1084 |
"""deinitialize the model, implement if needed"""
|
1085 |
pass
|
1086 |
|
1087 |
+
def clear_cuda_cache(self):
|
1088 |
+
gc.collect()
|
1089 |
+
torch.cuda.empty_cache()
|
1090 |
+
|
1091 |
|
1092 |
class Base_Chat_Langchain_Client(BaseLLMModel):
|
1093 |
def __init__(self, model_name, user_name=""):
|
modules/models/models.py
CHANGED
@@ -138,8 +138,14 @@ def get_model(
|
|
138 |
from .DALLE3 import OpenAI_DALLE3_Client
|
139 |
access_key = os.environ.get("OPENAI_API_KEY", access_key)
|
140 |
model = OpenAI_DALLE3_Client(model_name, api_key=access_key, user_name=user_name)
|
|
|
|
|
|
|
|
|
141 |
elif model_type == ModelType.Unknown:
|
142 |
-
raise ValueError(f"
|
|
|
|
|
143 |
logging.info(msg)
|
144 |
except Exception as e:
|
145 |
import traceback
|
|
|
138 |
from .DALLE3 import OpenAI_DALLE3_Client
|
139 |
access_key = os.environ.get("OPENAI_API_KEY", access_key)
|
140 |
model = OpenAI_DALLE3_Client(model_name, api_key=access_key, user_name=user_name)
|
141 |
+
elif model_type == ModelType.GoogleGemma:
|
142 |
+
from .GoogleGemma import GoogleGemmaClient
|
143 |
+
model = GoogleGemmaClient(
|
144 |
+
model_name, access_key, user_name=user_name)
|
145 |
elif model_type == ModelType.Unknown:
|
146 |
+
raise ValueError(f"Unknown model: {model_name}")
|
147 |
+
else:
|
148 |
+
raise ValueError(f"Unimplemented model type: {model_type}")
|
149 |
logging.info(msg)
|
150 |
except Exception as e:
|
151 |
import traceback
|
modules/presets.py
CHANGED
@@ -10,6 +10,8 @@ CHATGLM_MODEL = None
|
|
10 |
CHATGLM_TOKENIZER = None
|
11 |
LLAMA_MODEL = None
|
12 |
LLAMA_INFERENCER = None
|
|
|
|
|
13 |
|
14 |
# ChatGPT 设置
|
15 |
INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
|
@@ -67,6 +69,8 @@ ONLINE_MODELS = [
|
|
67 |
"Gemini Pro",
|
68 |
"Gemini Pro Vision",
|
69 |
"GooglePaLM",
|
|
|
|
|
70 |
"xmchat",
|
71 |
"Azure OpenAI",
|
72 |
"yuanai-1.0-base_10B",
|
@@ -178,6 +182,16 @@ MODEL_METADATA = {
|
|
178 |
"Gemini Pro Vision": {
|
179 |
"model_name": "gemini-pro-vision",
|
180 |
"token_limit": 30720,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
}
|
182 |
}
|
183 |
|
@@ -193,7 +207,12 @@ os.makedirs("lora", exist_ok=True)
|
|
193 |
os.makedirs("history", exist_ok=True)
|
194 |
for dir_name in os.listdir("models"):
|
195 |
if os.path.isdir(os.path.join("models", dir_name)):
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
197 |
MODELS.append(dir_name)
|
198 |
|
199 |
TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
|
|
|
10 |
CHATGLM_TOKENIZER = None
|
11 |
LLAMA_MODEL = None
|
12 |
LLAMA_INFERENCER = None
|
13 |
+
GEMMA_MODEL = None
|
14 |
+
GEMMA_TOKENIZER = None
|
15 |
|
16 |
# ChatGPT 设置
|
17 |
INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
|
|
|
69 |
"Gemini Pro",
|
70 |
"Gemini Pro Vision",
|
71 |
"GooglePaLM",
|
72 |
+
"Gemma 2B",
|
73 |
+
"Gemma 7B",
|
74 |
"xmchat",
|
75 |
"Azure OpenAI",
|
76 |
"yuanai-1.0-base_10B",
|
|
|
182 |
"Gemini Pro Vision": {
|
183 |
"model_name": "gemini-pro-vision",
|
184 |
"token_limit": 30720,
|
185 |
+
},
|
186 |
+
"Gemma 2B": {
|
187 |
+
"repo_id": "google/gemma-2b-it",
|
188 |
+
"model_name": "gemma-2b-it",
|
189 |
+
"token_limit": 8192,
|
190 |
+
},
|
191 |
+
"Gemma 7B": {
|
192 |
+
"repo_id": "google/gemma-7b-it",
|
193 |
+
"model_name": "gemma-7b-it",
|
194 |
+
"token_limit": 8192,
|
195 |
}
|
196 |
}
|
197 |
|
|
|
207 |
os.makedirs("history", exist_ok=True)
|
208 |
for dir_name in os.listdir("models"):
|
209 |
if os.path.isdir(os.path.join("models", dir_name)):
|
210 |
+
display_name = None
|
211 |
+
for model_name, metadata in MODEL_METADATA.items():
|
212 |
+
if "model_name" in metadata and metadata["model_name"] == dir_name:
|
213 |
+
display_name = model_name
|
214 |
+
break
|
215 |
+
if display_name is None:
|
216 |
MODELS.append(dir_name)
|
217 |
|
218 |
TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
|