InspirationYF commited on
Commit
d62ccf6
·
1 Parent(s): 7ca365c

feat: add env config

Browse files
Files changed (2) hide show
  1. app.py +8 -3
  2. env_config.py +5 -0
app.py CHANGED
@@ -5,8 +5,13 @@ import gradio as gr
5
  from huggingface_hub import login
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
 
 
 
 
 
 
8
  # 登录 Hugging Face API
9
- api_token = os.environ.get("HF_API_TOKEN")
10
  login(api_token)
11
 
12
  # 模型加载函数
@@ -19,7 +24,7 @@ def get_llm(model_id):
19
  @spaces.GPU(duration=120)
20
  def retriever_qa(file, query):
21
  # 加载模型和分词器
22
- model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
23
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
24
 
25
  # 确保 CUDA 初始化不在主线程
@@ -47,7 +52,7 @@ def retriever_qa(file, query):
47
  print('Start Inference')
48
 
49
  # 推理
50
- generated_ids = llm.generate(model_inputs, max_new_tokens=50, do_sample=True)
51
  # generated_ids = llm.generate(input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'], max_new_tokens=50, do_sample=True)
52
  print(f'Generated ids: {generated_ids}')
53
  # 解码输出
 
5
  from huggingface_hub import login
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
+ import env_config
9
+
10
+ api_token = env_config.HF_API_TOKEN
11
+ max_new_tokens = env_config.MAX_NEW_TOKENS
12
+ model_id = env_config.MODEL_ID
13
+
14
  # 登录 Hugging Face API
 
15
  login(api_token)
16
 
17
  # 模型加载函数
 
24
  @spaces.GPU(duration=120)
25
  def retriever_qa(file, query):
26
  # 加载模型和分词器
27
+ # model_id = 'mistralai/Mistral-7B-Instruct-v0.2'
28
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
29
 
30
  # 确保 CUDA 初始化不在主线程
 
52
  print('Start Inference')
53
 
54
  # 推理
55
+ generated_ids = llm.generate(model_inputs, max_new_tokens=max_new_tokens, do_sample=True)
56
  # generated_ids = llm.generate(input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'], max_new_tokens=50, do_sample=True)
57
  print(f'Generated ids: {generated_ids}')
58
  # 解码输出
env_config.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import os
2
+
3
+ HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
4
+ MAX_NEW_TOKENS = os.environ.get("MAX_NEW_TOKENS", 1024)
5
+ MODEL_ID = os.environ.get("MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.2")