Artteiv tosanoob commited on
Commit
f93a31b
·
verified ·
1 Parent(s): f3de9b9

fix query bug no.2 (#11)

Browse files

- fix query bug no.2 (bca1d0f824d83dd5dd04871551143bfaa4b05b89)


Co-authored-by: Trương Tấn Cường <[email protected]>

Files changed (1) hide show
  1. chat/arxiv_bot/arxiv_bot_utils.py +273 -252
chat/arxiv_bot/arxiv_bot_utils.py CHANGED
@@ -1,276 +1,297 @@
1
- # import chromadb
2
- # from chromadb import Documents, EmbeddingFunction, Embeddings
3
- # from transformers import AutoModel
4
- # import json
5
- # from numpy.linalg import norm
6
- # import sqlite3
7
- # import urllib
8
- # from django.conf import settings
 
9
 
 
10
 
11
- # # this module act as a singleton class
 
 
 
12
 
13
- # class JinaAIEmbeddingFunction(EmbeddingFunction):
14
- # def __init__(self, model):
15
- # super().__init__()
16
- # self.model = model
17
 
18
- # def __call__(self, input: Documents) -> Embeddings:
19
- # embeddings = self.model.encode(input)
20
- # return embeddings.tolist()
 
21
 
22
- # # instance of embedding_model
23
- # embedding_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en',
24
- # trust_remote_code=True,
25
- # cache_dir='models')
26
 
27
- # # instance of JinaAIEmbeddingFunction
28
- # ef = JinaAIEmbeddingFunction(embedding_model)
 
 
 
29
 
30
- # # list of topics
31
- # topic_descriptions = json.load(open("topic_descriptions.txt"))
32
- # topics = list(dict.keys(topic_descriptions))
33
- # embeddings = [embedding_model.encode(topic_descriptions[key]) for key in topic_descriptions]
34
- # cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b))
35
 
36
- # def choose_topic(summary):
37
- # embed = embedding_model.encode(summary)
38
- # topic = ""
39
- # max_sim = 0.
40
- # for i,key in enumerate(topics):
41
- # sim = cos_sim(embed,embeddings[i])
42
- # if sim > max_sim:
43
- # topic = key
44
- # max_sim = sim
45
- # return topic
46
 
47
- # def authors_list_to_str(authors):
48
- # """input a list of authors, return a string represent authors"""
49
- # text = ""
50
- # for author in authors:
51
- # text+=author+", "
52
- # return text[:-3]
53
 
54
- # def authors_str_to_list(string):
55
- # """input a string of authors, return a list of authors"""
56
- # authors = []
57
- # list_auth = string.split("and")
58
- # for author in list_auth:
59
- # if author != "et al.":
60
- # authors.append(author.strip())
61
- # return authors
62
 
63
- # def chunk_texts(text, max_char=400):
64
- # """
65
- # Chunk a long text into several chunks, with each chunk about 300-400 characters long,
66
- # but make sure no word is cut in half.
67
- # Args:
68
- # text: The long text to be chunked.
69
- # max_char: The maximum number of characters per chunk (default: 400).
70
- # Returns:
71
- # A list of chunks.
72
- # """
73
- # chunks = []
74
- # current_chunk = ""
75
- # words = text.split()
76
- # for word in words:
77
- # if len(current_chunk) + len(word) + 1 >= max_char:
78
- # chunks.append(current_chunk)
79
- # current_chunk = " "
80
- # else:
81
- # current_chunk += " " + word
82
- # chunks.append(current_chunk.strip())
83
- # return chunks
84
 
85
- # def trimming(txt):
86
- # start = txt.find("{")
87
- # end = txt.rfind("}")
88
- # return txt[start:end+1].replace("\n"," ")
89
 
90
- # # crawl data
91
 
92
- # def extract_tag(txt,tagname):
93
- # return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("</"+tagname+">")]
94
 
95
- # def get_record(extract):
96
- # id = extract_tag(extract,"id")
97
- # updated = extract_tag(extract,"updated")
98
- # published = extract_tag(extract,"published")
99
- # title = extract_tag(extract,"title").replace("\n ","").strip()
100
- # summary = extract_tag(extract,"summary").replace("\n","").strip()
101
- # authors = []
102
- # while extract.find("<author>")!=-1:
103
- # author = extract_tag(extract,"name")
104
- # extract = extract[extract.find("</author>")+9:]
105
- # authors.append(author)
106
- # pattern = '<link title="pdf" href="'
107
- # link_start = extract.find('<link title="pdf" href="')
108
- # link = extract[link_start+len(pattern):extract.find("rel=",link_start)-2]
109
- # return [id, updated, published, title, authors, link, summary]
110
 
111
- # def crawl_exact_paper(title,author,max_results=3):
112
- # authors = authors_list_to_str(author)
113
- # records = []
114
- # url = 'http://export.arxiv.org/api/query?search_query=ti:{title}+AND+au:{author}&max_results={max_results}'.format(title=title,author=authors,max_results=max_results)
115
- # url = url.replace(" ","%20")
116
- # try:
117
- # arxiv_page = urllib.request.urlopen(url,timeout=100).read()
118
- # xml = str(arxiv_page,encoding="utf-8")
119
- # while xml.find("<entry>") != -1:
120
- # extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
121
- # xml = xml[xml.find("</entry>")+8:]
122
- # extract = get_record(extract)
123
- # topic = choose_topic(extract[6])
124
- # records.append([topic,*extract])
125
- # return records
126
- # except Exception as e:
127
- # return "Error: "+str(e)
128
 
129
- # def crawl_arxiv(keyword_list, max_results=100):
130
- # baseurl = 'http://export.arxiv.org/api/query?search_query='
131
- # records = []
132
- # for i,keyword in enumerate(keyword_list):
133
- # if i ==0:
134
- # url = baseurl + 'all:' + keyword
135
- # else:
136
- # url = url + '+OR+' + 'all:' + keyword
137
- # url = url+ '&max_results=' + str(max_results)
138
- # url = url.replace(' ', '%20')
139
- # try:
140
- # arxiv_page = urllib.request.urlopen(url,timeout=100).read()
141
- # xml = str(arxiv_page,encoding="utf-8")
142
- # while xml.find("<entry>") != -1:
143
- # extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
144
- # xml = xml[xml.find("</entry>")+8:]
145
- # extract = get_record(extract)
146
- # topic = choose_topic(extract[6])
147
- # records.append([topic,*extract])
148
- # return records
149
- # except Exception as e:
150
- # return "Error: "+str(e)
151
 
152
- # class ArxivSQL:
153
- # def __init__(self, table="arxivsql", name="db.sqlite3"):
154
- # self.con = sqlite3.connect(name)
155
- # self.cur = self.con.cursor()
156
- # self.table = table
 
 
 
 
157
 
158
- # def query(self, title="", author=[]):
159
- # if len(title)>0:
160
- # query_title = 'title like "%{}%"'.format(title)
161
- # else:
162
- # query_title = "True"
163
- # if len(author)>0:
164
- # query_author = 'authors like '
165
- # for auth in author:
166
- # query_author += "'%{}%' or ".format(auth)
167
- # query_author = query_author[:-4]
168
- # else:
169
- # query_author = "True"
170
- # query = "select * from {} where {} and {}".format(self.table,query_title,query_author)
171
- # result = self.cur.execute(query)
172
- # return result.fetchall()
173
 
174
- # def query_id(self, ids=[]):
175
- # try:
176
- # if len(ids) == 0:
177
- # return None
178
- # query = "select * from {} where id in (".format(self.table)
179
- # for id in ids:
180
- # query+="'"+id+"',"
181
- # query = query[:-1] + ")"
182
- # result = self.cur.execute(query)
183
- # return result.fetchall()
184
- # except Exception as e:
185
- # print(e)
186
- # print("Error query: ",query)
187
-
188
- # def add(self, crawl_records):
189
- # """
190
- # Add crawl_records (list) obtained from arxiv_crawlers
191
- # A record is a list of 8 columns:
192
- # [topic, id, updated, published, title, author, link, summary]
193
- # Return the final length of the database table
194
- # """
195
- # results = ""
196
- # for record in crawl_records:
197
- # try:
198
- # query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format(
199
- # record[1][21:],
200
- # record[0],
201
- # record[4].replace('"',"'"),
202
- # authors_list_to_str(record[5]),
203
- # record[2][:10],
204
- # record[3][:10],
205
- # record[6]
206
- # )
207
- # self.cur.execute(query)
208
- # self.con.commit()
209
- # except Exception as e:
210
- # result+=str(e)
211
- # result+="\n" + query + "\n"
212
- # finally:
213
- # return results
214
-
215
- # # instance of ArxivSQL
216
- # sqldb = ArxivSQL()
217
 
218
- # class ArxivChroma:
219
- # """
220
- # Create an interface to arxivdb, which only support query and addition.
221
- # This interface do not support edition and deletion procedures.
222
- # """
223
- # def __init__(self, table="arxiv_records", name="arxivdb/"):
224
- # self.client = chromadb.PersistentClient(name)
225
- # self.model = embedding_model
226
- # self.collection = self.client.get_or_create_collection(table,
227
- # embedding_function=JinaAIEmbeddingFunction(
228
- # model = self.model
229
- # ))
230
-
231
- # def query_relevant(self, keywords, query_texts, n_results=3):
232
- # """
233
- # Perform a query using a list of keywords (str),
234
- # or using a relavant string
235
- # """
236
- # contains = []
237
- # for keyword in keywords:
238
- # contains.append({"$contains":keyword.lower()})
239
- # return self.collection.query(
240
- # query_texts=query_texts,
241
- # where_document={
242
- # "$or":contains
243
- # },
244
- # n_results=n_results,
245
- # )
246
-
247
- # def query_exact(self, id):
248
- # ids = ["{}_{}".format(id,j) for j in range(0,10)]
249
- # return self.collection.get(ids=ids)
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- # def add(self, crawl_records):
252
- # """
253
- # Add crawl_records (list) obtained from arxiv_crawlers
254
- # A record is a list of 8 columns:
255
- # [topic, id, updated, published, title, author, link, summary]
256
- # Return the final length of the database table
257
- # """
258
- # for record in crawl_records:
259
- # embed_text = """
260
- # Topic: {},
261
- # Title: {},
262
- # Summary: {}
263
- # """.format(record[0],record[4],record[7])
264
- # chunks = chunk_texts(embed_text)
265
- # ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))]
266
- # paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))]
267
- # self.collection.add(
268
- # documents = chunks,
269
- # metadatas=paper_ids,
270
- # ids = ids
271
- # )
272
- # return self.collection.count()
273
 
274
- # # instance of ArxivChroma
275
- # db = ArxivChroma()
 
 
 
 
 
 
 
 
 
 
 
 
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from chromadb import Documents, EmbeddingFunction, Embeddings
3
+ from transformers import AutoModel
4
+ import json
5
+ from numpy.linalg import norm
6
+ import sqlite3
7
+ import urllib.request
8
+ from django.conf import settings
9
+ import Levenshtein
10
 
11
+ # this module act as a singleton class
12
 
13
+ class JinaAIEmbeddingFunction(EmbeddingFunction):
14
+ def __init__(self, model):
15
+ super().__init__()
16
+ self.model = model
17
 
18
+ def __call__(self, input: Documents) -> Embeddings:
19
+ embeddings = self.model.encode(input)
20
+ return embeddings.tolist()
 
21
 
22
+ # instance of embedding_model
23
+ embedding_model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en',
24
+ trust_remote_code=True,
25
+ cache_dir='models')
26
 
