Spaces:
Runtime error
Runtime error
import os | |
import re | |
from dataclasses import asdict | |
import pandas as pd | |
from langchain.callbacks import get_openai_callback | |
from langchain.chains import LLMChain | |
from langchain.chat_models import ChatOpenAI | |
from langchain.document_loaders import DataFrameLoader | |
from langchain.prompts import PromptTemplate | |
from langchain.text_splitter import TokenTextSplitter | |
from tqdm import tqdm | |
from wandb.integration.langchain import WandbTracer | |
import wandb | |
from config import config | |
def get_data(artifact_name: str = "gladiator/gradient_dissent_bot/summary_data:latest"): | |
podcast_artifact = wandb.use_artifact(artifact_name, type="dataset") | |
podcast_artifact_dir = podcast_artifact.download(config.root_data_dir) | |
filename = artifact_name.split(":")[0].split("/")[-1] | |
df = pd.read_csv(os.path.join(podcast_artifact_dir, f"{filename}.csv")) | |
return df | |
def extract_questions(episode_df: pd.DataFrame): | |
# load docs into langchain format | |
loader = DataFrameLoader(episode_df, page_content_column="transcript") | |
data = loader.load() | |
# split the documents | |
text_splitter = TokenTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0) | |
docs = text_splitter.split_documents(data) | |
print(f"Number of documents for podcast {data[0].metadata['title']}: {len(docs)}") | |
# initialize LLM | |
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) | |
# define prompt | |
prompt = """You are provided with a short transcript from a podcast episode. | |
Your task is to extract the relevant and most important questions one might ask from the transcript and present them in a bullet-point list. | |
Ensure that the total number of questions is no more than 3. | |
TRANSCRIPT: | |
{text} | |
QUESTIONS:""" | |
prompt_template = PromptTemplate(template=prompt, input_variables=["text"]) | |
pattern = r"\d+\.\s" | |
que_by_llm = [] | |
for doc in docs: | |
llm_chain = LLMChain(llm=llm, prompt=prompt_template) | |
out = llm_chain.run(doc) | |
cleaned_ques = re.sub(pattern, "", out).split("\n") | |
que_by_llm.extend(cleaned_ques) | |
return que_by_llm | |
if __name__ == "__main__": | |
# initialize wandb tracer | |
WandbTracer.init( | |
{ | |
"project": "gradient_dissent_bot", | |
"name": "extract_questions", | |
"job_type": "extract_questions", | |
"config": asdict(config), | |
} | |
) | |
# get data | |
df = get_data(artifact_name=config.summarized_data_artifact) | |
questions = [] | |
with get_openai_callback() as cb: | |
for episode in tqdm( | |
df.iterrows(), total=len(df), desc="Extracting questions from episodes" | |
): | |
episode_data = episode[1].to_frame().T | |
episode_questions = extract_questions(episode_data) | |
questions.append(episode_questions) | |
print("*" * 25) | |
print(cb) | |
print("*" * 25) | |
wandb.log( | |
{ | |
"total_prompt_tokens": cb.prompt_tokens, | |
"total_completion_tokens": cb.completion_tokens, | |
"total_tokens": cb.total_tokens, | |
"total_cost": cb.total_cost, | |
} | |
) | |
df["questions"] = questions | |
# log to wandb artifact | |
path_to_save = os.path.join(config.root_data_dir, "summary_que_data.csv") | |
df.to_csv(path_to_save, index=False) | |
artifact = wandb.Artifact("summary_que_data", type="dataset") | |
artifact.add_file(path_to_save) | |
wandb.log_artifact(artifact) | |
# create wandb table | |
table = wandb.Table(dataframe=df) | |
wandb.log({"summary_que_data": table}) | |
WandbTracer.finish() | |