Qwen2.5-72B-Instruct-FP8 / inference.py
Imran1's picture
Update inference.py (#8)
53b4f42 verified
raw
history blame
3.83 kB
import os
import json
import subprocess
import sys
import torch
from typing import List, Dict
# Ensure vllm is installed and specify version to match CUDA compatibility
try:
import vllm
except ImportError:
# Install vllm with CUDA 11.8 support
vllm_version = "v0.6.1.post1"
pip_cmd = [
sys.executable,
"-m", "pip", "install",
f"https://github.com/vllm-project/vllm/releases/download/{vllm_version}/vllm-{vllm_version}+cu118-cp310-cp310-manylinux1_x86_64.whl",
"--extra-index-url", "https://download.pytorch.org/whl/cu118"
]
subprocess.check_call(pip_cmd)
# Import the necessary modules after installation
from vllm import LLM, SamplingParams
from vllm.utils import random_uuid
# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]]) -> str:
"""
Format chat messages using Qwen's chat template
"""
formatted_text = ""
for message in messages:
role = message["role"]
content = message["content"]
if role == "system":
formatted_text += f"<|im_start|>system\n{content}<|im_end|>\n"
elif role == "user":
formatted_text += f"<|im_start|>user\n{content}<|im_end|>\n"
elif role == "assistant":
formatted_text += f"<|im_start|>assistant\n{content}<|im_end|>\n"
# Add the final assistant prompt
formatted_text += "<|im_start|>assistant\n"
return formatted_text
# Model loading function for SageMaker
def model_fn(model_dir):
# Load the quantized model from the model directory
model = LLM(
model=model_dir,
trust_remote_code=True,
gpu_memory_utilization=0.9 # Optimal GPU usage
)
return model
# Custom predict function for SageMaker
def predict_fn(input_data, model):
try:
data = json.loads(input_data)
# Format the prompt using Qwen's chat template
messages = data.get("messages", [])
formatted_prompt = format_chat(messages)
# Build sampling parameters (without do_sample to match OpenAI API)
sampling_params = SamplingParams(
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.9),
max_new_tokens=data.get("max_new_tokens", 512),
top_k=data.get("top_k", -1), # Support for top-k sampling
repetition_penalty=data.get("repetition_penalty", 1.0),
length_penalty=data.get("length_penalty", 1.0),
stop_token_ids=data.get("stop_token_ids", None),
skip_special_tokens=data.get("skip_special_tokens", True)
)
# Generate output
outputs = model.generate(formatted_prompt, sampling_params)
generated_text = outputs[0].outputs[0].text
# Build response
response = {
"id": f"chatcmpl-{random_uuid()}",
"object": "chat.completion",
"created": int(torch.cuda.current_timestamp()),
"model": "qwen-72b",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": generated_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(formatted_prompt),
"completion_tokens": len(generated_text),
"total_tokens": len(formatted_prompt) + len(generated_text)
}
}
return response
except Exception as e:
return {"error": str(e), "details": repr(e)}
# Define input and output formats for SageMaker
def input_fn(serialized_input_data, content_type):
return serialized_input_data
def output_fn(prediction_output, accept):
return json.dumps(prediction_output)