27
+ # instance of JinaAIEmbeddingFunction
28
+ ef = JinaAIEmbeddingFunction(embedding_model)
 
 
29
 
30
+ # list of topics
31
+ topic_descriptions = json.load(open("topic_descriptions.txt"))
32
+ topics = list(dict.keys(topic_descriptions))
33
+ embeddings = [embedding_model.encode(topic_descriptions[key]) for key in topic_descriptions]
34
+ cos_sim = lambda a,b: (a @ b.T) / (norm(a)*norm(b))
35
 
36
+ def lev_sim(a,b): return Levenshtein.distance(a,b)
 
 
 
 
37
 
38
+ def choose_topic(summary):
39
+ embed = embedding_model.encode(summary)
40
+ topic = ""
41
+ max_sim = 0.
42
+ for i,key in enumerate(topics):
43
+ sim = cos_sim(embed,embeddings[i])
44
+ if sim > max_sim:
45
+ topic = key
46
+ max_sim = sim
47
+ return topic
48
 
49
+ def authors_list_to_str(authors):
50
+ """input a list of authors, return a string represent authors"""
51
+ text = ""
52
+ for author in authors:
53
+ text+=author+", "
54
+ return text[:-3]
55
 
56
+ def authors_str_to_list(string):
57
+ """input a string of authors, return a list of authors"""
58
+ authors = []
59
+ list_auth = string.split("and")
60
+ for author in list_auth:
61
+ if author != "et al.":
62
+ authors.append(author.strip())
63
+ return authors
64
 
