Aidan-Bench / app.py
Presidentlin's picture
x
a4e6a71
raw
history blame
8.01 kB
import streamlit as st
from main import get_novelty_score
from models import chat_with_model, embed
from prompts import questions as predefined_questions, create_gen_prompt, create_judge_prompt
import requests
import numpy as np
import os
st.title("Aiden Bench - Generator")
# API Key Inputs with Security and User Experience Enhancements
st.warning("Please keep your API keys secure and confidential. This app does not store or log your API keys.")
st.write("Learn how to obtain API keys from Open Router and OpenAI.") # Add links or instructions here
if "open_router_key" not in st.session_state:
st.session_state.open_router_key = ""
if "openai_api_key" not in st.session_state:
st.session_state.openai_api_key = ""
open_router_key = st.text_input("Enter your Open Router API Key:", type="password", value=st.session_state.open_router_key)
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password", value=st.session_state.openai_api_key)
if st.button("Confirm API Keys"):
if open_router_key and openai_api_key:
st.session_state.open_router_key = open_router_key
st.session_state.openai_api_key = openai_api_key
st.success("API keys confirmed!")
else:
st.warning("Please enter both API keys.")
# Access API keys from session state
if st.session_state.open_router_key and st.session_state.openai_api_key:
# Fetch models from OpenRouter API
try:
response = requests.get("https://openrouter.ai/api/v1/models")
response.raise_for_status() # Raise an exception for bad status codes
models = response.json()["data"]
# Sort models alphabetically by their ID
models.sort(key=lambda model: model["id"])
model_names = [model["id"] for model in models]
except requests.exceptions.RequestException as e:
st.error(f"Error fetching models from OpenRouter API: {e}")
model_names = [] # Provide an empty list if API call fails
# Model Selection
if model_names:
model_name = st.selectbox("Select a Language Model", model_names)
else:
st.error("No models available. Please check your API connection.")
st.stop() # Stop execution if no models are available
# Initialize session state for user_questions and predefined_questions
if "user_questions" not in st.session_state:
st.session_state.user_questions = []
# Workflow Selection
workflow = st.radio("Select Workflow:", ["Use Predefined Questions", "Use User-Defined Questions"])
# Handle Predefined Questions
if workflow == "Use Predefined Questions":
st.header("Question Selection")
# Multiselect for predefined questions
selected_questions = st.multiselect(
"Select questions to benchmark:",
predefined_questions,
predefined_questions # Select all by default
)
# Handle User-Defined Questions
elif workflow == "Use User-Defined Questions":
st.header("Question Input")
# Input for adding a new question
new_question = st.text_input("Enter a new question:")
if st.button("Add Question") and new_question:
new_question = new_question.strip() # Remove leading/trailing whitespace
if new_question and new_question not in st.session_state.user_questions:
st.session_state.user_questions.append(new_question) # Append to session state
st.success(f"Question '{new_question}' added successfully.")
else:
st.warning("Question already exists or is empty!")
# Display multiselect with updated user questions
selected_questions = st.multiselect(
"Select your custom questions:",
options=st.session_state.user_questions,
default=st.session_state.user_questions
)
# Display selected questions
st.write("Selected Questions:", selected_questions)
# Benchmark Execution
if st.button("Start Benchmark"):
if not selected_questions:
st.warning("Please select at least one question.")
else:
# Initialize progress bar
progress_bar = st.progress(0)
num_questions = len(selected_questions)
results = [] # List to store results
# Iterate through selected questions
for i, question in enumerate(selected_questions):
# Display current question
st.write(f"Processing question {i+1}/{num_questions}: {question}")
previous_answers = []
question_novelty = 0
try:
while True:
gen_prompt = create_gen_prompt(question, previous_answers)
try:
new_answer = chat_with_model(
prompt=gen_prompt,
model=model_name,
open_router_key=st.session_state.open_router_key,
openai_api_key=st.session_state.openai_api_key
)
except requests.exceptions.RequestException as e:
st.error(f"API Error: {e}")
break
judge_prompt = create_judge_prompt(question, new_answer)
judge = "openai/gpt-4o-mini"
try:
judge_response = chat_with_model(
prompt=judge_prompt,
model=judge,
open_router_key=st.session_state.open_router_key,
openai_api_key=st.session_state.openai_api_key
)
except requests.exceptions.RequestException as e:
st.error(f"API Error (Judge): {e}")
break
coherence_score = int(judge_response.split("<coherence_score>")[1].split("</coherence_score>")[0])
if coherence_score <= 3:
st.warning("Output is incoherent. Moving to next question.")
break
novelty_score = get_novelty_score(new_answer, previous_answers, st.session_state.openai_api_key)
if novelty_score < 0.1:
st.warning("Output is redundant. Moving to next question.")
break
st.write(f"New Answer:\n{new_answer}")
st.write(f"Coherence Score: {coherence_score}")
st.write(f"Novelty Score: {novelty_score}")
previous_answers.append(new_answer)
question_novelty += novelty_score
except Exception as e:
st.error(f"Error processing question: {e}")
results.append({
"question": question,
"answers": previous_answers,
"coherence_score": coherence_score,
"novelty_score": novelty_score
})
# Update progress bar
progress_bar.progress((i + 1) / num_questions)
st.success("Benchmark completed!")
# Display results in a table
st.write("Results:")
results_table = []
for result in results:
for answer in result["answers"]:
results_table.append({
"Question": result["question"],
"Answer": answer,
"Coherence Score": result["coherence_score"],
"Novelty Score": result["novelty_score"]
})
st.table(results_table)
else:
st.warning("Please confirm your API keys first.")