Imran1 commited on
Commit
6c47fa9
1 Parent(s): cdf8b4a

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +42 -44
code/inference.py CHANGED
@@ -1,53 +1,53 @@
1
- import json
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import List, Dict
5
  from accelerate import load_checkpoint_and_dispatch
6
- import os
7
  # Global variables to persist the model and tokenizer between invocations
8
  model = None
9
  tokenizer = None
10
 
11
  # Function to format chat messages using Qwen's chat template
12
  def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
 
 
 
13
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
14
 
15
- # Model loading function for SageMaker with tensor parallelism and FP8 quantization
16
  def model_fn(model_dir, context=None):
17
- global model, tokenizer
 
 
 
 
 
18
 
19
- if model is None:
20
- print("Loading the FP8 quantized model and tokenizer...")
21
-
22
- # Define an offload directory
23
- offload_dir = "/tmp/offload_dir"
24
- os.makedirs(offload_dir, exist_ok=True)
25
-
26
- # Load the tokenizer
27
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
28
 
29
- # Load the FP8 quantized model
30
- model = AutoModelForCausalLM.from_pretrained(
31
- model_dir,
32
- torch_dtype=torch.float8, # Specify FP8 dtype
33
- low_cpu_mem_usage=True,
34
- device_map="auto",
35
- offload_folder=offload_dir,
36
- )
37
-
38
- # Use load_checkpoint_and_dispatch for tensor parallelism
39
  model = load_checkpoint_and_dispatch(
40
  model,
41
  model_dir,
42
- device_map="auto",
43
- offload_folder=offload_dir,
44
- no_split_module_classes=["QWenLMHeadModel"], # Adjust if needed for Qwen architecture
45
  )
 
 
 
46
 
47
  return model, tokenizer
48
 
49
  # Custom predict function for SageMaker
50
  def predict_fn(input_data, model_and_tokenizer):
 
 
 
51
  try:
52
  model, tokenizer = model_and_tokenizer
53
  data = json.loads(input_data)
@@ -57,28 +57,27 @@ def predict_fn(input_data, model_and_tokenizer):
57
  formatted_prompt = format_chat(messages, tokenizer)
58
 
59
  # Tokenize the input
60
- inputs = tokenizer([formatted_prompt], return_tensors="pt").to(model.device)
61
 
62
  # Generate output
63
- with torch.no_grad():
64
- outputs = model.generate(
65
- inputs['input_ids'],
66
- max_new_tokens=data.get("max_new_tokens", 512),
67
- temperature=data.get("temperature", 0.7),
68
- top_p=data.get("top_p", 0.9),
69
- repetition_penalty=data.get("repetition_penalty", 1.0),
70
- length_penalty=data.get("length_penalty", 1.0),
71
- do_sample=True
72
- )
73
 
74
  # Decode the output
75
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
76
 
77
  # Build response
78
  response = {
79
- "id": "chatcmpl-fp8-quantized",
80
  "object": "chat.completion",
81
- "model": "qwen-72b-fp8",
82
  "choices": [{
83
  "index": 0,
84
  "message": {
@@ -89,8 +88,8 @@ def predict_fn(input_data, model_and_tokenizer):
89
  }],
90
  "usage": {
91
  "prompt_tokens": len(inputs['input_ids'][0]),
92
- "completion_tokens": len(outputs[0]) - len(inputs['input_ids'][0]),
93
- "total_tokens": len(outputs[0])
94
  }
95
  }
96
  return response
@@ -98,16 +97,15 @@ def predict_fn(input_data, model_and_tokenizer):
98
  except Exception as e:
99
  return {"error": str(e), "details": repr(e)}
100
 
101
-
102
  # Define input format for SageMaker
103
- def input_fn(serialized_input_data, content_type, context=None):
104
  """
105
  Prepare the input data for inference.
106
  """
107
  return serialized_input_data
108
 
109
  # Define output format for SageMaker
110
- def output_fn(prediction_output, accept , context=None):
111
  """
112
  Convert the model output to a JSON response.
113
  """
 
1
+ import json
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import List, Dict
5
  from accelerate import load_checkpoint_and_dispatch
6
+
7
  # Global variables to persist the model and tokenizer between invocations
8
  model = None
9
  tokenizer = None
10
 
11
  # Function to format chat messages using Qwen's chat template
12
  def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
13
+ """
14
+ Format chat messages using Qwen's chat template.
15
+ """
16
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
17
 
18
+ # Model loading function for SageMaker with tensor parallelism and offloading
19
  def model_fn(model_dir, context=None):
20
+ """
21
+ Load the model and tokenizer from the model directory for inference.
22
+ Supports tensor parallelism across multiple GPUs with offloading.
23
+ The model is loaded only once and stored in a global variable.
24
+ """
25
+ global model, tokenizer # Declare model and tokenizer as global to persist across invocations
26
 
27
+ if model is None: # Check if the model is already loaded
28
+ print("Loading the model and tokenizer...")
29
+ # Define an offload directory for any model components that can't fit in GPU memory
30
+ offload_dir = "/tmp/offload_dir" # Ensure SageMaker has write access to this directory
 
 
 
 
 
31
 
32
+ # Load and dispatch the model across multiple GPUs
33
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
34
  model = load_checkpoint_and_dispatch(
35
  model,
36
  model_dir,
37
+ device_map="auto", # Automatically map model layers across devices
38
+ offload_folder=offload_dir, # Offload parts of the model to disk if GPU memory is insufficient
 
39
  )
40
+
41
+ # Load the tokenizer
42
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
43
 
44
  return model, tokenizer
45
 
46
  # Custom predict function for SageMaker
47
  def predict_fn(input_data, model_and_tokenizer):
48
+ """
49
+ Generate predictions for the input data.
50
+ """
51
  try:
52
  model, tokenizer = model_and_tokenizer
53
  data = json.loads(input_data)
 
57
  formatted_prompt = format_chat(messages, tokenizer)
58
 
59
  # Tokenize the input
60
+ inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") # Send input to GPU 0 for generation
61
 
62
  # Generate output
63
+ outputs = model.generate(
64
+ inputs['input_ids'],
65
+ max_new_tokens=data.get("max_new_tokens", 512),
66
+ temperature=data.get("temperature", 0.7),
67
+ top_p=data.get("top_p", 0.9),
68
+ repetition_penalty=data.get("repetition_penalty", 1.0),
69
+ length_penalty=data.get("length_penalty", 1.0),
70
+ do_sample=True
71
+ )
 
72
 
73
  # Decode the output
74
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
75
 
76
  # Build response
77
  response = {
78
+ "id": "chatcmpl-uuid",
79
  "object": "chat.completion",
80
+ "model": "qwen-72b",
81
  "choices": [{
82
  "index": 0,
83
  "message": {
 
88
  }],
89
  "usage": {
90
  "prompt_tokens": len(inputs['input_ids'][0]),
91
+ "completion_tokens": len(outputs[0]),
92
+ "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
93
  }
94
  }
95
  return response
 
97
  except Exception as e:
98
  return {"error": str(e), "details": repr(e)}
99
 
 
100
  # Define input format for SageMaker
101
+ def input_fn(serialized_input_data, content_type,context=None):
102
  """
103
  Prepare the input data for inference.
104
  """
105
  return serialized_input_data
106
 
107
  # Define output format for SageMaker
108
+ def output_fn(prediction_output, accept, context=None):
109
  """
110
  Convert the model output to a JSON response.
111
  """