Update inference.py
Browse files- inference.py +1 -1
inference.py
CHANGED
@@ -17,7 +17,7 @@ def format_chat(messages: List[Dict[str, str]], tokenizer) -> str:
|
|
17 |
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
18 |
|
19 |
# Model loading function for SageMaker with tensor parallelism
|
20 |
-
def model_fn(model_dir):
|
21 |
"""
|
22 |
Load the model and tokenizer from the model directory for inference.
|
23 |
This version supports tensor parallelism across 4 GPUs.
|
|
|
17 |
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
18 |
|
19 |
# Model loading function for SageMaker with tensor parallelism
|
20 |
+
def model_fn(model_dir,context=None):
|
21 |
"""
|
22 |
Load the model and tokenizer from the model directory for inference.
|
23 |
This version supports tensor parallelism across 4 GPUs.
|