Imran1 commited on
Commit
54088e7
1 Parent(s): 2744dc1

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +1 -1
code/inference.py CHANGED
@@ -66,7 +66,7 @@ def model_fn(model_dir, context=None):
66
  os.makedirs(offload_dir, exist_ok=True)
67
 
68
  # Load and dispatch model across 4 GPUs using tensor parallelism
69
- model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype=torch.float16)
70
  model = load_checkpoint_and_dispatch(
71
  model,
72
  model_dir,
 
66
  os.makedirs(offload_dir, exist_ok=True)
67
 
68
  # Load and dispatch model across 4 GPUs using tensor parallelism
69
+ model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype="auto")
70
  model = load_checkpoint_and_dispatch(
71
  model,
72
  model_dir,