Aidan-Bench / app.py
Presidentlin's picture
x
2cb6075
raw
history blame
8 kB
import streamlit as st
from main import get_novelty_score
from models import chat_with_model, embed
from prompts import questions as predefined_questions, create_gen_prompt, create_judge_prompt
import requests
import numpy as np
import os
# Set the title in the browser tab
st.set_page_config(page_title="Aidan Bench - Generator")
st.title("Aidan Bench - Generator")
# API Key Inputs with Security and User Experience Enhancements
st.warning("Please keep your API keys secure and confidential. This app does not store or log your API keys.")
if "open_router_key" not in st.session_state:
st.session_state.open_router_key = ""
if "openai_api_key" not in st.session_state:
st.session_state.openai_api_key = ""
open_router_key = st.text_input("Enter your Open Router API Key:", type="password", value=st.session_state.open_router_key)
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password", value=st.session_state.openai_api_key)
if st.button("Confirm API Keys"):
if open_router_key and openai_api_key:
st.session_state.open_router_key = open_router_key
st.session_state.openai_api_key = openai_api_key
st.success("API keys confirmed!")
else:
st.warning("Please enter both API keys.")
# Access API keys from session state
if st.session_state.open_router_key and st.session_state.openai_api_key:
# Fetch models from OpenRouter API
try:
response = requests.get("https://openrouter.ai/api/v1/models")
response.raise_for_status() # Raise an exception for bad status codes
models = response.json()["data"]
# Sort models alphabetically by their ID
models.sort(key=lambda model: model["id"])
model_names = [model["id"] for model in models]
except requests.exceptions.RequestException as e:
st.error(f"Error fetching models from OpenRouter API: {e}")
model_names = [] # Provide an empty list if API call fails
# Model Selection
if model_names:
model_name = st.selectbox("Select a Language Model", model_names)
else:
st.error("No models available. Please check your API connection.")
st.stop() # Stop execution if no models are available
# Initialize session state for user_questions and predefined_questions
if "user_questions" not in st.session_state:
st.session_state.user_questions = []
# Workflow Selection
workflow = st.radio("Select Workflow:", ["Use Predefined Questions", "Use User-Defined Questions"])
# Handle Predefined Questions
if workflow == "Use Predefined Questions":
st.header("Question Selection")
# Multiselect for predefined questions
selected_questions = st.multiselect(
"Select questions to benchmark:",
predefined_questions,
predefined_questions # Select all by default
)
# Handle User-Defined Questions
elif workflow == "Use User-Defined Questions":
st.header("Question Input")
# Input for adding a new question
new_question = st.text_input("Enter a new question:")
if st.button("Add Question") and new_question:
new_question = new_question.strip() # Remove leading/trailing whitespace
if new_question and new_question not in st.session_state.user_questions:
st.session_state.user_questions.append(new_question) # Append to session state
st.success(f"Question '{new_question}' added successfully.")
else:
st.warning("Question already exists or is empty!")
# Display multiselect with updated user questions
selected_questions = st.multiselect(
"Select your custom questions:",
options=st.session_state.user_questions,
default=st.session_state.user_questions
)
# Display selected questions
st.write("Selected Questions:", selected_questions)
# Benchmark Execution
if st.button("Start Benchmark"):
if not selected_questions:
st.warning("Please select at least one question.")
else:
# Initialize progress bar
progress_bar = st.progress(0)
num_questions = len(selected_questions)
results = [] # List to store results
# Iterate through selected questions
for i, question in enumerate(selected_questions):
# Display current question
st.write(f"Processing question {i+1}/{num_questions}: {question}")
previous_answers = []
question_novelty = 0
try:
while True:
gen_prompt = create_gen_prompt(question, previous_answers)
try:
new_answer = chat_with_model(
prompt=gen_prompt,
model=model_name,
open_router_key=st.session_state.open_router_key,
openai_api_key=st.session_state.openai_api_key
)
except requests.exceptions.RequestException as e:
st.error(f"API Error: {e}")
break
judge_prompt = create_judge_prompt(question, new_answer)
judge = "openai/gpt-4o-mini"
try:
judge_response = chat_with_model(
prompt=judge_prompt,
model=judge,
open_router_key=st.session_state.open_router_key,
openai_api_key=st.session_state.openai_api_key
)
except requests.exceptions.RequestException as e:
st.error(f"API Error (Judge): {e}")
break
coherence_score = int(judge_response.split("<coherence_score>")[1].split("</coherence_score>")[0])
if coherence_score <= 3:
st.warning("Output is incoherent. Moving to next question.")
break
novelty_score = get_novelty_score(new_answer, previous_answers, st.session_state.openai_api_key)
if novelty_score < 0.1:
st.warning("Output is redundant. Moving to next question.")
break
st.write(f"New Answer:\n{new_answer}")
st.write(f"Coherence Score: {coherence_score}")
st.write(f"Novelty Score: {novelty_score}")
previous_answers.append(new_answer)
question_novelty += novelty_score
except Exception as e:
st.error(f"Error processing question: {e}")
results.append({
"question": question,
"answers": previous_answers,
"coherence_score": coherence_score,
"novelty_score": novelty_score
})
# Update progress bar
progress_bar.progress((i + 1) / num_questions)
st.success("Benchmark completed!")
# Display results in a table
st.write("Results:")
results_table = []
for result in results:
for answer in result["answers"]:
results_table.append({
"Question": result["question"],
"Answer": answer,
"Coherence Score": result["coherence_score"],
"Novelty Score": result["novelty_score"]
})
st.table(results_table)
else:
st.warning("Please confirm your API keys first.")