Chris4K commited on
Commit
63635f8
Β·
verified Β·
1 Parent(s): 208ea59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -40
app.py CHANGED
@@ -135,6 +135,29 @@ def top_p_sampling(model, tokenizer, prompt, top_p=0.9, num_samples=5):
135
  "final_result": outputs[0]
136
  }
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def custom_strategy(prompt, flow):
139
  intermediate_results = []
140
  for step in flow:
@@ -190,18 +213,57 @@ def test_generation():
190
 
191
 
192
  #####
193
- import gradio as gr
 
 
 
 
194
  import pandas as pd
 
 
195
  import json
 
 
 
 
 
 
 
 
 
196
 
197
- def format_outputs(outputs):
198
- if isinstance(outputs, list):
199
- return "\n\n".join([f"Output {i+1}: {out}" for i, out in enumerate(outputs)])
200
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  def run_single_strategy(prompt, strategy, num_samples):
203
  if not prompt:
204
- return "Please enter a prompt."
 
 
205
 
206
  strategies = {
207
  "Majority Voting": lambda: majority_voting(llama_model, llama_tokenizer, prompt, num_samples),
@@ -210,89 +272,133 @@ def run_single_strategy(prompt, strategy, num_samples):
210
  }
211
 
212
  if strategy not in strategies:
213
- return "Invalid strategy selected."
214
 
215
  result = strategies[strategy]()
 
 
 
 
 
 
 
 
 
216
 
217
  formatted_output = f"""
218
- ### Final Result:
 
 
219
  {result['final_result']}
220
 
221
- ### All Outputs:
 
 
222
  {format_outputs(result['outputs'])}
 
 
 
 
 
 
223
  """
224
- return formatted_output
 
225
 
226
  def run_all_strategies(prompt, num_samples):
227
  if not prompt:
228
- return "Please enter a prompt."
229
 
230
- strategies_results, results_df = compare_strategies(
231
- llama_model, llama_tokenizer, prm_model, prompt, num_samples
232
- )
 
233
 
234
- # Format the output for display
235
  output_text = "# Results from All Strategies\n\n"
236
- for strategy, results in strategies_results.items():
 
 
 
 
 
 
 
 
 
 
 
 
237
  output_text += f"""
238
  ## {strategy}
239
- ### Final Result:
240
- {results['final_result']}
241
-
242
- ### All Outputs:
243
- {format_outputs(results['outputs'])}
244
-
245
  ---
246
  """
247
 
248
- return output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
- # Create the Gradio interface
251
- with gr.Blocks(title="Text Generation Strategies") as demo:
252
- gr.Markdown("# Text Generation Strategies Demo")
253
 
254
  with gr.Row():
255
- with gr.Column():
256
  prompt_input = gr.Textbox(
257
  label="Enter your prompt",
258
  placeholder="Type your prompt here...",
259
  lines=3
260
  )
261
- num_samples = gr.Slider(
262
- minimum=1,
263
- maximum=10,
264
- value=5,
265
- step=1,
266
- label="Number of samples/beams"
267
- )
268
-
269
  with gr.Row():
 
 
 
 
 
 
 
270
  strategy_dropdown = gr.Dropdown(
271
  choices=["Majority Voting", "Best-of-N", "Beam Search"],
272
  label="Select Strategy",
273
  value="Majority Voting"
274
  )
275
-
276
  with gr.Row():
277
  single_strategy_btn = gr.Button("Run Selected Strategy")
278
  all_strategies_btn = gr.Button("Run All Strategies")
279
 
280
- with gr.Column():
281
  output_display = gr.Markdown(label="Results")
 
 
 
 
 
282
 
283
  # Set up event handlers
284
  single_strategy_btn.click(
285
  fn=run_single_strategy,
286
  inputs=[prompt_input, strategy_dropdown, num_samples],
287
- outputs=output_display
288
  )
289
 
290
  all_strategies_btn.click(
291
  fn=run_all_strategies,
292
  inputs=[prompt_input, num_samples],
293
- outputs=output_display
294
  )
295
 
296
- # Launch the interface
297
  if __name__ == "__main__":
