Spaces:
Runtime error
Runtime error
Rohan Kataria
commited on
Commit
·
f3405cb
1
Parent(s):
0489dff
files
Browse files- .gitignore +5 -0
- app.py +66 -0
- requirements.txt +5 -0
- src/main.py +124 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.history/
|
2 |
+
.vscode/
|
3 |
+
notebooks/
|
4 |
+
.*ipynb
|
5 |
+
__pycache__/
|
app.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from src.main import ConversationalResponse
|
3 |
+
import os
|
4 |
+
|
5 |
+
# Constants
|
6 |
+
ROLE_USER = "user"
|
7 |
+
ROLE_ASSISTANT = "assistant"
|
8 |
+
|
9 |
+
st.set_page_config(page_title="Chat with Git Codes", page_icon="🦜")
|
10 |
+
st.title("Chat with Git Codes 🤖📄")
|
11 |
+
st.markdown("by [Rohan Kataria](https://www.linkedin.com/in/imrohan/) view more at [VEW.AI](https://vew.ai/)")
|
12 |
+
st.markdown("This app allows you to chat with Git. You can paste link to the Git repository and ask questions about it. In the backround uses the Git Loader and ConversationalRetrival chain from langchain, Streamlit for UI.")
|
13 |
+
|
14 |
+
@st.cache_resource(ttl="1h")
|
15 |
+
def load_agent(url, branch, file_filter):
|
16 |
+
with st.spinner('Loading Git documents...'):
|
17 |
+
agent = ConversationalResponse(url, branch, file_filter)
|
18 |
+
st.success("Git Loaded Successfully")
|
19 |
+
return agent
|
20 |
+
|
21 |
+
def main():
|
22 |
+
api_key = st.sidebar.text_input("Enter your OpenAI API Key", type="password")
|
23 |
+
if api_key:
|
24 |
+
os.environ["OPENAI_API_KEY"] = api_key
|
25 |
+
else:
|
26 |
+
st.sidebar.error("Please enter your OpenAI API Key.")
|
27 |
+
return
|
28 |
+
|
29 |
+
git_link = st.sidebar.text_input("Enter your Git Link")
|
30 |
+
branch = st.sidebar.text_input("Enter your Git Branch")
|
31 |
+
file_filter = st.sidebar.text_input("Enter the Extension of Files to Load eg. py,sql,r (no spaces)")
|
32 |
+
|
33 |
+
if "agent" not in st.session_state:
|
34 |
+
st.session_state["agent"] = None
|
35 |
+
|
36 |
+
if st.sidebar.button("Load Agent"):
|
37 |
+
if git_link and branch and file_filter:
|
38 |
+
try:
|
39 |
+
st.session_state["agent"] = load_agent(git_link, branch, file_filter)
|
40 |
+
st.session_state["messages"] = [{"role": ROLE_ASSISTANT, "content": "How can I help you?"}]
|
41 |
+
except Exception as e:
|
42 |
+
st.sidebar.error(f"Error loading Git repository: {str(e)}")
|
43 |
+
return
|
44 |
+
|
45 |
+
if st.session_state["agent"]: # Chat will only appear if the agent is loaded
|
46 |
+
for msg in st.session_state.messages:
|
47 |
+
st.chat_message(msg["role"]).write(msg["content"])
|
48 |
+
|
49 |
+
user_query = st.chat_input(placeholder="Ask me anything!")
|
50 |
+
|
51 |
+
if user_query:
|
52 |
+
st.session_state.messages.append({"role": ROLE_USER, "content": user_query})
|
53 |
+
st.chat_message(ROLE_USER).write(user_query)
|
54 |
+
|
55 |
+
# Generate the response
|
56 |
+
with st.spinner("Generating response"):
|
57 |
+
response = st.session_state["agent"](user_query)
|
58 |
+
|
59 |
+
# Display the response immediately
|
60 |
+
st.chat_message(ROLE_ASSISTANT).write(response)
|
61 |
+
|
62 |
+
# Add the response to the message history
|
63 |
+
st.session_state.messages.append({"role": ROLE_ASSISTANT, "content": response})
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai
|
2 |
+
streamlit
|
3 |
+
langchain
|
4 |
+
langchain[docarray]
|
5 |
+
tiktoken
|
src/main.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import openai
|
4 |
+
import sys
|
5 |
+
sys.path.append('../..')
|
6 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
7 |
+
from langchain.text_splitter import CharacterTextSplitter, RecursiveCharacterTextSplitter
|
8 |
+
from langchain.vectorstores import DocArrayInMemorySearch
|
9 |
+
from langchain.document_loaders import TextLoader
|
10 |
+
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
|
11 |
+
from langchain.memory import ConversationBufferMemory
|
12 |
+
from langchain.chat_models import ChatOpenAI
|
13 |
+
from langchain.document_loaders import TextLoader
|
14 |
+
from langchain.document_loaders import GitLoader
|
15 |
+
from langchain.llms import OpenAI
|
16 |
+
from langchain.memory import ConversationBufferMemory
|
17 |
+
from langchain.vectorstores import Chroma
|
18 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
19 |
+
from langchain.prompts import PromptTemplate
|
20 |
+
import datetime
|
21 |
+
import shutil
|
22 |
+
|
23 |
+
# Function to load the data from github using langchain with string type url, string type branch, string type file_filter
|
24 |
+
def loader(url: str, branch: str, file_filter: str):
|
25 |
+
repo_path = "./github_repo"
|
26 |
+
if os.path.exists(repo_path):
|
27 |
+
shutil.rmtree(repo_path)
|
28 |
+
|
29 |
+
loader = GitLoader(
|
30 |
+
clone_url= url,
|
31 |
+
repo_path="./github_repo/",
|
32 |
+
branch=branch,
|
33 |
+
file_filter=lambda file_path: file_path.endswith(tuple(file_filter.split(','))) # Filter out files in Data but whole repo is cloned
|
34 |
+
)
|
35 |
+
|
36 |
+
data = loader.load()
|
37 |
+
return data
|
38 |
+
|
39 |
+
|
40 |
+
#Function to split the data into chunks using recursive character text splitter
|
41 |
+
def split_data(data):
|
42 |
+
splitter = RecursiveCharacterTextSplitter(
|
43 |
+
chunk_size=1000,
|
44 |
+
chunk_overlap=150,
|
45 |
+
length_function=len, # Function to measure the length of chunks while splitting
|
46 |
+
add_start_index=True # Include the starting position of each chunk in metadata
|
47 |
+
)
|
48 |
+
chunks = splitter.split_documents(data)
|
49 |
+
return chunks
|
50 |
+
|
51 |
+
#Function to ingest the chunks into a vectorstore of doc
|
52 |
+
def ingest_chunks(chunks):
|
53 |
+
embedding = OpenAIEmbeddings()
|
54 |
+
vector_store = DocArrayInMemorySearch.from_documents(chunks, embedding)
|
55 |
+
|
56 |
+
repo_path = "./github_repo"
|
57 |
+
if os.path.exists(repo_path):
|
58 |
+
shutil.rmtree(repo_path)
|
59 |
+
|
60 |
+
return vector_store
|
61 |
+
|
62 |
+
#Retreival function to get the data from the database and reply to the user
|
63 |
+
def retreival(vector_store):
|
64 |
+
# Selecting the right model
|
65 |
+
current_date = datetime.datetime.now().date()
|
66 |
+
if current_date < datetime.date(2023, 9, 2):
|
67 |
+
llm_name = "gpt-3.5-turbo-0301"
|
68 |
+
else:
|
69 |
+
llm_name = "gpt-3.5-turbo"
|
70 |
+
|
71 |
+
#Creating LLM
|
72 |
+
llm = ChatOpenAI(model=llm_name, temperature=0)
|
73 |
+
|
74 |
+
# Creating Prompt template
|
75 |
+
template = """
|
76 |
+
You're a code summarisation assistant. Given the following extracted parts of a long document and a question, create a final answer with "CODE SNIPPETS" from "SOURCE DOCUMENTS".
|
77 |
+
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
78 |
+
ALWAYS return a "CODE SNIPPETS" from "SOURCE DOCUMENTS" part in your answer.
|
79 |
+
|
80 |
+
QUESTION: {question}
|
81 |
+
=========
|
82 |
+
CONTEXT: {context}
|
83 |
+
=========
|
84 |
+
FINAL ANSWER:"""
|
85 |
+
|
86 |
+
PROMPT = PromptTemplate(input_variables=["context", "question"], template=template,)
|
87 |
+
|
88 |
+
#Creating memory
|
89 |
+
memory = ConversationBufferMemory(
|
90 |
+
memory_key="chat_history",
|
91 |
+
input_key="question",
|
92 |
+
output_key="answer",
|
93 |
+
return_messages=True)
|
94 |
+
|
95 |
+
#Creating the retriever, this can also be a contextual compressed retriever
|
96 |
+
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5}) #search_type can be "similarity" or "mmr"
|
97 |
+
|
98 |
+
chain = ConversationalRetrievalChain.from_llm(
|
99 |
+
llm=llm,
|
100 |
+
chain_type="stuff", #chain type can be refine, stuff, map_reduce
|
101 |
+
retriever=retriever,
|
102 |
+
memory=memory,
|
103 |
+
return_source_documents=True, #When used these 2 properties, the output gets 3 properties: answer, source_document, source_document_score and then have to speocify input and output key in memory for it to work
|
104 |
+
combine_docs_chain_kwargs=dict({"prompt": PROMPT})
|
105 |
+
)
|
106 |
+
|
107 |
+
return chain
|
108 |
+
|
109 |
+
#Class using all above components to create QA system
|
110 |
+
class ConversationalResponse:
|
111 |
+
def __init__(self, url, branch, file_filter):
|
112 |
+
self.url = url
|
113 |
+
self.branch = branch
|
114 |
+
self.file_filter = file_filter
|
115 |
+
self.data = loader(self.url, self.branch, self.file_filter)
|
116 |
+
self.chunks = split_data(self.data)
|
117 |
+
self.vector_store = ingest_chunks(self.chunks)
|
118 |
+
self.chain_type = "stuff"
|
119 |
+
self.k = 5
|
120 |
+
self.chain = retreival(self.vector_store)
|
121 |
+
|
122 |
+
def __call__(self, question):
|
123 |
+
agent = self.chain(question)
|
124 |
+
return agent['answer']
|