Imran1 commited on
Commit
5c36d2f
1 Parent(s): 4676515

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +37 -56
code/inference.py CHANGED
@@ -3,14 +3,14 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import List, Dict
5
  from accelerate import load_checkpoint_and_dispatch
6
- import fcntl # For file locking
7
- import os # For file operations
8
- import time # For sleep function
 
9
  # Global variables to persist the model and tokenizer between invocations
10
  model = None
11
  tokenizer = None
12
 
13
- # Function to format chat messages using Qwen's chat template
14
  def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
15
  """
16
  Format chat messages using Qwen's chat template.
@@ -20,83 +20,74 @@ def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
20
  def model_fn(model_dir, context=None):
21
  global model, tokenizer
22
 
23
- # Path to lock file for ensuring single loading
24
  lock_file = "/tmp/model_load.lock"
25
- # Path to in-progress file indicating model loading is happening
26
  in_progress_file = "/tmp/model_loading_in_progress"
27
 
28
- if model is not None:
29
- print("Model already loaded, skipping reload.")
30
  return model, tokenizer
31
 
32
- # Attempt to acquire the lock
33
  with open(lock_file, 'w') as lock:
34
  print("Attempting to acquire model load lock...")
35
- fcntl.flock(lock, fcntl.LOCK_EX) # Exclusive lock
36
 
37
  try:
38
- # Check if another worker is in the process of loading
39
  if os.path.exists(in_progress_file):
40
  print("Another worker is currently loading the model, waiting...")
41
-
42
- # Poll the in-progress flag until the other worker finishes loading
43
  while os.path.exists(in_progress_file):
44
- time.sleep(5) # Wait for 5 seconds before checking again
45
-
46
- print("Loading complete by another worker, skipping reload.")
47
- return model, tokenizer
48
-
49
- # If no one is loading, start loading the model and set the in-progress flag
50
- print("No one is loading, proceeding to load the model.")
51
  with open(in_progress_file, 'w') as f:
52
  f.write("loading")
53
 
54
- if model is None:
55
- print("Loading the model and tokenizer...")
 
56
 
57
- offload_dir = "/tmp/offload_dir"
58
- os.makedirs(offload_dir, exist_ok=True)
59
 
60
- # Load and dispatch model across 8 GPUs using tensor parallelism
61
- model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
62
- model = load_checkpoint_and_dispatch(
63
- model,
64
- model_dir,
65
- device_map="auto", # Automatically map layers across GPUs
66
- offload_folder=offload_dir, # Offload parts to disk if needed
67
- max_memory={i: "24GiB" for i in range(8)} # Set memory limit per GPU
68
- )
69
 
70
- # Load the tokenizer
71
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
72
 
 
 
 
73
  finally:
74
- # Remove the in-progress flag once the loading is complete
75
  if os.path.exists(in_progress_file):
76
  os.remove(in_progress_file)
77
-
78
- # Release the lock
79
  fcntl.flock(lock, fcntl.LOCK_UN)
80
 
81
  return model, tokenizer
82
 
83
- # Custom predict function for SageMaker
84
- def predict_fn(input_data, model_and_tokenizer,context=None):
85
  """
86
  Generate predictions for the input data.
87
  """
88
  try:
89
  model, tokenizer = model_and_tokenizer
90
- data = json.loads(input_data)
 
91
 
92
- # Format the prompt using Qwen's chat template
93
  messages = data.get("messages", [])
94
  formatted_prompt = format_chat(messages, tokenizer)
95
 
96
- # Tokenize the input
97
- inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") # Send input to GPU 0 for generation
98
 
99
- # Generate output
100
  outputs = model.generate(
101
  inputs['input_ids'],
102
  max_new_tokens=data.get("max_new_tokens", 512),
@@ -107,10 +98,8 @@ def predict_fn(input_data, model_and_tokenizer,context=None):
107
  do_sample=True
108
  )
109
 
110
- # Decode the output
111
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
112
 
113
- # Build response
114
  response = {
115
  "id": "chatcmpl-uuid",
116
  "object": "chat.completion",
@@ -134,16 +123,8 @@ def predict_fn(input_data, model_and_tokenizer,context=None):
134
  except Exception as e:
135
  return {"error": str(e), "details": repr(e)}
136
 
137
- # Define input format for SageMaker
138
- def input_fn(serialized_input_data, content_type,context=None):
139
- """
140
- Prepare the input data for inference.
141
- """
142
  return serialized_input_data
143
 
144
- # Define output format for SageMaker
145
  def output_fn(prediction_output, accept, context=None):
146
- """
147
- Convert the model output to a JSON response.
148
- """
149
  return json.dumps(prediction_output)
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import List, Dict
5
  from accelerate import load_checkpoint_and_dispatch
6
+ import fcntl
7
+ import os
8
+ import time
9
+
10
  # Global variables to persist the model and tokenizer between invocations
11
  model = None
12
  tokenizer = None
13
 
 
14
  def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
15
  """