298
  demo.launch(debug=True)
 
135
  "final_result": outputs[0]
136
  }
137
 
138
+ def dvts(prompt, depth=3, breadth=2):
139
+ """
140
+ Simplified implementation of DVTS: generates a tree of solutions and evaluates branches using PRM.
141
+ """
142
+ results = []
143
+ for _ in range(breadth):
144
+ input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
145
+ output = llama_model.generate(input_ids, max_new_tokens=50)
146
+ response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
147
+ score = prm_model(**prm_tokenizer(response, return_tensors="pt").to(device)).logits.mean().item()
148
+ results.append((response, score))
149
+ # Select the top responses and expand them recursively
150
+ for _ in range(depth - 1):
151
+ best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
152
+ for response, _ in best_responses:
153
+ input_ids = llama_tokenizer(response, return_tensors="pt").input_ids.to(device)
154
+ output = llama_model.generate(input_ids, max_new_tokens=50)
155
+ extended_response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
156
+ score = prm_model(**prm_tokenizer(extended_response, return_tensors="pt").to(device)).logits.mean().item()
157
+ results.append((extended_response, score))
158
+ # Return the best overall response
159
+ return max(results, key=lambda x: x[1])[0]
160
+
161
  def custom_strategy(prompt, flow):
162
  intermediate_results = []
163
  for step in flow:
 
213
 
214
 
215
  #####
216
+ import torch
217
+ from transformers import AutoModelForCausalLM, AutoTokenizer
218
+ from llama_cpp import Llama
219
+ from huggingface_hub import hf_hub_download
220
+ import matplotlib.pyplot as plt
221
  import pandas as pd
222
+ import gradio as gr
223
+ import time
224
  import json
225
+ import numpy as np
226
+ from datetime import datetime
227
+
228
+ def calculate_metrics(text):
229
+ return {
230
+ 'token_count': len(text.split()),
231
+ 'char_count': len(text),
232
+ 'sentence_count': len([s for s in text.split('.') if s.strip()]),
233
+ }
234
 
235
+ def create_performance_plot(times, strategies):
236
+ plt.figure(figsize=(10, 5))
237
+ plt.bar(strategies, times)
238
+ plt.title('Generation Time by Strategy')
239
+ plt.ylabel('Time (seconds)')
240
+ plt.xticks(rotation=45)
241
+ plt.tight_layout()
242
+ return plt
243
+
244
+ def create_token_plot(tokens, strategies):
245
+ plt.figure(figsize=(10, 5))
246
+ plt.bar(strategies, tokens)
247
+ plt.title('Output Token Count by Strategy')
248
+ plt.ylabel('Number of Tokens')
249
+ plt.xticks(rotation=45)
250
+ plt.tight_layout()
251
+ return plt
252
+
253
+ def format_metrics(metrics):
254
+ return f"""
255
+ ### Metrics
256
+ - Token Count: {metrics['token_count']}
257
+ - Character Count: {metrics['char_count']}
258
+ - Sentence Count: {metrics['sentence_count']}
259
+ - Generation Time: {metrics['generation_time']:.2f}s
260
+ """
261
 
262
  def run_single_strategy(prompt, strategy, num_samples):
263
  if not prompt:
264
+ return "Please enter a prompt.", None, None, None
265
+
266
+ start_time = time.time()
267
 
268
  strategies = {
269
  "Majority Voting": lambda: majority_voting(llama_model, llama_tokenizer, prompt, num_samples),
 
272
  }
273
 
274
  if strategy not in strategies:
275
+ return "Invalid strategy selected.", None, None, None
276
 
277
  result = strategies[strategy]()
278
+ generation_time = time.time() - start_time
279
+
280
+ # Calculate metrics
281
+ metrics = calculate_metrics(result['final_result'])
282
+ metrics['generation_time'] = generation_time
283
+
284
+ # Create visualizations
285
+ performance_fig = create_performance_plot([generation_time], [strategy])
286
+ token_fig = create_token_plot([metrics['token_count']], [strategy])
287
 