65
+ def chunk_texts(text, max_char=400):
66
+ """
67
+ Chunk a long text into several chunks, with each chunk about 300-400 characters long,
68
+ but make sure no word is cut in half.
69
+ Args:
70
+ text: The long text to be chunked.
71
+ max_char: The maximum number of characters per chunk (default: 400).
72
+ Returns:
73
+ A list of chunks.
74
+ """
75
+ chunks = []
76
+ current_chunk = ""
77
+ words = text.split()
78
+ for word in words:
79
+ if len(current_chunk) + len(word) + 1 >= max_char:
80
+ chunks.append(current_chunk)
81
+ current_chunk = " "
82
+ else:
83
+ current_chunk += " " + word
84
+ chunks.append(current_chunk.strip())
85
+ return chunks
86
 
87
+ def trimming(txt):
88
+ start = txt.find("{")
89
+ end = txt.rfind("}")
90
+ return txt[start:end+1].replace("\n"," ")
91
 
92
+ # crawl data
93
 
94
+ def extract_tag(txt,tagname):
95
+ return txt[txt.find("<"+tagname+">")+len(tagname)+2:txt.find("</"+tagname+">")]
96
 
97
+ def get_record(extract):
98
+ id = extract_tag(extract,"id")
99
+ updated = extract_tag(extract,"updated")
100
+ published = extract_tag(extract,"published")
101
+ title = extract_tag(extract,"title").replace("\n ","").strip()
102
+ summary = extract_tag(extract,"summary").replace("\n","").strip()
103
+ authors = []
104
+ while extract.find("<author>")!=-1:
105
+ author = extract_tag(extract,"name")
106
+ extract = extract[extract.find("</author>")+9:]
107
+ authors.append(author)
108
+ pattern = '<link title="pdf" href="'
109
+ link_start = extract.find('<link title="pdf" href="')
110
+ link = extract[link_start+len(pattern):extract.find("rel=",link_start)-2]
111
+ return [id, updated, published, title, authors, link, summary]
112
 
