minko186 commited on
Commit
704db80
·
verified ·
1 Parent(s): 410b2f1

Create ai_generate.py

Browse files
Files changed (1) hide show
  1. ai_generate.py +223 -0
ai_generate.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_community.document_loaders import PyMuPDFLoader
3
+ from langchain_core.documents import Document
4
+ from langchain_community.embeddings.sentence_transformer import (
5
+ SentenceTransformerEmbeddings,
6
+ )
7
+ from langchain.schema import StrOutputParser
8
+ from langchain_community.vectorstores import Chroma
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from langchain import hub
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.runnables import RunnablePassthrough
13
+ from langchain_groq import ChatGroq
14
+ from langchain_openai import ChatOpenAI
15
+ from langchain_google_genai import ChatGoogleGenerativeAI
16
+ from langchain_anthropic import ChatAnthropic
17
+ from dotenv import load_dotenv
18
+ from langchain_core.output_parsers import XMLOutputParser
19
+ from langchain.prompts import ChatPromptTemplate
20
+
21
+ load_dotenv()
22
+
23
+ # suppress grpc and glog logs for gemini
24
+ os.environ["GRPC_VERBOSITY"] = "ERROR"
25
+ os.environ["GLOG_minloglevel"] = "2"
26
+
27
+ # RAG parameters
28
+ CHUNK_SIZE = 1024
29
+ CHUNK_OVERLAP = CHUNK_SIZE // 8
30
+ K = 10
31
+ FETCH_K = 20
32
+
33
+ llm_model_translation = {
34
+ "LLaMA 3": "llama3-70b-8192",
35
+ "OpenAI GPT 4o Mini": "gpt-4o-mini",
36
+ "OpenAI GPT 4o": "gpt-4o",
37
+ "OpenAI GPT 4": "gpt-4-turbo",
38
+ "Gemini 1.5 Pro": "gemini-1.5-pro",
39
+ "Claude Sonnet 3.5": "claude-3-5-sonnet-20240620",
40
+ }
41
+
42
+ llm_classes = {
43
+ "llama3-70b-8192": ChatGroq,
44
+ "gpt-4o-mini": ChatOpenAI,
45
+ "gpt-4o": ChatOpenAI,
46
+ "gpt-4-turbo": ChatOpenAI,
47
+ "gemini-1.5-pro": ChatGoogleGenerativeAI,
48
+ "claude-3-5-sonnet-20240620": ChatAnthropic,
49
+ }
50
+
51
+ xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, \
52
+ fulfill all the requirements of the prompt and provide citations. If a part of the generated text does \
53
+ not use any of the sources, don't put a citation for that part. Otherwise, list all sources used for that part of the text.
54
+ At the end of each relevant part, add a citation in square brackets, numbered sequentially starting from [0], regardless of the source's original ID.
55
+
56
+
57
+ Remember, you must return both the requested text and citations. A citation consists of a VERBATIM quote that \
58
+ justifies the text and a sequential number (starting from 0) for the quote's article. Return a citation for every quote across all articles \
59
+ that justify the text. Use the following format for your final output:
60
+
61
+ <cited_text>
62
+ <text></text>
63
+ <citations>
64
+ <citation><source_id></source_id><source></source><quote></quote></citation>
65
+ <citation><source_id></source_id><source></source><quote></quote></citation>
66
+ ...
67
+ </citations>
68
+ </cited_text>
69
+
70
+ Here are the sources:{context}"""
71
+ xml_prompt = ChatPromptTemplate.from_messages(
72
+ [("system", xml_system), ("human", "{input}")]
73
+ )
74
+
75
+ def format_docs_xml(docs: list[Document]) -> str:
76
+ formatted = []
77
+ for i, doc in enumerate(docs):
78
+ doc_str = f"""\
79
+ <source>
80
+ <source>{doc.metadata['source']}</source>
81
+ <title>{doc.metadata['title']}</title>
82
+ <article_snippet>{doc.page_content}</article_snippet>
83
+ </source>"""
84
+ formatted.append(doc_str)
85
+ return "\n\n<sources>" + "\n".join(formatted) + "</sources>"
86
+
87
+
88
+ def citations_to_html(citations_data):
89
+ if citations_data:
90
+ html_output = "<ul>"
91
+
92
+ for index, citation in enumerate(citations_data):
93
+ source_id = citation['citation'][0]['source_id']
94
+ source = citation['citation'][1]['source']
95
+ quote = citation['citation'][2]['quote']
96
+
97
+ html_output += f"""
98
+ <li>
99
+ [{index}] - "{source}" <br>
100
+ "{quote}"
101
+ </li>
102
+ """
103
+
104
+ html_output += "</ul>"
105
+ return html_output
106
+ return ""
107
+
108
+
109
+ def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
110
+ model_name = llm_model_translation.get(model)
111
+ llm_class = llm_classes.get(model_name)
112
+ if not llm_class:
113
+ raise ValueError(f"Model {model} not supported.")
114
+ try:
115
+ llm = llm_class(model_name=model_name, temperature=temperature, max_tokens=max_length)
116
+ except Exception as e:
117
+ print(f"An error occurred: {e}")
118
+ llm = None
119
+ return llm
120
+
121
+
122
+ def create_db_with_langchain(path: list[str], url_content: dict):
123
+ all_docs = []
124
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
125
+ embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
126
+ if path:
127
+ for file in path:
128
+ loader = PyMuPDFLoader(file)
129
+ data = loader.load()
130
+ # split it into chunks
131
+ docs = text_splitter.split_documents(data)
132
+ all_docs.extend(docs)
133
+
134
+ if url_content:
135
+ for url, content in url_content.items():
136
+ doc = Document(page_content=content, metadata={"source": url})
137
+ # split it into chunks
138
+ docs = text_splitter.split_documents([doc])
139
+ all_docs.extend(docs)
140
+
141
+ # print docs
142
+ for idx, doc in enumerate(all_docs):
143
+ print(f"Doc: {idx} | Length = {len(doc.page_content)}")
144
+
145
+ assert len(all_docs) > 0, "No PDFs or scrapped data provided"
146
+ db = Chroma.from_documents(all_docs, embedding_function)
147
+ return db
148
+
149
+
150
+ def generate_rag(
151
+ prompt: str,
152
+ topic: str,
153
+ model: str,
154
+ url_content: dict,
155
+ path: list[str],
156
+ temperature: float = 1.0,
157
+ max_length: int = 2048,
158
+ api_key: str = "",
159
+ sys_message="",
160
+ ):
161
+ llm = load_llm(model, api_key, temperature, max_length)
162
+ if llm is None:
163
+ print("Failed to load LLM. Aborting operation.")
164
+ return None
165
+ db = create_db_with_langchain(path, url_content)
166
+ retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K})
167
+ rag_prompt = hub.pull("rlm/rag-prompt")
168
+
169
+ def format_docs(docs):
170
+ if all(isinstance(doc, Document) for doc in docs):
171
+ return "\n\n".join(doc.page_content for doc in docs)
172
+ else:
173
+ raise TypeError("All items in docs must be instances of Document.")
174
+
175
+ docs = retriever.get_relevant_documents(topic)
176
+ # formatted_docs = format_docs(docs)
177
+ # rag_chain = (
178
+ # {"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser()
179
+ # )
180
+ # return rag_chain.invoke(prompt)
181
+
182
+ formatted_docs = format_docs_xml(docs)
183
+ rag_chain = (
184
+ RunnablePassthrough.assign(context=lambda _: formatted_docs)
185
+ | xml_prompt
186
+ | llm
187
+ | XMLOutputParser()
188
+ )
189
+ result = rag_chain.invoke({"input": prompt})
190
+ from pprint import pprint
191
+ pprint(result)
192
+ return result['cited_text'][0]['text'], citations_to_html(result['cited_text'][1]['citations'])
193
+
194
+ def generate_base(
195
+ prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
196
+ ):
197
+ llm = load_llm(model, api_key, temperature, max_length)
198
+ if llm is None:
199
+ print("Failed to load LLM. Aborting operation.")
200
+ return None
201
+ try:
202
+ output = llm.invoke(prompt).content
203
+ return output
204
+ except Exception as e:
205
+ print(f"An error occurred while running the model: {e}")
206
+ return None
207
+
208
+
209
+ def generate(
210
+ prompt: str,
211
+ topic: str,
212
+ model: str,
213
+ url_content: dict,
214
+ path: list[str],
215
+ temperature: float = 1.0,
216
+ max_length: int = 2048,
217
+ api_key: str = "",
218
+ sys_message="",
219
+ ):
220
+ if path or url_content:
221
+ return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
222
+ else:
223
+ return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)