ccm commited on
Commit
41ef5eb
·
verified ·
1 Parent(s): 0400fe2

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -22
main.py CHANGED
@@ -1,4 +1,3 @@
1
- import json # to work with JSON
2
  import threading # to allow streaming response
3
  import time # to pave the deliver of the message
4
 
@@ -6,46 +5,50 @@ import datasets # for loading RAG database
6
  import faiss # to create a search index
7
  import gradio # for the interface
8
  import numpy # to work with vectors
9
- import pandas # to work with pandas
10
  import sentence_transformers # to load an embedding model
11
  import spaces # for GPU
12
  import transformers # to load an LLM
13
 
14
- # Constants
15
  GREETING = (
16
  "Howdy! I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) "
17
  "to answer questions about research published at [ASME IDETC](https://asmedigitalcollection.asme.org/IDETC-CIE) within the last 10 years or so. "
18
  "I always try to cite my sources, but sometimes things get a little weird. "
19
  "What can I tell you about today?"
20
  )
 
 
21
  EXAMPLE_QUERIES = [
22
  "What's the difference between a markov chain and a hidden markov model?",
23
  "What can you tell me about analytical target cascading?",
24
  "What is known about different modes for human-AI teaming?",
25
  ]
 
 
26
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
 
 
27
  LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
28
 
29
  # Load the dataset and convert to pandas
30
  data = datasets.load_dataset("ccm/rag-idetc")["train"].to_pandas()
31
 
32
  # Load the model for later use in embeddings
33
- model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
34
 
35
  # Create an LLM pipeline that we can send queries to
36
  tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
37
  streamer = transformers.TextIteratorStreamer(
38
  tokenizer, skip_prompt=True, skip_special_tokens=True
39
  )
40
- chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
41
  LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
42
  )
43
 
44
  # Create a FAISS index for fast similarity search
45
- metric = faiss.METRIC_INNER_PRODUCT
46
- vectors = numpy.stack(data["embedding"].tolist(), axis=0).astype('float32')
47
  index = faiss.IndexFlatL2(len(data["embedding"][0]))
48
- index.metric_type = metric
49
  faiss.normalize_L2(vectors)
50
  index.train(vectors)
51
  index.add(vectors)
@@ -60,7 +63,7 @@ def preprocess(query: str, k: int) -> tuple[str, str]:
60
  Returns:
61
  tuple[str, str]: A tuple containing the prompt and references
62
  """
63
- encoded_query = numpy.expand_dims(model.encode(query), axis=0)
64
  faiss.normalize_L2(encoded_query)
65
  D, I = index.search(encoded_query, k)
66
  top_five = data.loc[I[0]]
@@ -68,16 +71,16 @@ def preprocess(query: str, k: int) -> tuple[str, str]:
68
  print(top_five["text"].values)
69
 
70
  prompt = (
71
- "You are an AI assistant who delights in helping people learn about research from the IDETC Conference."
72
  "Your main task is to provide an ANSWER to the USER_QUERY based on the RESEARCH_EXCERPTS."
73
  "Your ANSWER should be concise.\n\n"
74
- "RESEARCH_EXCERPTS:\n{{ABSTRACTS_GO_HERE}}\n\n"
75
  "USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n"
76
  "ANSWER:\n"
77
  )
78
 
79
  references = {}
80
- research_abstracts = ""
81
 
82
  for i in range(k):
83
  title = top_five["title"].values[i]
@@ -86,21 +89,32 @@ def preprocess(query: str, k: int) -> tuple[str, str]:
86
  path = top_five["path"].values[i]
87
  text = top_five["text"].values[i]
88
 
89
- research_abstracts += str(i + i) + ". This excerpt is from: '" + title + "':\n" + text + "\n"
 
 
90
  header = "[" + title.title() + "](" + url + ")\n"
91
 
92
  if header not in references.keys():
93
  references[header] = []
94
 
95
  references[header].append(text)
96
-
97
- prompt = prompt.replace("{{ABSTRACTS_GO_HERE}}", research_abstracts)
98
  prompt = prompt.replace("{{QUERY_GOES_HERE}}", query)
99
 
100
  print(references)
101
-
102
- return prompt, "\n\n### References\n\n"+"\n".join([str(i+1)+". " + ref + "\n - ".join(["", *["\"..." + x + "...\"" for x in references[ref]]]) for i, ref in enumerate(references.keys())])
103
-
 
 
 
 
 
 
 
 
 
104
  def postprocess(response: str, bypass_from_preprocessing: str) -> str:
105
  """
106
  Applies a postprocessing step to the LLM's response before the user receives it
@@ -142,7 +156,7 @@ def reply(message: str, history: list[str]) -> str:
142
  model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
143
 
144
  generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
145
- t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
146
  t.start()
147
 
148
  partial_message = ""
