Update app.py
Browse files
app.py
CHANGED
@@ -52,13 +52,13 @@ footer {
|
|
52 |
}
|
53 |
"""
|
54 |
|
55 |
-
MODEL_ID = "
|
56 |
|
57 |
model = AutoModelForCausalLM.from_pretrained(
|
58 |
MODEL_ID,
|
59 |
torch_dtype=torch.float16,
|
60 |
device_map="auto"
|
61 |
-
)
|
62 |
|
63 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
64 |
|
@@ -130,24 +130,27 @@ async def gen_show(script):
|
|
130 |
|
131 |
@spaces.GPU
|
132 |
def generator(messages):
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
return_tensors='pt'
|
138 |
)
|
|
|
139 |
|
140 |
-
|
141 |
-
input_ids
|
142 |
-
eos_token_id=tokenizer.eos_token_id,
|
143 |
max_new_tokens=4096,
|
144 |
temperature=0.5,
|
145 |
repetition_penalty=1.2,
|
146 |
)
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
151 |
|
152 |
def extract_content(text):
|
153 |
"""Extracts the JSON content from the given text."""
|
|
|
52 |
}
|
53 |
"""
|
54 |
|
55 |
+
MODEL_ID = "Qwen/Qwen2-1.5B-Instruct"
|
56 |
|
57 |
model = AutoModelForCausalLM.from_pretrained(
|
58 |
MODEL_ID,
|
59 |
torch_dtype=torch.float16,
|
60 |
device_map="auto"
|
61 |
+
)
|
62 |
|
63 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
64 |
|
|
|
130 |
|
131 |
@spaces.GPU
|
132 |
def generator(messages):
|
133 |
+
text = tokenizer.apply_chat_template(
|
134 |
+
messages,
|
135 |
+
tokenize=False,
|
136 |
+
add_generation_prompt=True
|
|
|
137 |
)
|
138 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(0)
|
139 |
|
140 |
+
generated_ids = model.generate(
|
141 |
+
model_inputs.input_ids,
|
|
|
142 |
max_new_tokens=4096,
|
143 |
temperature=0.5,
|
144 |
repetition_penalty=1.2,
|
145 |
)
|
146 |
|
147 |
+
generated_ids = [
|
148 |
+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
149 |
+
]
|
150 |
+
|
151 |
+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
152 |
+
print(response)
|
153 |
+
return response
|
154 |
|
155 |
def extract_content(text):
|
156 |
"""Extracts the JSON content from the given text."""
|