vincentmin commited on
Commit
a326270
·
1 Parent(s): 461450f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -6
app.py CHANGED
@@ -4,19 +4,54 @@ from langchain.document_loaders import ArxivLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.vectorstores import Chroma
6
  from langchain.embeddings import HuggingFaceEmbeddings
 
 
 
 
7
  from langchain.schema import Document
8
 
9
- CHUNK_SIZE = 1000
10
- LOAD_MAX_DOCS = 100
11
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE)
12
 
 
13
  min_date = (date.today() - timedelta(days=2)).strftime('%Y%m%d')
14
  max_date = date.today().strftime('%Y%m%d')
15
  query = f"cat:hep-th AND submittedDate:[{min_date} TO {max_date}]"
16
  loader = ArxivLoader(query=query, load_max_docs=LOAD_MAX_DOCS)
17
 
 
 
 
18
  embeddings = HuggingFaceEmbeddings()
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def process_document(doc: Document):
21
  metadata = doc.metadata
22
  metadata["Body"] = doc.page_content
@@ -29,10 +64,11 @@ def get_data(user_query: str):
29
  retriever = db.as_retriever()
30
  relevant_docs = retriever.get_relevant_documents(user_query)
31
  print(relevant_docs[0].metadata)
32
- output = ""
33
  for doc in relevant_docs:
34
- output += f"**Title: {doc.metadata['Title']}**\nAbstract: {doc.metadata['Summary']}\n\n"
35
- return output
 
36
 
37
  demo = gr.Interface(
38
  fn=get_data,
 
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.vectorstores import Chroma
6
  from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.llms import HuggingFaceHub
8
+ # from langchain.llms import FakeListLLM
9
+ from langchain.chains import LLMChain, StuffDocumentsChain
10
+ from langchain.prompts import PromptTemplate
11
  from langchain.schema import Document
12
 
 
 
 
13
 
14
+ LOAD_MAX_DOCS = 100
15
  min_date = (date.today() - timedelta(days=2)).strftime('%Y%m%d')
16
  max_date = date.today().strftime('%Y%m%d')
17
  query = f"cat:hep-th AND submittedDate:[{min_date} TO {max_date}]"
18
  loader = ArxivLoader(query=query, load_max_docs=LOAD_MAX_DOCS)
19
 
20
+ # CHUNK_SIZE = 1000
21
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE)
22
+
23
  embeddings = HuggingFaceEmbeddings()
24
 
25
+
26
+ document_prompt = PromptTemplate(
27
+ template="Title: {Title}\nContent: {page_content}",
28
+ input_variables=["page_content", "Title"],
29
+ )
30
+ prompt = PromptTemplate(
31
+ template=
32
+ """Write a personalised newsletter for a researcher. The researcher describes his work as follows:"{context}". Base the newsletter on the following articles:\n\n"{text}"\n\nNEWSLETTER:""",
33
+ input_variables=["context", "text"])
34
+
35
+ # llm = FakeListLLM(responses=list(map(str, range(100))))
36
+ REPO_ID = "HuggingFaceH4/starchat-beta"
37
+ llm = HuggingFaceHub(
38
+ repo_id=REPO_ID,
39
+ model_kwargs={
40
+ "max_new_tokens": 1024,
41
+ "do_sample": True,
42
+ "temperature": 0.8,
43
+ "top_p": 0.9
44
+ }
45
+ )
46
+
47
+ llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=True)
48
+ stuff_chain = StuffDocumentsChain(
49
+ llm_chain=llm_chain,
50
+ document_variable_name="text",
51
+ document_prompt=document_prompt,
52
+ verbose=True,
53
+ )
54
+
55
  def process_document(doc: Document):
56
  metadata = doc.metadata
57
  metadata["Body"] = doc.page_content
 
64
  retriever = db.as_retriever()
65
  relevant_docs = retriever.get_relevant_documents(user_query)
66
  print(relevant_docs[0].metadata)
67
+ articles = ""
68
  for doc in relevant_docs:
69
+ articles += f"**Title: {doc.metadata['Title']}**\n\nAbstract: {doc.metadata['Summary']}\n\n"
70
+ output = stuff_chain({"input_documents": relevant_docs, "context": user_query})
71
+ return f"{output["output_text"]}\n\n\n\nUsed articles:\n\n{output}"
72
 
73
  demo = gr.Interface(
74
  fn=get_data,