srinuksv commited on
Commit
12fc0e4
·
verified ·
1 Parent(s): 63283a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -20
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import time
3
- from fastapi import FastAPI,Request
4
  from fastapi.responses import HTMLResponse
5
  from fastapi.staticfiles import StaticFiles
6
  from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
@@ -15,12 +15,12 @@ from fastapi.templating import Jinja2Templates
15
  from huggingface_hub import InferenceClient
16
  import json
17
  import re
18
-
19
-
20
 
21
  # Define Pydantic model for incoming request body
22
  class MessageRequest(BaseModel):
23
  message: str
 
24
  repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
25
  llm_client = InferenceClient(
26
  model=repo_id,
@@ -29,10 +29,8 @@ llm_client = InferenceClient(
29
 
30
  os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
31
 
32
-
33
  app = FastAPI()
34
 
35
-
36
  @app.middleware("http")
37
  async def add_security_headers(request: Request, call_next):
38
  response = await call_next(request)
@@ -40,7 +38,6 @@ async def add_security_headers(request: Request, call_next):
40
  response.headers["X-Frame-Options"] = "ALLOWALL"
41
  return response
42
 
43
-
44
  # Allow CORS requests from any domain
45
  app.add_middleware(
46
  CORSMiddleware,
@@ -50,17 +47,14 @@ app.add_middleware(
50
  allow_headers=["*"],
51
  )
52
 
53
-
54
-
55
-
56
  @app.get("/favicon.ico")
57
  async def favicon():
58
  return HTMLResponse("") # or serve a real favicon if you have one
59
 
60
-
61
  app.mount("/static", StaticFiles(directory="static"), name="static")
62
 
63
  templates = Jinja2Templates(directory="static")
 
64
  # Configure Llama index settings
65
  Settings.llm = HuggingFaceInferenceAPI(
66
  model_name="meta-llama/Meta-Llama-3-8B-Instruct",
@@ -82,6 +76,7 @@ os.makedirs(PDF_DIRECTORY, exist_ok=True)
82
  os.makedirs(PERSIST_DIR, exist_ok=True)
83
  chat_history = []
84
  current_chat_history = []
 
85
  def data_ingestion_from_directory():
86
  documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
87
  storage_context = StorageContext.from_defaults()
@@ -92,6 +87,7 @@ def initialize():
92
  start_time = time.time()
93
  data_ingestion_from_directory() # Process PDF ingestion at startup
94
  print(f"Data ingestion time: {time.time() - start_time} seconds")
 
95
  def split_name(full_name):
96
  # Split the name by spaces
97
  words = full_name.strip().split()
@@ -111,7 +107,6 @@ def split_name(full_name):
111
 
112
  initialize() # Run initialization tasks
113
 
114
-
115
  def handle_query(query):
116
  chat_text_qa_msgs = [
117
  (
@@ -133,19 +128,23 @@ def handle_query(query):
133
  if past_query.strip():
134
  context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
135
 
136
-
137
  query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
138
  answer = query_engine.query(query)
139
 
140
  if hasattr(answer, 'response'):
141
- response=answer.response
142
  elif isinstance(answer, dict) and 'response' in answer:
143
- response =answer['response']
144
  else:
145
- response ="Sorry, I couldn't find an answer."
146
  current_chat_history.append((query, response))
147
  return response
148
 
 
 
 
 
 
149
  @app.post("/hist/")
150
  async def save_chat_history(history: dict):
151
  # Check if 'userId' is present in the incoming dictionary
@@ -160,10 +159,23 @@ async def save_chat_history(history: dict):
160
  hist = ''.join([f"'{entry['sender']}: {entry['message']}'\n" for entry in history['history']])
161
  hist = "You are a Redfernstech summarize model. Your aim is to use this conversation to identify user interests solely based on that conversation: " + hist
162
  print(hist)
 
163
  # Get the summarized result from the client model
164
- result = hist
 
165
  return {"summary": result, "message": "Chat history saved"}
166
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  @app.post("/chat/")
169
  async def chat(request: MessageRequest):
@@ -177,7 +189,7 @@ async def chat(request: MessageRequest):
177
  }
178
  chat_history.append(message_data)
179
  return {"response": response}
180
- @app.get("/", response_class=HTMLResponse)
181
- async def load_chat(request: Request, id: str):
182
- return templates.TemplateResponse("index.html", {"request": request, "user_id": id})
183
- # Route to save chat history
 
1
  import os
2
  import time
3
+ from fastapi import FastAPI, Request
4
  from fastapi.responses import HTMLResponse
5
  from fastapi.staticfiles import StaticFiles
6
  from llama_index.core import StorageContext, load_index_from_storage, VectorStoreIndex, SimpleDirectoryReader, ChatPromptTemplate, Settings
 
15
  from huggingface_hub import InferenceClient
16
  import json
17
  import re
18
+ from gradio_client import Client
 
19
 
20
  # Define Pydantic model for incoming request body
21
  class MessageRequest(BaseModel):
22
  message: str
23
+
24
  repo_id = "meta-llama/Meta-Llama-3-8B-Instruct"
25
  llm_client = InferenceClient(
26
  model=repo_id,
 
29
 
30
  os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN")
31
 
 
32
  app = FastAPI()
33
 
 
34
  @app.middleware("http")
35
  async def add_security_headers(request: Request, call_next):
36
  response = await call_next(request)
 
38
  response.headers["X-Frame-Options"] = "ALLOWALL"
39
  return response
40
 
 
41
  # Allow CORS requests from any domain
42
  app.add_middleware(
43
  CORSMiddleware,
 
47
  allow_headers=["*"],
48
  )
49
 
 
 
 
50
  @app.get("/favicon.ico")
51
  async def favicon():
52
  return HTMLResponse("") # or serve a real favicon if you have one
53
 
 
54
  app.mount("/static", StaticFiles(directory="static"), name="static")
55
 
56
  templates = Jinja2Templates(directory="static")
57
+
58
  # Configure Llama index settings
59
  Settings.llm = HuggingFaceInferenceAPI(
60
  model_name="meta-llama/Meta-Llama-3-8B-Instruct",
 
76
  os.makedirs(PERSIST_DIR, exist_ok=True)
77
  chat_history = []
78
  current_chat_history = []
79
+
80
  def data_ingestion_from_directory():
81
  documents = SimpleDirectoryReader(PDF_DIRECTORY).load_data()
82
  storage_context = StorageContext.from_defaults()
 
87
  start_time = time.time()
88
  data_ingestion_from_directory() # Process PDF ingestion at startup
89
  print(f"Data ingestion time: {time.time() - start_time} seconds")
90
+
91
  def split_name(full_name):
92
  # Split the name by spaces
93
  words = full_name.strip().split()
 
107
 
108
  initialize() # Run initialization tasks
109
 
 
110
  def handle_query(query):
111
  chat_text_qa_msgs = [
112
  (
 
128
  if past_query.strip():
129
  context_str += f"User asked: '{past_query}'\nBot answered: '{response}'\n"
130
 
 
131
  query_engine = index.as_query_engine(text_qa_template=text_qa_template, context_str=context_str)
132
  answer = query_engine.query(query)
133
 
134
  if hasattr(answer, 'response'):
135
+ response = answer.response
136
  elif isinstance(answer, dict) and 'response' in answer:
137
+ response = answer['response']
138
  else:
139
+ response = "Sorry, I couldn't find an answer."
140
  current_chat_history.append((query, response))
141
  return response
142
 
143
+ @app.get("/ch/{id}", response_class=HTMLResponse)
144
+ async def load_chat(request: Request, id: str):
145
+ return templates.TemplateResponse("index.html", {"request": request, "user_id": id})
146
+
147
+ # Route to save chat history
148
  @app.post("/hist/")
149
  async def save_chat_history(history: dict):
150
  # Check if 'userId' is present in the incoming dictionary
 
159
  hist = ''.join([f"'{entry['sender']}: {entry['message']}'\n" for entry in history['history']])
160
  hist = "You are a Redfernstech summarize model. Your aim is to use this conversation to identify user interests solely based on that conversation: " + hist
161
  print(hist)
162
+
163
  # Get the summarized result from the client model
164
+ result = hist
165
+
166
  return {"summary": result, "message": "Chat history saved"}
167
 
168
+ @app.post("/webhook")
169
+ async def receive_form_data(request: Request):
170
+ form_data = await request.json()
171
+ # Generate a unique ID (for tracking user)
172
+ unique_id = str(uuid.uuid4())
173
+
174
+ # Here you can do something with form_data like saving it to a database
175
+ print("Received form data:", form_data)
176
+
177
+ # Send back the unique id to the frontend
178
+ return JSONResponse({"id": unique_id})
179
 
180
  @app.post("/chat/")
181
  async def chat(request: MessageRequest):
 
189
  }
190
  chat_history.append(message_data)
191
  return {"response": response}
192
+
193
+ @app.get("/")
194
+ def read_root():
195
+ return {"message": "Welcome to the API"}