Imran1 commited on
Commit
eea1116
1 Parent(s): 272e727

Create inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +162 -0
code/inference.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from typing import List, Dict
5
+ from accelerate import load_checkpoint_and_dispatch
6
+ import fcntl # For file locking
7
+ import os # For file operations
8
+ import time # For sleep function
9
+
10
+ # Set the max_split_size globally at the start
11
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
12
+
13
+ # Print to verify the environment variable is correctly set
14
+ print(f"PYTORCH_CUDA_ALLOC_CONF: {os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}")
15
+
16
+ # Global variables to persist the model and tokenizer between invocations
17
+ model = None
18
+ tokenizer = None
19
+
20
+ # Function to format chat messages using Qwen's chat template
21
+ def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
22
+ """
23
+ Format chat messages using Qwen's chat template.
24
+ """
25
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
26
+
27
+ def model_fn(model_dir, context=None):
28
+ global model, tokenizer
29
+
30
+ # Path to lock file for ensuring single loading
31
+ lock_file = "/tmp/model_load.lock"
32
+ # Path to in-progress file indicating model loading is happening
33
+ in_progress_file = "/tmp/model_loading_in_progress"
34
+
35
+ if model is not None and tokenizer is not None:
36
+ print("Model and tokenizer already loaded, skipping reload.")
37
+ return model, tokenizer
38
+
39
+ # Attempt to acquire the lock
40
+ with open(lock_file, 'w') as lock:
41
+ print("Attempting to acquire model load lock...")
42
+ fcntl.flock(lock, fcntl.LOCK_EX) # Exclusive lock
43
+
44
+ try:
45
+ # Check if another worker is in the process of loading
46
+ if os.path.exists(in_progress_file):
47
+ print("Another worker is currently loading the model, waiting...")
48
+
49
+ # Poll the in-progress flag until the other worker finishes loading
50
+ while os.path.exists(in_progress_file):
51
+ time.sleep(5) # Wait for 5 seconds before checking again
52
+
53
+ print("Loading complete by another worker, skipping reload.")
54
+ return model, tokenizer
55
+
56
+ # If no one is loading, start loading the model and set the in-progress flag
57
+ print("No one is loading, proceeding to load the model.")
58
+ with open(in_progress_file, 'w') as f:
59
+ f.write("loading")
60
+
61
+ # Loading the model and tokenizer
62
+ if model is None or tokenizer is None:
63
+ print("Loading the model and tokenizer...")
64
+
65
+ offload_dir = "/tmp/offload_dir"
66
+ os.makedirs(offload_dir, exist_ok=True)
67
+
68
+ # Load and dispatch model across 4 GPUs using tensor parallelism
69
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
70
+ model = load_checkpoint_and_dispatch(
71
+ model,
72
+ model_dir,
73
+ device_map="auto", # Automatically map layers across GPUs
74
+ offload_folder=offload_dir, # Offload parts to disk if needed
75
+ max_memory = {i: "15GiB" for i in range(torch.cuda.device_count())} # Example for reducing usage per GPU
76
+ no_split_module_classes=["QwenForCausalLM"] # Ensure model is split across the GPUs
77
+ )
78
+
79
+ # Load the tokenizer
80
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
81
+
82
+ except Exception as e:
83
+ print(f"Error loading model and tokenizer: {e}")
84
+ raise
85
+
86
+ finally:
87
+ # Remove the in-progress flag once the loading is complete
88
+ if os.path.exists(in_progress_file):
89
+ os.remove(in_progress_file)
90
+
91
+ # Release the lock
92
+ fcntl.flock(lock, fcntl.LOCK_UN)
93
+
94
+ return model, tokenizer
95
+
96
+ # Custom predict function for SageMaker
97
+ def predict_fn(input_data, model_and_tokenizer, context=None):
98
+ """
99
+ Generate predictions for the input data.
100
+ """
101
+ try:
102
+ model, tokenizer = model_and_tokenizer
103
+ data = json.loads(input_data)
104
+
105
+ # Format the prompt using Qwen's chat template
106
+ messages = data.get("messages", [])
107
+ formatted_prompt = format_chat(messages, tokenizer)
108
+
109
+ # Tokenize the input
110
+ inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") # Send input to GPU 0 for generation
111
+
112
+ # Generate output
113
+ outputs = model.generate(
114
+ inputs['input_ids'],
115
+ max_new_tokens=data.get("max_new_tokens", 512),
116
+ temperature=data.get("temperature", 0.7),
117
+ top_p=data.get("top_p", 0.9),
118
+ repetition_penalty=data.get("repetition_penalty", 1.0),
119
+ length_penalty=data.get("length_penalty", 1.0),
120
+ do_sample=True
121
+ )
122
+
123
+ # Decode the output
124
+ generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
125
+
126
+ # Build response
127
+ response = {
128
+ "id": "chatcmpl-uuid",
129
+ "object": "chat.completion",
130
+ "model": "qwen-72b",
131
+ "choices": [{
132
+ "index": 0,
133
+ "message": {
134
+ "role": "assistant",
135
+ "content": generated_text
136
+ },
137
+ "finish_reason": "stop"
138
+ }],
139
+ "usage": {
140
+ "prompt_tokens": len(inputs['input_ids'][0]),
141
+ "completion_tokens": len(outputs[0]),
142
+ "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
143
+ }
144
+ }
145
+ return response
146
+
147
+ except Exception as e:
148
+ return {"error": str(e), "details": repr(e)}
149
+
150
+ # Define input format for SageMaker
151
+ def input_fn(serialized_input_data, content_type, context=None):
152
+ """
153
+ Prepare the input data for inference.
154
+ """
155
+ return serialized_input_data
156
+
157
+ # Define output format for SageMaker
158
+ def output_fn(prediction_output, accept, context=None):
159
+ """
160
+ Convert the model output to a JSON response.
161
+ """
162
+ return json.dumps(prediction_output)