Commit
·
7eefd64
1
Parent(s):
fb2eb9b
Init commit
Browse files- assets/prompt-order-experiment.svg +3 -0
- prompt_order_exeriment/__init__.py +0 -0
- prompt_order_exeriment/prompt_order_exeriment.py +214 -0
- prompt_order_exeriment/sidebar.py +107 -0
- requirements.txt +6 -0
- rxconfig.py +5 -0
assets/prompt-order-experiment.svg
ADDED
prompt_order_exeriment/__init__.py
ADDED
File without changes
|
prompt_order_exeriment/prompt_order_exeriment.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import plotly.express as px
|
3 |
+
import reflex as rx
|
4 |
+
from datasets import load_dataset
|
5 |
+
from reflex_ag_grid import ag_grid
|
6 |
+
from sklearn.metrics import accuracy_score
|
7 |
+
|
8 |
+
from .sidebar import sidebar
|
9 |
+
|
10 |
+
# Load the HF dataset
|
11 |
+
dataset = load_dataset("derek-thomas/labeled-multiple-choice-explained-mistral-results")
|
12 |
+
|
13 |
+
# Convert the dataset to a Pandas DataFrame
|
14 |
+
df = dataset['train'].to_pandas()
|
15 |
+
|
16 |
+
# Columns to analyze
|
17 |
+
cols_to_analyze = [
|
18 |
+
"predictions_base",
|
19 |
+
"predictions_FA",
|
20 |
+
"predictions_RFA_mistral",
|
21 |
+
"predictions_FAR_mistral",
|
22 |
+
"predictions_RFA_gpt3_5",
|
23 |
+
"predictions_FAR_gpt3_5",
|
24 |
+
]
|
25 |
+
|
26 |
+
# Mapping for renaming models
|
27 |
+
model_names = {
|
28 |
+
"predictions_base": "Base Model",
|
29 |
+
"predictions_FA": "Final Answer",
|
30 |
+
"predictions_RFA_mistral": "Reasoning (Mistral) -> Final Answer)",
|
31 |
+
"predictions_FAR_mistral": "Final Answer -> Reasoning (Mistral)",
|
32 |
+
"predictions_RFA_gpt3_5": "Reasoning (GPT-3.5 ) -> Final Answer",
|
33 |
+
"predictions_FAR_gpt3_5": "Final Answer -> Reasoning(GPT-3.5)",
|
34 |
+
}
|
35 |
+
|
36 |
+
# Compute metrics for each model
|
37 |
+
metrics_data = []
|
38 |
+
for col in cols_to_analyze:
|
39 |
+
accuracy = round(accuracy_score(df["answer_key"], df[col]) * 100, 2)
|
40 |
+
metrics_data.append({"Prediction Type": model_names[col], "Accuracy (%)": accuracy})
|
41 |
+
|
42 |
+
# Create a DataFrame for metrics
|
43 |
+
metrics_df = pd.DataFrame(metrics_data)
|
44 |
+
|
45 |
+
# Column definitions for the metrics table
|
46 |
+
metrics_column_defs = [
|
47 |
+
ag_grid.column_def(field="Prediction Type", header_name="Prediction Type", width=250),
|
48 |
+
ag_grid.column_def(field="Accuracy (%)", header_name="Accuracy (%)"),
|
49 |
+
]
|
50 |
+
|
51 |
+
|
52 |
+
# Function to generate the topic performance star chart
|
53 |
+
def topic_star_chart():
|
54 |
+
# Calculate per-topic accuracy
|
55 |
+
topic_accuracy = []
|
56 |
+
for topic in df["topic"].unique():
|
57 |
+
topic_data = df[df["topic"] == topic]
|
58 |
+
for col in cols_to_analyze:
|
59 |
+
accuracy = round((topic_data[col] == topic_data["answer_key"]).mean() * 100, 2)
|
60 |
+
topic_accuracy.append({"Topic": topic, "Prediction Type": model_names[col], "Accuracy (%)": accuracy})
|
61 |
+
|
62 |
+
# Create DataFrame for visualization
|
63 |
+
topic_df = pd.DataFrame(topic_accuracy)
|
64 |
+
|
65 |
+
# Find the top 10 topics by number of rows
|
66 |
+
topic_counts = df["topic"].value_counts().head(10).index
|
67 |
+
filtered_topic_df = topic_df[topic_df["Topic"].isin(topic_counts)]
|
68 |
+
|
69 |
+
# Create star chart (radar chart)
|
70 |
+
fig = px.line_polar(
|
71 |
+
filtered_topic_df,
|
72 |
+
r="Accuracy (%)",
|
73 |
+
theta="Topic",
|
74 |
+
color="Prediction Type",
|
75 |
+
title="Top 10 Topics: Per-Topic Performance Star Chart",
|
76 |
+
line_close=True, # Close the lines to form a star shape
|
77 |
+
)
|
78 |
+
fig.update_layout(width=900, height=900)
|
79 |
+
|
80 |
+
return fig
|
81 |
+
|
82 |
+
|
83 |
+
p1 = '''
|
84 |
+
# Prompt Order Experiment
|
85 |
+
## 🎯 Goal
|
86 |
+
This experiment aims to explore various scenarios for **prompt fine-tuning** using structured generation. We'll test how the order of elements in a prompt affects model performance. The elements we consider are:
|
87 |
+
- **(Q)**: Question
|
88 |
+
- **(AC)**: Answer Choices
|
89 |
+
- **(R)**: Reasoning
|
90 |
+
- **(FA)**: Final Answer
|
91 |
+
|
92 |
+
### Scenarios
|
93 |
+
We will evaluate the following prompt orders:
|
94 |
+
|
95 |
+
1. **Scenario 1: `Q - AC - R - FA`** (Mistral and GPT3.5)
|
96 |
+
- This is the most natural order. The model generates reasoning before the final answer, providing the most information prior to making a selection. This order leverages decoding mechanics effectively.
|
97 |
+
|
98 |
+
This is our user message, we can see the question and answer choices
|
99 |
+
```
|
100 |
+
\<s>[INST] Answer the Question and include your Reasoning and the Final Answer in a json like: {"Reasoning: "...", "Final Answer": "x"} where x is a letter that corresponds to the answer choice which is a letter between a and h.
|
101 |
+
Question: What is satellite technology used for predicting?
|
102 |
+
Answer Choices: (a) Seconds and minutes (b) The strength and magnitude of an earthquake (c) What it's like outside each day (d) 70-75 degrees fahrenheit (e) Rapid changes occur (f) Dead-ends and false starts. (g) Snow, ice, and rock (h) Around 5 to 27 degrees celsius[/INST]
|
103 |
+
```
|
104 |
+
|
105 |
+
This is our assistant message, you can see that we are forcing a JSON (note I added spacing for visual purposes), and we are putting the reasoning first.
|
106 |
+
```
|
107 |
+
{
|
108 |
+
"Reasoning": "a) Seconds and minutes: This option is incorrect because satellite technology is not used for predicting time intervals. Satellite technology is used for various purposes such as communication, navigation, and weather forecasting, but it is not used for predicting time intervals.\n\nb) The strength and magnitude of an earthquake: This option is incorrect because satellite technology is not used for predicting earthquakes. Earthquake prediction is a complex process that involves seismology and other scientific methods, but satellite technology is not one of them.\n\nd) 70-75 degrees Fahrenheit: This option is incorrect because satellite technology is not used for predicting specific temperature ranges. While satellite technology can provide temperature data, it is not used for predicting specific temperature ranges.\n\ne) Rapid changes occur: This option is too vague and does not provide enough information to determine whether it is correct or not. Satellite technology can be used to monitor changes in various environmental factors, but it is not used specifically for predicting rapid changes.\n\nf) Dead-ends and false starts: This option is incorrect because it is not related to satellite technology or any type of prediction.\n\ng) Snow, ice, and rock: This option is incorrect because it is too specific and does not cover the broad range of predictions that satellite technology can be used for. While satellite technology can be used to monitor snow, ice, and rock formations, it is not used exclusively for this purpose.\n\nh) Around 5 to 27 degrees Celsius: This option is incorrect because it is too specific and does not cover the broad range of temperature predictions that satellite technology can be used for. While satellite technology can provide temperature data, it is not used exclusively for predicting temperatures within a specific range.\n\nTherefore, the correct answer is c) what it's like outside each day, as satellite technology is commonly used for weather forecasting and predicting daily weather conditions.",
|
109 |
+
"Final Answer": "c"
|
110 |
+
}</s>
|
111 |
+
```
|
112 |
+
|
113 |
+
2. **Scenario 2: `Q - AC - FA - R`** (Mistral and GPT3.5)
|
114 |
+
- An awkward order, placing reasoning after the final answer. While it is faster, it assumes the model can "know" reasoning internally before generating it. This approach saves tokens but is a skeptical case worth testing.
|
115 |
+
|
116 |
+
```
|
117 |
+
\<s>[INST] Answer the Question and include your Final Answer and the Reasoning in a json like: {"Final Answer": "x", "Reasoning: "..."} where x is a letter that corresponds to the answer choice which is a letter between a and h.
|
118 |
+
Question: What is satellite technology used for predicting?
|
119 |
+
Answer Choices: (a) Seconds and minutes (b) The strength and magnitude of an earthquake (c) What it's like outside each day (d) 70-75 degrees fahrenheit (e) Rapid changes occur (f) Dead-ends and false starts. (g) Snow, ice, and rock (h) Around 5 to 27 degrees celsius[/INST]
|
120 |
+
```
|
121 |
+
|
122 |
+
```
|
123 |
+
{
|
124 |
+
'Final Answer': 'c',
|
125 |
+
'Reasoning': "a) Seconds and minutes: This option is incorrect because satellite technology is not used for predicting time intervals. Satellite technology is used for various purposes such as communication, navigation, and weather forecasting, but it is not used for predicting time intervals.\n\nb) The strength and magnitude of an earthquake: This option is incorrect because satellite technology is not used for predicting earthquakes. Earthquake prediction is a complex process that involves seismology and other scientific methods, but satellite technology is not one of them.\n\nd) 70-75 degrees Fahrenheit: This option is incorrect because satellite technology is not used for predicting specific temperature ranges. While satellite technology can provide temperature data, it is not used for predicting specific temperature ranges.\n\ne) Rapid changes occur: This option is too vague and does not provide enough information to determine whether it is correct or not. Satellite technology can be used to monitor changes in various environmental factors, but it is not used specifically for predicting rapid changes.\n\nf) Dead-ends and false starts: This option is incorrect because it is not related to satellite technology or any type of prediction.\n\ng) Snow, ice, and rock: This option is incorrect because it is too specific and does not cover the broad range of predictions that satellite technology can be used for. While satellite technology can be used to monitor snow, ice, and rock formations, it is not used exclusively for this purpose.\n\nh) Around 5 to 27 degrees Celsius: This option is incorrect because it is too specific and does not cover the broad range of temperature predictions that satellite technology can be used for. While satellite technology can provide temperature data, it is not used exclusively for predicting temperatures within a specific range.\n\nTherefore, the correct answer is c) what it's like outside each day, as satellite technology is commonly used for weather forecasting and predicting daily weather conditions."
|
126 |
+
}</s>
|
127 |
+
```
|
128 |
+
|
129 |
+
3. **Scenario 3: `Q - AC - FA`**
|
130 |
+
- This serves as a fine-tuning control. No reasoning is provided in the output.
|
131 |
+
|
132 |
+
4. **Scenario 4: Base**
|
133 |
+
- An un-fine-tuned control for comparison purposes.
|
134 |
+
|
135 |
+
### Structured Generation
|
136 |
+
Structured generation ensures consistent response formats, which is crucial for reliable fine-tuning. Initial experiments faced difficulties with response consistency and structured generation can solve this.
|
137 |
+
|
138 |
+
'''
|
139 |
+
|
140 |
+
p2 = '''
|
141 |
+
# Steps
|
142 |
+
### Dataset Selection
|
143 |
+
We begin with the [layoric/labeled-multiple-choice-explained](https://huggingface.co/datasets/layoric/labeled-multiple-choice-explained) dataset, which includes reasoning provided by GPT-3.5-turbo. reasoning explanations serve as a starting point but may differ from Mistral's reasoning style.
|
144 |
+
|
145 |
+
0. 00-poe-generate-mistral-reasoning.ipynb: To align with Mistral, we need to create a refined dataset: [derek-thomas/labeled-multiple-choice-explained-mistral-reasoning](https://huggingface.co/datasets/derek-thomas/labeled-multiple-choice-explained-mistral-reasoning) with [poe-generate-mistral-reasoning.ipynb](./poe-generate-mistral-reasoning.ipynb).
|
146 |
+
1. 01-poe-dataset-creation.ipynb: Then we need to create our prompt experiments
|
147 |
+
2. 02-autotrain.ipynb: We generate autotrain jobs on spaces to train our models
|
148 |
+
3. 03-poe-token-count-exploration.ipynb: We do some quick analysis so we can optimize our TGI settings
|
149 |
+
4. 04-poe-eval.ipynb: We finally evaluate our trained models
|
150 |
+
|
151 |
+
'''
|
152 |
+
|
153 |
+
|
154 |
+
def mermaid_svg():
|
155 |
+
with open('assets/poe.svg', 'r') as file:
|
156 |
+
svg_content = file.read()
|
157 |
+
|
158 |
+
return rx.html(
|
159 |
+
f'<div style="width: 300%; height: auto;">{svg_content}</div>'
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def overview():
|
164 |
+
return rx.hstack(
|
165 |
+
sidebar(),
|
166 |
+
rx.vstack(
|
167 |
+
rx.markdown(p2),
|
168 |
+
mermaid_svg(),
|
169 |
+
)
|
170 |
+
)
|
171 |
+
|
172 |
+
|
173 |
+
def results():
|
174 |
+
return rx.hstack(
|
175 |
+
sidebar(),
|
176 |
+
rx.vstack(
|
177 |
+
rx.heading("Results", size="8", margin="20px 0"),
|
178 |
+
# rx.markdown(diagram),
|
179 |
+
ag_grid(
|
180 |
+
id="ag_grid_metrics",
|
181 |
+
row_data=metrics_df.to_dict("records"),
|
182 |
+
column_defs=metrics_column_defs,
|
183 |
+
width="60%",
|
184 |
+
margin="20px auto", # Center the table
|
185 |
+
),
|
186 |
+
rx.divider(),
|
187 |
+
rx.heading("Top 10 Topics: Per-Topic Performance Star Chart", size="8", margin="20px 0"),
|
188 |
+
rx.text(
|
189 |
+
"The chart below shows how each model performed across the top 10 topics by row count. "
|
190 |
+
"Each line represents a model, and the chart allows you to compare their performance across topics.",
|
191 |
+
font_size="md",
|
192 |
+
padding="10px",
|
193 |
+
),
|
194 |
+
rx.plotly(data=topic_star_chart()), # Render the radar chart
|
195 |
+
padding="20px",
|
196 |
+
|
197 |
+
)
|
198 |
+
)
|
199 |
+
|
200 |
+
|
201 |
+
# Main page of the Reflex app
|
202 |
+
def index():
|
203 |
+
return rx.hstack(
|
204 |
+
sidebar(),
|
205 |
+
rx.vstack(
|
206 |
+
rx.markdown(p1),
|
207 |
+
))
|
208 |
+
|
209 |
+
|
210 |
+
# Initialize the Reflex app with polished layout
|
211 |
+
app = rx.App()
|
212 |
+
app.add_page(overview, title='Overview', route='/overview') # Add the index page
|
213 |
+
app.add_page(results, title='Results', route='/results') # Add the index page
|
214 |
+
app.add_page(index) # Add the index page
|
prompt_order_exeriment/sidebar.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import reflex as rx
|
2 |
+
|
3 |
+
|
4 |
+
def sidebar_item(
|
5 |
+
text: str, icon: str, href: str
|
6 |
+
) -> rx.Component:
|
7 |
+
return rx.link(
|
8 |
+
rx.hstack(
|
9 |
+
rx.icon(icon),
|
10 |
+
rx.text(text, size="4"),
|
11 |
+
width="100%",
|
12 |
+
padding_x="0.5rem",
|
13 |
+
padding_y="0.75rem",
|
14 |
+
align="center",
|
15 |
+
style={
|
16 |
+
"_hover": {
|
17 |
+
"bg": rx.color("accent", 4),
|
18 |
+
"color": rx.color("accent", 11),
|
19 |
+
},
|
20 |
+
"border-radius": "0.5em",
|
21 |
+
},
|
22 |
+
),
|
23 |
+
href=href,
|
24 |
+
underline="none",
|
25 |
+
weight="medium",
|
26 |
+
width="100%",
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def sidebar_items() -> rx.Component:
|
31 |
+
return rx.vstack(
|
32 |
+
sidebar_item("Prompt Order Experiment", "square-library", "/#"),
|
33 |
+
sidebar_item("Overview", "layout-dashboard", "/overview"),
|
34 |
+
sidebar_item("Results", "bar-chart-4", "/results"),
|
35 |
+
# sidebar_item("Messages", "mail", "/#"),
|
36 |
+
spacing="1",
|
37 |
+
width="100%",
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
def sidebar() -> rx.Component:
|
42 |
+
return rx.box(
|
43 |
+
rx.desktop_only(
|
44 |
+
rx.vstack(
|
45 |
+
rx.hstack(
|
46 |
+
rx.image(
|
47 |
+
src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg",
|
48 |
+
width="2.25em",
|
49 |
+
height="auto",
|
50 |
+
border_radius="25%",
|
51 |
+
),
|
52 |
+
rx.heading(
|
53 |
+
"Prompt Order Experiment", size="7", weight="bold"
|
54 |
+
),
|
55 |
+
align="center",
|
56 |
+
justify="start",
|
57 |
+
padding_x="0.5rem",
|
58 |
+
width="100%",
|
59 |
+
),
|
60 |
+
sidebar_items(),
|
61 |
+
spacing="5",
|
62 |
+
# position="fixed",
|
63 |
+
# left="0px",
|
64 |
+
# top="0px",
|
65 |
+
# z_index="5",
|
66 |
+
padding_x="1em",
|
67 |
+
padding_y="1.5em",
|
68 |
+
bg=rx.color("accent", 3),
|
69 |
+
align="start",
|
70 |
+
# height="100%",
|
71 |
+
height="650px",
|
72 |
+
width="16em",
|
73 |
+
),
|
74 |
+
),
|
75 |
+
rx.mobile_and_tablet(
|
76 |
+
rx.drawer.root(
|
77 |
+
rx.drawer.trigger(
|
78 |
+
rx.icon("align-justify", size=30)
|
79 |
+
),
|
80 |
+
rx.drawer.overlay(z_index="5"),
|
81 |
+
rx.drawer.portal(
|
82 |
+
rx.drawer.content(
|
83 |
+
rx.vstack(
|
84 |
+
rx.box(
|
85 |
+
rx.drawer.close(
|
86 |
+
rx.icon("x", size=30)
|
87 |
+
),
|
88 |
+
width="100%",
|
89 |
+
),
|
90 |
+
sidebar_items(),
|
91 |
+
spacing="5",
|
92 |
+
width="100%",
|
93 |
+
),
|
94 |
+
top="auto",
|
95 |
+
right="auto",
|
96 |
+
height="100%",
|
97 |
+
width="20em",
|
98 |
+
padding="1.5em",
|
99 |
+
bg=rx.color("accent", 2),
|
100 |
+
),
|
101 |
+
width="100%",
|
102 |
+
),
|
103 |
+
direction="left",
|
104 |
+
),
|
105 |
+
padding="1em",
|
106 |
+
),
|
107 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
reflex==0.6.6.post3
|
2 |
+
reflex-ag-grid==0.0.10
|
3 |
+
scikit-learn==1.5.2
|
4 |
+
plotly==5.24.1
|
5 |
+
pandas==2.2.3
|
6 |
+
datasets==3.1.0
|
rxconfig.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import reflex as rx
|
2 |
+
|
3 |
+
config = rx.Config(
|
4 |
+
app_name="prompt_order_exeriment",
|
5 |
+
)
|