homer7676 commited on
Commit
11ffb79
·
verified ·
1 Parent(s): 9d029ab

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +47 -92
handler.py CHANGED
@@ -1,63 +1,62 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from typing import Dict, Any
4
- import re
5
 
6
  class EndpointHandler:
7
- def __init__(self, model_dir: str = None):
8
- self.model_dir = model_dir
9
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
- self.model = None
11
  self.tokenizer = None
 
 
12
 
13
- def initialize(self, context: Dict[str, Any] = None):
14
- """Initialize the model and tokenizer."""
15
- model_id = "homer7676/FrierenChatbotV1"
16
-
17
- # Initialize tokenizer
 
 
 
18
  self.tokenizer = AutoTokenizer.from_pretrained(
19
- model_id,
20
- trust_remote_code=True,
21
- padding_side="left"
22
  )
23
-
24
- # Ensure pad token exists
25
- if self.tokenizer.pad_token is None:
26
- self.tokenizer.pad_token = self.tokenizer.eos_token
27
-
28
- # Initialize model
29
  self.model = AutoModelForCausalLM.from_pretrained(
30
- model_id,
31
  trust_remote_code=True,
32
- torch_dtype="auto",
33
- low_cpu_mem_usage=True
34
  ).to(self.device)
35
-
36
  self.model.eval()
37
- return self
38
-
39
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
40
- """Main prediction pipeline."""
41
- inputs = self.preprocess(data)
42
- outputs = self.inference(inputs)
43
- return self.postprocess(outputs)
44
 
45
  def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
46
- """Preprocess the input data."""
47
- if isinstance(data, str):
48
- return {"message": data}
49
  inputs = data.pop("inputs", data)
50
- return inputs if isinstance(inputs, dict) else {"message": inputs}
 
 
51
 
52
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
53
- """Run the inference."""
54
  try:
55
- # 準備輸入
56
  message = inputs.get("message", "")
57
  context = inputs.get("context", "")
