File size: 6,187 Bytes
7b79b7e
5106444
 
 
 
 
 
 
5c36d2f
acb4176
 
 
 
 
939342b
 
 
 
5106444
 
 
40362ea
5106444
 
 
 
 
 
d952a77
5106444
 
325c145
5106444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
939342b
5106444
 
939342b
 
acb4176
 
 
 
5106444
 
 
 
 
acb4176
 
 
5106444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40362ea
5106444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c47fa9
5106444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40362ea
5106444
 
 
 
 
 
 
 
40362ea
5106444
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from accelerate import load_checkpoint_and_dispatch
import fcntl  # For file locking
import os  # For file operations
import time  # For sleep function

# Set max_split_size globally to prevent memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

# Enable detailed distributed logs
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"

# Print to verify the environment variable is correctly set
print(f"PYTORCH_CUDA_ALLOC_CONF: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")

# Global variables to persist the model and tokenizer between invocations
model = None
tokenizer = None

# Function to format chat messages using Qwen's chat template
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
    """
    Format chat messages using Qwen's chat template.
    """
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def model_fn(model_dir, context=None):
    global model, tokenizer

    # Path to lock file for ensuring single loading
    lock_file = "/tmp/model_load.lock"
    # Path to in-progress file indicating model loading is happening
    in_progress_file = "/tmp/model_loading_in_progress"

    if model is not None and tokenizer is not None:
        print("Model and tokenizer already loaded, skipping reload.")
        return model, tokenizer

    # Attempt to acquire the lock
    with open(lock_file, 'w') as lock:
        print("Attempting to acquire model load lock...")
        fcntl.flock(lock, fcntl.LOCK_EX)  # Exclusive lock

        try:
            # Check if another worker is in the process of loading
            if os.path.exists(in_progress_file):
                print("Another worker is currently loading the model, waiting...")

                # Poll the in-progress flag until the other worker finishes loading
                while os.path.exists(in_progress_file):
                    time.sleep(5)  # Wait for 5 seconds before checking again

                print("Loading complete by another worker, skipping reload.")
                return model, tokenizer

            # If no one is loading, start loading the model and set the in-progress flag
            print("No one is loading, proceeding to load the model.")
            with open(in_progress_file, 'w') as f:
                f.write("loading")

            # Loading the model and tokenizer
            if model is None or tokenizer is None:
                print("Loading the model and tokenizer...")

                offload_dir = "/tmp/offload_dir"
                os.makedirs(offload_dir, exist_ok=True)

                # Load and dispatch model across 4 GPUs using tensor parallelism
                model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
                model = load_checkpoint_and_dispatch(
                    model,
                    model_dir,
                    device_map="balanced",  # Evenly distribute across GPUs
                    offload_folder=offload_dir,
                    max_memory={i: "18GiB" for i in range(torch.cuda.device_count())},  # Allocate 18 GiB per GPU
                    no_split_module_classes=["QwenForCausalLM"]  # Split model across GPUs
                )

                # Load the tokenizer
                tokenizer = AutoTokenizer.from_pretrained(model_dir)

                # Free up any unused memory after loading
                torch.cuda.empty_cache()

        except Exception as e:
            print(f"Error loading model and tokenizer: {e}")
            raise

        finally:
            # Remove the in-progress flag once the loading is complete
            if os.path.exists(in_progress_file):
                os.remove(in_progress_file)

            # Release the lock
            fcntl.flock(lock, fcntl.LOCK_UN)

    return model, tokenizer

# Custom predict function for SageMaker
def predict_fn(input_data, model_and_tokenizer, context=None):
    """
    Generate predictions for the input data.
    """
    try:
        model, tokenizer = model_and_tokenizer
        data = json.loads(input_data)

        # Format the prompt using Qwen's chat template
        messages = data.get("messages", [])
        formatted_prompt = format_chat(messages, tokenizer)

        # Tokenize the input
        inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0")  # Send input to GPU 0 for generation

        # Generate output
        outputs = model.generate(
            inputs['input_ids'],
            max_new_tokens=data.get("max_new_tokens", 512),
            temperature=data.get("temperature", 0.7),
            top_p=data.get("top_p", 0.9),
            repetition_penalty=data.get("repetition_penalty", 1.0),
            length_penalty=data.get("length_penalty", 1.0),
            do_sample=True
        )

        # Decode the output
        generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        # Build response
        response = {
            "id": "chatcmpl-uuid",
            "object": "chat.completion",
            "model": "qwen-72b",
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": generated_text
                },
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": len(inputs['input_ids'][0]),
                "completion_tokens": len(outputs[0]),
                "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
            }
        }
        return response

    except Exception as e:
        return {"error": str(e), "details": repr(e)}

# Define input format for SageMaker
def input_fn(serialized_input_data, content_type, context=None):
    """
    Prepare the input data for inference.
    """
    return serialized_input_data

# Define output format for SageMaker
def output_fn(prediction_output, accept, context=None):
    """
    Convert the model output to a JSON response.
    """
    return json.dumps(prediction_output)