"""Definitions for refine data summarizer.""" from typing import Any, List, Dict from langchain.chat_models.base import BaseChatModel from langchain.docstore.document import Document from langchain.prompts import PromptTemplate from langchain.document_loaders import TextLoader from langchain.text_splitter import TokenTextSplitter, CharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains.summarize import load_summarize_chain class RefineDataSummarizer: """Refine data summarizer.""" token_limit = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "gpt-3.5-turbo-16k": 16385, "gpt-3.5-turbo-1106": 16385, "gpt-4-1106-preview": 128000, "gemini-pro": 32768, "codechat-bison": 8192, "chat-bison": 8192 } def __init__( self, llm: BaseChatModel, prompt_template: str, refine_template: str, ): """Initialize the data summarizer.""" self.llm = llm self.llm_model = self.llm.model_name self.prompt = PromptTemplate.from_template(prompt_template.strip()) self.refine_prompt = PromptTemplate.from_template(refine_template.strip()) def get_summarization(self, text: str, chunk_num: int = 5, chunk_overlap: int = 30) -> Dict: """Get Summarization.""" text_splitter = TokenTextSplitter( chunk_size=self.token_limit[self.llm_model] // chunk_num, chunk_overlap=chunk_overlap, ) docs = [Document(page_content=t, metadata={"source": "local"}) for t in text_splitter.split_text(text)] chain = load_summarize_chain( llm=self.llm, chain_type="refine", question_prompt=self.prompt, refine_prompt=self.refine_prompt, return_intermediate_steps=True, input_key="input_documents", output_key="output_text", verbose=True, ) result = chain({"input_documents": docs}, return_only_outputs=True) return result