johnsmith253325 commited on
Commit
d3fb4a3
·
1 Parent(s): 06281ff

feat: Added Gemma support (needs hf auth token)

Browse files
config_example.json CHANGED
@@ -17,6 +17,7 @@
17
  "claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
18
  "ernie_api_key": "",// 你的文心一言在百度云中的API Key,用于文心一言对话模型
19
  "ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
 
20
 
21
 
22
  //== Azure ==
 
17
  "claude_api_secret":"",// 你的 Claude API Secret,用于 Claude 对话模型
18
  "ernie_api_key": "",// 你的文心一言在百度云中的API Key,用于文心一言对话模型
19
  "ernie_secret_key": "",// 你的文心一言在百度云中的Secret Key,用于文心一言对话模型
20
+ "huggingface_auth_token": "", // 你的 Hugging Face API Token,用于访问有限制的模型
21
 
22
 
23
  //== Azure ==
modules/config.py CHANGED
@@ -116,6 +116,10 @@ google_genai_api_key = config.get("google_palm_api_key", google_genai_api_key)
116
  google_genai_api_key = config.get("google_genai_api_key", google_genai_api_key)
117
  os.environ["GOOGLE_GENAI_API_KEY"] = google_genai_api_key
118
 
 
 
 
 
119
  xmchat_api_key = config.get("xmchat_api_key", "")
120
  os.environ["XMCHAT_API_KEY"] = xmchat_api_key
121
 
 
116
  google_genai_api_key = config.get("google_genai_api_key", google_genai_api_key)
117
  os.environ["GOOGLE_GENAI_API_KEY"] = google_genai_api_key
118
 
119
+ huggingface_auth_token = os.environ.get("HF_AUTH_TOKEN", "")
120
+ huggingface_auth_token = config.get("hf_auth_token", huggingface_auth_token)
121
+ os.environ["HF_AUTH_TOKEN"] = huggingface_auth_token
122
+
123
  xmchat_api_key = config.get("xmchat_api_key", "")
124
  os.environ["XMCHAT_API_KEY"] = xmchat_api_key
125
 