@@ -160,7 +174,10 @@ gradio.ChatInterface(
160
  reply,
161
  examples=EXAMPLE_QUERIES,
162
  chatbot=gradio.Chatbot(
163
- avatar_images=[None, "https://event.asme.org/Events/media/library/images/IDETC-CIE/IDETC-Logo-Announcements.png?ext=.png"],
 
 
 
164
  show_label=False,
165
  show_share_button=False,
166
  show_copy_button=False,
@@ -172,5 +189,3 @@ gradio.ChatInterface(
172
  undo_btn=None,
173
  clear_btn=None,
174
  ).launch(debug=True)
175
-
176
-
 
 
1
  import threading # to allow streaming response
2
  import time # to pave the deliver of the message
3
 
 
5
  import faiss # to create a search index
6
  import gradio # for the interface
7
  import numpy # to work with vectors
 
8
  import sentence_transformers # to load an embedding model
9
  import spaces # for GPU
10
  import transformers # to load an LLM
11
 
12
+ # The greeting supplied by the agent when it starts
13
  GREETING = (
14
  "Howdy! I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) "
15
  "to answer questions about research published at [ASME IDETC](https://asmedigitalcollection.asme.org/IDETC-CIE) within the last 10 years or so. "
16
  "I always try to cite my sources, but sometimes things get a little weird. "
17
  "What can I tell you about today?"
18
  )
19
+
20
+ # Example queries supplied in the interface
21
  EXAMPLE_QUERIES = [
22
  "What's the difference between a markov chain and a hidden markov model?",
23
  "What can you tell me about analytical target cascading?",
24
  "What is known about different modes for human-AI teaming?",
25
  ]
26
+
27
+ # The embedding model used
28
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
29
+
30
+ # The conversational model used
31
  LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
32
 
33
  # Load the dataset and convert to pandas
34
  data = datasets.load_dataset("ccm/rag-idetc")["train"].to_pandas()
35
 
36
  # Load the model for later use in embeddings
37
+ embedding_model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
38
 
39
  # Create an LLM pipeline that we can send queries to
40
  tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
41
  streamer = transformers.TextIteratorStreamer(
42
  tokenizer, skip_prompt=True, skip_special_tokens=True
43
  )
44
+ chat_model = transformers.AutoModelForCausalLM.from_pretrained(
45
  LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
46
  )
47
 
48
  # Create a FAISS index for fast similarity search
49
+ vectors = numpy.stack(data["embedding"].tolist(), axis=0).astype("float32")
 
50
  index = faiss.IndexFlatL2(len(data["embedding"][0]))
51
+ index.metric_type = faiss.METRIC_INNER_PRODUCT
52
  faiss.normalize_L2(vectors)
53
  index.train(vectors)
54
  index.add(vectors)
 
63
  Returns:
64
  tuple[str, str]: A tuple containing the prompt and references
65
  """
66
+ encoded_query = numpy.expand_dims(embedding_model.encode(query), axis=0)
67
  faiss.normalize_L2(encoded_query)
68
  D, I = index.search(encoded_query, k)
69
  top_five = data.loc[I[0]]
 
71
  print(top_five["text"].values)
72
 
73
  prompt = (
74
+ "You are an AI assistant who delights in helping people learn about research from the IDETC Conference."
75
  "Your main task is to provide an ANSWER to the USER_QUERY based on the RESEARCH_EXCERPTS."
76
  "Your ANSWER should be concise.\n\n"
77
+ "RESEARCH_EXCERPTS:\n{{EXCERPTS_GO_HERE}}\n\n"
78
  "USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n"
79
  "ANSWER:\n"
80
  )
81
 
82
  references = {}
83
+ research_excerpts = ""
84
 
85
  for i in range(k):
86
  title = top_five["title"].values[i]
 
89
  path = top_five["path"].values[i]
90
  text = top_five["text"].values[i]
91
 
92
+ research_excerpts += (
93
+ str(i + i) + ". This excerpt is from: '" + title + "':\n" + text + "\n"
94
+ )
95
  header = "[" + title.title() + "](" + url + ")\n"
96
 
97
  if header not in references.keys():
98
  references[header] = []
99
 
100
  references[header].append(text)
101
+
102
+ prompt = prompt.replace("{{EXCERPTS_GO_HERE}}", research_excerpts)
103
  prompt = prompt.replace("{{QUERY_GOES_HERE}}", query)
104
 
105
  print(references)
106
+
107
+ return prompt, "\n\n### References\n\n" + "\n".join(
108
+ [
109
+ str(i + 1)
110
+ + ". "
111
+ + ref
112
+ + "\n - ".join(["", *['"...' + x + '..."' for x in references[ref]]])
113
+ for i, ref in enumerate(references.keys())
114
+ ]
115
+ )
116
+
117
+
118
  def postprocess(response: str, bypass_from_preprocessing: str) -> str:
119
  """
120
  Applies a postprocessing step to the LLM's response before the user receives it
 
156
  model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
157
 
158
  generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
159
+ t = threading.Thread(target=chat_model.generate, kwargs=generate_kwargs)
160
  t.start()
161
 
162
  partial_message = ""
 
174
  reply,
175
  examples=EXAMPLE_QUERIES,
176
  chatbot=gradio.Chatbot(
177
+ avatar_images=(
178
+ None,
179
+ "https://event.asme.org/Events/media/library/images/IDETC-CIE/IDETC-Logo-Announcements.png?ext=.png",
180
+ ),
181
  show_label=False,
182
  show_share_button=False,
183
  show_copy_button=False,
 
189
  undo_btn=None,
190
  clear_btn=None,
191
  ).launch(debug=True)