File size: 4,523 Bytes
7b79b7e 40362ea 544c001 5c36d2f 544c001 40362ea 6c47fa9 40362ea d952a77 a254d16 629b5aa 325c145 629b5aa 325c145 d952a77 5c36d2f 629b5aa 5c36d2f 629b5aa 325c145 5c36d2f 325c145 5c36d2f 40362ea 5c36d2f 629b5aa 5c36d2f 629b5aa 5c36d2f 629b5aa 5c36d2f 629b5aa 325c145 629b5aa 40362ea 544c001 5c36d2f 6c47fa9 40362ea 5c36d2f 40362ea 5c36d2f 40362ea 5c36d2f 40362ea 6c47fa9 40362ea 6c47fa9 40362ea 6c47fa9 40362ea 6c47fa9 40362ea 5c36d2f 40362ea 6c47fa9 54fb1b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from accelerate import load_checkpoint_and_dispatch
import fcntl
import os
import time
# Global variables to persist the model and tokenizer between invocations
model = None
tokenizer = None
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
"""
Format chat messages using Qwen's chat template.
"""
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def model_fn(model_dir, context=None):
global model, tokenizer
lock_file = "/tmp/model_load.lock"
in_progress_file = "/tmp/model_loading_in_progress"
if model is not None and tokenizer is not None:
print("Model and tokenizer already loaded, skipping reload.")
return model, tokenizer
with open(lock_file, 'w') as lock:
print("Attempting to acquire model load lock...")
fcntl.flock(lock, fcntl.LOCK_EX)
try:
if os.path.exists(in_progress_file):
print("Another worker is currently loading the model, waiting...")
while os.path.exists(in_progress_file):
time.sleep(5)
print("Loading complete by another worker.")
if model is not None and tokenizer is not None:
return model, tokenizer
print("Proceeding to load the model and tokenizer.")
with open(in_progress_file, 'w') as f:
f.write("loading")
print("Loading the model and tokenizer...")
offload_dir = "/tmp/offload_dir"
os.makedirs(offload_dir, exist_ok=True)
# Load the tokenizer first
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Load and dispatch model across GPUs
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
model = load_checkpoint_and_dispatch(
model,
model_dir,
device_map="auto",
offload_folder=offload_dir,
max_memory={i: "24GiB" for i in range(8)}
)
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
finally:
if os.path.exists(in_progress_file):
os.remove(in_progress_file)
fcntl.flock(lock, fcntl.LOCK_UN)
return model, tokenizer
def predict_fn(input_data, model_and_tokenizer, context=None):
"""
Generate predictions for the input data.
"""
try:
model, tokenizer = model_and_tokenizer
if model is None or tokenizer is None:
raise ValueError("Model or tokenizer is None. Please ensure they are loaded correctly.")
data = json.loads(input_data)
messages = data.get("messages", [])
formatted_prompt = format_chat(messages, tokenizer)
inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0")
outputs = model.generate(
inputs['input_ids'],
max_new_tokens=data.get("max_new_tokens", 512),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
do_sample=True
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
response = {
"id": "chatcmpl-uuid",
"object": "chat.completion",
"model": "qwen-72b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(inputs['input_ids'][0]),
"completion_tokens": len(outputs[0]),
"total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
}
}
return response
except Exception as e:
return {"error": str(e), "details": repr(e)}
def input_fn(serialized_input_data, content_type, context=None):
return serialized_input_data
def output_fn(prediction_output, accept, context=None):
return json.dumps(prediction_output) |