modules/models/GoogleGemma.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from threading import Thread
3
+
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+
7
+ from ..presets import *
8
+ from .base_model import BaseLLMModel
9
+
10
+
11
+ class GoogleGemmaClient(BaseLLMModel):
12
+ def __init__(self, model_name, api_key, user_name="") -> None:
13
+ super().__init__(model_name=model_name, user=user_name)
14
+
15
+ global GEMMA_TOKENIZER, GEMMA_MODEL
16
+ # self.deinitialize()
17
+ self.max_generation_token = self.token_upper_limit
18
+ if GEMMA_TOKENIZER is None or GEMMA_MODEL is None:
19
+ model_path = None
20
+ if os.path.exists("models"):
21
+ model_dirs = os.listdir("models")
22
+ if model_name in model_dirs:
23
+ model_path = f"models/{model_name}"
24
+ if model_path is not None:
25
+ model_source = model_path
26
+ else:
27
+ if os.path.exists(
28
+ os.path.join("models", MODEL_METADATA[model_name]["model_name"])
29
+ ):
30
+ model_source = os.path.join(
31
+ "models", MODEL_METADATA[model_name]["model_name"]
32
+ )
33
+ else:
34
+ try:
35
+ model_source = MODEL_METADATA[model_name]["repo_id"]
36
+ except:
37
+ model_source = model_name
38
+ dtype = torch.bfloat16
39
+ GEMMA_TOKENIZER = AutoTokenizer.from_pretrained(
40
+ model_source, use_auth_token=os.environ["HF_AUTH_TOKEN"]
41
+ )
42
+ GEMMA_MODEL = AutoModelForCausalLM.from_pretrained(
43
+ model_source,
44
+ device_map="auto",
45
+ torch_dtype=dtype,
46
+ trust_remote_code=True,
47
+ resume_download=True,
48
+ use_auth_token=os.environ["HF_AUTH_TOKEN"],
49
+ )
50
+
51
+ def deinitialize(self):
52
+ global GEMMA_TOKENIZER, GEMMA_MODEL
53
+ GEMMA_TOKENIZER = None
54
+ GEMMA_MODEL = None
55
+ self.clear_cuda_cache()
56
+ logging.info("GEMMA deinitialized")
57
+
58
+ def _get_gemma_style_input(self):
59
+ global GEMMA_TOKENIZER
60
+ # messages = [{"role": "system", "content": self.system_prompt}, *self.history] # system prompt is not supported
61
+ messages = self.history
62
+ prompt = GEMMA_TOKENIZER.apply_chat_template(
63
+ messages, tokenize=False, add_generation_prompt=True
64
+ )
65
+ inputs = GEMMA_TOKENIZER.encode(
66
+ prompt, add_special_tokens=True, return_tensors="pt"
67
+ )
68
+ return inputs
69
+
70
+ def get_answer_at_once(self):
71
+ global GEMMA_TOKENIZER, GEMMA_MODEL
72
+ inputs = self._get_gemma_style_input()
73
+ outputs = GEMMA_MODEL.generate(
74
+ input_ids=inputs.to(GEMMA_MODEL.device),
75
+ max_new_tokens=self.max_generation_token,
76
+ )
77
+ generated_token_count = outputs.shape[1] - inputs.shape[1]
78
+ outputs = GEMMA_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
79
+ outputs = outputs.split("<start_of_turn>model\n")[-1][:-5]
80
+ self.clear_cuda_cache()
81
+ return outputs, generated_token_count
82
+
83
+ def get_answer_stream_iter(self):
84
+ global GEMMA_TOKENIZER, GEMMA_MODEL
85
+ inputs = self._get_gemma_style_input()
86
+ streamer = TextIteratorStreamer(
87
+ GEMMA_TOKENIZER, timeout=10.0, skip_prompt=True, skip_special_tokens=True
88
+ )
89
+ input_kwargs = dict(
90
+ input_ids=inputs.to(GEMMA_MODEL.device),
91
+ max_new_tokens=self.max_generation_token,
92
+ streamer=streamer,
93
+ )
94
+ t = Thread(target=GEMMA_MODEL.generate, kwargs=input_kwargs)
95
+ t.start()
96
+
97
+ partial_text = ""
98
+ for new_text in streamer:
99
+ partial_text += new_text
100
+ yield partial_text
101
+ self.clear_cuda_cache()
modules/models/LLaMA.py CHANGED
@@ -2,14 +2,12 @@ from __future__ import annotations
2
 
3
  import json
4
  import os
5
-
6
- from huggingface_hub import hf_hub_download
7
  from llama_cpp import Llama
8
 
9
  from ..index_func import *
10
  from ..presets import *
11
  from ..utils import *
12
- from .base_model import BaseLLMModel
13
 
14
  SYS_PREFIX = "<<SYS>>\n"
15
  SYS_POSTFIX = "\n<</SYS>>\n\n"
@@ -19,34 +17,6 @@ OUTPUT_PREFIX = "[/INST] "
19
  OUTPUT_POSTFIX = "</s>"
20
 
21
 
22
- def download(repo_id, filename, retry=10):
23
- if os.path.exists("./models/downloaded_models.json"):
24
- with open("./models/downloaded_models.json", "r") as f:
25
- downloaded_models = json.load(f)
26
- if repo_id in downloaded_models:
27
- return downloaded_models[repo_id]["path"]
28
- else:
29
- downloaded_models = {}
30
- while retry > 0:
31
- try:
32
- model_path = hf_hub_download(
33
- repo_id=repo_id,
34
- filename=filename,
35
- cache_dir="models",
36
- resume_download=True,
37
- )
38
- downloaded_models[repo_id] = {"path": model_path}
39
- with open("./models/downloaded_models.json", "w") as f:
40
- json.dump(downloaded_models, f)
41
- break
42
- except:
43
- print("Error downloading model, retrying...")
44
- retry -= 1
45
- if retry == 0:
46
- raise Exception("Error downloading model, please try again later.")
47
- return model_path
48
-
49
-
50
  class LLaMA_Client(BaseLLMModel):
51
  def __init__(self, model_name, lora_path=None, user_name="") -> None:
