CMLL commited on
Commit
6e999ef
1 Parent(s): 6c69482

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -56,7 +56,8 @@ def generate(
56
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
57
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
58
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
59
- input_ids = input_ids.to(model.device)
 
60
 
61
  outputs = []
62
  generated_ids = model.generate(
@@ -69,9 +70,13 @@ def generate(
69
  num_beams=1,
70
  repetition_penalty=repetition_penalty
71
  )
 
 
 
72
  outputs.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
73
  return "".join(outputs)
74
 
 
75
  chat_interface = gr.ChatInterface(
76
  fn=generate,
77
  additional_inputs=[
 
56
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
57
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
58
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
59
+
60
+ input_ids = input_ids.to(device) # Ensure the input tensor is on the correct device
61
 
62
  outputs = []
63
  generated_ids = model.generate(
 
70
  num_beams=1,
71
  repetition_penalty=repetition_penalty
72
  )
73
+
74
+ generated_ids = generated_ids.to(device) # Ensure the generated ids are moved to the device
75
+
76
  outputs.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
77
  return "".join(outputs)
78
 
79
+
80
  chat_interface = gr.ChatInterface(
81
  fn=generate,
82
  additional_inputs=[