|
import gradio as gr |
|
import utils |
|
from langchain_mistralai import ChatMistralAI |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_core.runnables import RunnablePassthrough |
|
import torch |
|
|
|
import os |
|
os.environ['MISTRAL_API_KEY'] = 'XuyOObDE7trMbpAeI7OXYr3dnmoWy3L0' |
|
|
|
class VectorData(): |
|
def __init__(self): |
|
embedding_model_name = 'l3cube-pune/punjabi-sentence-similarity-sbert' |
|
|
|
model_kwargs = {'device':'cuda' if torch.cuda.is_available() else 'cpu',"trust_remote_code": True} |
|
|
|
self.embeddings = HuggingFaceEmbeddings( |
|
model_name=embedding_model_name, |
|
model_kwargs=model_kwargs |
|
) |
|
|
|
self.vectorstore = Chroma(persist_directory="chroma_db", embedding_function=self.embeddings) |
|
self.retriever = self.vectorstore.as_retriever() |
|
self.ingested_files = [] |
|
self.prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
( |
|
"system", |
|
"""Answer the question based on the given context. Dont give any ans if context is not valid to question. Always give the source of context: |
|
{context} |
|
""", |
|
), |
|
("human", "{question}"), |
|
] |
|
) |
|
self.llm = ChatMistralAI(model="mistral-large-latest") |
|
self.rag_chain = ( |
|
{"context": self.retriever, "question": RunnablePassthrough()} |
|
| self.prompt |
|
| self.llm |
|
| StrOutputParser() |
|
) |
|
|
|
def add_file(self,file): |
|
if file is not None: |
|
self.ingested_files.append(file.name.split('/')[-1]) |
|
self.retriever, self.vectorstore = utils.add_doc(file,self.vectorstore) |
|
self.rag_chain = ( |
|
{"context": self.retriever, "question": RunnablePassthrough()} |
|
| self.prompt |
|
| self.llm |
|
| StrOutputParser() |
|
) |
|
return [[name] for name in self.ingested_files] |
|
|
|
def delete_file_by_name(self,file_name): |
|
if file_name in self.ingested_files: |
|
self.retriever, self.vectorstore = utils.delete_doc(file_name,self.vectorstore) |
|
self.ingested_files.remove(file_name) |
|
return [[name] for name in self.ingested_files] |
|
|
|
def delete_all_files(self): |
|
self.ingested_files.clear() |
|
self.retriever, self.vectorstore = utils.delete_all_doc(self.vectorstore) |
|
return [] |
|
|
|
data_obj = VectorData() |
|
|
|
|
|
def answer_question(question): |
|
if question.strip(): |
|
return f'{data_obj.rag_chain.invoke(question)}' |
|
return "Please enter a question." |
|
|
|
|
|
|
|
with gr.Blocks() as rag_interface: |
|
|
|
gr.Markdown("# RAG Interface") |
|
gr.Markdown("Manage documents and ask questions with a Retrieval-Augmented Generation (RAG) system.") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
gr.Markdown("### File Management") |
|
|
|
|
|
file_input = gr.File(label="Upload File to Ingest") |
|
add_file_button = gr.Button("Ingest File") |
|
|
|
|
|
ingested_files_box = gr.Dataframe( |
|
headers=["Files"], |
|
datatype="str", |
|
row_count=4, |
|
interactive=False |
|
) |
|
|
|
|
|
delete_option = gr.Radio(choices=["Delete by File Name", "Delete All Files"], label="Delete Option") |
|
file_name_input = gr.Textbox(label="Enter File Name to Delete", visible=False) |
|
delete_button = gr.Button("Delete Selected") |
|
|
|
|
|
def toggle_file_input(option): |
|
return gr.update(visible=(option == "Delete by File Name")) |
|
|
|
delete_option.change(fn=toggle_file_input, inputs=delete_option, outputs=file_name_input) |
|
|
|
|
|
add_file_button.click( |
|
fn=data_obj.add_file, |
|
inputs=file_input, |
|
outputs=ingested_files_box |
|
) |
|
|
|
|
|
def delete_action(delete_option, file_name): |
|
if delete_option == "Delete by File Name" and file_name: |
|
return data_obj.delete_file_by_name(file_name) |
|
elif delete_option == "Delete All Files": |
|
return data_obj.delete_all_files() |
|
else: |
|
return [[name] for name in data_obj.ingested_files] |
|
|
|
delete_button.click( |
|
fn=delete_action, |
|
inputs=[delete_option, file_name_input], |
|
outputs=ingested_files_box |
|
) |
|
|
|
|
|
with gr.Column(): |
|
gr.Markdown("### Ask a Question") |
|
|
|
|
|
question_input = gr.Textbox(label="Enter your question") |
|
|
|
|
|
ask_button = gr.Button("Get Answer") |
|
answer_output = gr.Textbox(label="Answer", interactive=False) |
|
|
|
ask_button.click(fn=answer_question, inputs=question_input, outputs=answer_output) |
|
|
|
|
|
rag_interface.launch() |
|
|