Imran1 commited on
Commit
f74b603
1 Parent(s): 1d61c11

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +40 -39
code/inference.py CHANGED
@@ -3,51 +3,50 @@ import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import List, Dict
5
  from accelerate import load_checkpoint_and_dispatch
6
-
7
  # Global variables to persist the model and tokenizer between invocations
8
  model = None
9
  tokenizer = None
10
 
11
  # Function to format chat messages using Qwen's chat template
12
  def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
13
- """
14
- Format chat messages using Qwen's chat template.
15
- """
16
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
17
 
18
- # Model loading function for SageMaker with tensor parallelism and offloading
19
  def model_fn(model_dir, context=None):
20
- """
21
- Load the model and tokenizer from the model directory for inference.
22
- Supports tensor parallelism across multiple GPUs with offloading.
23
- The model is loaded only once and stored in a global variable.
24
- """
25
- global model, tokenizer # Declare model and tokenizer as global to persist across invocations
26
 
27
- if model is None: # Check if the model is already loaded
28
- print("Loading the model and tokenizer...")
29
- # Define an offload directory for any model components that can't fit in GPU memory
30
- offload_dir = "/tmp/offload_dir" # Ensure SageMaker has write access to this directory
 
 
 
 
 
31
 
32
- # Load and dispatch the model across multiple GPUs
33
- model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
34
  model = load_checkpoint_and_dispatch(
35
  model,
36
  model_dir,
37
- device_map="auto", # Automatically map model layers across devices
38
- offload_folder=offload_dir, # Offload parts of the model to disk if GPU memory is insufficient
 
39
  )
40
-
41
- # Load the tokenizer
42
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
43
 
44
  return model, tokenizer
45
 
46
  # Custom predict function for SageMaker
47
  def predict_fn(input_data, model_and_tokenizer):
48
- """
49
- Generate predictions for the input data.
50
- """
51
  try:
52
  model, tokenizer = model_and_tokenizer
53
  data = json.loads(input_data)
@@ -57,27 +56,28 @@ def predict_fn(input_data, model_and_tokenizer):
57
  formatted_prompt = format_chat(messages, tokenizer)
58
 
59
  # Tokenize the input
60
- inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda:0") # Send input to GPU 0 for generation
61
 
62
  # Generate output
63
- outputs = model.generate(
64
- inputs['input_ids'],
65
- max_new_tokens=data.get("max_new_tokens", 512),
66
- temperature=data.get("temperature", 0.7),
67
- top_p=data.get("top_p", 0.9),
68
- repetition_penalty=data.get("repetition_penalty", 1.0),
69
- length_penalty=data.get("length_penalty", 1.0),
70
- do_sample=True
71
- )
 
72
 
73
  # Decode the output
74
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
75
 
76
  # Build response
77
  response = {
78
- "id": "chatcmpl-uuid",
79
  "object": "chat.completion",
80
- "model": "qwen-72b",
81
  "choices": [{
82
  "index": 0,
83
  "message": {
@@ -88,8 +88,8 @@ def predict_fn(input_data, model_and_tokenizer):
88
  }],
89
  "usage": {
90
  "prompt_tokens": len(inputs['input_ids'][0]),
91
- "completion_tokens": len(outputs[0]),
92
- "total_tokens": len(inputs['input_ids'][0]) + len(outputs[0])
93
  }
94
  }
95
  return response
@@ -97,6 +97,7 @@ def predict_fn(input_data, model_and_tokenizer):
97
  except Exception as e:
98
  return {"error": str(e), "details": repr(e)}
99
 
 
100
  # Define input format for SageMaker
101
  def input_fn(serialized_input_data, content_type, context=None):
102
  """
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from typing import List, Dict
5
  from accelerate import load_checkpoint_and_dispatch
 
6
  # Global variables to persist the model and tokenizer between invocations
7
  model = None
8
  tokenizer = None
9
 
10
  # Function to format chat messages using Qwen's chat template
11
  def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
 
 
 
12
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
13
 
14
+ # Model loading function for SageMaker with tensor parallelism and FP8 quantization
15
  def model_fn(model_dir, context=None):
16
+ global model, tokenizer
 
 
 
 
 
17
 
18
+ if model is None:
19
+ print("Loading the FP8 quantized model and tokenizer...")
20
+
21
+ # Define an offload directory
22
+ offload_dir = "/tmp/offload_dir"
23
+ os.makedirs(offload_dir, exist_ok=True)
24
+
25
+ # Load the tokenizer
26
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
27
 
28
+ # Load the FP8 quantized model
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ model_dir,
31
+ torch_dtype=torch.float8, # Specify FP8 dtype
32
+ low_cpu_mem_usage=True,
33
+ device_map="auto",
34
+ offload_folder=offload_dir,
35
+ )
36
+
37
+ # Use load_checkpoint_and_dispatch for tensor parallelism
38
  model = load_checkpoint_and_dispatch(
39
  model,
40
  model_dir,
41
+ device_map="auto",
42
+ offload_folder=offload_dir,
43
+ no_split_module_classes=["QWenLMHeadModel"], # Adjust if needed for Qwen architecture
44
  )
 
 
 
45
 
46
  return model, tokenizer
47
 
48
  # Custom predict function for SageMaker
49
  def predict_fn(input_data, model_and_tokenizer):
 
 
 
50
  try:
51
  model, tokenizer = model_and_tokenizer
52
  data = json.loads(input_data)
 
56
  formatted_prompt = format_chat(messages, tokenizer)
57
 
58
  # Tokenize the input
59
+ inputs = tokenizer([formatted_prompt], return_tensors="pt").to(model.device)
60
 
61
  # Generate output
62
+ with torch.no_grad():
63
+ outputs = model.generate(
64
+ inputs['input_ids'],
65
+ max_new_tokens=data.get("max_new_tokens", 512),
66
+ temperature=data.get("temperature", 0.7),
67
+ top_p=data.get("top_p", 0.9),
68
+ repetition_penalty=data.get("repetition_penalty", 1.0),
69
+ length_penalty=data.get("length_penalty", 1.0),
70
+ do_sample=True
71
+ )
72
 
73
  # Decode the output
74
  generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
75
 
76
  # Build response
77
  response = {
78
+ "id": "chatcmpl-fp8-quantized",
79
  "object": "chat.completion",
80
+ "model": "qwen-72b-fp8",
81
  "choices": [{
82
  "index": 0,
83
  "message": {
 
88
  }],
89
  "usage": {
90
  "prompt_tokens": len(inputs['input_ids'][0]),
91
+ "completion_tokens": len(outputs[0]) - len(inputs['input_ids'][0]),
92
+ "total_tokens": len(outputs[0])
93
  }
94
  }
95
  return response
 
97
  except Exception as e:
98
  return {"error": str(e), "details": repr(e)}
99
 
100
+
101
  # Define input format for SageMaker
102
  def input_fn(serialized_input_data, content_type, context=None):
103
  """