288
  formatted_output = f"""
289
+ # Results for {strategy}
290
+
291
+ ## Final Result
292
  {result['final_result']}
293
 
294
+ {format_metrics(metrics)}
295
+
296
+ ## All Outputs
297
  {format_outputs(result['outputs'])}
298
+
299
+ ## Generation Details
300
+ - Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
301
+ - Number of samples: {num_samples}
302
+ - Model: {MODEL_NAME}
303
+ - Device: {device}
304
  """
305
+
306
+ return formatted_output, performance_fig, token_fig, metrics
307
 
308
  def run_all_strategies(prompt, num_samples):
309
  if not prompt:
310
+ return "Please enter a prompt.", None, None, None
311
 
312
+ all_metrics = {}
313
+ all_times = []
314
+ all_tokens = []
315
+ strategies = ["Majority Voting", "Best-of-N", "Beam Search"]
316
 
 
317
  output_text = "# Results from All Strategies\n\n"
318
+
319
+ for strategy in strategies:
320
+ start_time = time.time()
321
+ result = run_single_strategy(prompt, strategy, num_samples)[0]
322
+ generation_time = time.time() - start_time
323
+
324
+ metrics = calculate_metrics(result)
325
+ metrics['generation_time'] = generation_time
326
+ all_metrics[strategy] = metrics
327
+
328
+ all_times.append(generation_time)
329
+ all_tokens.append(metrics['token_count'])
330
+
331
  output_text += f"""
332
  ## {strategy}
333
+ {result}
 
 
 
 
 
334
  ---
335
  """
336
 
337
+ # Create comparison visualizations
338
+ performance_fig = create_performance_plot(all_times, strategies)
339
+ token_fig = create_token_plot(all_tokens, strategies)
340
+
341
+ # Add comparison summary
342
+ output_text += """
343
+ # Strategy Comparison Summary
344
+ """
345
+ for strategy, metrics in all_metrics.items():
346
+ output_text += f"""
347
+ ## {strategy}
348
+ {format_metrics(metrics)}
349
+ """
350
+
351
+ return output_text, performance_fig, token_fig, all_metrics
352
 
353
+ # Create the enhanced Gradio interface
354
+ with gr.Blocks(title="Advanced Text Generation Strategies") as demo:
355
+ gr.Markdown("# Advanced Text Generation Strategies Demo")
356
 
357
  with gr.Row():
358
+ with gr.Column(scale=2):
359
  prompt_input = gr.Textbox(
360
  label="Enter your prompt",
361
  placeholder="Type your prompt here...",
362
  lines=3
363
  )
 
 
 
 
 
 
 
 
364
  with gr.Row():
365
+ num_samples = gr.Slider(
366
+ minimum=1,
367
+ maximum=10,
368
+ value=5,
369
+ step=1,
370
+ label="Number of samples/beams"
371
+ )
372
  strategy_dropdown = gr.Dropdown(
373
  choices=["Majority Voting", "Best-of-N", "Beam Search"],
374
  label="Select Strategy",
375
  value="Majority Voting"
376
  )
377
+
378
  with gr.Row():
379
  single_strategy_btn = gr.Button("Run Selected Strategy")
380
  all_strategies_btn = gr.Button("Run All Strategies")
381
 
382
+ with gr.Column(scale=3):
383
  output_display = gr.Markdown(label="Results")
384
+ with gr.Row():
385
+ performance_plot = gr.Plot(label="Performance Comparison")
386
+ token_plot = gr.Plot(label="Token Count Comparison")
387
+
388
+ metrics_display = gr.JSON(label="Detailed Metrics")
389
 
390
  # Set up event handlers
391
  single_strategy_btn.click(
392
  fn=run_single_strategy,
393
  inputs=[prompt_input, strategy_dropdown, num_samples],
394
+ outputs=[output_display, performance_plot, token_plot, metrics_display]
395
  )
396
 
397
  all_strategies_btn.click(
398
  fn=run_all_strategies,
399
  inputs=[prompt_input, num_samples],
400
+ outputs=[output_display, performance_plot, token_plot, metrics_display]
401
  )
402
 
 
403
  if __name__ == "__main__":
404
  demo.launch(debug=True)