113
+ def crawl_exact_paper(title,author,max_results=3):
114
+ authors = authors_list_to_str(author)
115
+ records = []
116
+ url = 'http://export.arxiv.org/api/query?search_query=ti:{title}+AND+au:{author}&max_results={max_results}'.format(title=title,author=authors,max_results=max_results)
117
+ url = url.replace(" ","%20")
118
+ try:
119
+ arxiv_page = urllib.request.urlopen(url,timeout=100).read()
120
+ xml = str(arxiv_page,encoding="utf-8")
121
+ while xml.find("<entry>") != -1:
122
+ extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
123
+ xml = xml[xml.find("</entry>")+8:]
124
+ extract = get_record(extract)
125
+ topic = choose_topic(extract[6])
126
+ records.append([topic,*extract])
127
+ return records
128
+ except Exception as e:
129
+ return "Error: "+str(e)
130
 
131
+ def crawl_arxiv(keyword_list, max_results=100):
132
+ baseurl = 'http://export.arxiv.org/api/query?search_query='
133
+ records = []
134
+ for i,keyword in enumerate(keyword_list):
135
+ if i ==0:
136
+ url = baseurl + 'all:' + keyword
137
+ else:
138
+ url = url + '+OR+' + 'all:' + keyword
139
+ url = url+ '&max_results=' + str(max_results)
140
+ url = url.replace(' ', '%20')
141
+ try:
142
+ arxiv_page = urllib.request.urlopen(url,timeout=100).read()
143
+ xml = str(arxiv_page,encoding="utf-8")
144
+ while xml.find("<entry>") != -1:
145
+ extract = xml[xml.find("<entry>")+7:xml.find("</entry>")]
146
+ xml = xml[xml.find("</entry>")+8:]
147
+ extract = get_record(extract)
148
+ topic = choose_topic(extract[6])
149
+ records.append([topic,*extract])
150
+ return records
151
+ except Exception as e:
152
+ return "Error: "+str(e)
153
 
154
+ # This class act as a module
155
+ class ArxivChroma:
156
+ """
157
+ Create an interface to arxivdb, which only support query and addition.
158
+ This interface do not support edition and deletion procedures.
159
+ """
160
+ client = None
161
+ model = None
162
+ collection = None
163
 
164
+ @staticmethod
165
+ def connect(table="arxiv_records", name="arxivdb/"):
166
+ ArxivChroma.client = chromadb.PersistentClient(name)
167
+ ArxivChroma.model = embedding_model
168
+ ArxivChroma.collection = ArxivChroma.client.get_or_create_collection(table,
169
+ embedding_function=JinaAIEmbeddingFunction(
170
+ model = ArxivChroma.model
171
+ ))
 
 
 
 
 
 
 
172
 