52
  super().__init__(model_name=model_name, user=user_name)
@@ -115,7 +85,7 @@ class LLaMA_Client(BaseLLMModel):
115
  iter = self.model(
116
  context,
117
  max_tokens=self.max_generation_token,
118
- stop=[SYS_PREFIX, SYS_POSTFIX, INST_PREFIX, OUTPUT_PREFIX,OUTPUT_POSTFIX],
119
  echo=False,
120
  stream=True,
121
  )
 
2
 
3
  import json
4
  import os
 
 
5
  from llama_cpp import Llama
6
 
7
  from ..index_func import *
8
  from ..presets import *
9
  from ..utils import *
10
+ from .base_model import BaseLLMModel, download
11
 
12
  SYS_PREFIX = "<<SYS>>\n"
13
  SYS_POSTFIX = "\n<</SYS>>\n\n"
 
17
  OUTPUT_POSTFIX = "</s>"
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class LLaMA_Client(BaseLLMModel):
21
  def __init__(self, model_name, lora_path=None, user_name="") -> None:
22
  super().__init__(model_name=model_name, user=user_name)
 
85
  iter = self.model(
86
  context,
87
  max_tokens=self.max_generation_token,
88
+ stop=[SYS_PREFIX, SYS_POSTFIX, INST_PREFIX, OUTPUT_PREFIX, OUTPUT_POSTFIX],
89
  echo=False,
90
  stream=True,
91
  )
modules/models/base_model.py CHANGED
@@ -1,43 +1,41 @@
1
  from __future__ import annotations
2
- from typing import TYPE_CHECKING, List
3
 
4
- import logging
 
5
  import json
6
- import commentjson as cjson
7
  import os
8
- import sys
9
- import requests
10
- import urllib3
11
- import traceback
12
  import pathlib
13
  import shutil
 
 
 
 
 
 
 
14
 
15
- from tqdm import tqdm
16
  import colorama
 
 
 
 
17
  from duckduckgo_search import DDGS
18
- from itertools import islice
19
- import asyncio
20
- import aiohttp
21
- from enum import Enum
22
-
23
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
24
- from langchain.callbacks.base import BaseCallbackManager
25
-
26
- from typing import Any, Dict, List, Optional, Union
27
-
28
- from langchain.callbacks.base import BaseCallbackHandler
29
- from langchain.input import print_text
30
- from langchain.schema import AgentAction, AgentFinish, LLMResult
31
- from threading import Thread, Condition
32
- from collections import deque
33
  from langchain.chat_models.base import BaseChatModel
34
- from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
 
 
 
35
 
36
- from ..presets import *
37
- from ..index_func import *
38
- from ..utils import *
39
  from .. import shared
40
  from ..config import retrieve_proxy
 
 
 
41
 
42
 
43
  class CallbackToIterator:
@@ -155,6 +153,7 @@ class ModelType(Enum):
155
  ERNIE = 17
156
  DALLE3 = 18
157
  GoogleGemini = 19
 
158
 
159
  @classmethod
160
  def get_type(cls, model_name: str):
@@ -201,11 +200,41 @@ class ModelType(Enum):
201
  model_type = ModelType.ERNIE
202
  elif "dall" in model_name_lower:
203
  model_type = ModelType.DALLE3
 
 
204
  else:
205
  model_type = ModelType.LLaMA
206
  return model_type
