|
from langchain_community.document_loaders import UnstructuredMarkdownLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
from langchain_huggingface.llms import HuggingFacePipeline |
|
from langchain.prompts import PromptTemplate |
|
from transformers import pipeline |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
import glob |
|
import gradio as gr |
|
|
|
|
|
|
|
md_path = glob.glob( "md_files/*.md") |
|
|
|
docs = [UnstructuredMarkdownLoader(md).load() for md in md_path] |
|
docs_list = [item for sublist in docs for item in sublist] |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( |
|
chunk_size=1000, chunk_overlap=200 |
|
) |
|
doc_splits = text_splitter.split_documents(docs_list) |
|
|
|
|
|
|
|
|
|
db = FAISS.from_documents(doc_splits, |
|
HuggingFaceEmbeddings(model_name='BAAI/bge-base-en-v1.5')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = 'HuggingFaceH4/zephyr-7b-beta' |
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
|
|
|
text_generation_pipeline = pipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
task="text-generation", |
|
temperature=0.2, |
|
do_sample=True, |
|
repetition_penalty=1.1, |
|
return_full_text=True, |
|
max_new_tokens=512, |
|
) |
|
|
|
llm = HuggingFacePipeline(pipeline=text_generation_pipeline) |
|
|
|
|
|
|
|
|
|
|
|
prompt_template = '''You are an assistant for question-answering tasks. |
|
|
|
Here is the context to use to answer the question: |
|
|
|
{context} |
|
|
|
Think carefully about the above context. |
|
|
|
Now, review the user question: |
|
|
|
{question} |
|
|
|
Provide an answer to this questions using only the above context. |
|
|
|
Use three sentences maximum and keep the answer concise. |
|
|
|
Answer:''' |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template=prompt_template, |
|
) |
|
|
|
|
|
llm_chain = prompt | llm | StrOutputParser() |
|
|
|
|
|
|
|
|
|
retriever = db.as_retriever() |
|
|
|
rag_chain = ( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
| llm_chain |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_output(is_RAG:str,questions:str): |
|
if is_RAG== "RAG": |
|
generation2=rag_chain.invoke(questions) |
|
return generation2.content |
|
else: |
|
generation1=llm_chain.invoke({"context":"", "question": questions}) |
|
return generation1.content |
|
|
|
demo = gr.Interface( |
|
fn=get_output, |
|
inputs=[ |
|
gr.Radio( |
|
choices=["RAG", "No RAG"], |
|
type="value", |
|
value="RAG", |
|
label="RAG or not" |
|
), |
|
gr.Textbox(label="Input Questions",info="input questions on chalcogenide perovskites") |
|
], |
|
outputs="markdown", |
|
title="RAG using llm zephyr-7b-beta, embedding model BAAI/bge-base-en-v1.5, based on chalcogenide perovskite papers", |
|
description=""" |
|
## ask a question to get answer on chalcogenide perovskite; or click on the examples below. |
|
""", |
|
examples=[["RAG","what is advantage of BaZrS3?"], |
|
["RAG","what is bandgap of SrHfS3?"], |
|
["RAG","why is chalcogenide perovskite important?"] |
|
] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=False) |
|
|
|
|
|
|
|
|