Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -56,7 +56,8 @@ def generate(
|
|
56 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
57 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
58 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
59 |
-
|
|
|
60 |
|
61 |
outputs = []
|
62 |
generated_ids = model.generate(
|
@@ -69,9 +70,13 @@ def generate(
|
|
69 |
num_beams=1,
|
70 |
repetition_penalty=repetition_penalty
|
71 |
)
|
|
|
|
|
|
|
72 |
outputs.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
|
73 |
return "".join(outputs)
|
74 |
|
|
|
75 |
chat_interface = gr.ChatInterface(
|
76 |
fn=generate,
|
77 |
additional_inputs=[
|
|
|
56 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
57 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
58 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
59 |
+
|
60 |
+
input_ids = input_ids.to(device) # Ensure the input tensor is on the correct device
|
61 |
|
62 |
outputs = []
|
63 |
generated_ids = model.generate(
|
|
|
70 |
num_beams=1,
|
71 |
repetition_penalty=repetition_penalty
|
72 |
)
|
73 |
+
|
74 |
+
generated_ids = generated_ids.to(device) # Ensure the generated ids are moved to the device
|
75 |
+
|
76 |
outputs.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
|
77 |
return "".join(outputs)
|
78 |
|
79 |
+
|
80 |
chat_interface = gr.ChatInterface(
|
81 |
fn=generate,
|
82 |
additional_inputs=[
|