207
 
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  class BaseLLMModel:
210
  def __init__(
211
  self,
@@ -371,10 +400,10 @@ class BaseLLMModel:
371
  status = i18n("总结完成")
372
  logging.info(i18n("生成内容总结中……"))
373
  os.environ["OPENAI_API_KEY"] = self.api_key
 
374
  from langchain.chains.summarize import load_summarize_chain
375
- from langchain.prompts import PromptTemplate
376
  from langchain.chat_models import ChatOpenAI
377
- from langchain.callbacks import StdOutCallbackHandler
378
 
379
  prompt_template = (
380
  "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN "
@@ -1055,6 +1084,10 @@ class BaseLLMModel:
1055
  """deinitialize the model, implement if needed"""
1056
  pass
1057
 
 
 
 
 
1058
 
1059
  class Base_Chat_Langchain_Client(BaseLLMModel):
1060
  def __init__(self, model_name, user_name=""):
 
1
  from __future__ import annotations
 
2
 
3
+ import asyncio
4
+ import gc
5
  import json
6
+ import logging
7
  import os
 
 
 
 
8
  import pathlib
9
  import shutil
10
+ import sys
11
+ import traceback
12
+ from collections import deque
13
+ from enum import Enum
14
+ from itertools import islice
15
+ from threading import Condition, Thread
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
17
 
18
+ import aiohttp
19
  import colorama
20
+ import commentjson as cjson
21
+ import requests
22
+ import torch
23
+ import urllib3
24
  from duckduckgo_search import DDGS
25
+ from huggingface_hub import hf_hub_download
26
+ from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager
 
 
 
27
  from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
 
 
 
 
 
 
 
 
28
  from langchain.chat_models.base import BaseChatModel
29
+ from langchain.input import print_text
30
+ from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage,
31
+ HumanMessage, LLMResult, SystemMessage)
32
+ from tqdm import tqdm
33
 
 
 
 
34
  from .. import shared
35
  from ..config import retrieve_proxy
36
+ from ..index_func import *
37
+ from ..presets import *
38
+ from ..utils import *
39
 
40
 
41
  class CallbackToIterator:
 
153
  ERNIE = 17
154
  DALLE3 = 18
155
  GoogleGemini = 19
156
+ GoogleGemma = 20
157
 
158
  @classmethod
159
  def get_type(cls, model_name: str):
 
200
  model_type = ModelType.ERNIE
201
  elif "dall" in model_name_lower:
202
  model_type = ModelType.DALLE3
203
+ elif "gemma" in model_name_lower:
204
+ model_type = ModelType.GoogleGemma
205
  else:
206
  model_type = ModelType.LLaMA
207
  return model_type
208
 
209
 
210
+ def download(repo_id, filename, retry=10):
211
+ if os.path.exists("./models/downloaded_models.json"):
212
+ with open("./models/downloaded_models.json", "r") as f:
213
+ downloaded_models = json.load(f)
214
+ if repo_id in downloaded_models:
215
+ return downloaded_models[repo_id]["path"]
216
+ else:
217
+ downloaded_models = {}
218
+ while retry > 0:
219
+ try:
220
+ model_path = hf_hub_download(
221
+ repo_id=repo_id,
222
+ filename=filename,
223
+ cache_dir="models",
224
+ resume_download=True,
225
+ )
226
+ downloaded_models[repo_id] = {"path": model_path}
227
+ with open("./models/downloaded_models.json", "w") as f:
228
+ json.dump(downloaded_models, f)
229
+ break
230
+ except:
231
+ print("Error downloading model, retrying...")
232
+ retry -= 1
233
+ if retry == 0:
234
+ raise Exception("Error downloading model, please try again later.")
235
+ return model_path
236
+
237
+
238
  class BaseLLMModel:
239
  def __init__(
240
  self,
 
400
  status = i18n("总结完成")
401
  logging.info(i18n("生成内容总结中……"))
402
  os.environ["OPENAI_API_KEY"] = self.api_key
403
+ from langchain.callbacks import StdOutCallbackHandler
404
  from langchain.chains.summarize import load_summarize_chain
 
405
  from langchain.chat_models import ChatOpenAI
406
+ from langchain.prompts import PromptTemplate
407
 
408
  prompt_template = (
409
  "Write a concise summary of the following:\n\n{text}\n\nCONCISE SUMMARY IN "
 
1084
  """deinitialize the model, implement if needed"""
1085
  pass
1086
 
1087
+ def clear_cuda_cache(self):
1088
+ gc.collect()
1089
+ torch.cuda.empty_cache()
1090
+
1091
 
1092
  class Base_Chat_Langchain_Client(BaseLLMModel):
1093
  def __init__(self, model_name, user_name=""):
modules/models/models.py CHANGED
@@ -138,8 +138,14 @@ def get_model(
138
  from .DALLE3 import OpenAI_DALLE3_Client
139
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
140
  model = OpenAI_DALLE3_Client(model_name, api_key=access_key, user_name=user_name)
 
 
 
 
141
  elif model_type == ModelType.Unknown:
142
- raise ValueError(f"未知模型: {model_name}")
 
 
143
  logging.info(msg)
144
  except Exception as e:
145
  import traceback
 
138
  from .DALLE3 import OpenAI_DALLE3_Client
139
  access_key = os.environ.get("OPENAI_API_KEY", access_key)
140
  model = OpenAI_DALLE3_Client(model_name, api_key=access_key, user_name=user_name)
141
+ elif model_type == ModelType.GoogleGemma:
142
+ from .GoogleGemma import GoogleGemmaClient
143
+ model = GoogleGemmaClient(
144
+ model_name, access_key, user_name=user_name)
145
  elif model_type == ModelType.Unknown:
146
+ raise ValueError(f"Unknown model: {model_name}")
147
+ else:
148
+ raise ValueError(f"Unimplemented model type: {model_type}")
149
  logging.info(msg)
150
  except Exception as e:
151
  import traceback
modules/presets.py CHANGED
@@ -10,6 +10,8 @@ CHATGLM_MODEL = None
10
  CHATGLM_TOKENIZER = None
11
  LLAMA_MODEL = None
12
  LLAMA_INFERENCER = None
 
 
13
 
14
  # ChatGPT 设置
15
  INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
@@ -67,6 +69,8 @@ ONLINE_MODELS = [
67
  "Gemini Pro",
68
  "Gemini Pro Vision",
69
  "GooglePaLM",
 
 
70
  "xmchat",
71
  "Azure OpenAI",
72
  "yuanai-1.0-base_10B",
@@ -178,6 +182,16 @@ MODEL_METADATA = {
178
  "Gemini Pro Vision": {
179
  "model_name": "gemini-pro-vision",
180
  "token_limit": 30720,
 
 
 
 
 
 
 
 
 
 
181
  }
182
  }
183
 
@@ -193,7 +207,12 @@ os.makedirs("lora", exist_ok=True)
193
  os.makedirs("history", exist_ok=True)
194
  for dir_name in os.listdir("models"):
195
  if os.path.isdir(os.path.join("models", dir_name)):
196
- if dir_name not in MODELS:
 
 
 
 
 
197
  MODELS.append(dir_name)
198
 
199
  TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。
 
10
  CHATGLM_TOKENIZER = None
11
  LLAMA_MODEL = None
12
  LLAMA_INFERENCER = None
13
+ GEMMA_MODEL = None
14
+ GEMMA_TOKENIZER = None
15
 
16
  # ChatGPT 设置
17
  INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
 
69
  "Gemini Pro",
70
  "Gemini Pro Vision",
71
  "GooglePaLM",
72
+ "Gemma 2B",
73
+ "Gemma 7B",
74
  "xmchat",
75
  "Azure OpenAI",
76
  "yuanai-1.0-base_10B",
 
182
  "Gemini Pro Vision": {
183
  "model_name": "gemini-pro-vision",
184
  "token_limit": 30720,
185
+ },
186
+ "Gemma 2B": {
187
+ "repo_id": "google/gemma-2b-it",
188
+ "model_name": "gemma-2b-it",
189
+ "token_limit": 8192,
190
+ },
191
+ "Gemma 7B": {
192
+ "repo_id": "google/gemma-7b-it",
193
+ "model_name": "gemma-7b-it",
194
+ "token_limit": 8192,
195
  }
196
  }
197
 
 
207
  os.makedirs("history", exist_ok=True)
208
  for dir_name in os.listdir("models"):
209
  if os.path.isdir(os.path.join("models", dir_name)):
210
+ display_name = None
211
+ for model_name, metadata in MODEL_METADATA.items():
212
+ if "model_name" in metadata and metadata["model_name"] == dir_name:
213
+ display_name = model_name
214
+ break
215
+ if display_name is None:
216
  MODELS.append(dir_name)
217
 
218
  TOKEN_OFFSET = 1000 # 模型的token上限减去这个值,得到软上限。到达软上限之后,自动尝试减少token占用。