58
- prompt = self._build_prompt(context, message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Tokenize
61
  inputs = self.tokenizer(
62
  prompt,
63
  return_tensors="pt",
@@ -66,72 +65,28 @@ class EndpointHandler:
66
  max_length=2048
67
  ).to(self.device)
68
 
69
- # Generate
70
  with torch.no_grad():
71
- generation_output = self.model.generate(
72
- input_ids=inputs["input_ids"],
73
- attention_mask=inputs["attention_mask"],
74
  max_new_tokens=256,
75
  temperature=0.7,
76
  top_p=0.9,
77
  top_k=50,
78
  do_sample=True,
 
79
  pad_token_id=self.tokenizer.pad_token_id,
80
- eos_token_id=self.tokenizer.eos_token_id,
81
- repetition_penalty=1.2
82
  )
83
 
84
- response = self.tokenizer.decode(
85
- generation_output[0],
86
- skip_special_tokens=True
87
- )
88
-
89
- # 處理回應
90
  response = response.split("芙莉蓮:")[-1].strip()
91
- response = self._process_response(response)
92
 
93
- return {"response": response}
94
- except Exception as e:
95
- return {"error": f"Inference error: {str(e)}"}
96
-
97
- def _build_prompt(self, context: str, query: str) -> str:
98
- """Build the prompt for the model."""
99
- return f"""你是芙莉蓮,需要遵守以下規則回答:
100
- 1. 身份設定:
101
- - 千年精靈魔法師
102
- - 態度溫柔但帶著些許嘲諷
103
- - 說話優雅且有距離感
104
- 2. 重要關係:
105
- - 弗蘭梅是我的師傅
106
- - 費倫是我的學生
107
- - 欣梅爾是我的摯友
108
- - 海塔是我的故友
109
- 3. 回答規則:
110
- - 使用繁體中文
111
- - 必須提供具體詳細的內容
112
- - 保持回答的連貫性和完整性
113
- 相關資訊:{context}
114
- 用戶:{query}
115
- 芙莉蓮:"""
116
-
117
- def _process_response(self, response: str) -> str:
118
- """Process the model's response."""
119
- if not response or not response.strip():
120
- return "抱歉,我現在有點恍神,請你再問一次好嗎?"
121
-
122
- # Convert to traditional Chinese
123
- for simplified, traditional in SIMPLIFIED_TO_TRADITIONAL.items():
124
- response = response.replace(simplified, traditional)
125
-
126
- # Clean up whitespace
127
- response = re.sub(r'\s+', '', response)
128
-
129
- # Add ending punctuation if needed
130
- if not response.endswith(('。', '!', '?', '~', '呢', '啊', '吶')):
131
- response += '呢。'
132
 
133
- return response
 
 
134
 
135
  def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
136
- """Postprocess the output data."""
137
  return data
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from typing import Dict, Any
 
4
 
5
  class EndpointHandler:
6
+ def __init__(self):
 
 
 
7
  self.tokenizer = None
8
+ self.model = None
9
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
12
+ """使 handler 可調用"""
13
+ inputs = self.preprocess(data)
14
+ outputs = self.inference(inputs)
15
+ return self.postprocess(outputs)
16
+
17
+ def initialize(self, context):
18
+ """初始化模型和 tokenizer"""
19
  self.tokenizer = AutoTokenizer.from_pretrained(
20
+ "homer7676/FrierenChatbotV1",
21
+ trust_remote_code=True
 
22
  )
 
 
 
 
 
 
23
  self.model = AutoModelForCausalLM.from_pretrained(
24
+ "homer7676/FrierenChatbotV1",
25
  trust_remote_code=True,
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
27
  ).to(self.device)
 
28
  self.model.eval()
 
 
 
 
 
 
 
29
 
30
  def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
31
+ """預處理輸入數據"""
 
 
32
  inputs = data.pop("inputs", data)
33
+ if not isinstance(inputs, dict):
34
+ inputs = {"message": inputs}
35
+ return inputs
36
 
37
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
38
+ """執行推理"""
39
  try:
 
40
  message = inputs.get("message", "")
41
  context = inputs.get("context", "")
42
+ prompt = f"""你是芙莉蓮,需要遵守以下規則回答:
43
+ 1. 身份設定:
44
+ - 千年精靈魔法師
45
+ - 態度溫柔但帶著些許嘲諷
46
+ - 說話優雅且有距離感
47
+ 2. 重要關係:
48
+ - 弗蘭梅是我的師傅
49
+ - 費倫是我的學生
50
+ - 欣梅爾是我的摯友
51
+ - 海塔是我的故友
52
+ 3. 回答規則:
53
+ - 使用繁體中文
54
+ - 必須提供具體詳細的內容
55
+ - 保持回答的連貫性和完整性
56
+ 相關資訊:{context}
57
+ 用戶:{message}
58
+ 芙莉蓮:"""
59
 
 
60
  inputs = self.tokenizer(
61
  prompt,
62
  return_tensors="pt",
 
65
  max_length=2048
66
  ).to(self.device)
67
 
 
68
  with torch.no_grad():
69
+ outputs = self.model.generate(
70
+ **inputs,
 
71
  max_new_tokens=256,
72
  temperature=0.7,
73
  top_p=0.9,
74
  top_k=50,
75
  do_sample=True,
76
+ repetition_penalty=1.2,
77
  pad_token_id=self.tokenizer.pad_token_id,
78
+ eos_token_id=self.tokenizer.eos_token_id
 
79
  )
80
 
81
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
82
  response = response.split("芙莉蓮:")[-1].strip()
 
83
 
84
+ return {"generated_text": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ except Exception as e:
87
+ print(f"推理過程錯誤: {str(e)}")
88
+ return {"error": str(e)}
89
 
90
  def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
91
+ """後處理輸出數據"""
92
  return data