Update code/inference.py
Browse files- code/inference.py +1 -1
code/inference.py
CHANGED
@@ -66,7 +66,7 @@ def model_fn(model_dir, context=None):
|
|
66 |
os.makedirs(offload_dir, exist_ok=True)
|
67 |
|
68 |
# Load and dispatch model across 4 GPUs using tensor parallelism
|
69 |
-
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=
|
70 |
model = load_checkpoint_and_dispatch(
|
71 |
model,
|
72 |
model_dir,
|
|
|
66 |
os.makedirs(offload_dir, exist_ok=True)
|
67 |
|
68 |
# Load and dispatch model across 4 GPUs using tensor parallelism
|
69 |
+
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
|
70 |
model = load_checkpoint_and_dispatch(
|
71 |
model,
|
72 |
model_dir,
|