import gradio as gr import utils from langchain_mistralai import ChatMistralAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_community.vectorstores import Chroma from langchain_huggingface import HuggingFaceEmbeddings from langchain_core.runnables import RunnablePassthrough import torch import os os.environ['MISTRAL_API_KEY'] = 'XuyOObDE7trMbpAeI7OXYr3dnmoWy3L0' class VectorData(): def __init__(self): embedding_model_name = 'l3cube-pune/punjabi-sentence-similarity-sbert' model_kwargs = {'device':'cuda' if torch.cuda.is_available() else 'cpu',"trust_remote_code": True} self.embeddings = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs=model_kwargs ) self.vectorstore = Chroma(persist_directory="chroma_db", embedding_function=self.embeddings) self.retriever = self.vectorstore.as_retriever() self.ingested_files = [] self.prompt = ChatPromptTemplate.from_messages( [ ( "system", """Answer the question based on the given context. Dont give any ans if context is not valid to question. Always give the source of context: {context} """, ), ("human", "{question}"), ] ) self.llm = ChatMistralAI(model="mistral-large-latest") self.rag_chain = ( {"context": self.retriever, "question": RunnablePassthrough()} | self.prompt | self.llm | StrOutputParser() ) def add_file(self,file): if file is not None: self.ingested_files.append(file.name.split('/')[-1]) self.retriever, self.vectorstore = utils.add_doc(file,self.vectorstore) self.rag_chain = ( {"context": self.retriever, "question": RunnablePassthrough()} | self.prompt | self.llm | StrOutputParser() ) return [[name] for name in self.ingested_files] def delete_file_by_name(self,file_name): if file_name in self.ingested_files: self.retriever, self.vectorstore = utils.delete_doc(file_name,self.vectorstore) self.ingested_files.remove(file_name) return [[name] for name in self.ingested_files] def delete_all_files(self): self.ingested_files.clear() self.retriever, self.vectorstore = utils.delete_all_doc(self.vectorstore) return [] data_obj = VectorData() # Function to handle question answering def answer_question(question): if question.strip(): return f'{data_obj.rag_chain.invoke(question)}' return "Please enter a question." # Define the Gradio interface with gr.Blocks() as rag_interface: # Title and Description gr.Markdown("# RAG Interface") gr.Markdown("Manage documents and ask questions with a Retrieval-Augmented Generation (RAG) system.") with gr.Row(): # Left Column: File Management with gr.Column(): gr.Markdown("### File Management") # File upload and ingest file_input = gr.File(label="Upload File to Ingest") add_file_button = gr.Button("Ingest File") # Scrollable list for ingested files ingested_files_box = gr.Dataframe( headers=["Files"], datatype="str", row_count=4, # Limits the visible rows to create a scrollable view interactive=False ) # Radio buttons to choose delete option delete_option = gr.Radio(choices=["Delete by File Name", "Delete All Files"], label="Delete Option") file_name_input = gr.Textbox(label="Enter File Name to Delete", visible=False) delete_button = gr.Button("Delete Selected") # Show or hide file name input based on delete option selection def toggle_file_input(option): return gr.update(visible=(option == "Delete by File Name")) delete_option.change(fn=toggle_file_input, inputs=delete_option, outputs=file_name_input) # Handle file ingestion add_file_button.click( fn=data_obj.add_file, inputs=file_input, outputs=ingested_files_box ) # Handle delete based on selected option def delete_action(delete_option, file_name): if delete_option == "Delete by File Name" and file_name: return data_obj.delete_file_by_name(file_name) elif delete_option == "Delete All Files": return data_obj.delete_all_files() else: return [[name] for name in data_obj.ingested_files] delete_button.click( fn=delete_action, inputs=[delete_option, file_name_input], outputs=ingested_files_box ) # Right Column: Question Answering with gr.Column(): gr.Markdown("### Ask a Question") # Question input question_input = gr.Textbox(label="Enter your question") # Get answer button and answer output ask_button = gr.Button("Get Answer") answer_output = gr.Textbox(label="Answer", interactive=False) ask_button.click(fn=answer_question, inputs=question_input, outputs=answer_output) # Launch the Gradio interface rag_interface.launch()