File size: 3,857 Bytes
d54b6e0
36cfeeb
d54b6e0
51ea417
df9d248
36cfeeb
5827354
51ea417
 
5827354
36cfeeb
d54b6e0
51ea417
d54b6e0
 
 
51ea417
 
 
edaf8b6
51ea417
 
 
 
 
 
 
 
 
 
 
 
 
5827354
51ea417
36cfeeb
51ea417
 
 
 
d54b6e0
 
51ea417
 
 
 
d54b6e0
51ea417
d54b6e0
 
 
 
51ea417
d54b6e0
51ea417
 
 
 
 
 
 
d54b6e0
 
 
 
51ea417
d54b6e0
 
51ea417
 
d54b6e0
 
 
51ea417
d54b6e0
 
 
 
 
 
 
 
 
 
 
51ea417
 
 
d54b6e0
 
 
51ea417
d54b6e0
 
 
51ea417
d54b6e0
51ea417
 
 
d54b6e0
 
51ea417
d54b6e0
51ea417
 
 
d54b6e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json
import subprocess
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict

# Install accelerate dynamically (though it's better to pre-install this in the container)
subprocess.check_call(['pip', 'install', 'accelerate'])

from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model  # Correct import

# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
    """
    Format chat messages using Qwen's chat template
    """
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Model loading function for SageMaker with tensor parallelism
def model_fn(model_dir,context=None):
    """
    Load the model and tokenizer from the model directory for inference.
    This version supports tensor parallelism across 4 GPUs.
    """
    device_map = "auto"  # To automatically infer the device map across GPUs

    # Initialize an empty model and then dispatch it to the devices
    with init_empty_weights():
        model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")

    # Automatically infer the device map to distribute the model across 4 GPUs
    device_map = infer_auto_device_map(model, max_memory={i: "25GiB" for i in range(4)}, no_split_module_classes=["QwenBlock"])

    # Dispatch the model to devices using accelerate's dispatch_model
    model = dispatch_model(model, device_map=device_map, offload_buffers=True)
    
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    return model, tokenizer

# Custom predict function for SageMaker
def predict_fn(input_data, model_and_tokenizer):
    """
    Generate predictions for the input data.
    """
    try:
        model, tokenizer = model_and_tokenizer
        data = json.loads(input_data)

        # Format the prompt using Qwen's chat template
        messages = data.get("messages", [])
        formatted_prompt = format_chat(messages, tokenizer)

        # Tokenize the input
        inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0")  # Send input to GPU 0 for generation

        # Generate output
        outputs = model.generate(
            inputs['input_ids'],
            max_new_tokens=data.get("max_new_tokens", 512),
            temperature=data.get("temperature", 0.7),
            top_p=data.get("top_p", 0.9),
            repetition_penalty=data.get("repetition_penalty", 1.0),
            length_penalty=data.get("length_penalty", 1.0),
            do_sample=True
        )

        # Decode the output
        generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        # Build response
        response = {
            "id": "chatcmpl-uuid",
            "object": "chat.completion",
            "model": "qwen-72b",
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": generated_text
                },
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": len(inputs['input_ids'][0]),
                "completion_tokens": len(outputs[0]),
                "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
            }
        }
        return response

    except Exception as e:
        return {"error": str(e), "details": repr(e)}

# Define input format for SageMaker
def input_fn(serialized_input_data, content_type):
    """
    Prepare the input data for inference.
    """
    return serialized_input_data

# Define output format for SageMaker
def output_fn(prediction_output, accept):
    """
    Convert the model output to a JSON response.
    """
    return json.dumps(prediction_output)