File size: 5,544 Bytes
8364e36 4b6695b 5379f04 b2c2e74 5379f04 f77c387 5379f04 fe68312 a710661 5379f04 f77c387 5379f04 f77c387 2fb331a f77c387 d6d54c5 5379f04 09813a1 5379f04 d6d54c5 fe68312 5379f04 a9460a2 5379f04 4b6695b 5379f04 4b6695b 5379f04 b2c2e74 5379f04 b2c2e74 5379f04 226eb83 5379f04 b2c2e74 5379f04 f229ceb 09813a1 9a6b2aa 5379f04 8364e36 5379f04 fe68312 5379f04 a9460a2 09813a1 5379f04 1ebf8b3 a9460a2 a716052 5379f04 8970dab 5379f04 d6d54c5 5379f04 b2c2e74 5379f04 8970dab 5379f04 fe68312 1615c34 fe68312 1615c34 a716052 fe68312 14b0fd8 5379f04 14b0fd8 51b1469 5379f04 1f285b3 8970dab 1f285b3 5379f04 b3efb9e 8970dab 8364e36 dd3fe36 8970dab dd3fe36 1f285b3 de693c7 62fd8b9 b3efb9e de693c7 b3efb9e de693c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import os
import httpx
from llama_index.core import Document
from llama_index.core import Settings
from llama_index.core import SimpleDirectoryReader
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.chroma import ChromaVectorStore
import chromadb
import re
from llama_index.llms.cohere import Cohere
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.chat_engine import CondensePlusContextChatEngine
import gradio as gr
import uuid
api_key = os.environ.get("API_KEY")
base_url = os.environ.get("BASE_URL")
llm = Cohere(
api_key=api_key,
model="command")
embedding_model = CohereEmbedding(
api_key=api_key,
model_name="embed-multilingual-v3.0",
input_type="search_document",
embedding_type="int8",)
memory = ""
# Set Global settings
Settings.llm = llm
Settings.embed_model=embedding_model
# set context window
Settings.context_window = 4096
# set number of output tokens
Settings.num_output = 512
db_path=""
def validate_url(url):
try:
response = httpx.get(url, timeout=60.0)
response.raise_for_status()
text = [Document(text=response.text)]
option = "web"
return text, option
except httpx.RequestError as e:
raise gr.Error(f"An error occurred while requesting {url}: {str(e)}")
except httpx.HTTPStatusError as e:
raise gr.Error(f"Error response {e.response.status_code} while requesting {url}")
except Exception as e:
raise gr.Error(f"An unexpected error occurred: {str(e)}")
def extract_web(url):
print("Entered Webpage Extraction")
prefix_url = "https://r.jina.ai/"
full_url = prefix_url + url
print(full_url)
print("Exited Webpage Extraction")
return validate_url(full_url)
def extract_doc(path):
documents = SimpleDirectoryReader(input_files=path).load_data()
option = "doc"
return documents, option
def create_col(documents):
# Create a client and a new collection
db_path = f'database/{str(uuid.uuid4())[:4]}'
client = chromadb.PersistentClient(path=db_path)
chroma_collection = client.get_or_create_collection("quickstart")
# Create a vector store
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# Create a storage context
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Create an index from the documents and save it to the disk.
VectorStoreIndex.from_documents(
documents, storage_context=storage_context
)
return db_path
def infer(message:str, history: list):
global db_path
global memory
option=""
print(f'message: {message}')
print(f'history: {history}')
messages = []
files_list = message["files"]
if files_list:
documents, option = extract_doc(files_list)
db_path = create_col(documents)
memory = ChatMemoryBuffer.from_defaults(token_limit=3900)
else:
if message["text"].startswith("http://") or message["text"].startswith("https://"):
documents, option = extract_web(message["text"])
db_path = create_col(documents)
memory = ChatMemoryBuffer.from_defaults(token_limit=3900)
elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0:
raise gr.Error("Please send an URL or document")
# Load from disk
load_client = chromadb.PersistentClient(path=db_path)
# Fetch the collection
chroma_collection = load_client.get_collection("quickstart")
# Fetch the vector store
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# Get the index from the vector store
index = VectorStoreIndex.from_vector_store(
vector_store,
)
if option == "web" and len(history) == 0:
response = "Getcha! Now ask your question."
else:
question = message['text']
chat_engine = CondensePlusContextChatEngine.from_defaults(
index.as_retriever(),
memory=memory,
context_prompt=(
"You are an assistant for question-answering tasks."
"Use the following context to answer the question:\n"
"{context_str}"
"\nIf you don't know the answer, just say that you don't know."
"Use five sentences maximum and keep the answer concise."
"\nInstruction: Use the previous chat history, or the context above, to interact and help the user."
),
verbose=True,
)
response = chat_engine.chat(
question
)
print(type(response))
print(f'response: {response}')
return str(response)
css="""
footer {
display:none !important
}
h1 {
text-align: center;
display: block;
}
"""
title="""
<h1>RAG demo</h1>
<p style="text-align: center">Retrieval for web and documents</p>
"""
chatbot = gr.Chatbot(placeholder="Please send an URL or document file at first<br>Then ask question and get an answer.", height=800)
with gr.Blocks(theme="soft", css=css, fill_height="true") as demo:
gr.ChatInterface(
fn = infer,
title = title,
multimodal = True,
chatbot = chatbot,
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False) |