alimrb commited on
Commit
2d38204
1 Parent(s): 7dc358b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -0
app.py CHANGED
@@ -20,6 +20,9 @@ def make_inference(product_name, product_description):
20
  return_tensors="pt",
21
  )
22
 
 
 
 
23
  with torch.cuda.amp.autocast():
24
  output_tokens = model.generate(**batch, max_new_tokens=50)
25
 
 
20
  return_tensors="pt",
21
  )
22
 
23
+ # Move batch to the same device as the model
24
+ batch = {k: v.to(model.device) for k, v in batch.items()}
25
+
26
  with torch.cuda.amp.autocast():
27
  output_tokens = model.generate(**batch, max_new_tokens=50)
28