Imran1 commited on
Commit
5106444
1 Parent(s): 28a42cc

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +153 -37
code/inference.py CHANGED
@@ -1,47 +1,163 @@
1
  import json
2
- import logging
3
- import os
 
 
 
 
 
4
 
5
- from vllm import LLM, SamplingParams
 
 
6
 
7
- logger = logging.getLogger()
8
- logger.setLevel(logging.INFO)
 
 
 
 
9
 
10
- def model_fn(model_dir):
11
- model = LLM(
12
- model=model_dir,
13
- trust_remote_code=True,
14
- dtype="float16",
15
- tensor_parallel_size=4, # Use 4 GPUs for parallelization
16
- gpu_memory_utilization=0.9,
17
- )
18
- return model
19
 
20
- def predict_fn(data, model, context=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  try:
22
- input_text = data.pop("inputs", data)
23
- parameters = data.pop("parameters", {})
24
-
25
- sampling_params = SamplingParams(
26
- temperature=parameters.get("temperature", 0.7),
27
- top_p=parameters.get("top_p", 0.9),
28
- max_tokens=parameters.get("max_tokens", 512),
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
-
31
- outputs = model.generate(input_text, sampling_params)
32
- generated_text = outputs[0].outputs[0].text
33
-
34
- return {"generated_text": generated_text}
35
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  except Exception as e:
37
- logger.error(f"Exception during prediction: {e}")
38
- return {"error": str(e)}
 
 
 
 
 
 
39
 
40
- def input_fn(request_body, request_content_type, context=None):
41
- if request_content_type == "application/json":
42
- return json.loads(request_body)
43
- else:
44
- raise ValueError(f"Unsupported content type: {request_content_type}")
 
45
 
46
- def output_fn(prediction, accept, context=None):
47
- return json.dumps(prediction)
 
 
 
1
  import json
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from typing import List, Dict
5
+ from accelerate import load_checkpoint_and_dispatch
6
+ import fcntl # For file locking
7
+ import os # For file operations
8
+ import time # For sleep function
9
 
10
+ # Global variables to persist the model and tokenizer between invocations
11
+ model = None
12
+ tokenizer = None
13
 
14
+ # Function to format chat messages using Qwen's chat template
15
+ def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
16
+ """
17
+ Format chat messages using Qwen's chat template.
18
+ """
19
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
20
 
21
+ def model_fn(model_dir, context=None):
22
+ global model, tokenizer
 
 
 
 
 
 
 
23
 
24
+ # Path to lock file for ensuring single loading
25
+ lock_file = "/tmp/model_load.lock"
26
+ # Path to in-progress file indicating model loading is happening
27
+ in_progress_file = "/tmp/model_loading_in_progress"
28
+
29
+ if model is not None and tokenizer is not None:
30
+ print("Model and tokenizer already loaded, skipping reload.")
31
+ return model, tokenizer
32
+
33
+ # Attempt to acquire the lock
34
+ with open(lock_file, 'w') as lock:
35
+ print("Attempting to acquire model load lock...")
36
+ fcntl.flock(lock, fcntl.LOCK_EX) # Exclusive lock
37
+
38
+ try:
39
+ # Check if another worker is in the process of loading
40
+ if os.path.exists(in_progress_file):
41
+ print("Another worker is currently loading the model, waiting...")
42
+
43
+ # Poll the in-progress flag until the other worker finishes loading
44
+ while os.path.exists(in_progress_file):
45
+ time.sleep(5) # Wait for 5 seconds before checking again
46
+
47
+ print("Loading complete by another worker, skipping reload.")
48
+ return model, tokenizer
49
+
50
+ # If no one is loading, start loading the model and set the in-progress flag
51
+ print("No one is loading, proceeding to load the model.")
52
+ with open(in_progress_file, 'w') as f:
53
+ f.write("loading")
54
+
55
+ # Loading the model and tokenizer
56
+ if model is None or tokenizer is None:
57
+ print("Loading the model and tokenizer...")
58
+
59
+ offload_dir = "/tmp/offload_dir"
60
+ os.makedirs(offload_dir, exist_ok=True)
61
+
62
+ # Reduce memory fragmentation by setting the max split size
63
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
64
+
65
+ # Load and dispatch model across GPUs with tensor parallelism
66
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
67
+ model = load_checkpoint_and_dispatch(
68
+ model,
69
+ model_dir,
70
+ device_map="auto", # Automatically map layers across GPUs
71
+ offload_folder=offload_dir, # Offload parts to disk if needed
72
+ max_memory={i: "20GiB" for i in range(torch.cuda.device_count())} # Adjust memory per GPU
73
+ )
74
+
75
+ # Load the tokenizer
76
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
77
+
78
+ except Exception as e:
79
+ print(f"Error loading model and tokenizer: {e}")
80
+ raise
81
+
82
+ finally:
83
+ # Remove the in-progress flag once the loading is complete
84
+ if os.path.exists(in_progress_file):
85
+ os.remove(in_progress_file)
86
+
87
+ # Release the lock
88
+ fcntl.flock(lock, fcntl.LOCK_UN)
89
+
90
+ return model, tokenizer
91
+
92
+ # Custom predict function for SageMaker
93
+ def predict_fn(input_data, model_and_tokenizer, context=None):
94
+ """
95
+ Generate predictions for the input data.
96
+ """
97
  try:
98
+ model, tokenizer = model_and_tokenizer
99
+ data = json.loads(input_data)
100
+
101
+ # Format the prompt using Qwen's chat template
102
+ messages = data.get("messages", [])
103
+ formatted_prompt = format_chat(messages, tokenizer)
104
+
105
+ # Tokenize the input
106
+ inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") # Send input to GPU 0 for generation
107
+
108
+ # Generate output
109
+ outputs = model.generate(
110
+ inputs['input_ids'],
111
+ max_new_tokens=data.get("max_new_tokens", 512),
112
+ temperature=data.get("temperature", 0.7),
113
+ top_p=data.get("top_p", 0.9),
114
+ repetition_penalty=data.get("repetition_penalty", 1.0),
115
+ length_penalty=data.get("length_penalty", 1.0),
116
+ do_sample=True
117
  )
118
+
119
+ # Decode the output
120
+ generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
121
+
122
+ # Build response
123
+ response = {
124
+ "id": "chatcmpl-uuid",
125
+ "object": "chat.completion",
126
+ "model": "qwen-72b",
127
+ "choices": [{
128
+ "index": 0,
129
+ "message": {
130
+ "role": "assistant",
131
+ "content": generated_text
132
+ },
133
+ "finish_reason": "stop"
134
+ }],
135
+ "usage": {
136
+ "prompt_tokens": len(inputs['input_ids'][0]),
137
+ "completion_tokens": len(outputs[0]),
138
+ "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
139
+ }
140
+ }
141
+ return response
142
+
143
  except Exception as e:
144
+ return {"error": str(e), "details": repr(e)}
145
+
146
+ # Define input format for SageMaker
147
+ def input_fn(serialized_input_data, content_type, context=None):
148
+ """
149
+ Prepare the input data for inference.
150
+ """
151
+ return serialized_input_data
152
 
153
+ # Define output format for SageMaker
154
+ def output_fn(prediction_output, accept, context=None):
155
+ """
156
+ Convert the model output to a JSON response.
157
+ """
158
+ return json.dumps(prediction_output)
159
 
160
+ # Memory tracker for debugging
161
+ def track_memory():
162
+ print(f"Total allocated memory on GPU 0: {torch.cuda.memory_allocated(0) / 1024 ** 3:.2f} GB")
163
+ print(f"Total reserved memory on GPU 0: {torch.cuda.memory_reserved(0) / 1024 ** 3:.2f} GB")