SiddarthaRachakonda commited on
Commit
afe6333
1 Parent(s): cab27a8

added app:main

Browse files
Files changed (12) hide show
  1. Dockerfile +1 -1
  2. app.py +2 -1
  3. app/callbacks.py +24 -0
  4. app/chains.py +53 -0
  5. app/crud.py +23 -0
  6. app/data_indexing.py +150 -0
  7. app/database.py +12 -0
  8. app/main.py +89 -0
  9. app/models.py +21 -0
  10. app/prompts.py +51 -0
  11. app/schemas.py +19 -0
  12. requirements.txt +12 -1
Dockerfile CHANGED
@@ -16,4 +16,4 @@ RUN pip install --no-cache-dir --upgrade -r requirements.txt
16
  # Again, ensure the copied files are owned by 'user'
17
  COPY --chown=user . /app
18
  # Specify the command to run when the container starts
19
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
16
  # Again, ensure the copied files are owned by 'user'
17
  COPY --chown=user . /app
18
  # Specify the command to run when the container starts
19
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI
 
2
  app = FastAPI()
3
  @app.get("/")
4
  def greet_json():
5
- return {"Hello": "World!"}
 
1
  from fastapi import FastAPI
2
+
3
  app = FastAPI()
4
  @app.get("/")
5
  def greet_json():
