Chris4K commited on
Commit
440099e
Β·
verified Β·
1 Parent(s): 6bea409

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install necessary libraries
2
+ #!pip install transformers accelerate datasets gradio sympy
3
+
4
+ # Import libraries
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ import gradio as gr
8
+ import sympy
9
+
10
+ # Load Model and Tokenizer
11
+ MODEL_NAME = "meta/llama-3.2-1b-instruct"
12
+ PRM_NAME = "RLHFlow/Llama3.1-8B-PRM"
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # Load LLaMA model
17
+ def load_model(model_name):
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
20
+ return model.to(device), tokenizer
21
+
22
+ llama_model, llama_tokenizer = load_model(MODEL_NAME)
23
+
24
+ # Load Process Reward Model (PRM)
25
+ prm_model, prm_tokenizer = load_model(PRM_NAME)
26
+
27
+ # Strategies
28
+ def majority_voting(prompt, num_samples=5):
29
+ outputs = []
30
+ for _ in range(num_samples):
31
+ input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
32
+ output = llama_model.generate(input_ids, max_new_tokens=50)
33
+ outputs.append(llama_tokenizer.decode(output[0], skip_special_tokens=True))
34
+ # Return the most common result
35
+ return max(set(outputs), key=outputs.count)
36
+
37
+ def best_of_n(prompt, num_samples=5):
38
+ scored_outputs = []
39
+ for _ in range(num_samples):
40
+ input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
41
+ output = llama_model.generate(input_ids, max_new_tokens=50)
42
+ response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
43
+ score = prm_model(**prm_tokenizer(response, return_tensors="pt").to(device)).logits.mean().item()
44
+ scored_outputs.append((response, score))
45
+ # Return the highest scored response
46
+ return max(scored_outputs, key=lambda x: x[1])[0]
47
+
48
+ def beam_search(prompt, num_beams=5):
49
+ input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
50
+ outputs = llama_model.generate(input_ids, max_new_tokens=50, num_beams=num_beams, num_return_sequences=num_beams)
51
+ return [llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
52
+
53
+ def dvts(prompt, depth=3, breadth=2):
54
+ """
55
+ Simplified implementation of DVTS: generates a tree of solutions and evaluates branches using PRM.
56
+ """
57
+ results = []
58
+ for _ in range(breadth):
59
+ input_ids = llama_tokenizer(prompt, return_tensors="pt").input_ids.to(device)
60
+ output = llama_model.generate(input_ids, max_new_tokens=50)
61
+ response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
62
+ score = prm_model(**prm_tokenizer(response, return_tensors="pt").to(device)).logits.mean().item()
63
+ results.append((response, score))
64
+ # Select the top responses and expand them recursively
65
+ for _ in range(depth - 1):
66
+ best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth]
67
+ for response, _ in best_responses:
68
+ input_ids = llama_tokenizer(response, return_tensors="pt").input_ids.to(device)
69
+ output = llama_model.generate(input_ids, max_new_tokens=50)
70
+ extended_response = llama_tokenizer.decode(output[0], skip_special_tokens=True)
71
+ score = prm_model(**prm_tokenizer(extended_response, return_tensors="pt").to(device)).logits.mean().item()
72
+ results.append((extended_response, score))
73
+ # Return the best overall response
74
+ return max(results, key=lambda x: x[1])[0]
75
+
76
+ # Gradio Interface
77
+ def inference(prompt, strategy, num_samples, depth, breadth):
78
+ if strategy == "Majority Voting":
79
+ return majority_voting(prompt, num_samples)
80
+ elif strategy == "Best-of-N":
81
+ return best_of_n(prompt, num_samples)
82
+ elif strategy == "Beam Search":
83
+ return beam_search(prompt, num_samples)
84
+ elif strategy == "DVTS":
85
+ return dvts(prompt, depth, breadth)
86
+ else:
87
+ return "Invalid Strategy"
88
+
89
+ gr.Interface(
90
+ fn=inference,
91
+ inputs=[
92
+ gr.Textbox(label="Problem Statement", placeholder="Enter your problem here"),
93
+ gr.Radio(
94
+ ["Majority Voting", "Best-of-N", "Beam Search", "DVTS"],
95
+ label="Inference Strategy",
96
+ ),
97
+ gr.Slider(1, 10, step=1, value=5, label="Number of Samples"),
98
+ gr.Slider(1, 5, step=1, value=3, label="Depth (DVTS Only)"),
99
+ gr.Slider(1, 5, step=1, value=2, label="Breadth (DVTS Only)"),
100
+ ],
101
+ outputs="text",
102
+ title="Dynamic Inference Toolkit",
103
+ description="Explore test-time compute scaling strategies with Meta's LLaMA model.",
104
+ ).launch()