mgchavez commited on
Commit
d8cb89c
·
verified ·
1 Parent(s): ac23530

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -2
app.py CHANGED
@@ -29,6 +29,8 @@ load_dotenv()
29
 
30
  os.environ['API_KEY_PROJ3'] = os.getenv('API_KEY_PROJ3')
31
 
 
 
32
  client = OpenAI(
33
  base_url="https://api.endpoints.anyscale.com/v1",
34
  api_key=os.environ['API_KEY_PROJ3']
@@ -38,15 +40,20 @@ client = OpenAI(
38
  embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
39
 
40
  # Load the persisted vectorDB
 
 
 
 
 
41
  persisted_vectordb_location = './proj3_db'
42
 
43
  # Prepare the logging functionality
44
 
45
  log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
46
  log_folder = log_file.parent
47
-
48
  scheduler = CommitScheduler(
49
- repo_id="dataset",
50
  repo_type="dataset",
51
  folder_path=log_folder,
52
  path_in_repo="data",
@@ -76,3 +83,75 @@ Here are some documents that are relevant to the question mentioned below.
76
  ###Question
77
  {question}
78
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  os.environ['API_KEY_PROJ3'] = os.getenv('API_KEY_PROJ3')
31
 
32
+ collection_name = 'collection'
33
+
34
  client = OpenAI(
35
  base_url="https://api.endpoints.anyscale.com/v1",
36
  api_key=os.environ['API_KEY_PROJ3']
 
40
  embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
41
 
42
  # Load the persisted vectorDB
43
+ vectorstore_persisted = Chroma(
44
+ collection_name=collection_name,
45
+ persist_directory='./proj3_db',
46
+ embedding_function=embedding_model
47
+ )
48
  persisted_vectordb_location = './proj3_db'
49
 
50
  # Prepare the logging functionality
51
 
52
  log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
53
  log_folder = log_file.parent
54
+
55
  scheduler = CommitScheduler(
56
+ repo_id="---------",
57
  repo_type="dataset",
58
  folder_path=log_folder,
59
  path_in_repo="data",
 
83
  ###Question
84
  {question}
85
  """
86
+
87
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
88
+ def predict(user_input,company):
89
+
90
+ filter = "dataset/"+company+"-10-k-2023.pdf"
91
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
92
+
93
+ # Create context_for_query
94
+ context_for_query = ". ".join(relevant_document_chunks)
95
+
96
+ # Create messages
97
+ prompt = [
98
+ {'role': 'system', 'content': qna_system_message},
99
+ {'role': 'user', 'content': qna_user_message_template.format(
100
+ context=context_for_query,
101
+ question=user_input
102
+ )
103
+ }
104
+ ]
105
+ model_name = 'mlabonne/NeuralHermes-2.5-Mistral-7B'
106
+ # Get response from the LLM
107
+ try:
108
+ response = client.chat.completions.create(
109
+ model=model_name,
110
+ messages=prompt,
111
+ temperature=0
112
+ )
113
+
114
+ prediction = response.choices[0].message.content.strip()
115
+ except Exception as e:
116
+ prediction = f'Sorry, I encountered the following error: \n {e}'
117
+
118
+ # While the prediction is made, log both the inputs and outputs to a local log file
119
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
120
+ # access
121
+
122
+ with scheduler.lock:
123
+ with log_file.open("a") as f:
124
+ f.write(json.dumps(
125
+ {
126
+ 'user_input': user_input,
127
+ 'retrieved_context': context_for_query,
128
+ 'model_response': prediction
129
+ }
130
+ ))
131
+ f.write("\n")
132
+
133
+ return prediction
134
+
135
+ # Set-up the Gradio UI
136
+ # Add text box and radio button to the interface
137
+ # The radio button is used to select the company 10k report in which the context needs to be retrieved.
138
+ lst_companies = ['aws', 'google', 'IBM', 'Meta', 'msft']
139
+ textbox = gr.Textbox('Input user')
140
+ company = gr.Radio('Company', lst_companies)
141
+
142
+ model_output = gr.Label(label="Charge predictor")
143
+
144
+ # Create the interface
145
+ # For the inputs parameter of Interface provide [textbox,company]
146
+ demo = gr.Interface(
147
+ fn=predict,
148
+ inputs=[textbox, company],
149
+ outputs=model_output,
150
+ title="Charge Predictor",
151
+ description="This API allows you to predict the charge of insurace",
152
+ allow_flagging="auto",
153
+ concurrency_limit=8
154
+ )
155
+
156
+ demo.queue()
157
+ demo.launch(share=False)