homer7676 commited on
Commit
9d029ab
·
verified ·
1 Parent(s): 4de6589

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +70 -59
handler.py CHANGED
@@ -3,90 +3,99 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from typing import Dict, Any
4
  import re
5
 
6
- SIMPLIFIED_TO_TRADITIONAL = {
7
- '发': '發', '书': '書', '记': '記', '亚': '亞', '欧': '歐', '韩': '韓', '边': '邊',
8
- '恒': '恆', '说': '說', '话': '話', '东': '東', '车': '車', '马': '馬', '样': '樣',
9
- '风': '風', '专': '專', '万': '萬', '劳': '勞', '动': '動', '习': '習', '头': '頭',
10
- '们': '們', '为': '為', '产': '產', '场': '場', '实': '實', '观': '觀', '见': '見',
11
- '师': '師', '长': '長', '识': '識', '电': '電', '图': '圖', '华': '華', '龙': '龍',
12
- '变': '變', '问': '問', '岁': '歲', '义': '義', '还': '還', '报': '報', '乐': '樂',
13
- '欢': '歡', '权': '權', '态': '態', '极': '極', '环': '環', '带': '帶', '难': '難'
14
- }
15
-
16
  class EndpointHandler:
17
- def __init__(self, model_dir=None):
18
- self.tokenizer = None
19
- self.model = None
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
- self.model_dir = model_dir if model_dir else "homer7676/FrierenChatbotV1"
 
22
 
23
- def initialize(self, context):
24
- try:
25
- self.tokenizer = AutoTokenizer.from_pretrained(
26
- self.model_dir,
27
- trust_remote_code=True
28
- )
29
-
30
- if self.tokenizer.pad_token is None:
31
- self.tokenizer.pad_token = self.tokenizer.eos_token
32
-
33
- self.model = AutoModelForCausalLM.from_pretrained(
34
- self.model_dir,
35
- trust_remote_code=True,
36
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
37
- ).to(self.device)
38
-
39
- self.model.eval()
40
 
41
- except Exception as e:
42
- print(f"模型載入錯誤: {str(e)}")
43
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
46
  inputs = data.pop("inputs", data)
47
- if not isinstance(inputs, dict):
48
- inputs = {"message": inputs}
49
- return inputs
50
 
51
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
 
52
  try:
 
53
  message = inputs.get("message", "")
54
  context = inputs.get("context", "")
55
  prompt = self._build_prompt(context, message)
56
-
57
- encoding = self.tokenizer(
 
58
  prompt,
59
  return_tensors="pt",
60
- add_special_tokens=True,
61
  truncation=True,
62
- max_length=2048,
63
- padding=True
64
  ).to(self.device)
65
-
 
66
  with torch.no_grad():
67
- outputs = self.model.generate(
68
- input_ids=encoding["input_ids"],
69
- attention_mask=encoding["attention_mask"],
70
  max_new_tokens=256,
71
  temperature=0.7,
72
  top_p=0.9,
73
  top_k=50,
74
  do_sample=True,
75
- repetition_penalty=1.2,
76
- num_beams=4,
77
- early_stopping=True
78
  )
 
 
 
 
 
79
 
80
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
81
  response = response.split("芙莉蓮:")[-1].strip()
82
  response = self._process_response(response)
83
- return {"response": response}
84
 
 
85
  except Exception as e:
86
- print(f"推理過程錯誤: {str(e)}")
87
- return {"response": "抱歉,在處理您的請求時發生了錯誤。請稍後再試。", "error": str(e)}
88
 
89
  def _build_prompt(self, context: str, query: str) -> str:
 
90
  return f"""你是芙莉蓮,需要遵守以下規則回答:
91
  1. 身份設定:
92
  - 千年精靈魔法師
@@ -105,22 +114,24 @@ class EndpointHandler:
105
  用戶:{query}
106
  芙莉蓮:"""
107
 
108
- def _convert_to_traditional(self, text: str) -> str:
109
- for simplified, traditional in SIMPLIFIED_TO_TRADITIONAL.items():
110
- text = text.replace(simplified, traditional)
111
- return text
112
-
113
  def _process_response(self, response: str) -> str:
 
114
  if not response or not response.strip():
115
  return "抱歉,我現在有點恍神,請你再問一次好嗎?"
116
 
117
- response = self._convert_to_traditional(response)
 
 
 
 
118
  response = re.sub(r'\s+', '', response)
