File size: 4,665 Bytes
8364e36 5379f04 b2c2e74 5379f04 f77c387 5379f04 a710661 5379f04 f77c387 5379f04 f77c387 5379f04 f77c387 5379f04 a9460a2 5379f04 b2c2e74 5379f04 b2c2e74 5379f04 226eb83 5379f04 b2c2e74 5379f04 f229ceb 9a6b2aa 5379f04 8364e36 5379f04 8364e36 5379f04 b2c2e74 5379f04 a9460a2 5379f04 1ebf8b3 a9460a2 5379f04 b2c2e74 8364e36 5379f04 8364e36 5379f04 51b1469 5379f04 8364e36 dd3fe36 de693c7 62fd8b9 5379f04 de693c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
from bs4 import BeautifulSoup
from llama_index.core import Document
from llama_index.core import Settings
from llama_index.core import SimpleDirectoryReader
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.readers.web import SimpleWebPageReader
from llama_index.vector_stores.chroma import ChromaVectorStore
import chromadb
import re
from llama_index.llms.cohere import Cohere
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.core import PromptTemplate
from llama_index.core.llms import ChatMessage
import gradio as gr
import uuid
api_key = os.environ.get("API_KEY")
base_url = os.environ.get("BASE_URL")
llm = Cohere(
api_key=api_key,
model_name="command")
embedding_model = CohereEmbedding(
api_key=api_key,
model_name="embed-multilingual-v3.0",
input_type="search_query",
embedding_type="int8",)
# Set Global settings
Settings.llm = llm
Settings.embed_model = embedding_model
db_path=""
def extract_web(url):
web_documents = SimpleWebPageReader().load_data(
[url]
)
html_content = web_documents[0].text
# Parse the data.
soup = BeautifulSoup(html_content, 'html.parser')
p_tags = soup.findAll('p')
text_content = ""
for each in p_tags:
text_content += each.text + "\n"
# Convert back to Document format
documents = [Document(text=text_content)]
option = "web"
return documents, option
def extract_doc(path):
documents = SimpleDirectoryReader(input_files=path).load_data()
option = "doc"
return documents, option
def create_col(documents):
# Create a client and a new collection
db_path = f'database/{str(uuid.uuid4())[:4]}'
client = chromadb.PersistentClient(path=db_path)
chroma_collection = client.get_or_create_collection("quickstart")
# Create a vector store
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# Create a storage context
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Create an index from the documents and save it to the disk.
VectorStoreIndex.from_documents(
documents, storage_context=storage_context
)
return db_path
def infer(message:str, history: list):
global db_path
option=""
print(f'message: {message}')
print(f'history: {history}')
messages = []
files_list = message["files"]
for prompt,answer in history:
if prompt is tuple:
files_list += prompt[0]
else:
messages.append(ChatMessage(role= "user", content = prompt))
messages.append(ChatMessage(role= "assistant", content = answer))
if files_list:
documents, option = extract_doc(files_list)
db_path = create_col(documents)
else:
if message["text"].startswith("http://") or message["text"].startswith("https://"):
documents, option = extract_web(message["text"])
db_path = create_col(documents)
elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0:
gr.Error("Please input an url or upload file at first.")
# Load from disk
load_client = chromadb.PersistentClient(path=db_path)
# Fetch the collection
chroma_collection = load_client.get_collection("quickstart")
# Fetch the vector store
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# Get the index from the vector store
index = VectorStoreIndex.from_vector_store(
vector_store
)
template = (
""" You are an assistant for question-answering tasks.
Use the following context to answer the question.
If you don't know the answer, just say that you don't know.
Use five sentences maximum and keep the answer concise.\n
Question: {query_str} \nContext: {context_str} \nAnswer:"""
)
llm_prompt = PromptTemplate(template)
print(llm_prompt)
if option == "web" and len(history) == 0:
response = "Get the web data! You can ask it."
else:
question = message['text']
query_engine = index.as_query_engine(text_qa_template=llm_prompt)
response = query_engine.query(question)
return response
chatbot = gr.Chatbot()
with gr.Blocks(theme="soft") as demo:
gr.ChatInterface(
fn = infer,
title = "RAG demo",
multimodal = True,
chatbot=chatbot,
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False) |