import json import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import List, Dict from accelerate import load_checkpoint_and_dispatch import fcntl import os import time # Global variables to persist the model and tokenizer between invocations model = None tokenizer = None def format_chat(messages: List[Dict[str, str]], tokenizer) -> str: """ Format chat messages using Qwen's chat template. """ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) def model_fn(model_dir, context=None): global model, tokenizer lock_file = "/tmp/model_load.lock" in_progress_file = "/tmp/model_loading_in_progress" if model is not None and tokenizer is not None: print("Model and tokenizer already loaded, skipping reload.") return model, tokenizer with open(lock_file, 'w') as lock: print("Attempting to acquire model load lock...") fcntl.flock(lock, fcntl.LOCK_EX) try: if os.path.exists(in_progress_file): print("Another worker is currently loading the model, waiting...") while os.path.exists(in_progress_file): time.sleep(5) print("Loading complete by another worker.") if model is not None and tokenizer is not None: return model, tokenizer print("Proceeding to load the model and tokenizer.") with open(in_progress_file, 'w') as f: f.write("loading") print("Loading the model and tokenizer...") offload_dir = "/tmp/offload_dir" os.makedirs(offload_dir, exist_ok=True) # Load the tokenizer first tokenizer = AutoTokenizer.from_pretrained(model_dir) # Load and dispatch model across GPUs model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto") model = load_checkpoint_and_dispatch( model, model_dir, device_map="auto", offload_folder=offload_dir, max_memory={i: "24GiB" for i in range(8)} ) print("Model and tokenizer loaded successfully.") except Exception as e: print(f"Error loading model: {str(e)}") raise finally: if os.path.exists(in_progress_file): os.remove(in_progress_file) fcntl.flock(lock, fcntl.LOCK_UN) return model, tokenizer def predict_fn(input_data, model_and_tokenizer, context=None): """ Generate predictions for the input data. """ try: model, tokenizer = model_and_tokenizer if model is None or tokenizer is None: raise ValueError("Model or tokenizer is None. Please ensure they are loaded correctly.") data = json.loads(input_data) messages = data.get("messages", []) formatted_prompt = format_chat(messages, tokenizer) inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") outputs = model.generate( inputs['input_ids'], max_new_tokens=data.get("max_new_tokens", 512), temperature=data.get("temperature", 0.7), top_p=data.get("top_p", 0.9), repetition_penalty=data.get("repetition_penalty", 1.0), length_penalty=data.get("length_penalty", 1.0), do_sample=True ) generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] response = { "id": "chatcmpl-uuid", "object": "chat.completion", "model": "qwen-72b", "choices": [{ "index": 0, "message": { "role": "assistant", "content": generated_text }, "finish_reason": "stop" }], "usage": { "prompt_tokens": len(inputs['input_ids'][0]), "completion_tokens": len(outputs[0]), "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0]) } } return response except Exception as e: return {"error": str(e), "details": repr(e)} def input_fn(serialized_input_data, content_type, context=None): return serialized_input_data def output_fn(prediction_output, accept, context=None): return json.dumps(prediction_output)