16
  Format chat messages using Qwen's chat template.
 
20
  def model_fn(model_dir, context=None):
21
  global model, tokenizer
22
 
 
23
  lock_file = "/tmp/model_load.lock"
 
24
  in_progress_file = "/tmp/model_loading_in_progress"
25
 
26
+ if model is not None and tokenizer is not None:
27
+ print("Model and tokenizer already loaded, skipping reload.")
28
  return model, tokenizer
29
 
 
30
  with open(lock_file, 'w') as lock:
31
  print("Attempting to acquire model load lock...")
32
+ fcntl.flock(lock, fcntl.LOCK_EX)
33
 
34
  try:
 
35
  if os.path.exists(in_progress_file):
36
  print("Another worker is currently loading the model, waiting...")
 
 
37
  while os.path.exists(in_progress_file):
38
+ time.sleep(5)
39
+ print("Loading complete by another worker.")
40
+ if model is not None and tokenizer is not None:
41
+ return model, tokenizer
42
+
43
+ print("Proceeding to load the model and tokenizer.")
 
44
  with open(in_progress_file, 'w') as f:
45
  f.write("loading")
46
 
47
+ print("Loading the model and tokenizer...")
48
+ offload_dir = "/tmp/offload_dir"
49
+ os.makedirs(offload_dir, exist_ok=True)
50
 
51
+ # Load the tokenizer first
52
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
53
 
54
+ # Load and dispatch model across GPUs
55
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
56
+ model = load_checkpoint_and_dispatch(
57
+ model,
58
+ model_dir,
59
+ device_map="auto",
60
+ offload_folder=offload_dir,
61
+ max_memory={i: "24GiB" for i in range(8)}
62
+ )
63
 
64
+ print("Model and tokenizer loaded successfully.")
 
65
 
66
+ except Exception as e:
67
+ print(f"Error loading model: {str(e)}")
68
+ raise
69
  finally:
 
70
  if os.path.exists(in_progress_file):
71
  os.remove(in_progress_file)
 
 
72
  fcntl.flock(lock, fcntl.LOCK_UN)
73
 
74
  return model, tokenizer
75
 
76
+ def predict_fn(input_data, model_and_tokenizer, context=None):
 
77
  """
78
  Generate predictions for the input data.
79
  """
80
  try:
81
  model, tokenizer = model_and_tokenizer
82
+ if model is None or tokenizer is None:
83
+ raise ValueError("Model or tokenizer is None. Please ensure they are loaded correctly.")
84
 
85
+ data = json.loads(input_data)
86
  messages = data.get("messages", [])
87
  formatted_prompt = format_chat(messages, tokenizer)
88
 
89
+ inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0")
 
90
 
 
91
  outputs = model.generate(
92
  inputs['input_ids'],
93
  max_new_tokens=data.get("max_new_tokens", 512),
 
98
  do_sample=True
99
  )
100
 
 
101
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
102
 
 
103
  response = {
104
  "id": "chatcmpl-uuid",
105
  "object": "chat.completion",
 
123
  except Exception as e:
124
  return {"error": str(e), "details": repr(e)}
125
 
126
+ def input_fn(serialized_input_data, content_type, context=None):
 
 
 
 
127
  return serialized_input_data
128
 
 
129
  def output_fn(prediction_output, accept, context=None):
 
 
 
130
  return json.dumps(prediction_output)