File size: 3,865 Bytes
1d61c11 40362ea 544c001 40362ea d952a77 f74b603 a254d16 f74b603 d952a77 f74b603 544c001 f74b603 544c001 1d61c11 f74b603 544c001 40362ea 544c001 40362ea f74b603 40362ea f74b603 40362ea f74b603 40362ea f74b603 40362ea f74b603 40362ea f74b603 40362ea 73941c5 40362ea 1d61c11 40362ea 1d61c11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from accelerate import load_checkpoint_and_dispatch
# Global variables to persist the model and tokenizer between invocations
model = None
tokenizer = None
# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Model loading function for SageMaker with tensor parallelism and FP8 quantization
def model_fn(model_dir, context=None):
global model, tokenizer
if model is None:
print("Loading the FP8 quantized model and tokenizer...")
# Define an offload directory
offload_dir = "/tmp/offload_dir"
os.makedirs(offload_dir, exist_ok=True)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Load the FP8 quantized model
model = AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.float8, # Specify FP8 dtype
low_cpu_mem_usage=True,
device_map="auto",
offload_folder=offload_dir,
)
# Use load_checkpoint_and_dispatch for tensor parallelism
model = load_checkpoint_and_dispatch(
model,
model_dir,
device_map="auto",
offload_folder=offload_dir,
no_split_module_classes=["QWenLMHeadModel"], # Adjust if needed for Qwen architecture
)
return model, tokenizer
# Custom predict function for SageMaker
def predict_fn(input_data, model_and_tokenizer):
try:
model, tokenizer = model_and_tokenizer
data = json.loads(input_data)
# Format the prompt using Qwen's chat template
messages = data.get("messages", [])
formatted_prompt = format_chat(messages, tokenizer)
# Tokenize the input
inputs = tokenizer([formatted_prompt], return_tensors="pt").to(model.device)
# Generate output
with torch.no_grad():
outputs = model.generate(
inputs['input_ids'],
max_new_tokens=data.get("max_new_tokens", 512),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
do_sample=True
)
# Decode the output
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Build response
response = {
"id": "chatcmpl-fp8-quantized",
"object": "chat.completion",
"model": "qwen-72b-fp8",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(inputs['input_ids'][0]),
"completion_tokens": len(outputs[0]) - len(inputs['input_ids'][0]),
"total_tokens": len(outputs[0])
}
}
return response
except Exception as e:
return {"error": str(e), "details": repr(e)}
# Define input format for SageMaker
def input_fn(serialized_input_data, content_type, context=None):
"""
Prepare the input data for inference.
"""
return serialized_input_data
# Define output format for SageMaker
def output_fn(prediction_output, accept , context=None):
"""
Convert the model output to a JSON response.
"""
return json.dumps(prediction_output) you have my code |