holyhigh666's picture
Update app.py
515b242 verified
raw
history blame
3.8 kB
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from langchain_huggingface.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from transformers import pipeline
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import glob
import gradio as gr
# Prepare the data
md_path = glob.glob( "md_files/*.md")
docs = [UnstructuredMarkdownLoader(md).load() for md in md_path]
docs_list = [item for sublist in docs for item in sublist]
# Split documents
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=1000, chunk_overlap=200
)
doc_splits = text_splitter.split_documents(docs_list)
# Create the embeddings + retriever
db = FAISS.from_documents(doc_splits,
HuggingFaceEmbeddings(model_name='BAAI/bge-base-en-v1.5'))
# Load quantized model
model_name = 'HuggingFaceH4/zephyr-7b-beta'
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Setup the LLM chain
text_generation_pipeline = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
temperature=0.2,
do_sample=True,
repetition_penalty=1.1,
return_full_text=True,
max_new_tokens=512,
)
llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
# search in vector database
prompt_template = '''You are an assistant for question-answering tasks.
Here is the context to use to answer the question:
{context}
Think carefully about the above context.
Now, review the user question:
{question}
Provide an answer to this questions using only the above context.
Use three sentences maximum and keep the answer concise.
Answer:'''
prompt = PromptTemplate(
input_variables=["context", "question"],
template=prompt_template,
)
llm_chain = prompt | llm | StrOutputParser()
retriever = db.as_retriever()
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| llm_chain
)
#question = "what is advantage of chalcogenide perovskite?"
def get_output(is_RAG:str,questions:str):
if is_RAG== "RAG":
generation2=rag_chain.invoke(questions)
return generation2.content
else:
generation1=llm_chain.invoke({"context":"", "question": questions})
return generation1.content
demo = gr.Interface(
fn=get_output,
inputs=[
gr.Radio(
choices=["RAG", "No RAG"],
type="value",
value="RAG", # Set default value to "Model 1"
label="RAG or not"
),
gr.Textbox(label="Input Questions",info="input questions on chalcogenide perovskites")
],
outputs="markdown",
title="RAG using llm zephyr-7b-beta, embedding model BAAI/bge-base-en-v1.5, based on chalcogenide perovskite papers",
description="""
## ask a question to get answer on chalcogenide perovskite; or click on the examples below.
""",
examples=[["RAG","what is advantage of BaZrS3?"],
["RAG","what is bandgap of SrHfS3?"],
["RAG","why is chalcogenide perovskite important?"]
]
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch(share=False)