transcript_summary / utils /refine_summary.py
XINZHANG-Geotab's picture
Update utils/refine_summary.py
14fc586 verified
raw
history blame
4.08 kB
"""Definitions for refine data summarizer."""
from typing import Any, List, Dict
from langchain.chat_models.base import BaseChatModel
from langchain.docstore.document import Document
from langchain.prompts import PromptTemplate
from langchain.document_loaders import TextLoader
from langchain.text_splitter import TokenTextSplitter, CharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain
class RefineDataSummarizer:
"""Refine data summarizer."""
token_limit = {"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
"gpt-3.5-turbo-16k": 16385,
"gpt-3.5-turbo-1106": 16385,
"gpt-4-1106-preview": 128000,
"gemini-pro": 32768,
"codechat-bison": 8192,
"chat-bison": 8192
}
def __init__(
self,
llm: BaseChatModel
):
"""Initialize the data summarizer."""
self.llm = llm
self.llm_model = self.llm.model_name
prompt_template_bullet_point = (
"Write a summary of the following text.\n"
"TEXT: {text}\n"
"SUMMARY:\n"
)
prompt_bullet_point = PromptTemplate(
template=prompt_template_bullet_point, input_variables=["text"]
)
refine_prompt_template_bullet_point = (
"Write a concise summary of the following text delimited by triple backquotes.\n"
"Return your response in bullet points which covers the key points of the text.\n"
" ```{text}```\n"
"BULLET POINT SUMMARY:\n"
)
refine_prompt_bullet_point = PromptTemplate(
template=refine_prompt_template_bullet_point, input_variables=["text"]
)
prompt_template = (
"Write a concise summary of the following:\n"
"{text}\n"
"CONCISE SUMMARY:\n"
)
prompt = PromptTemplate.from_template(prompt_template)
refine_template = (
"Your job is to produce a final summary\n"
"We have provided an existing summary up to a certain point: {existing_answer}\n"
"We have the opportunity to refine the existing summary"
"(only if needed) with some more context below.\n"
"------------\n"
"{text}\n"
"------------\n"
"Given the new context, refine the original summary.\n"
"If the context isn't useful, return the original summary."
)
refine_prompt = PromptTemplate.from_template(refine_template)
self.prompt = prompt
self.refine_prompt = refine_prompt
self.prompt_bullet_point = prompt_bullet_point
self.refine_prompt_bullet_point = refine_prompt_bullet_point
def get_summarization(self,
text: str,
chunk_num: int = 5,
chunk_overlap: int = 30,
bullet_point: bool = True) -> Dict:
"""Get Summarization."""
if bullet_point:
prompt = self.prompt_bullet_point
refine_prompt = self.refine_prompt_bullet_point
else:
prompt = self.prompt
refine_prompt = self.refine_prompt
text_splitter = TokenTextSplitter(
chunk_size=self.token_limit[self.llm_model] // chunk_num,
chunk_overlap=chunk_overlap,
)
docs = [Document(page_content=t, metadata={"source": "local"}) for t in text_splitter.split_text(text)]
chain = load_summarize_chain(
llm=self.llm,
chain_type="refine",
question_prompt=prompt,
refine_prompt=refine_prompt,
return_intermediate_steps=True,
input_key="input_documents",
output_key="output_text",
verbose=True,
)
result = chain({"input_documents": docs}, return_only_outputs=True)
return result