173
+ @staticmethod
174
+ def query_relevant(keywords, query_texts, n_results=3):
175
+ """
176
+ Perform a query using a list of keywords (str),
177
+ or using a relavant string
178
+ """
179
+ contains = []
180
+ for keyword in keywords:
181
+ contains.append({"$contains":keyword.lower()})
182
+ return ArxivChroma.collection.query(
183
+ query_texts=query_texts,
184
+ where_document={
185
+ "$or":contains
186
+ },
187
+ n_results=n_results,
188
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ @staticmethod
191
+ def query_exact(id):
192
+ ids = ["{}_{}".format(id,j) for j in range(0,10)]
193
+ return ArxivChroma.collection.get(ids=ids)
194
+
195
+ @staticmethod
196
+ def add(crawl_records):
197
+ """
198
+ Add crawl_records (list) obtained from arxiv_crawlers
199
+ A record is a list of 8 columns:
200
+ [topic, id, updated, published, title, author, link, summary]
201
+ Return the final length of the database table
202
+ """
203
+ for record in crawl_records:
204
+ embed_text = """
205
+ Topic: {},
206
+ Title: {},
207
+ Summary: {}
208
+ """.format(record[0],record[4],record[7])
209
+ chunks = chunk_texts(embed_text)
210
+ ids = [record[1][21:]+"_"+str(j) for j in range(len(chunks))]
211
+ paper_ids = [{"paper_id":record[1][21:]} for _ in range(len(chunks))]
212
+ ArxivChroma.collection.add(
213
+ documents = chunks,
214
+ metadatas=paper_ids,
215
+ ids = ids
216
+ )
217
+ return ArxivChroma.collection.count()
218
+
219
+ @staticmethod
220
+ def close_connection():
221
+ pass
222
+
223
+ # This class act as a module
224
+ class ArxivSQL:
225
+ table = "arxivsql"
226
+ con = None
227
+ cur = None
228
+
229
+ @staticmethod
230
+ def connect(name="db.sqlite3"):
231
+ ArxivSQL.con = sqlite3.connect(name, check_same_thread=False)
232
+ ArxivSQL.cur = ArxivSQL.con.cursor()
233
 
234
+ @staticmethod
235
+ def query(title="", author=[], threshold = 15):
236
+ if len(author)>0:
237
+ query_author= " OR ".join([f"authors LIKE '%{a}%'" for a in author])
238
+ else:
239
+ query_author= "True"
240
+ # Execute the query
241
+ query = f"select * from {ArxivSQL.table} where {query_author}"
242
+ results = ArxivSQL.cur.execute(query).fetchall()
243
+ if len(title) == 0:
244
+ return results
245
+ else:
246
+ sim_score = {}
247
+ for row in results:
248
+ row_title = row[2]
249
+ row_id = row[0]
250
+ score = lev_sim(title, row_title)
251
+ if score < threshold:
252
+ sim_score[row_id] = score
253
+ sorted_results = sorted(sim_score.items(), key=lambda x: x[1])
254
+ return ArxivSQL.query_id(sorted_results)
 
255
 
256
+ @staticmethod
257
+ def query_id(ids=[]):
258
+ try:
259
+ if len(ids) == 0:
260
+ return None
261
+ query = "select * from {} where id in (".format(ArxivSQL.table)
262
+ for id in ids:
263
+ query+="'"+id+"',"
264
+ query = query[:-1] + ")"
265
+ result = ArxivSQL.cur.execute(query)
266
+ return result.fetchall()
267
+ except Exception as e:
268
+ print(e)
269
+ print("Error query: ",query)
270
 
271
+ @staticmethod
272
+ def add(crawl_records):
273
+ """
274
+ Add crawl_records (list) obtained from arxiv_crawlers
275
+ A record is a list of 8 columns:
276
+ [topic, id, updated, published, title, author, link, summary]
277
+ Return the final length of the database table
278
+ """
279
+ results = ""
280
+ for record in crawl_records:
281
+ try:
282
+ query = """insert into arxivsql values("{}","{}","{}","{}","{}","{}","{}")""".format(
283
+ record[1][21:],
284
+ record[0],
285
+ record[4].replace('"',"'"),
286
+ authors_list_to_str(record[5]),
287
+ record[2][:10],
288
+ record[3][:10],
289
+ record[6]
290
+ )
291
+ ArxivSQL.cur.execute(query)
292
+ ArxivSQL.con.commit()
293
+ except Exception as e:
294
+ results+=str(e)
295
+ results+="\n" + query + "\n"
296
+ finally:
297
+ return results