import os import torch import spaces import gradio as gr from huggingface_hub import login from transformers import AutoModelForCausalLM, AutoTokenizer # 登录 Hugging Face API api_token = os.environ.get("HF_API_TOKEN") login(api_token) # 模型加载函数 def get_llm(model_id): # 使用 `device_map="auto"` 自动分配设备 model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") return model # 问答逻辑 @spaces.GPU(duration=120) def retriever_qa(file, query): # 加载模型和分词器 model_id = 'mistralai/Mistral-7B-Instruct-v0.2' tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) # 确保 CUDA 初始化不在主线程 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f'Device: {device}') # 子进程中完成模型加载和推理 def process_inference(file, query): # 加载模型 llm = get_llm(model_id) # 加载文件的第一行内容 with open(file, 'r') as f: first_line = f.readline() # 准备输入 messages = [ {"role": "user", "content": first_line + query} ] print(messages) # Tokenize 输入 model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) print(f"Model Inputs: {model_inputs}") print('Start Inference') # 推理 generated_ids = llm.generate(model_inputs, max_new_tokens=50, do_sample=True) # generated_ids = llm.generate(input_ids=model_inputs['input_ids'], attention_mask=model_inputs['attention_mask'], max_new_tokens=50, do_sample=True) print(f'Generated ids: {generated_ids}') # 解码输出 print('Start detokenize') response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] print(response) return response # 调用推理逻辑 response = process_inference(file, query) return response # Gradio 界面 rag_application = gr.Interface( fn=retriever_qa, allow_flagging="never", inputs=[ gr.File(label="Upload txt File", file_count="single", file_types=['.txt'], type="filepath"), # 仅支持 TXT 文件 gr.Textbox(label="Input Query", lines=2, placeholder="Type your question here...") # 查询输入框 ], outputs=gr.Textbox(label="Output"), # 输出显示框 title="RAG Chatbot", description="Upload a TXT document and ask any question. The chatbot will try to answer using the provided document." ) # 启动 Gradio 应用 rag_application.launch(share=True)