File size: 4,665 Bytes
8364e36
5379f04
 
 
b2c2e74
5379f04
 
 
 
 
 
 
 
f77c387
 
5379f04
 
 
a710661
5379f04
 
 
f77c387
5379f04
f77c387
 
 
 
 
 
 
 
5379f04
 
 
 
 
 
f77c387
5379f04
a9460a2
5379f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c2e74
5379f04
 
b2c2e74
 
5379f04
 
226eb83
5379f04
 
 
 
 
 
 
 
 
 
 
 
 
b2c2e74
5379f04
f229ceb
9a6b2aa
5379f04
 
8364e36
5379f04
 
 
 
 
 
8364e36
5379f04
 
 
b2c2e74
5379f04
 
a9460a2
5379f04
 
1ebf8b3
a9460a2
5379f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c2e74
8364e36
5379f04
 
 
 
 
 
 
 
 
8364e36
5379f04
 
 
 
 
 
 
 
51b1469
 
5379f04
 
8364e36
dd3fe36
 
de693c7
 
 
 
62fd8b9
5379f04
de693c7
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
from bs4 import BeautifulSoup
from llama_index.core import Document
from llama_index.core import Settings
from llama_index.core import SimpleDirectoryReader
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.readers.web import SimpleWebPageReader

from llama_index.vector_stores.chroma import ChromaVectorStore

import chromadb
import re
from llama_index.llms.cohere import Cohere
from llama_index.embeddings.cohere import CohereEmbedding

from llama_index.core import PromptTemplate
from llama_index.core.llms import ChatMessage
import gradio as gr
import uuid

api_key = os.environ.get("API_KEY")
base_url = os.environ.get("BASE_URL")

llm = Cohere(
    api_key=api_key, 
    model_name="command")
embedding_model = CohereEmbedding(
    api_key=api_key, 
    model_name="embed-multilingual-v3.0",
    input_type="search_query",
    embedding_type="int8",)




# Set Global settings
Settings.llm = llm
Settings.embed_model = embedding_model

db_path=""

def extract_web(url):
    web_documents = SimpleWebPageReader().load_data(
        [url]
    )
    html_content = web_documents[0].text
    # Parse the data.
    soup = BeautifulSoup(html_content, 'html.parser')
    p_tags = soup.findAll('p')
    text_content = ""
    for each in p_tags:
        text_content += each.text + "\n"
    
    # Convert back to Document format
    documents = [Document(text=text_content)]
    option = "web"
    return documents, option

def extract_doc(path):
    documents = SimpleDirectoryReader(input_files=path).load_data()
    option = "doc"
    return documents, option


def create_col(documents):
    # Create a client and a new collection
    db_path = f'database/{str(uuid.uuid4())[:4]}'
    client = chromadb.PersistentClient(path=db_path)
    chroma_collection = client.get_or_create_collection("quickstart")
    
    # Create a vector store
    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
    
    # Create a storage context
    storage_context = StorageContext.from_defaults(vector_store=vector_store)
    # Create an index from the documents and save it to the disk.
    VectorStoreIndex.from_documents(
        documents, storage_context=storage_context
    )
    return db_path

def infer(message:str, history: list):
    global db_path
    option=""
    print(f'message: {message}')
    print(f'history: {history}')
    messages = []
    files_list = message["files"]
    

    for prompt,answer in history:
        if prompt is tuple:
            files_list += prompt[0]
        else:
            messages.append(ChatMessage(role= "user", content = prompt))
            messages.append(ChatMessage(role= "assistant", content = answer))

            
    if files_list:
        documents, option = extract_doc(files_list)
        db_path = create_col(documents)
    else:
        if message["text"].startswith("http://") or message["text"].startswith("https://"):
            documents, option = extract_web(message["text"])
            db_path = create_col(documents)
        elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0:
            gr.Error("Please input an url or upload file at first.")
            

    # Load from disk
    load_client = chromadb.PersistentClient(path=db_path)
    
    # Fetch the collection
    chroma_collection = load_client.get_collection("quickstart")
    
    # Fetch the vector store
    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
    
    # Get the index from the vector store
    index = VectorStoreIndex.from_vector_store(
        vector_store
    )


    template = (
        """ You are an assistant for question-answering tasks.
    Use the following context to answer the question.
    If you don't know the answer, just say that you don't know.
    Use five sentences maximum and keep the answer concise.\n
    Question: {query_str} \nContext: {context_str} \nAnswer:"""
    )
    llm_prompt = PromptTemplate(template)
    print(llm_prompt)

    if option == "web" and len(history) == 0:
        response = "Get the web data! You can ask it."   
    else: 
        question = message['text']
        query_engine = index.as_query_engine(text_qa_template=llm_prompt)
        response = query_engine.query(question)

    return response
    






chatbot = gr.Chatbot()

with gr.Blocks(theme="soft") as demo:
    gr.ChatInterface(
        fn = infer,
        title = "RAG demo", 
        multimodal = True,
        chatbot=chatbot,
    )

if __name__ == "__main__":
    demo.queue(api_open=False).launch(show_api=False, share=False)