import json import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import List, Dict from accelerate import load_checkpoint_and_dispatch import os # Global variables to persist the model and tokenizer between invocations model = None tokenizer = None # Function to format chat messages using Qwen's chat template def format_chat(messages: List[Dict[str, str]], tokenizer) -> str: """ Format chat messages using Qwen's chat template. """ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Model loading function for SageMaker with tensor parallelism and offloading def model_fn(model_dir, context=None): """ Load the model and tokenizer from the model directory for inference. Supports tensor parallelism across multiple GPUs with offloading. The model is loaded only once and stored in a global variable. """ global model, tokenizer # Declare model and tokenizer as global to persist across invocations if model is None: # Check if the model is already loaded print("Loading the model and tokenizer...") # Define an offload directory for any model components that can't fit in GPU memory offload_dir = "/tmp/offload_dir" os.makedirs(offload_dir, exist_ok=True) # Ensure the directory exists and SageMaker has write access # Explicitly map the model across 8 GPUs device_map = { "transformer.h.0": 0, "transformer.h.1": 0, "transformer.h.2": 1, "transformer.h.3": 1, "transformer.h.4": 2, "transformer.h.5": 2, "transformer.h.6": 3, "transformer.h.7": 3, "transformer.h.8": 4, "transformer.h.9": 4, "transformer.h.10": 5, "transformer.h.11": 5, "transformer.h.12": 6, "transformer.h.13": 6, "transformer.h.14": 7, "transformer.h.15": 7, "transformer.ln_f": 7, "lm_head": 7 } # Load and dispatch the model across multiple GPUs with offloading model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16) model = load_checkpoint_and_dispatch( model, model_dir, device_map=device_map, # Explicitly map layers across 8 GPUs offload_folder=offload_dir, # Offload parts of the model to disk if GPU memory is insufficient ) # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_dir) return model, tokenizer # Custom predict function for SageMaker def predict_fn(input_data, model_and_tokenizer): """ Generate predictions for the input data. """ try: model, tokenizer = model_and_tokenizer data = json.loads(input_data) # Format the prompt using Qwen's chat template messages = data.get("messages", []) formatted_prompt = format_chat(messages, tokenizer) # Tokenize the input inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") # Send input to GPU 0 for generation # Generate output outputs = model.generate( inputs['input_ids'], max_new_tokens=data.get("max_new_tokens", 512), temperature=data.get("temperature", 0.7), top_p=data.get("top_p", 0.9), repetition_penalty=data.get("repetition_penalty", 1.0), length_penalty=data.get("length_penalty", 1.0), do_sample=True ) # Decode the output generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] # Build response response = { "id": "chatcmpl-uuid", "object": "chat.completion", "model": "qwen-72b", "choices": [{ "index": 0, "message": { "role": "assistant", "content": generated_text }, "finish_reason": "stop" }], "usage": { "prompt_tokens": len(inputs['input_ids'][0]), "completion_tokens": len(outputs[0]), "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0]) } } return response except Exception as e: return {"error": str(e), "details": repr(e)} # Define input format for SageMaker def input_fn(serialized_input_data, content_type, context=None): """ Prepare the input data for inference. """ return serialized_input_data # Define output format for SageMaker def output_fn(prediction_output, accept, context=None): """ Convert the model output to a JSON response. """ return json.dumps(prediction_output)