chat-with-idetc / main.py
ccm's picture
Update main.py
15e16ab verified
import threading # to allow streaming response
import time # to pave the delivery of the message
import datasets # for loading RAG database
import faiss # to create a search index
import gradio # for the interface
import numpy # to work with vectors
import sentence_transformers # to load an embedding model
import spaces # for GPU
import transformers # to load an LLM
# The greeting supplied by the agent when it starts
GREETING = (
"Howdy! I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) "
"to answer questions about research published at [ASME IDETC](https://asmedigitalcollection.asme.org/IDETC-CIE) within the last 10 years or so. "
"I always try to cite my sources, but sometimes things get a little weird. "
"What can I tell you about today?"
)
# Example queries supplied in the interface
EXAMPLE_QUERIES = [
"What's the difference between a markov chain and a hidden markov model?",
"What can you tell me about analytical target cascading?",
"What is known about different modes for human-AI teaming?",
"What are some examples of opportunistic versus restrictive design for additive manufacturing? Format your answer as a table with two columns (opportunistic, restrictive)."
]
# The embedding model used
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
# The conversational model used
LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
# Load the dataset and convert to pandas
data = datasets.load_dataset("ccm/rag-idetc")["train"].to_pandas()
# Load the model for later use in embeddings
embedding_model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
# Create an LLM pipeline that we can send queries to
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
streamer = transformers.TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
chat_model = transformers.AutoModelForCausalLM.from_pretrained(
LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
)
# Create a FAISS index for fast similarity search
vectors = numpy.stack(data["embedding"].tolist(), axis=0).astype("float32")
excerpt_index = faiss.IndexFlatL2(len(data["embedding"][0]))
excerpt_index.metric_type = faiss.METRIC_INNER_PRODUCT
faiss.normalize_L2(vectors)
excerpt_index.train(vectors)
excerpt_index.add(vectors)
def preprocess(query: str, k: int) -> tuple[str, str]:
"""
Searches the dataset for the top k most relevant papers to the query and returns a prompt and references
Args:
query (str): The user's query
k (int): The number of results to return
Returns:
tuple[str, str]: A tuple containing the prompt and references
"""
encoded_query = numpy.expand_dims(embedding_model.encode(query), axis=0)
faiss.normalize_L2(encoded_query)
_, indices = excerpt_index.search(encoded_query, k)
top_five = data.loc[indices[0]]
print(top_five["text"].values)
prompt = (
"You are an AI assistant who delights in helping people learn about research from the IDETC Conference."
"Your main task is to provide an ANSWER to the USER_QUERY based on the RESEARCH_EXCERPTS."
"Your ANSWER should be concise.\n\n"
"RESEARCH_EXCERPTS:\n{{EXCERPTS_GO_HERE}}\n\n"
"USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n"
"ANSWER:\n"
)
references = {}
research_excerpts = ""
for i in range(k):
title = top_five["title"].values[i]
id = top_five["id"].values[i]
url = "https://doi.org/10.1115/" + id
text = top_five["text"].values[i]
research_excerpts += (
str(i + i) + ". This excerpt is from: '" + title + "':\n" + text + "\n"
)
header = "[" + title.title() + "](" + url + ")\n"
if header not in references.keys():
references[header] = []
references[header].append(text)
prompt = prompt.replace("{{EXCERPTS_GO_HERE}}", research_excerpts)
prompt = prompt.replace("{{QUERY_GOES_HERE}}", query)
print(references)
list_of_references = "\n".join(
[
"### "
+ hyperlinked_title
+ "\n\n> ".join(
[
"",
*[
'"...' + excerpt + '..."'
for excerpt in references[hyperlinked_title]
],
]
)
for idx, hyperlinked_title in enumerate(references.keys())
]
)
return (
prompt,
"\n\n<details><summary><h3>References</h3></summary>\n\n"
+ list_of_references
+ "\n\n</summary>",
)
def postprocess(response: str, bypass_from_preprocessing: str) -> str:
"""
Applies a postprocessing step to the LLM's response before the user receives it
Args:
response (str): The LLM's response
bypass_from_preprocessing (str): The bypass variable from the preprocessing step
Returns:
str: The postprocessed response
"""
return response + bypass_from_preprocessing
@spaces.GPU
def reply(message: str, history: list[str]) -> str:
"""
This function is responsible for crafting a response
Args:
message (str): The user's message
history (list[str]): The conversation history
Returns:
str: The AI's response
"""
# Apply preprocessing
message, bypass = preprocess(message, 10)
# This is some handling that is applied to the history variable to put it in a good format
history_transformer_format = [
{"role": role, "content": message_pair[idx]}
for message_pair in history
for idx, role in enumerate(["user", "assistant"])
if message_pair[idx] is not None
] + [{"role": "user", "content": message}]
# Stream a response from pipe
text = tokenizer.apply_chat_template(
history_transformer_format, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
t = threading.Thread(target=chat_model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != "<":
partial_message += new_token
time.sleep(0.05)
yield partial_message
yield partial_message + bypass
# Create and run the gradio interface
gradio.ChatInterface(
reply,
examples=EXAMPLE_QUERIES,
chatbot=gradio.Chatbot(
avatar_images=(
None,
"https://event.asme.org/Events/media/library/images/IDETC-CIE/IDETC-Logo-Announcements.png?ext=.png",
),
show_label=False,
show_share_button=False,
show_copy_button=False,
value=[[None, GREETING]],
height="60vh",
bubble_full_width=False,
),
retry_btn=None,
undo_btn=None,
clear_btn=None,
).launch(debug=True)