File size: 3,865 Bytes
1d61c11
40362ea
 
 
544c001
 
 
 
40362ea
 
 
 
d952a77
f74b603
a254d16
f74b603
d952a77
f74b603
 
 
 
 
 
 
 
 
544c001
f74b603
 
 
 
 
 
 
 
 
 
544c001
1d61c11
 
f74b603
 
 
544c001
40362ea
 
544c001
40362ea
 
 
 
 
 
 
 
 
 
 
f74b603
40362ea
 
f74b603
 
 
 
 
 
 
 
 
 
40362ea
 
 
 
 
 
f74b603
40362ea
f74b603
40362ea
 
 
 
 
 
 
 
 
 
f74b603
 
40362ea
 
 
 
 
 
 
f74b603
40362ea
73941c5
40362ea
 
 
 
 
 
1d61c11
40362ea
 
 
1d61c11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
 import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from accelerate import load_checkpoint_and_dispatch
# Global variables to persist the model and tokenizer between invocations
model = None
tokenizer = None

# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Model loading function for SageMaker with tensor parallelism and FP8 quantization
def model_fn(model_dir, context=None):
    global model, tokenizer
    
    if model is None:
        print("Loading the FP8 quantized model and tokenizer...")
        
        # Define an offload directory
        offload_dir = "/tmp/offload_dir"
        os.makedirs(offload_dir, exist_ok=True)

        # Load the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_dir)

        # Load the FP8 quantized model
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            torch_dtype=torch.float8,  # Specify FP8 dtype
            low_cpu_mem_usage=True,
            device_map="auto",
            offload_folder=offload_dir,
        )
        
        # Use load_checkpoint_and_dispatch for tensor parallelism
        model = load_checkpoint_and_dispatch(
            model, 
            model_dir, 
            device_map="auto",
            offload_folder=offload_dir,
            no_split_module_classes=["QWenLMHeadModel"],  # Adjust if needed for Qwen architecture
        )

    return model, tokenizer

# Custom predict function for SageMaker
def predict_fn(input_data, model_and_tokenizer):
    try:
        model, tokenizer = model_and_tokenizer
        data = json.loads(input_data)

        # Format the prompt using Qwen's chat template
        messages = data.get("messages", [])
        formatted_prompt = format_chat(messages, tokenizer)

        # Tokenize the input
        inputs = tokenizer([formatted_prompt], return_tensors="pt").to(model.device)

        # Generate output
        with torch.no_grad():
            outputs = model.generate(
                inputs['input_ids'],
                max_new_tokens=data.get("max_new_tokens", 512),
                temperature=data.get("temperature", 0.7),
                top_p=data.get("top_p", 0.9),
                repetition_penalty=data.get("repetition_penalty", 1.0),
                length_penalty=data.get("length_penalty", 1.0),
                do_sample=True
            )

        # Decode the output
        generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        # Build response
        response = {
            "id": "chatcmpl-fp8-quantized",
            "object": "chat.completion",
            "model": "qwen-72b-fp8",
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": generated_text
                },
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": len(inputs['input_ids'][0]),
                "completion_tokens": len(outputs[0]) - len(inputs['input_ids'][0]),
                "total_tokens": len(outputs[0])
            }
        }
        return response

    except Exception as e:
        return {"error": str(e), "details": repr(e)}


# Define input format for SageMaker
def input_fn(serialized_input_data, content_type, context=None):
    """
    Prepare the input data for inference.
    """
    return serialized_input_data

# Define output format for SageMaker
def output_fn(prediction_output, accept , context=None):
    """
    Convert the model output to a JSON response.
    """
    return json.dumps(prediction_output)  you have my code