Spaces:
Sleeping
Sleeping
"""Definitions for refine data summarizer.""" | |
from typing import Any, List, Dict | |
from langchain.chat_models.base import BaseChatModel | |
from langchain.docstore.document import Document | |
from langchain.prompts import PromptTemplate | |
from langchain.document_loaders import TextLoader | |
from langchain.text_splitter import TokenTextSplitter, CharacterTextSplitter | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains.summarize import load_summarize_chain | |
class RefineDataSummarizer: | |
"""Refine data summarizer.""" | |
token_limit = {"gpt-3.5-turbo": 4096, | |
"gpt-4": 8192, | |
"gpt-3.5-turbo-16k": 16385, | |
"gpt-3.5-turbo-1106": 16385, | |
"gpt-4-1106-preview": 128000, | |
"gemini-pro": 32768, | |
"codechat-bison": 8192, | |
"chat-bison": 8192 | |
} | |
def __init__( | |
self, | |
llm: BaseChatModel | |
): | |
"""Initialize the data summarizer.""" | |
self.llm = llm | |
self.llm_model = self.llm.model_name | |
prompt_template_bullet_point = ( | |
"Write a summary of the following text.\n" | |
"TEXT: {text}\n" | |
"SUMMARY:\n" | |
) | |
prompt_bullet_point = PromptTemplate( | |
template=prompt_template_bullet_point, input_variables=["text"] | |
) | |
refine_prompt_template_bullet_point = ( | |
"Write a concise summary of the following text delimited by triple backquotes.\n" | |
"Return your response in bullet points which covers the key points of the text.\n" | |
" ```{text}```\n" | |
"BULLET POINT SUMMARY:\n" | |
) | |
refine_prompt_bullet_point = PromptTemplate( | |
template=refine_prompt_template_bullet_point, input_variables=["text"] | |
) | |
prompt_template = ( | |
"Write a concise summary of the following:\n" | |
"{text}\n" | |
"CONCISE SUMMARY:\n" | |
) | |
prompt = PromptTemplate.from_template(prompt_template) | |
refine_template = ( | |
"Your job is to produce a final summary\n" | |
"We have provided an existing summary up to a certain point: {existing_answer}\n" | |
"We have the opportunity to refine the existing summary" | |
"(only if needed) with some more context below.\n" | |
"------------\n" | |
"{text}\n" | |
"------------\n" | |
"Given the new context, refine the original summary.\n" | |
"If the context isn't useful, return the original summary." | |
) | |
refine_prompt = PromptTemplate.from_template(refine_template) | |
self.prompt = prompt | |
self.refine_prompt = refine_prompt | |
self.prompt_bullet_point = prompt_bullet_point | |
self.refine_prompt_bullet_point = refine_prompt_bullet_point | |
def get_summarization(self, | |
text: str, | |
chunk_num: int = 5, | |
chunk_overlap: int = 30, | |
bullet_point: bool = True) -> Dict: | |
"""Get Summarization.""" | |
if bullet_point: | |
prompt = self.prompt_bullet_point | |
refine_prompt = self.refine_prompt_bullet_point | |
else: | |
prompt = self.prompt | |
refine_prompt = self.refine_prompt | |
text_splitter = TokenTextSplitter( | |
chunk_size=self.token_limit[self.llm_model] // chunk_num, | |
chunk_overlap=chunk_overlap, | |
) | |
docs = [Document(page_content=t, metadata={"source": "local"}) for t in text_splitter.split_text(text)] | |
chain = load_summarize_chain( | |
llm=self.llm, | |
chain_type="refine", | |
question_prompt=prompt, | |
refine_prompt=refine_prompt, | |
return_intermediate_steps=True, | |
input_key="input_documents", | |
output_key="output_text", | |
verbose=True, | |
) | |
result = chain({"input_documents": docs}, return_only_outputs=True) | |
return result | |