File size: 3,857 Bytes
d54b6e0 36cfeeb d54b6e0 51ea417 df9d248 36cfeeb 5827354 51ea417 5827354 36cfeeb d54b6e0 51ea417 d54b6e0 51ea417 edaf8b6 51ea417 5827354 51ea417 36cfeeb 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 51ea417 d54b6e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import json
import subprocess
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
# Install accelerate dynamically (though it's better to pre-install this in the container)
subprocess.check_call(['pip', 'install', 'accelerate'])
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model # Correct import
# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
"""
Format chat messages using Qwen's chat template
"""
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Model loading function for SageMaker with tensor parallelism
def model_fn(model_dir,context=None):
"""
Load the model and tokenizer from the model directory for inference.
This version supports tensor parallelism across 4 GPUs.
"""
device_map = "auto" # To automatically infer the device map across GPUs
# Initialize an empty model and then dispatch it to the devices
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
# Automatically infer the device map to distribute the model across 4 GPUs
device_map = infer_auto_device_map(model, max_memory={i: "25GiB" for i in range(4)}, no_split_module_classes=["QwenBlock"])
# Dispatch the model to devices using accelerate's dispatch_model
model = dispatch_model(model, device_map=device_map, offload_buffers=True)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
return model, tokenizer
# Custom predict function for SageMaker
def predict_fn(input_data, model_and_tokenizer):
"""
Generate predictions for the input data.
"""
try:
model, tokenizer = model_and_tokenizer
data = json.loads(input_data)
# Format the prompt using Qwen's chat template
messages = data.get("messages", [])
formatted_prompt = format_chat(messages, tokenizer)
# Tokenize the input
inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") # Send input to GPU 0 for generation
# Generate output
outputs = model.generate(
inputs['input_ids'],
max_new_tokens=data.get("max_new_tokens", 512),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
do_sample=True
)
# Decode the output
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Build response
response = {
"id": "chatcmpl-uuid",
"object": "chat.completion",
"model": "qwen-72b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(inputs['input_ids'][0]),
"completion_tokens": len(outputs[0]),
"total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
}
}
return response
except Exception as e:
return {"error": str(e), "details": repr(e)}
# Define input format for SageMaker
def input_fn(serialized_input_data, content_type):
"""
Prepare the input data for inference.
"""
return serialized_input_data
# Define output format for SageMaker
def output_fn(prediction_output, accept):
"""
Convert the model output to a JSON response.
"""
return json.dumps(prediction_output)
|