ravi6k commited on
Commit
9c082ed
·
verified ·
1 Parent(s): 3a1fcef

Update app.py

Browse files

Creating the app

Files changed (1) hide show
  1. app.py +156 -46
app.py CHANGED
@@ -1,63 +1,173 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
41
 
42
  """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
 
 
 
 
 
 
44
  """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if __name__ == "__main__":
 
63
  demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ %%writefile app.py
4
 
 
 
 
 
5
 
6
+ ## Setup
7
+ # Import the necessary Libraries
8
+ import os
9
+ import uuid
10
+ import joblib
11
+ import json
12
+
13
+ import gradio as gr
14
+ import pandas as pd
15
+
16
+ from huggingface_hub import CommitScheduler
17
+ from pathlib import Path
18
+ from langchain_community.embeddings.sentence_transformer import (
19
+ SentenceTransformerEmbeddings
20
+ )
21
+ from langchain_community.vectorstores import Chroma
22
+ from google.colab import userdata, drive
23
+
24
+ from openai import OpenAI
25
+
26
+ # Create Client
27
+ client = OpenAI(
28
+ base_url="https://api.endpoints.anyscale.com/v1",
29
+ api_key=anyscale_api_key
30
+ )
31
+
32
+ # Define the embedding model and the vectorstore
33
+ embedding_model_name = 'thenlper/gte-large'
34
+ embedding_model = SentenceTransformerEmbeddings(model_name=embedding_model_name)
35
+ collection_name_qna = 'report_10K_db'
36
+ persisted_vectordb_location = 'report_10K_db'
37
+ # Load the persisted vectorDB
38
+ vectorstore_persisted = Chroma(
39
+ collection_name=collection_name_qna,
40
+ persist_directory=persisted_vectordb_location,
41
+ embedding_function=embedding_model
42
+ )
43
+
44
+ # Prepare the logging functionality
45
+
46
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
47
+ log_folder = log_file.parent
48
+
49
+ scheduler = CommitScheduler(
50
+ repo_id="---------",
51
+ repo_type="dataset",
52
+ folder_path=log_folder,
53
+ path_in_repo="data",
54
+ every=2
55
+ )
56
+
57
+ # Define the Q&A system message
58
+ qna_system_message = """
59
+ You are an assistant to a financial technology services firm who answers user queries on 10-K reports.
60
+ User input will have the context required by you to answer user questions.
61
+ This context will begin with the token: ###Context.
62
+ The context contains references to specific portions of a document relevant to the user query.
63
 
64
+ User questions will begin with the token: ###Question.
 
 
 
 
 
 
 
 
65
 
66
+ When crafting your response:
67
+ 1. Select only context relevant to answer the question.
68
+ 2. Include the source links in your response.
69
+ 3. User questions will begin with the token: ###Question.
70
+ 4. If the question is irrelevant to 10-K respond with - "I am an assistant for 10-K reports. I can only help you with
71
 
72
+ Please adhere to the following guidelines:
73
+ - Your response should only be about the question asked and nothing else.
74
+ - Answer only using the context provided.
75
+ - Do not mention anything about the context in your final answer.
76
+ - If the answer is not found in the context, it is very very important for you to respond with "I don't know. Please check the docs @ '/content/dataset/'"
77
+ - Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
78
+ - Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
79
 
80
+ Please answer only using the context provided in the input. Do not mention anything about the context in your final answer.
81
 
82
+ Here is an example of how to structure your response:
 
 
 
 
 
 
 
83
 
84
+ Answer:
85
+ [Answer]
86
+
87
+ Source:
88
+ [Source]
89
 
90
  """
91
+ # Define the user message template
92
+ qna_user_message_template = """
93
+ ###Context
94
+ Here are some documents and their source links that are relevant to the question mentioned below.
95
+ {context}
96
+
97
+ ###Question
98
+ {question}
99
  """
100
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
101
+ def predict(user_input,company):
102
+
103
+ filter = "dataset/"+company+"-10-k-2023.pdf"
104
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
105
+
106
+ # Create context_for_query
107
+ user_input = user_input
108
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input,k=5,filter={"source":company})
109
+ context_list = [d.page_content for d in relevant_document_chunks]
110
+ context_for_query = ". ".join(context_list)
111
+
112
+ # Create messages
113
+ prompt = [
114
+ {'role':'system', 'content': qna_system_message},
115
+ {'role': 'user', 'content': qna_user_message_template.format(
116
+ context=context_for_query,
117
+ question=user_input
118
+ )
119
+ }
120
+ ]
121
+
122
+ # Get response from the LLM
123
+ try:
124
+ response = client.chat.completions.create(
125
+ model=model_name,
126
+ messages=prompt,
127
+ temperature=0
128
+ )
129
+
130
+ prediction = response.choices[0].message.content.strip()
131
+ except Exception as e:
132
+ prediction = f'Sorry, I encountered the following error: \n {e}'
133
 
134
+ # While the prediction is made, log both the inputs and outputs to a local log file
135
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
136
+ # access
137
+
138
+ with scheduler.lock:
139
+ with log_file.open("a") as f:
140
+ f.write(json.dumps(
141
+ {
142
+ 'user_input': user_input,
143
+ 'retrieved_context': context_for_query,
144
+ 'model_response': prediction
145
+ }
146
+ ))
147
+ f.write("\n")
148
+
149
+ return prediction
150
+
151
+ # Set-up the Gradio UI
152
+ # Add text box and radio button to the interface
153
+ # The radio button is used to select the company 10k report in which the context needs to be retrieved.
154
+
155
+ textbox = gr.Textbox(placeholder="Enter your Query.",lines=6)
156
+ company = gr.Radio(["/content/dataset/Meta-10-k-2023.pdf","/content/dataset/aws-10-k-2023.pdf","/content/dataset/google-10-k-2023.pdf","/content/dataset/IBM-10-k-2023.pdf","/content/dataset/msft-10-k-2023.pdf"], label="Companies Reports")
157
+
158
+
159
+ # Create the interface
160
+ # For the inputs parameter of Interface provide [textbox,company]
161
+ demo = gr.Interface(
162
+ fn=predict,
163
+ inputs=[textbox,company],
164
+ outputs="text",
165
+ title="Information from 10-K reports",
166
+ description="system to streamline the extraction and analysis of key information from 10-K reports",
167
+ allow_flagging="auto",
168
+ concurrency_limit=12
169
+ )
170
 
171
  if __name__ == "__main__":
172
+ demo.queue()
173
  demo.launch()