Imran1 commited on
Commit
28a42cc
1 Parent(s): 892e588

Update code/inference.py

Browse files
Files changed (1) hide show
  1. code/inference.py +5 -4
code/inference.py CHANGED
@@ -12,11 +12,12 @@ def model_fn(model_dir):
12
  model=model_dir,
13
  trust_remote_code=True,
14
  dtype="float16",
 
15
  gpu_memory_utilization=0.9,
16
  )
17
  return model
18
 
19
- def predict_fn(data, model,context = None):
20
  try:
21
  input_text = data.pop("inputs", data)
22
  parameters = data.pop("parameters", {})
@@ -36,11 +37,11 @@ def predict_fn(data, model,context = None):
36
  logger.error(f"Exception during prediction: {e}")
37
  return {"error": str(e)}
38
 
39
- def input_fn(request_body, request_content_type, context = None):
40
  if request_content_type == "application/json":
41
  return json.loads(request_body)
42
  else:
43
  raise ValueError(f"Unsupported content type: {request_content_type}")
44
 
45
- def output_fn(prediction, accept, context = None):
46
- return json.dumps(prediction)
 
12
  model=model_dir,
13
  trust_remote_code=True,
14
  dtype="float16",
15
+ tensor_parallel_size=4, # Use 4 GPUs for parallelization
16
  gpu_memory_utilization=0.9,
17
  )
18
  return model
19
 
20
+ def predict_fn(data, model, context=None):
21
  try:
22
  input_text = data.pop("inputs", data)
23
  parameters = data.pop("parameters", {})
 
37
  logger.error(f"Exception during prediction: {e}")
38
  return {"error": str(e)}
39
 
40
+ def input_fn(request_body, request_content_type, context=None):
41
  if request_content_type == "application/json":
42
  return json.loads(request_body)
43
  else:
44
  raise ValueError(f"Unsupported content type: {request_content_type}")
45
 
46
+ def output_fn(prediction, accept, context=None):
47
+ return json.dumps(prediction)