File size: 4,523 Bytes
7b79b7e
40362ea
 
 
544c001
5c36d2f
 
 
 
544c001
 
 
40362ea
 
6c47fa9
 
 
40362ea
d952a77
a254d16
629b5aa
325c145
629b5aa
325c145
d952a77
5c36d2f
 
629b5aa
 
 
 
5c36d2f
629b5aa
 
325c145
 
 
5c36d2f
 
 
 
 
 
325c145
 
 
5c36d2f
 
 
40362ea
5c36d2f
 
629b5aa
5c36d2f
 
 
 
 
 
 
 
 
629b5aa
5c36d2f
629b5aa
5c36d2f
 
 
629b5aa
325c145
 
629b5aa
 
40362ea
544c001
5c36d2f
6c47fa9
 
 
40362ea
 
5c36d2f
 
40362ea
5c36d2f
40362ea
 
 
5c36d2f
40362ea
6c47fa9
 
 
 
 
 
 
 
 
40362ea
 
 
 
6c47fa9
40362ea
6c47fa9
40362ea
 
 
 
 
 
 
 
 
 
6c47fa9
 
40362ea
 
 
 
 
 
 
5c36d2f
40362ea
 
6c47fa9
54fb1b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List, Dict
from accelerate import load_checkpoint_and_dispatch
import fcntl
import os
import time

# Global variables to persist the model and tokenizer between invocations
model = None
tokenizer = None

def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
    """
    Format chat messages using Qwen's chat template.
    """
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def model_fn(model_dir, context=None):
    global model, tokenizer

    lock_file = "/tmp/model_load.lock"
    in_progress_file = "/tmp/model_loading_in_progress"
    
    if model is not None and tokenizer is not None:
        print("Model and tokenizer already loaded, skipping reload.")
        return model, tokenizer

    with open(lock_file, 'w') as lock:
        print("Attempting to acquire model load lock...")
        fcntl.flock(lock, fcntl.LOCK_EX)

        try:
            if os.path.exists(in_progress_file):
                print("Another worker is currently loading the model, waiting...")
                while os.path.exists(in_progress_file):
                    time.sleep(5)
                print("Loading complete by another worker.")
                if model is not None and tokenizer is not None:
                    return model, tokenizer
            
            print("Proceeding to load the model and tokenizer.")
            with open(in_progress_file, 'w') as f:
                f.write("loading")

            print("Loading the model and tokenizer...")
            offload_dir = "/tmp/offload_dir"
            os.makedirs(offload_dir, exist_ok=True)

            # Load the tokenizer first
            tokenizer = AutoTokenizer.from_pretrained(model_dir)

            # Load and dispatch model across GPUs
            model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
            model = load_checkpoint_and_dispatch(
                model, 
                model_dir, 
                device_map="auto",
                offload_folder=offload_dir,
                max_memory={i: "24GiB" for i in range(8)}
            )

            print("Model and tokenizer loaded successfully.")

        except Exception as e:
            print(f"Error loading model: {str(e)}")
            raise
        finally:
            if os.path.exists(in_progress_file):
                os.remove(in_progress_file)
            fcntl.flock(lock, fcntl.LOCK_UN)
    
    return model, tokenizer

def predict_fn(input_data, model_and_tokenizer, context=None):
    """
    Generate predictions for the input data.
    """
    try:
        model, tokenizer = model_and_tokenizer
        if model is None or tokenizer is None:
            raise ValueError("Model or tokenizer is None. Please ensure they are loaded correctly.")

        data = json.loads(input_data)
        messages = data.get("messages", [])
        formatted_prompt = format_chat(messages, tokenizer)

        inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0")

        outputs = model.generate(
            inputs['input_ids'],
            max_new_tokens=data.get("max_new_tokens", 512),
            temperature=data.get("temperature", 0.7),
            top_p=data.get("top_p", 0.9),
            repetition_penalty=data.get("repetition_penalty", 1.0),
            length_penalty=data.get("length_penalty", 1.0),
            do_sample=True
        )

        generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        response = {
            "id": "chatcmpl-uuid",
            "object": "chat.completion",
            "model": "qwen-72b",
            "choices": [{
                "index": 0,
                "message": {
                    "role": "assistant",
                    "content": generated_text
                },
                "finish_reason": "stop"
            }],
            "usage": {
                "prompt_tokens": len(inputs['input_ids'][0]),
                "completion_tokens": len(outputs[0]),
                "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
            }
        }
        return response

    except Exception as e:
        return {"error": str(e), "details": repr(e)}

def input_fn(serialized_input_data, content_type, context=None):
    return serialized_input_data

def output_fn(prediction_output, accept, context=None):
    return json.dumps(prediction_output)