|
import streamlit as st |
|
import pandas as pd |
|
|
|
@st.cache_data |
|
def load_dataframe() -> pd.DataFrame: |
|
""" |
|
Load dataframe from the csv file in public directory |
|
Returns |
|
dataframe: a pd.DataFrame of the average scores of the LLMs on each task |
|
""" |
|
|
|
dataframe = pd.read_csv("public/datasets/models_scores.csv") |
|
dataframe = dataframe.drop(columns = "Unnamed: 0") |
|
return dataframe |
|
|
|
@st.cache_data |
|
def show_dataframe_top(n:int , dataframe: pd.DataFrame) -> pd.DataFrame: |
|
""" |
|
read only the n-th first row |
|
Arguments |
|
-n: an integer telling the number of row |
|
-dataframe: the dataframe to slice |
|
Returns |
|
dataframe: a pd.DataFrame of the average scores of the LLMs on each task |
|
""" |
|
|
|
return dataframe.head(n) |
|
|
|
@st.cache_data |
|
def sort_by(dataframe: pd.DataFrame, column_name: str, ascending:bool = False) -> pd.DataFrame: |
|
""" |
|
Sort the dataframe by column_name |
|
|
|
Arguments: |
|
- dataframe: a pandas dataframe to sort |
|
- column_name: a string stating the column to sort the dataframe by |
|
- ascending: a boolean stating to sort in ascending order or not, default to False |
|
|
|
Returns: |
|
a sorted dataframe |
|
""" |
|
return dataframe.sort_values(by = column_name, ascending = ascending ) |
|
|
|
@st.cache_data |
|
def search_by_name(name: str) -> pd.DataFrame: |
|
""" |
|
Search a model by its name |
|
|
|
Arguments: |
|
- name: the name of the model or part of it |
|
|
|
Returns: |
|
a pandas Dataframe of every row that contains name |
|
""" |
|
dataframe = load_dataframe() |
|
indexes = dataframe["model_name"].str.contains(name) |
|
return dataframe[indexes] |
|
|
|
def validate_categories(categories: list) -> bool: |
|
""" |
|
validate a list of categories to the columns in the dataframe |
|
Arguments: |
|
- categories: a list of categories for the ordering of the columns in the dataframe |
|
|
|
This expects a list with six elements that should be (not necessary in order): |
|
- ARC |
|
- GSM8K |
|
- TruthfulQA |
|
- Winogrande |
|
- HellaSwag |
|
- MMLU |
|
|
|
Returns |
|
- True if the list has the right number of element and right elements |
|
- False otherwise |
|
""" |
|
valid_categories = False |
|
if len(categories) == 6: |
|
if ("ARC" in categories and "GSM8K" in categories and "TruthfulQA" in categories |
|
and "Winogrande" in categories and "HellaSwag" in categories and "MMLU" in categories): |
|
valid_categories = True |
|
else: |
|
valid_categories = False |
|
else: |
|
valid_categories = False |
|
|
|
return valid_categories |