119
 
 
120
  if not response.endswith(('。', '!', '?', '~', '呢', '啊', '吶')):
121
  response += '呢。'
122
 
123
  return response
124
 
125
  def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
126
  return data
 
3
  from typing import Dict, Any
4
  import re
5
 
 
 
 
 
 
 
 
 
 
 
6
  class EndpointHandler:
7
+ def __init__(self, model_dir: str = None):
8
+ self.model_dir = model_dir
 
9
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ self.model = None
11
+ self.tokenizer = None
12
 
13
+ def initialize(self, context: Dict[str, Any] = None):
14
+ """Initialize the model and tokenizer."""
15
+ model_id = "homer7676/FrierenChatbotV1"
16
+
17
+ # Initialize tokenizer
18
+ self.tokenizer = AutoTokenizer.from_pretrained(
19
+ model_id,
20
+ trust_remote_code=True,
21
+ padding_side="left"
22
+ )
23
+
24
+ # Ensure pad token exists
25
+ if self.tokenizer.pad_token is None:
26
+ self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
 
27
 
28
+ # Initialize model
29
+ self.model = AutoModelForCausalLM.from_pretrained(
30
+ model_id,
31
+ trust_remote_code=True,
32
+ torch_dtype="auto",
33
+ low_cpu_mem_usage=True
34
+ ).to(self.device)
35
+
36
+ self.model.eval()
37
+ return self
38
+
39
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
40
+ """Main prediction pipeline."""
41
+ inputs = self.preprocess(data)
42
+ outputs = self.inference(inputs)
43
+ return self.postprocess(outputs)
44
 
45
  def preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
46
+ """Preprocess the input data."""
47
+ if isinstance(data, str):
48
+ return {"message": data}
49
  inputs = data.pop("inputs", data)
50
+ return inputs if isinstance(inputs, dict) else {"message": inputs}
 
 
51
 
52
  def inference(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
53
+ """Run the inference."""
54
  try:
55
+ # 準備輸入
56
  message = inputs.get("message", "")
57
  context = inputs.get("context", "")
58
  prompt = self._build_prompt(context, message)
59
+
60
+ # Tokenize
61
+ inputs = self.tokenizer(
62
  prompt,
63
  return_tensors="pt",
64
+ padding=True,
65
  truncation=True,
66
+ max_length=2048
 
67
  ).to(self.device)
68
+
69
+ # Generate
70
  with torch.no_grad():
71
+ generation_output = self.model.generate(
72
+ input_ids=inputs["input_ids"],
73
+ attention_mask=inputs["attention_mask"],
74
  max_new_tokens=256,
75
  temperature=0.7,
76
  top_p=0.9,
77
  top_k=50,
78
  do_sample=True,
79
+ pad_token_id=self.tokenizer.pad_token_id,
80
+ eos_token_id=self.tokenizer.eos_token_id,
81
+ repetition_penalty=1.2
82
  )
83
+
84
+ response = self.tokenizer.decode(
85
+ generation_output[0],
86
+ skip_special_tokens=True
87
+ )
88
 
89
+ # 處理回應
90
  response = response.split("芙莉蓮:")[-1].strip()
91
  response = self._process_response(response)
 
92
 
93
+ return {"response": response}
94
  except Exception as e:
95
+ return {"error": f"Inference error: {str(e)}"}
 
96
 
97
  def _build_prompt(self, context: str, query: str) -> str:
98
+ """Build the prompt for the model."""
99
  return f"""你是芙莉蓮,需要遵守以下規則回答:
100
  1. 身份設定:
101
  - 千年精靈魔法師
 
114
  用戶:{query}
115
  芙莉蓮:"""
116
 
 
 
 
 
 
117
  def _process_response(self, response: str) -> str:
118
+ """Process the model's response."""
119
  if not response or not response.strip():
120
  return "抱歉,我現在有點恍神,請你再問一次好嗎?"
121
 
122
+ # Convert to traditional Chinese
123
+ for simplified, traditional in SIMPLIFIED_TO_TRADITIONAL.items():
124
+ response = response.replace(simplified, traditional)
125
+
126
+ # Clean up whitespace
127
  response = re.sub(r'\s+', '', response)
128
 
129
+ # Add ending punctuation if needed
130
  if not response.endswith(('。', '!', '?', '~', '呢', '啊', '吶')):
131
  response += '呢。'
132
 
133
  return response
134
 
135
  def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
136
+ """Postprocess the output data."""
137
  return data