Imran1 commited on
Commit
325c145
1 Parent(s): 629b5aa

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +29 -9
code/inference.py CHANGED
@@ -4,7 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import List, Dict
5
  from accelerate import load_checkpoint_and_dispatch
6
  import fcntl # For file locking
7
-
8
  # Global variables to persist the model and tokenizer between invocations
9
  model = None
10
  tokenizer = None
@@ -16,12 +16,13 @@ def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
16
  """
17
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
18
 
19
- # Model loading function for SageMaker with tensor parallelism and offloading
20
  def model_fn(model_dir, context=None):
21
  global model, tokenizer
22
-
23
  # Path to lock file for ensuring single loading
24
  lock_file = "/tmp/model_load.lock"
 
 
25
 
26
  if model is not None:
27
  print("Model already loaded, skipping reload.")
@@ -33,27 +34,46 @@ def model_fn(model_dir, context=None):
33
  fcntl.flock(lock, fcntl.LOCK_EX) # Exclusive lock
34
 
35
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  if model is None:
37
  print("Loading the model and tokenizer...")
38
 
39
  offload_dir = "/tmp/offload_dir"
40
  os.makedirs(offload_dir, exist_ok=True)
41
 
42
- # Load and dispatch model
43
- model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
44
  model = load_checkpoint_and_dispatch(
45
  model,
46
  model_dir,
47
- device_map="auto",
48
- offload_folder=offload_dir
 
49
  )
50
 
51
  # Load the tokenizer
52
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
53
 
54
- else:
55
- print("Another process loaded the model while waiting for the lock.")
56
  finally:
 
 
 
 
57
  # Release the lock
58
  fcntl.flock(lock, fcntl.LOCK_UN)
59
 
 
4
  from typing import List, Dict
5
  from accelerate import load_checkpoint_and_dispatch
6
  import fcntl # For file locking
7
+ import os
8
  # Global variables to persist the model and tokenizer between invocations
9
  model = None
10
  tokenizer = None
 
16
  """
17
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
18
 
 
19
  def model_fn(model_dir, context=None):
20
  global model, tokenizer
21
+
22
  # Path to lock file for ensuring single loading
23
  lock_file = "/tmp/model_load.lock"
24
+ # Path to in-progress file indicating model loading is happening
25
+ in_progress_file = "/tmp/model_loading_in_progress"
26
 
27
  if model is not None:
28
  print("Model already loaded, skipping reload.")
 
34
  fcntl.flock(lock, fcntl.LOCK_EX) # Exclusive lock
35
 
36
  try:
37
+ # Check if another worker is in the process of loading
38
+ if os.path.exists(in_progress_file):
39
+ print("Another worker is currently loading the model, waiting...")
40
+
41
+ # Poll the in-progress flag until the other worker finishes loading
42
+ while os.path.exists(in_progress_file):
43
+ time.sleep(5) # Wait for 5 seconds before checking again
44
+
45
+ print("Loading complete by another worker, skipping reload.")
46
+ return model, tokenizer
47
+
48
+ # If no one is loading, start loading the model and set the in-progress flag
49
+ print("No one is loading, proceeding to load the model.")
50
+ with open(in_progress_file, 'w') as f:
51
+ f.write("loading")
52
+
53
  if model is None:
54
  print("Loading the model and tokenizer...")
55
 
56
  offload_dir = "/tmp/offload_dir"
57
  os.makedirs(offload_dir, exist_ok=True)
58
 
59
+ # Load and dispatch model across 8 GPUs using tensor parallelism
60
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
61
  model = load_checkpoint_and_dispatch(
62
  model,
63
  model_dir,
64
+ device_map="auto", # Automatically map layers across GPUs
65
+ offload_folder=offload_dir, # Offload parts to disk if needed
66
+ max_memory={i: "24GiB" for i in range(8)} # Set memory limit per GPU
67
  )
68
 
69
  # Load the tokenizer
70
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
71
 
 
 
72
  finally:
73
+ # Remove the in-progress flag once the loading is complete
74
+ if os.path.exists(in_progress_file):
75
+ os.remove(in_progress_file)
76
+
77
  # Release the lock
78
  fcntl.flock(lock, fcntl.LOCK_UN)
79