import threading # to allow streaming response import time # to pave the delivery of the message import datasets # for loading RAG database import faiss # to create a search index import gradio # for the interface import numpy # to work with vectors import sentence_transformers # to load an embedding model import spaces # for GPU import transformers # to load an LLM # The greeting supplied by the agent when it starts GREETING = ( "Howdy! I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) " "to answer questions about research published at [ASME IDETC](https://asmedigitalcollection.asme.org/IDETC-CIE) within the last 10 years or so. " "I always try to cite my sources, but sometimes things get a little weird. " "What can I tell you about today?" ) # Example queries supplied in the interface EXAMPLE_QUERIES = [ "What's the difference between a markov chain and a hidden markov model?", "What can you tell me about analytical target cascading?", "What is known about different modes for human-AI teaming?", "What are some examples of opportunistic versus restrictive design for additive manufacturing? Format your answer as a table with two columns (opportunistic, restrictive)." ] # The embedding model used EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" # The conversational model used LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct" # Load the dataset and convert to pandas data = datasets.load_dataset("ccm/rag-idetc")["train"].to_pandas() # Load the model for later use in embeddings embedding_model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME) # Create an LLM pipeline that we can send queries to tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME) streamer = transformers.TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) chat_model = transformers.AutoModelForCausalLM.from_pretrained( LLM_MODEL_NAME, torch_dtype="auto", device_map="auto" ) # Create a FAISS index for fast similarity search vectors = numpy.stack(data["embedding"].tolist(), axis=0).astype("float32") excerpt_index = faiss.IndexFlatL2(len(data["embedding"][0])) excerpt_index.metric_type = faiss.METRIC_INNER_PRODUCT faiss.normalize_L2(vectors) excerpt_index.train(vectors) excerpt_index.add(vectors) def preprocess(query: str, k: int) -> tuple[str, str]: """ Searches the dataset for the top k most relevant papers to the query and returns a prompt and references Args: query (str): The user's query k (int): The number of results to return Returns: tuple[str, str]: A tuple containing the prompt and references """ encoded_query = numpy.expand_dims(embedding_model.encode(query), axis=0) faiss.normalize_L2(encoded_query) _, indices = excerpt_index.search(encoded_query, k) top_five = data.loc[indices[0]] print(top_five["text"].values) prompt = ( "You are an AI assistant who delights in helping people learn about research from the IDETC Conference." "Your main task is to provide an ANSWER to the USER_QUERY based on the RESEARCH_EXCERPTS." "Your ANSWER should be concise.\n\n" "RESEARCH_EXCERPTS:\n{{EXCERPTS_GO_HERE}}\n\n" "USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n" "ANSWER:\n" ) references = {} research_excerpts = "" for i in range(k): title = top_five["title"].values[i] id = top_five["id"].values[i] url = "https://doi.org/10.1115/" + id text = top_five["text"].values[i] research_excerpts += ( str(i + i) + ". This excerpt is from: '" + title + "':\n" + text + "\n" ) header = "[" + title.title() + "](" + url + ")\n" if header not in references.keys(): references[header] = [] references[header].append(text) prompt = prompt.replace("{{EXCERPTS_GO_HERE}}", research_excerpts) prompt = prompt.replace("{{QUERY_GOES_HERE}}", query) print(references) list_of_references = "\n".join( [ "### " + hyperlinked_title + "\n\n> ".join( [ "", *[ '"...' + excerpt + '..."' for excerpt in references[hyperlinked_title] ], ] ) for idx, hyperlinked_title in enumerate(references.keys()) ] ) return ( prompt, "\n\n

References

\n\n" + list_of_references + "\n\n", ) def postprocess(response: str, bypass_from_preprocessing: str) -> str: """ Applies a postprocessing step to the LLM's response before the user receives it Args: response (str): The LLM's response bypass_from_preprocessing (str): The bypass variable from the preprocessing step Returns: str: The postprocessed response """ return response + bypass_from_preprocessing @spaces.GPU def reply(message: str, history: list[str]) -> str: """ This function is responsible for crafting a response Args: message (str): The user's message history (list[str]): The conversation history Returns: str: The AI's response """ # Apply preprocessing message, bypass = preprocess(message, 10) # This is some handling that is applied to the history variable to put it in a good format history_transformer_format = [ {"role": role, "content": message_pair[idx]} for message_pair in history for idx, role in enumerate(["user", "assistant"]) if message_pair[idx] is not None ] + [{"role": "user", "content": message}] # Stream a response from pipe text = tokenizer.apply_chat_template( history_transformer_format, tokenize=False, add_generation_prompt=True ) model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0") generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512) t = threading.Thread(target=chat_model.generate, kwargs=generate_kwargs) t.start() partial_message = "" for new_token in streamer: if new_token != "<": partial_message += new_token time.sleep(0.05) yield partial_message yield partial_message + bypass # Create and run the gradio interface gradio.ChatInterface( reply, examples=EXAMPLE_QUERIES, chatbot=gradio.Chatbot( avatar_images=( None, "https://event.asme.org/Events/media/library/images/IDETC-CIE/IDETC-Logo-Announcements.png?ext=.png", ), show_label=False, show_share_button=False, show_copy_button=False, value=[[None, GREETING]], height="60vh", bubble_full_width=False, ), retry_btn=None, undo_btn=None, clear_btn=None, ).launch(debug=True)