File size: 6,187 Bytes
7b79b7e 5106444 5c36d2f acb4176 939342b 5106444 40362ea 5106444 d952a77 5106444 325c145 5106444 939342b 5106444 939342b acb4176 5106444 acb4176 5106444 40362ea 5106444 6c47fa9 5106444 40362ea 5106444 40362ea 5106444 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from accelerate import load_checkpoint_and_dispatch
import fcntl # For file locking
import os # For file operations
import time # For sleep function
# Set max_split_size globally to prevent memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
# Enable detailed distributed logs
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
# Print to verify the environment variable is correctly set
print(f"PYTORCH_CUDA_ALLOC_CONF: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")
# Global variables to persist the model and tokenizer between invocations
model = None
tokenizer = None
# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
"""
Format chat messages using Qwen's chat template.
"""
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def model_fn(model_dir, context=None):
global model, tokenizer
# Path to lock file for ensuring single loading
lock_file = "/tmp/model_load.lock"
# Path to in-progress file indicating model loading is happening
in_progress_file = "/tmp/model_loading_in_progress"
if model is not None and tokenizer is not None:
print("Model and tokenizer already loaded, skipping reload.")
return model, tokenizer
# Attempt to acquire the lock
with open(lock_file, 'w') as lock:
print("Attempting to acquire model load lock...")
fcntl.flock(lock, fcntl.LOCK_EX) # Exclusive lock
try:
# Check if another worker is in the process of loading
if os.path.exists(in_progress_file):
print("Another worker is currently loading the model, waiting...")
# Poll the in-progress flag until the other worker finishes loading
while os.path.exists(in_progress_file):
time.sleep(5) # Wait for 5 seconds before checking again
print("Loading complete by another worker, skipping reload.")
return model, tokenizer
# If no one is loading, start loading the model and set the in-progress flag
print("No one is loading, proceeding to load the model.")
with open(in_progress_file, 'w') as f:
f.write("loading")
# Loading the model and tokenizer
if model is None or tokenizer is None:
print("Loading the model and tokenizer...")
offload_dir = "/tmp/offload_dir"
os.makedirs(offload_dir, exist_ok=True)
# Load and dispatch model across 4 GPUs using tensor parallelism
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
model = load_checkpoint_and_dispatch(
model,
model_dir,
device_map="balanced", # Evenly distribute across GPUs
offload_folder=offload_dir,
max_memory={i: "18GiB" for i in range(torch.cuda.device_count())}, # Allocate 18 GiB per GPU
no_split_module_classes=["QwenForCausalLM"] # Split model across GPUs
)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Free up any unused memory after loading
torch.cuda.empty_cache()
except Exception as e:
print(f"Error loading model and tokenizer: {e}")
raise
finally:
# Remove the in-progress flag once the loading is complete
if os.path.exists(in_progress_file):
os.remove(in_progress_file)
# Release the lock
fcntl.flock(lock, fcntl.LOCK_UN)
return model, tokenizer
# Custom predict function for SageMaker
def predict_fn(input_data, model_and_tokenizer, context=None):
"""
Generate predictions for the input data.
"""
try:
model, tokenizer = model_and_tokenizer
data = json.loads(input_data)
# Format the prompt using Qwen's chat template
messages = data.get("messages", [])
formatted_prompt = format_chat(messages, tokenizer)
# Tokenize the input
inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") # Send input to GPU 0 for generation
# Generate output
outputs = model.generate(
inputs['input_ids'],
max_new_tokens=data.get("max_new_tokens", 512),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
do_sample=True
)
# Decode the output
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Build response
response = {
"id": "chatcmpl-uuid",
"object": "chat.completion",
"model": "qwen-72b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(inputs['input_ids'][0]),
"completion_tokens": len(outputs[0]),
"total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
}
}
return response
except Exception as e:
return {"error": str(e), "details": repr(e)}
# Define input format for SageMaker
def input_fn(serialized_input_data, content_type, context=None):
"""
Prepare the input data for inference.
"""
return serialized_input_data
# Define output format for SageMaker
def output_fn(prediction_output, accept, context=None):
"""
Convert the model output to a JSON response.
"""
return json.dumps(prediction_output)
|