Update code/inference.py
Browse files- code/inference.py +24 -16
code/inference.py
CHANGED
@@ -2,7 +2,11 @@ import json
|
|
2 |
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
from typing import List, Dict
|
5 |
-
from accelerate import
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Function to format chat messages using Qwen's chat template
|
8 |
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
|
@@ -16,25 +20,29 @@ def model_fn(model_dir, context=None):
|
|
16 |
"""
|
17 |
Load the model and tokenizer from the model directory for inference.
|
18 |
Supports tensor parallelism across multiple GPUs with offloading.
|
|
|
19 |
"""
|
20 |
-
|
21 |
-
offload_dir = "/tmp/offload_dir" # Ensure SageMaker has write access to this directory
|
22 |
-
|
23 |
-
# Use `Accelerate` to load and dispatch the model across GPUs
|
24 |
-
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
|
25 |
-
|
26 |
-
# Distribute the model across multiple GPUs
|
27 |
-
model = load_checkpoint_and_dispatch(
|
28 |
-
model,
|
29 |
-
model_dir,
|
30 |
-
device_map="auto", # Automatically map model layers across devices
|
31 |
-
offload_folder=offload_dir, # Offload parts of the model to disk if GPU memory is insufficient
|
32 |
-
)
|
33 |
|
34 |
-
#
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
return model, tokenizer
|
|
|
38 |
# Custom predict function for SageMaker
|
39 |
def predict_fn(input_data, model_and_tokenizer):
|
40 |
"""
|
|
|
2 |
import torch
|
3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
from typing import List, Dict
|
5 |
+
from accelerate import load_checkpoint_and_dispatch
|
6 |
+
|
7 |
+
# Global variables to persist the model and tokenizer between invocations
|
8 |
+
model = None
|
9 |
+
tokenizer = None
|
10 |
|
11 |
# Function to format chat messages using Qwen's chat template
|
12 |
def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
|
|
|
20 |
"""
|
21 |
Load the model and tokenizer from the model directory for inference.
|
22 |
Supports tensor parallelism across multiple GPUs with offloading.
|
23 |
+
The model is loaded only once and stored in a global variable.
|
24 |
"""
|
25 |
+
global model, tokenizer # Declare model and tokenizer as global to persist across invocations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
if model is None: # Check if the model is already loaded
|
28 |
+
print("Loading the model and tokenizer...")
|
29 |
+
# Define an offload directory for any model components that can't fit in GPU memory
|
30 |
+
offload_dir = "/tmp/offload_dir" # Ensure SageMaker has write access to this directory
|
31 |
+
|
32 |
+
# Load and dispatch the model across multiple GPUs
|
33 |
+
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
|
34 |
+
model = load_checkpoint_and_dispatch(
|
35 |
+
model,
|
36 |
+
model_dir,
|
37 |
+
device_map="auto", # Automatically map model layers across devices
|
38 |
+
offload_folder=offload_dir, # Offload parts of the model to disk if GPU memory is insufficient
|
39 |
+
)
|
40 |
+
|
41 |
+
# Load the tokenizer
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
43 |
|
44 |
return model, tokenizer
|
45 |
+
|
46 |
# Custom predict function for SageMaker
|
47 |
def predict_fn(input_data, model_and_tokenizer):
|
48 |
"""
|