Imran1 commited on
Commit
544c001
1 Parent(s): 13252ff

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +24 -16
code/inference.py CHANGED
@@ -2,7 +2,11 @@ import json
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import List, Dict
5
- from accelerate import infer_auto_device_map, dispatch_model, load_checkpoint_and_dispatch
 
 
 
 
6
 
7
  # Function to format chat messages using Qwen's chat template
8
  def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
@@ -16,25 +20,29 @@ def model_fn(model_dir, context=None):
16
  """
17
  Load the model and tokenizer from the model directory for inference.
18
  Supports tensor parallelism across multiple GPUs with offloading.
 
19
  """
20
- # Define an offload directory for any model components that can't fit in GPU memory
21
- offload_dir = "/tmp/offload_dir" # Ensure SageMaker has write access to this directory
22
-
23
- # Use `Accelerate` to load and dispatch the model across GPUs
24
- model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
25
-
26
- # Distribute the model across multiple GPUs
27
- model = load_checkpoint_and_dispatch(
28
- model,
29
- model_dir,
30
- device_map="auto", # Automatically map model layers across devices
31
- offload_folder=offload_dir, # Offload parts of the model to disk if GPU memory is insufficient
32
- )
33
 
34
- # Load the tokenizer
35
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  return model, tokenizer
 
38
  # Custom predict function for SageMaker
39
  def predict_fn(input_data, model_and_tokenizer):
40
  """
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import List, Dict
5
+ from accelerate import load_checkpoint_and_dispatch
6
+
7
+ # Global variables to persist the model and tokenizer between invocations
8
+ model = None
9
+ tokenizer = None
10
 
11
  # Function to format chat messages using Qwen's chat template
12
  def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
 
20
  """
21
  Load the model and tokenizer from the model directory for inference.
22
  Supports tensor parallelism across multiple GPUs with offloading.
23
+ The model is loaded only once and stored in a global variable.
24
  """
25
+ global model, tokenizer # Declare model and tokenizer as global to persist across invocations
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ if model is None: # Check if the model is already loaded
28
+ print("Loading the model and tokenizer...")
29
+ # Define an offload directory for any model components that can't fit in GPU memory
30
+ offload_dir = "/tmp/offload_dir" # Ensure SageMaker has write access to this directory
31
+
32
+ # Load and dispatch the model across multiple GPUs
33
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
34
+ model = load_checkpoint_and_dispatch(
35
+ model,
36
+ model_dir,
37
+ device_map="auto", # Automatically map model layers across devices
38
+ offload_folder=offload_dir, # Offload parts of the model to disk if GPU memory is insufficient
39
+ )
40
+
41
+ # Load the tokenizer
42
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
43
 
44
  return model, tokenizer
45
+
46
  # Custom predict function for SageMaker
47
  def predict_fn(input_data, model_and_tokenizer):
48
  """