Imran1's picture
Update code/inference.py
acb4176 verified
raw
history blame
6.19 kB
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from accelerate import load_checkpoint_and_dispatch
import fcntl # For file locking
import os # For file operations
import time # For sleep function
# Set max_split_size globally to prevent memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"
# Enable detailed distributed logs
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
# Print to verify the environment variable is correctly set
print(f"PYTORCH_CUDA_ALLOC_CONF: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")
# Global variables to persist the model and tokenizer between invocations
model = None
tokenizer = None
# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
"""
Format chat messages using Qwen's chat template.
"""
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
def model_fn(model_dir, context=None):
global model, tokenizer
# Path to lock file for ensuring single loading
lock_file = "/tmp/model_load.lock"
# Path to in-progress file indicating model loading is happening
in_progress_file = "/tmp/model_loading_in_progress"
if model is not None and tokenizer is not None:
print("Model and tokenizer already loaded, skipping reload.")
return model, tokenizer
# Attempt to acquire the lock
with open(lock_file, 'w') as lock:
print("Attempting to acquire model load lock...")
fcntl.flock(lock, fcntl.LOCK_EX) # Exclusive lock
try:
# Check if another worker is in the process of loading
if os.path.exists(in_progress_file):
print("Another worker is currently loading the model, waiting...")
# Poll the in-progress flag until the other worker finishes loading
while os.path.exists(in_progress_file):
time.sleep(5) # Wait for 5 seconds before checking again
print("Loading complete by another worker, skipping reload.")
return model, tokenizer
# If no one is loading, start loading the model and set the in-progress flag
print("No one is loading, proceeding to load the model.")
with open(in_progress_file, 'w') as f:
f.write("loading")
# Loading the model and tokenizer
if model is None or tokenizer is None:
print("Loading the model and tokenizer...")
offload_dir = "/tmp/offload_dir"
os.makedirs(offload_dir, exist_ok=True)
# Load and dispatch model across 4 GPUs using tensor parallelism
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
model = load_checkpoint_and_dispatch(
model,
model_dir,
device_map="balanced", # Evenly distribute across GPUs
offload_folder=offload_dir,
max_memory={i: "18GiB" for i in range(torch.cuda.device_count())}, # Allocate 18 GiB per GPU
no_split_module_classes=["QwenForCausalLM"] # Split model across GPUs
)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
# Free up any unused memory after loading
torch.cuda.empty_cache()
except Exception as e:
print(f"Error loading model and tokenizer: {e}")
raise
finally:
# Remove the in-progress flag once the loading is complete
if os.path.exists(in_progress_file):
os.remove(in_progress_file)
# Release the lock
fcntl.flock(lock, fcntl.LOCK_UN)
return model, tokenizer
# Custom predict function for SageMaker
def predict_fn(input_data, model_and_tokenizer, context=None):
"""
Generate predictions for the input data.
"""
try:
model, tokenizer = model_and_tokenizer
data = json.loads(input_data)
# Format the prompt using Qwen's chat template
messages = data.get("messages", [])
formatted_prompt = format_chat(messages, tokenizer)
# Tokenize the input
inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") # Send input to GPU 0 for generation
# Generate output
outputs = model.generate(
inputs['input_ids'],
max_new_tokens=data.get("max_new_tokens", 512),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
do_sample=True
)
# Decode the output
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# Build response
response = {
"id": "chatcmpl-uuid",
"object": "chat.completion",
"model": "qwen-72b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(inputs['input_ids'][0]),
"completion_tokens": len(outputs[0]),
"total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
}
}
return response
except Exception as e:
return {"error": str(e), "details": repr(e)}
# Define input format for SageMaker
def input_fn(serialized_input_data, content_type, context=None):
"""
Prepare the input data for inference.
"""
return serialized_input_data
# Define output format for SageMaker
def output_fn(prediction_output, accept, context=None):
"""
Convert the model output to a JSON response.
"""
return json.dumps(prediction_output)