import pandas as pd | |
import streamlit as st | |
from langchain import PromptTemplate, HuggingFaceHub, LLMChain | |
from langchain.llms import OpenAI | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import os | |
import re | |
def extract_positive_negative(text): | |
pattern = r'\b(?:positive|negative)\b' | |
result = re.findall(pattern, text) | |
return result | |
def classify_text(text, llm_chain, api): | |
if api == "HuggingFace": | |
classification = | |
elif api == "OpenAI": | |
classification = | |
classification = re.sub(r'\s', '', classification) | |
return classification.lower() | |
def classify_csv(df, llm_chain, api): | |
df["label_gold"] = df["label"] | |
del df["label"] | |
df["label_pred"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api) | |
return df | |
def classify_csv_zero(zero_file, llm_chain, api): | |
df = pd.read_csv(zero_file, sep=';') | |
df["label"] = df["text"].apply(classify_text, llm_chain=llm_chain, api=api) | |
return df | |
def evaluate_performance(df): | |
merged_df = df | |
correct_preds = sum(merged_df["label_gold"] == merged_df["label_pred"]) | |
total_preds = len(merged_df) | |
percentage_overlap = correct_preds / total_preds * 100 | |
return percentage_overlap | |
def display_home(): | |
st.write("Please select an API and a model to classify the text. We currently support HuggingFace and OpenAI.") | |
api = st.selectbox("Select an API", ["HuggingFace", "OpenAI"]) | |
if api == "HuggingFace": | |
model = st.selectbox("Select a model", ["google/flan-t5-xl", "databricks/dolly-v1-6b"]) | |
api_key_hug = st.text_input("HuggingFace API Key") | |
elif api == "OpenAI": | |
model = None | |
api_key_openai = st.text_input("OpenAI API Key") | |
st.write("Please select a temperature for the model. The higher the temperature, the more creative the model will be.") | |
temperature = st.slider("Set the temperature", min_value=0.0, max_value=1.0, value=0.0, step=0.01) | |
st.write("We provide two different setups for the annotation task. In the first setup (**Test**), you can upload a CSV file with gold labels and evaluate the performance of the model. In the second setup (**Zero-Shot**), you can upload a CSV file without gold labels and use the model to classify the text.") | |
setup = st.selectbox("Setup", ["Test", "Zero-Shot"]) | |
if setup == "Test": | |
gold_file = st.file_uploader("Upload Gold Labels CSV file with a text and a label column", type=["csv"]) | |
elif setup == "Zero-Shot": | |
gold_file = None | |
zero_file = st.file_uploader("Upload CSV file with a text column", type=["csv"]) | |
st.write("Please enter the prompt template below. You can use the following variables: {text} (text to classify).") | |
prompt_template = st.text_area("Enter your task description", """Instruction: Identify the sentiment of a text. Please read the text and provide one of these responses: "positive" or "negative".\nText to classify in "positive" or "negative": {text}\nAnswer:""", height=200) | |
classify_button = st.button("Run Classification/ Annotation") | |
if classify_button: | |
if prompt_template: | |
prompt = PromptTemplate( | |
template=prompt_template, | |
input_variables=["text"] | |
) | |
if api == "HuggingFace": | |
if api_key_hug: | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key_hug | |
llm_chain = LLMChain(prompt=prompt, llm=HuggingFaceHub(repo_id=model, model_kwargs={"temperature": temperature, "max_length": 128})) | |
elif not api_key_hug: | |
st.warning("Please enter your HuggingFace API key to classify the text.") | |
elif api == "OpenAI": | |
if api_key_openai: | |
os.environ["OPENAI_API_KEY"] = api_key_openai | |
llm_chain = LLMChain(prompt=prompt, llm=OpenAI(temperature=temperature)) | |
elif not api_key_openai: | |
st.warning("Please enter your OpenAI API key to classify the text.") | |
if setup == "Zero-Shot": | |
if zero_file is not None: | |
df_predicted = classify_csv_zero(zero_file, llm_chain, api) | |
st.write(df_predicted) | |
st.download_button( | |
label="Download CSV", | |
data=df_predicted.to_csv(index=False), | |
file_name="classified_zero-shot_data.csv", | |
mime="text/csv" | |
) | |
elif setup == "Test": | |
if gold_file is not None: | |
df = pd.read_csv(gold_file, sep=';') | |
if "label" not in df.columns: | |
st.warning("Please make sure that the gold labels CSV file contains a column named 'label'.") | |
else: | |
df = classify_csv(df, llm_chain, api) | |
st.write(df) | |
st.download_button( | |
label="Download CSV", | |
data=df.to_csv(index=False), | |
file_name="classified_test_data.csv", | |
mime="text/csv" | |
) | |
percentage_overlap = evaluate_performance(df) | |
st.write("**Performance Evaluation**") | |
st.write(f"Percentage overlap between gold labels and predicted labels: {percentage_overlap:.2f}%") | |
elif gold_file is None: | |
st.warning("Please upload a gold labels CSV file to evaluate the performance of the model.") | |
elif not prompt: | |
st.warning("Please enter a prompt question to classify the text.") | |
def main(): | |
st.set_page_config(page_title="PromptCards Playground", page_icon=":pencil2:") | |
st.title("AInnotator") | |
# add a menu to the sidebar | |
if "current_page" not in st.session_state: | |
st.session_state.current_page = "homepage" | |
# Initialize selected_prompt in session_state if not set | |
if "selected_prompt" not in st.session_state: | |
st.session_state.selected_prompt = "" | |
# Add a menu | |
menu = ["Homepage", "Playground", "Prompt Archive", "Annotator", "About"] | |
st.sidebar.title("About") | |
st.sidebar.write("AInnotator π€π·οΈ is a tool for creating artificial labels/ annotations. It is based on the concept of PromptCards, which are small, self-contained descriptions of a task that can be used to generate labels for a wide range of NLP tasks. Check out the GitHub repository and the PromptCards Archive for more information.") | |
st.sidebar.write("---") | |
st.sidebar.write("Check out the [PromptCards archive]( to find a wide range of prompts for different NLP tasks.") | |
st.sidebar.write("---") | |
st.sidebar.write("Made with β€οΈ and π€.") | |
display_home() | |
if __name__ == "__main__": | |
main() | |