Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -53,32 +53,25 @@ latex_delimiters_set = [{
|
|
53 |
|
54 |
@spaces.GPU()
|
55 |
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
|
|
|
56 |
|
57 |
-
|
|
|
58 |
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
|
59 |
for user, assistant in history:
|
60 |
instruction += f'<|im_start|>user\n{user}\n<|im_end|>\n<|im_start|>assistant\n{assistant}\n<|im_end|>\n'
|
61 |
instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n'
|
62 |
|
63 |
-
print(
|
64 |
-
|
65 |
-
|
66 |
-
if tokenizer.pad_token is None:
|
67 |
-
tokenizer.pad_token = tokenizer.eos_token
|
68 |
-
|
69 |
|
|
|
70 |
enc = tokenizer(instruction, return_tensors="pt", padding=True, truncation=True)
|
71 |
input_ids, attention_mask = enc.input_ids, enc.attention_mask
|
72 |
|
73 |
-
|
74 |
if input_ids.shape[1] > CONTEXT_LENGTH:
|
75 |
input_ids = input_ids[:, -CONTEXT_LENGTH:]
|
76 |
attention_mask = attention_mask[:, -CONTEXT_LENGTH:]
|
77 |
|
78 |
-
|
79 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
80 |
-
|
81 |
-
# Define the generation parameters
|
82 |
generate_kwargs = dict(
|
83 |
input_ids=input_ids.to(device),
|
84 |
attention_mask=attention_mask.to(device),
|
@@ -88,24 +81,18 @@ def predict(message, history, system_prompt, temperature, max_new_tokens, top_k,
|
|
88 |
max_new_tokens=max_new_tokens,
|
89 |
top_k=top_k,
|
90 |
repetition_penalty=repetition_penalty,
|
91 |
-
top_p=top_p
|
92 |
-
pad_token_id=tokenizer.pad_token_id, # Explicitly set pad_token_id
|
93 |
-
eos_token_id=tokenizer.eos_token_id, # Explicitly set eos_token_id
|
94 |
)
|
95 |
-
|
96 |
-
# Start the generation in a separate thread
|
97 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
98 |
t.start()
|
99 |
-
|
100 |
-
# Stream the output token by token
|
101 |
outputs = []
|
102 |
for new_token in streamer:
|
103 |
outputs.append(new_token)
|
104 |
-
if
|
|
|
105 |
break
|
106 |
yield "".join(outputs)
|
107 |
|
108 |
-
|
109 |
# Load model
|
110 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
111 |
quantization_config = BitsAndBytesConfig(
|
|
|
53 |
|
54 |
@spaces.GPU()
|
55 |
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
|
56 |
+
# Format history with a given chat template
|
57 |
|
58 |
+
|
59 |
+
stop_tokens = ["<|endoftext|>", "<|im_end|>","|im_end|"]
|
60 |
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
|
61 |
for user, assistant in history:
|
62 |
instruction += f'<|im_start|>user\n{user}\n<|im_end|>\n<|im_start|>assistant\n{assistant}\n<|im_end|>\n'
|
63 |
instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n'
|
64 |
|
65 |
+
print(instruction)
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
68 |
enc = tokenizer(instruction, return_tensors="pt", padding=True, truncation=True)
|
69 |
input_ids, attention_mask = enc.input_ids, enc.attention_mask
|
70 |
|
|
|
71 |
if input_ids.shape[1] > CONTEXT_LENGTH:
|
72 |
input_ids = input_ids[:, -CONTEXT_LENGTH:]
|
73 |
attention_mask = attention_mask[:, -CONTEXT_LENGTH:]
|
74 |
|
|
|
|
|
|
|
|
|
75 |
generate_kwargs = dict(
|
76 |
input_ids=input_ids.to(device),
|
77 |
attention_mask=attention_mask.to(device),
|
|
|
81 |
max_new_tokens=max_new_tokens,
|
82 |
top_k=top_k,
|
83 |
repetition_penalty=repetition_penalty,
|
84 |
+
top_p=top_p
|
|
|
|
|
85 |
)
|
|
|
|
|
86 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
87 |
t.start()
|
|
|
|
|
88 |
outputs = []
|
89 |
for new_token in streamer:
|
90 |
outputs.append(new_token)
|
91 |
+
if new_token in stop_tokens:
|
92 |
+
|
93 |
break
|
94 |
yield "".join(outputs)
|
95 |
|
|
|
96 |
# Load model
|
97 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
98 |
quantization_config = BitsAndBytesConfig(
|