from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain.vectorstores import FAISS import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from langchain_huggingface.llms import HuggingFacePipeline from langchain.prompts import PromptTemplate from transformers import pipeline from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough import glob import gradio as gr # Prepare the data md_path = glob.glob( "md_files/*.md") docs = [UnstructuredMarkdownLoader(md).load() for md in md_path] docs_list = [item for sublist in docs for item in sublist] # Split documents text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=1000, chunk_overlap=200 ) doc_splits = text_splitter.split_documents(docs_list) # Create the embeddings + retriever db = FAISS.from_documents(doc_splits, HuggingFaceEmbeddings(model_name='BAAI/bge-base-en-v1.5')) # Load quantized model model_name = 'HuggingFaceH4/zephyr-7b-beta' bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config) tokenizer = AutoTokenizer.from_pretrained(model_name) # Setup the LLM chain text_generation_pipeline = pipeline( model=model, tokenizer=tokenizer, task="text-generation", temperature=0.2, do_sample=True, repetition_penalty=1.1, return_full_text=True, max_new_tokens=512, ) llm = HuggingFacePipeline(pipeline=text_generation_pipeline) # search in vector database prompt_template = '''You are an assistant for question-answering tasks. Here is the context to use to answer the question: {context} Think carefully about the above context. Now, review the user question: {question} Provide an answer to this questions using only the above context. Use three sentences maximum and keep the answer concise. Answer:''' prompt = PromptTemplate( input_variables=["context", "question"], template=prompt_template, ) llm_chain = prompt | llm | StrOutputParser() retriever = db.as_retriever() rag_chain = ( {"context": retriever, "question": RunnablePassthrough()} | llm_chain ) #question = "what is advantage of chalcogenide perovskite?" def get_output(is_RAG:str,questions:str): if is_RAG== "RAG": generation2=rag_chain.invoke(questions) return generation2.content else: generation1=llm_chain.invoke({"context":"", "question": questions}) return generation1.content demo = gr.Interface( fn=get_output, inputs=[ gr.Radio( choices=["RAG", "No RAG"], type="value", value="RAG", # Set default value to "Model 1" label="RAG or not" ), gr.Textbox(label="Input Questions",info="input questions on chalcogenide perovskites") ], outputs="markdown", title="RAG using llm zephyr-7b-beta, embedding model BAAI/bge-base-en-v1.5, based on chalcogenide perovskite papers", description=""" ## ask a question to get answer on chalcogenide perovskite; or click on the examples below. """, examples=[["RAG","what is advantage of BaZrS3?"], ["RAG","what is bandgap of SrHfS3?"], ["RAG","why is chalcogenide perovskite important?"] ] ) # Launch the Gradio app if __name__ == "__main__": demo.launch(share=False)