6
+ return {"Hello": "World!"}
app/callbacks.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+ from langchain_core.callbacks import BaseCallbackHandler
3
+ import schemas
4
+ import crud
5
+
6
+
7
+ class LogResponseCallback(BaseCallbackHandler):
8
+
9
+ def __init__(self, user_request: schemas.UserRequest, db):
10
+ super().__init__()
11
+ self.user_request = user_request
12
+ self.db = db
13
+
14
+ def on_llm_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
15
+ """Run when llm ends running."""
16
+ # TODO: The function on_llm_end is going to be called when the LLM stops sending
17
+ # the response. Use the crud.add_message function to capture that response.
18
+ raise NotImplemented
19
+
20
+ def on_llm_start(
21
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
22
+ ) -> Any:
23
+ for prompt in prompts:
24
+ print(prompt)
app/chains.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_huggingface import HuggingFaceEndpoint
3
+ from langchain_core.runnables import RunnablePassthrough
4
+ import schemas
5
+ from prompts import (
6
+ raw_prompt,
7
+ raw_prompt_formatted,
8
+ format_context,
9
+ tokenizer
10
+ )
11
+ from data_indexing import DataIndexer
12
+
13
+ data_indexer = DataIndexer()
14
+
15
+ llm = HuggingFaceEndpoint(
16
+ repo_id="meta-llama/Meta-Llama-3-8B-Instruct",
17
+ huggingfacehub_api_token=os.environ['HF_TOKEN'],
18
+ max_new_tokens=512,
19
+ stop_sequences=[tokenizer.eos_token],
20
+ streaming=True,
21
+ )
22
+
23
+ simple_chain = (raw_prompt | llm).with_types(input_type=schemas.UserQuestion)
24
+
25
+ # TODO: create formatted_chain by piping raw_prompt_formatted and the LLM endpoint.
26
+ formatted_chain = raw_prompt_formatted | llm
27
+
28
+ # TODO: use history_prompt_formatted and HistoryInput to create the history_chain
29
+ history_chain = None
30
+
31
+ # TODO: Let's construct the standalone_chain by piping standalone_prompt_formatted with the LLM
32
+ standalone_chain = None
33
+
34
+ input_1 = RunnablePassthrough.assign(new_question=standalone_chain)
35
+ input_2 = {
36
+ 'context': lambda x: format_context(data_indexer.search(x['new_question'])),
37
+ 'standalone_question': lambda x: x['new_question']
38
+ }
39
+ input_to_rag_chain = input_1 | input_2
40
+
41
+ # TODO: use input_to_rag_chain, rag_prompt_formatted,
42
+ # HistoryInput and the LLM to build the rag_chain.
43
+ rag_chain = None
44
+
45
+ # TODO: Implement the filtered_rag_chain. It should be the
46
+ # same as the rag_chain but with hybrid_search = True.
47
+ filtered_rag_chain = None
48
+
49
+
50
+
51
+
52
+
53
+
app/crud.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy.orm import Session
2
+ import models, schemas
3
+
4
+
5
+ def get_or_create_user(db: Session, username: str):
6
+ user = db.query(models.User).filter(models.User.username == username).first()
7
+ if not user:
8
+ user = models.User(username=username)
9
+ db.add(user)
10
+ db.commit()
11
+ db.refresh(user)
12
+ return user
13
+
14
+ def add_message(db: Session, message: schemas.MessageBase, username: str):
15
+ # TODO: Implement the add_message function. It should:
16
+ # - get or create the user with the username
17
+ # - create a models.Message instance
18
+ # - pass the retrieved user to the message instance
19
+ # - save the message instance to the database
20
+ raise NotImplemented
21
+
22
+ def get_user_chat_history(db: Session, username: str):
23
+ raise NotImplemented
app/data_indexing.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ from pathlib import Path
4
+ from pinecone.grpc import PineconeGRPC as Pinecone
5
+ from pinecone import ServerlessSpec
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain_openai import OpenAIEmbeddings
8
+
9
+ current_dir = Path(__file__).resolve().parent
10
+
11
+
12
+ class DataIndexer:
13
+
14
+ source_file = os.path.join(current_dir, 'sources.txt')
15
+
16
+ def __init__(self, index_name='langchain-repo') -> None:
17
+
18
+ # TODO: choose your embedding model
19
+ # self.embedding_client = InferenceClient(
20
+ # "dunzhang/stella_en_1.5B_v5",
21
+ # token=os.environ['HF_TOKEN'],
22
+ # )
23
+ self.embedding_client = OpenAIEmbeddings()
24
+ self.index_name = index_name
25
+ self.pinecone_client = Pinecone(api_key=os.environ.get('PINECONE_API_KEY'))
26
+
27
+ if index_name not in self.pinecone_client.list_indexes().names():
28
+ # TODO: create your index if it doesn't exist. Use the create_index function.
29
+ # Make sure to choose the dimension that corresponds to your embedding model
30
+ pass
31
+
32
+ self.index = self.pinecone_client.Index(self.index_name)
33
+ # TODO: make sure to build the index.
34
+ self.source_index = None
35
+
36
+ def get_source_index(self):
37
+ if not os.path.isfile(self.source_file):
38
+ print('No source file')
39
+ return None
40
+
41
+ print('create source index')
42
+
43
+ with open(self.source_file, 'r') as file:
44
+ sources = file.readlines()
45
+
46
+ sources = [s.rstrip('\n') for s in sources]
47
+ vectorstore = Chroma.from_texts(
48
+ sources, embedding=self.embedding_client
49
+ )
50
+ return vectorstore
51
+
52
+ def index_data(self, docs, batch_size=32):
53
+
54
+ with open(self.source_file, 'a') as file:
55
+ for doc in docs:
56
+ file.writelines(doc.metadata['source'] + '\n')
57
+
58
+ for i in range(0, len(docs), batch_size):
59
+ batch = docs[i: i + batch_size]
60
+
61
+ # TODO: create a list of the vector representations of each text data in the batch
62
+ # TODO: choose your embedding model
63
+ # values = self.embedding_client.embed_documents([
64
+ # doc.page_content for doc in batch
65
+ # ])
66
+
67
+ # values = self.embedding_client.feature_extraction([
68
+ # doc.page_content for doc in batch
69
+ # ])
70
+ values = None
71
+
72
+ # TODO: create a list of unique identifiers for each element in the batch with the uuid package.
73
+ vector_ids = None
74
+
75
+ # TODO: create a list of dictionaries representing the metadata. Capture the text data
76
+ # with the "text" key, and make sure to capture the rest of the doc.metadata.
77
+ metadatas = None
78
+
79
+ # create a list of dictionaries with keys "id" (the unique identifiers), "values"
80
+ # (the vector representation), and "metadata" (the metadata).
81
+ vectors = [{
82
+ 'id': vector_id,
83
+ 'values': value,
84
+ 'metadata': metadata
85
+ } for vector_id, value, metadata in zip(vector_ids, values, metadatas)]
86
+
87
+ try:
88
+ # TODO: Use the function upsert to upload the data to the database.
89
+ upsert_response = None
90
+ print(upsert_response)
91
+ except Exception as e:
92
+ print(e)
93
+
94
+ def search(self, text_query, top_k=5, hybrid_search=False):
95
+
96
+ filter = None
97
+ if hybrid_search and self.source_index:
98
+ # I implemented the filtering process to pull the 50 most relevant file names
99
+ # to the question. Make sure to adjust this number as you see fit.
100
+ source_docs = self.source_index.similarity_search(text_query, 50)
101
+ filter = {"source": {"$in":[doc.page_content for doc in source_docs]}}
102
+
103
+ # TODO: embed the text_query by using the embedding model
104
+ # TODO: choose your embedding model
105
+ # vector = self.embedding_client.feature_extraction(text_query)
106
+ # vector = self.embedding_client.embed_query(text_query)
107
+ vector = None
108
+
109
+ # TODO: use the vector representation of the text_query to
110
+ # search the database by using the query function.
111
+ result = None
112
+
113
+ docs = []
114
+ for res in result["matches"]:
115
+ # TODO: From the result's metadata, extract the "text" element.
116
+ pass
117
+
118
+ return docs
119
+
120
+
121
+ if __name__ == '__main__':
122
+
123
+ from langchain_community.document_loaders import GitLoader
124
+ from langchain_text_splitters import (
125
+ Language,
126
+ RecursiveCharacterTextSplitter,
127
+ )
128
+
129
+ loader = GitLoader(
130
+ clone_url="https://github.com/langchain-ai/langchain",
131
+ repo_path="./code_data/langchain_repo/",
132
+ branch="master",
133
+ )
134
+
135
+ python_splitter = RecursiveCharacterTextSplitter.from_language(
136
+ language=Language.PYTHON, chunk_size=10000, chunk_overlap=100
137
+ )
138
+
139
+ docs = loader.load()
140
+ docs = [doc for doc in docs if doc.metadata['file_type'] in ['.py', '.md']]
141
+ docs = [doc for doc in docs if len(doc.page_content) < 50000]
142
+ docs = python_splitter.split_documents(docs)
143
+ for doc in docs:
144
+ doc.page_content = '# {}\n\n'.format(doc.metadata['source']) + doc.page_content
145
+
146
+ indexer = DataIndexer()
147
+ with open('/app/sources.txt', 'a') as file:
148
+ for doc in docs:
149
+ file.writelines(doc.metadata['source'] + '\n')
150
+ indexer.index_data(docs)
app/database.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine
2
+ from sqlalchemy.ext.declarative import declarative_base
3
+ from sqlalchemy.orm import sessionmaker
4
+
5
+ SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
6
+
7
+ engine = create_engine(
8
+ SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
9
+ )
10
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
11
+
12
+ Base = declarative_base()
app/main.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.runnables import Runnable
2
+ from langchain_core.callbacks import BaseCallbackHandler
3
+ from fastapi import FastAPI, Request, Depends
4
+ from sse_starlette.sse import EventSourceResponse
5
+ from langserve.serialization import WellKnownLCSerializer
6
+ from typing import List
7
+ from sqlalchemy.orm import Session
8
+
9
+ import schemas
10
+ from chains import simple_chain, formatted_chain
11
+ import crud, models, schemas
12
+ from database import SessionLocal, engine
13
+ from callbacks import LogResponseCallback
14
+
15
+
16
+ models.Base.metadata.create_all(bind=engine)
17
+
18
+ app = FastAPI()
19
+
20
+ def get_db():
21
+ db = SessionLocal()
22
+ try:
23
+ yield db
24
+ finally:
25
+ db.close()
26
+
27
+
28
+ async def generate_stream(input_data: schemas.BaseModel, runnable: Runnable, callbacks: List[BaseCallbackHandler]=[]):
29
+ for output in runnable.stream(input_data.dict(), config={"callbacks": callbacks}):
30
+ data = WellKnownLCSerializer().dumps(output).decode("utf-8")
31
+ yield {'data': data, "event": "data"}
32
+ yield {"event": "end"}
33
+
34
+
35
+ @app.post("/simple/stream")
36
+ async def simple_stream(request: Request):
37
+ data = await request.json()
38
+ user_question = schemas.UserQuestion(**data['input'])
39
+ return EventSourceResponse(generate_stream(user_question, simple_chain))
40
+
41
+
42
+ @app.post("/formatted/stream")
43
+ async def formatted_stream(request: Request):
44
+ # TODO: use the formatted_chain to implement the "/formatted/stream" endpoint.
45
+ data = await request.json()
46
+ user_question = schemas.UserQuestion(**data['input'])
47
+ return EventSourceResponse(generate_stream(user_question, formatted_chain))
48
+
49
+
50
+ @app.post("/history/stream")
51
+ async def history_stream(request: Request, db: Session = Depends(get_db)):
52
+ # TODO: Let's implement the "/history/stream" endpoint. The endpoint should follow those steps:
53
+ # - The endpoint receives the request
54
+ # - The request is parsed into a user request
55
+ # - The user request is used to pull the chat history of the user
56
+ # - We add as part of the user history the current question by using add_message.
57
+ # - We create an instance of HistoryInput by using format_chat_history.
58
+ # - We use the history input within the history chain.
59
+ raise NotImplemented
60
+
61
+
62
+ @app.post("/rag/stream")
63
+ async def rag_stream(request: Request, db: Session = Depends(get_db)):
64
+ # TODO: Let's implement the "/rag/stream" endpoint. The endpoint should follow those steps:
65
+ # - The endpoint receives the request
66
+ # - The request is parsed into a user request
67
+ # - The user request is used to pull the chat history of the user
68
+ # - We add as part of the user history the current question by using add_message.
69
+ # - We create an instance of HistoryInput by using format_chat_history.
70
+ # - We use the history input within the rag chain.
71
+ raise NotImplemented
72
+
73
+
74
+ @app.post("/filtered_rag/stream")
75
+ async def filtered_rag_stream(request: Request, db: Session = Depends(get_db)):
76
+ # TODO: Let's implement the "/filtered_rag/stream" endpoint. The endpoint should follow those steps:
77
+ # - The endpoint receives the request
78
+ # - The request is parsed into a user request
79
+ # - The user request is used to pull the chat history of the user
80
+ # - We add as part of the user history the current question by using add_message.
81
+ # - We create an instance of HistoryInput by using format_chat_history.
82
+ # - We use the history input within the filtered rag chain.
83
+ raise NotImplemented
84
+
85
+
86
+
87
+ if __name__ == "__main__":
88
+ import uvicorn
89
+ uvicorn.run("main:app", host="localhost", reload=True, port=8000)
app/models.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, ForeignKey, Integer, String, DateTime
2
+ from sqlalchemy.orm import relationship
3
+
4
+ from database import Base
5
+
6
+ class User(Base):
7
+ __tablename__ = "users"
8
+
9
+ id = Column(Integer, primary_key=True, index=True)
10
+ username = Column(String, unique=True, index=True)
11
+ messages = relationship("Message", back_populates="user")
12
+
13
+ # TODO: Implement the Message SQLAlchemy model. Message should have a primary key,
14
+ # a message attribute to store the content of messages, a type, AI or Human,
15
+ # depending on if it is a user question or an AI response, a timestamp to
16
+ # order by time and a user attribute to get the user instance associated
17
+ # with the message. We also need a user_id that will use the User.id
18
+ # attribute as a foreign key.
19
+ class Message(Base):
20
+ __tablename__ = "messages"
21
+ pass
app/prompts.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import PromptTemplate
2
+ from typing import List
3
+ import models
4
+
5
+
6
+ def format_prompt(prompt) -> PromptTemplate:
7
+ # TODO: format the input prompt by using the model specific instruction template
8
+ # TODO: return a langchain PromptTemplate
9
+ return PromptTemplate.from_template(prompt)
10
+
11
+ def format_chat_history(messages: List[models.Message]):
12
+ # TODO: implement format_chat_history to format
13
+ # the list of Message into a text of chat history.
14
+ raise NotImplemented
15
+
16
+ def format_context(docs: List[str]):
17
+ # TODO: the output of the DataIndexer.search is a list of text,
18
+ # so we need to concatenate that list into a text that can fit into
19
+ # the rag_prompt_formatted. Implement format_context that takes a
20
+ # like of strings and returns the context as one string.
21
+ raise NotImplemented
22
+
23
+ raw_prompt = "{question}"
24
+
25
+ # TODO: Create the history_prompt prompt that will capture the question and the conversation history.
26
+ # The history_prompt needs a {chat_history} placeholder and a {question} placeholder.
27
+ history_prompt: str = None
28
+
29
+ # TODO: Create the standalone_prompt prompt that will capture the question and the chat history
30
+ # to generate a standalone question. It needs a {chat_history} placeholder and a {question} placeholder,
31
+ standalone_prompt: str = None
32
+
33
+ # TODO: Create the rag_prompt that will capture the context and the standalone question to generate
34
+ # a final answer to the question.
35
+ rag_prompt: str = None
36
+
37
+ # TODO: create raw_prompt_formatted by using format_prompt
38
+ raw_prompt_formatted = format_prompt(raw_prompt)
39
+ raw_prompt = PromptTemplate.from_template(raw_prompt)
40
+
41
+ # TODO: use format_prompt to create history_prompt_formatted
42
+ history_prompt_formatted: PromptTemplate = None
43
+ # TODO: use format_prompt to create standalone_prompt_formatted
44
+ standalone_prompt_formatted: PromptTemplate = None
45
+ # TODO: use format_prompt to create rag_prompt_formatted
46
+ rag_prompt_formatted: PromptTemplate = None
47
+
48
+
49
+
50
+
51
+
app/schemas.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic.v1 import BaseModel
2
+
3
+
4
+ class UserQuestion(BaseModel):
5
+ question: str
6
+
7
+ # TODO: create a HistoryInput data model with a chat_history and question attributes.
8
+ class HistoryInput(BaseModel):
9
+ pass
10
+
11
+ # TODO: let's create a UserRequest data model with a question and username attribute.
12
+ # This will be used to parse the input request.
13
+ class UserRequest(BaseModel):
14
+ username: str
15
+
16
+ # TODO: implement MessageBase as a schema mapping from the database model to the
17
+ # FastAPI data model. Basically MessageBase should have the same attributes as models.Message
18
+ class MessageBase(BaseModel):
19
+ pass
requirements.txt CHANGED
@@ -1,2 +1,13 @@
1
  fastapi
2
- uvicorn[standard]
 
 
 
 
 
 
 
 
 
 
 
 
1
  fastapi
2
+ uvicorn[standard]
3
+ langchain
4
+ langserve
5
+ sqlalchemy
6
+ pydantic
7
+ sse-starlette
8
+ requests
9
+ pinecone-client
10
+ langchain_huggingface
11
+ langchain_core
12
+ langchain_community
13
+ langchain_openai