Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -39,78 +39,47 @@ def load_model(model_name, quantized=False, quantized_model_path=None):
|
|
39 |
llama_model, llama_tokenizer = load_model(MODEL_NAME)
|
40 |
prm_model, _ = load_model(None, quantized=True, quantized_model_path=QUANTIZED_PRM_PATH)
|
41 |
|
42 |
-
|
|
|
43 |
outputs = []
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
# Prepare inputs
|
50 |
-
input_ids = tokenizer(prompt, return_tensors="pt", padding=True).input_ids.to(device)
|
51 |
-
|
52 |
-
for _ in range(num_samples):
|
53 |
-
output = model.generate(
|
54 |
-
input_ids,
|
55 |
-
max_new_tokens=50,
|
56 |
-
pad_token_id=tokenizer.pad_token_id,
|
57 |
-
)
|
58 |
-
outputs.append(tokenizer.decode(output[0], skip_special_tokens=True))
|
59 |
-
|
60 |
-
return {
|
61 |
-
"outputs": outputs,
|
62 |
-
"final_result": max(set(outputs), key=outputs.count)
|
63 |
-
}
|
64 |
|
65 |
-
def best_of_n(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
input_ids = tokenizer(prompt, return_tensors="pt", padding=True).input_ids.to(device)
|
75 |
-
|
76 |
-
for _ in range(num_samples):
|
77 |
-
output = model.generate(
|
78 |
-
input_ids,
|
79 |
-
max_new_tokens=50,
|
80 |
-
pad_token_id=tokenizer.pad_token_id,
|
81 |
-
)
|
82 |
-
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
83 |
-
score = len(response.split())
|
84 |
-
outputs.append((response, score))
|
85 |
-
|
86 |
-
outputs.sort(key=lambda x: x[1], reverse=True)
|
87 |
-
return {
|
88 |
-
"outputs": outputs,
|
89 |
-
"final_result": outputs[0][0]
|
90 |
-
}
|
91 |
|
92 |
-
def beam_search(
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
|
115 |
|
116 |
def temperature_sampling(model, tokenizer, prompt, temperature=0.7, num_samples=5):
|
@@ -135,29 +104,6 @@ def top_p_sampling(model, tokenizer, prompt, top_p=0.9, num_samples=5):
|
|
135 |
"final_result": outputs[0]
|
136 |
}
|
137 |
|
138 |
-
def dvts(prompt, depth=3, breadth=2):
|
139 |
-
"""
|
140 |
-
Simplified implementation of DVTS: generates a tree of solutions and evaluates branches using PRM.
|
141 |
-
"""
|
142 |
-
results = []
|
143 |
-
for _ in range(breadth):
|
144 |
-
input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
145 |
-
output = llama_model.generate(input_ids, max_new_tokens=50)
|
146 |
-
response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
|
147 |
-
score = prm_model(**prm_tokenizer(response, return_tensors="pt").to(device)).logits.mean().item()
|
148 |
-
results.append((response, score))
|
149 |
-
# Select the top responses and expand them recursively
|
150 |
-
for _ in range(depth - 1):
|
151 |
-
best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
|
152 |
-
for response, _ in best_responses:
|
153 |
-
input_ids = llama_tokenizer(response, return_tensors="pt").input_ids.to(device)
|
154 |
-
output = llama_model.generate(input_ids, max_new_tokens=50)
|
155 |
-
extended_response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
|
156 |
-
score = prm_model(**prm_tokenizer(extended_response, return_tensors="pt").to(device)).logits.mean().item()
|
157 |
-
results.append((extended_response, score))
|
158 |
-
# Return the best overall response
|
159 |
-
return max(results, key=lambda x: x[1])[0]
|
160 |
-
|
161 |
def custom_strategy(prompt, flow):
|
162 |
intermediate_results = []
|
163 |
for step in flow:
|
@@ -231,7 +177,7 @@ from datetime import datetime
|
|
231 |
|
232 |
def calculate_metrics(text):
|
233 |
return {
|
234 |
-
'token_count': len(text.split()),
|
235 |
'char_count': len(text),
|
236 |
'sentence_count': len([s for s in text.split('.') if s.strip()]),
|
237 |
}
|
@@ -255,12 +201,14 @@ def create_token_plot(tokens, strategies):
|
|
255 |
return plt
|
256 |
|
257 |
def format_metrics(metrics):
|
|
|
|
|
258 |
return f"""
|
259 |
### Metrics
|
260 |
-
- Token Count: {metrics['token_count']}
|
261 |
-
- Character Count: {metrics['char_count']}
|
262 |
-
- Sentence Count: {metrics['sentence_count']}
|
263 |
-
- Generation Time: {metrics['generation_time']:.2f}s
|
264 |
"""
|
265 |
|
266 |
def run_single_strategy(prompt, strategy, num_samples):
|
|
|
39 |
llama_model, llama_tokenizer = load_model(MODEL_NAME)
|
40 |
prm_model, _ = load_model(None, quantized=True, quantized_model_path=QUANTIZED_PRM_PATH)
|
41 |
|
42 |
+
# Strategies
|
43 |
+
def majority_voting(prompt, num_samples=5):
|
44 |
outputs = []
|
45 |
+
for _ in range(num_samples):
|
46 |
+
input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
47 |
+
output = llama_model.generate(input_ids, max_new_tokens=50)
|
48 |
+
outputs.append(llama_tokenizer.decode(output[0], skip_special_tokens=True))
|
49 |
+
return max(set(outputs), key=outputs.count)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
def best_of_n(prompt, num_samples=5):
|
52 |
+
scored_outputs = []
|
53 |
+
for _ in range(num_samples):
|
54 |
+
input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
55 |
+
output = llama_model.generate(input_ids, max_new_tokens=50)
|
56 |
+
response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
|
57 |
+
score = prm_model(**prm_tokenizer(response, return_tensors="pt").to(device)).logits.mean().item()
|
58 |
+
scored_outputs.append((response, score))
|
59 |
+
return max(scored_outputs, key=lambda x: x[1])[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
def beam_search(prompt, num_beams=5):
|
62 |
+
input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
63 |
+
outputs = llama_model.generate(input_ids, max_new_tokens=50, num_beams=num_beams, num_return_sequences=num_beams)
|
64 |
+
return [llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
|
65 |
+
|
66 |
+
def dvts(prompt, depth=3, breadth=2):
|
67 |
+
results = []
|
68 |
+
for _ in range(breadth):
|
69 |
+
input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
70 |
+
output = llama_model.generate(input_ids, max_new_tokens=50)
|
71 |
+
response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
|
72 |
+
score = prm_model(**prm_tokenizer(response, return_tensors="pt").to(device)).logits.mean().item()
|
73 |
+
results.append((response, score))
|
74 |
+
for _ in range(depth - 1):
|
75 |
+
best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
|
76 |
+
for response, _ in best_responses:
|
77 |
+
input_ids = llama_tokenizer(response, return_tensors="pt").input_ids.to(device)
|
78 |
+
output = llama_model.generate(input_ids, max_new_tokens=50)
|
79 |
+
extended_response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
|
80 |
+
score = prm_model(**prm_tokenizer(extended_response, return_tensors="pt").to(device)).logits.mean().item()
|
81 |
+
results.append((extended_response, score))
|
82 |
+
return max(results, key=lambda x: x[1])[0]
|
83 |
|
84 |
|
85 |
def temperature_sampling(model, tokenizer, prompt, temperature=0.7, num_samples=5):
|
|
|
104 |
"final_result": outputs[0]
|
105 |
}
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
def custom_strategy(prompt, flow):
|
108 |
intermediate_results = []
|
109 |
for step in flow:
|
|
|
177 |
|
178 |
def calculate_metrics(text):
|
179 |
return {
|
180 |
+
'token_count': len(text.split()),
|
181 |
'char_count': len(text),
|
182 |
'sentence_count': len([s for s in text.split('.') if s.strip()]),
|
183 |
}
|
|
|
201 |
return plt
|
202 |
|
203 |
def format_metrics(metrics):
|
204 |
+
print(type(metrics)) # Check if it's a list or dictionary
|
205 |
+
print(metrics) # Inspect its contents
|
206 |
return f"""
|
207 |
### Metrics
|
208 |
+
- Token Count: {metrics[0]['token_count']}
|
209 |
+
- Character Count: {metrics[0]['char_count']}
|
210 |
+
- Sentence Count: {metrics[0]['sentence_count']}
|
211 |
+
- Generation Time: {metrics[0]['generation_time']:.2f}s
|
212 |
"""
|
213 |
|
214 |
def run_single_strategy(prompt, strategy, num_samples):
|