Spaces:
Running
Running
Germano Cavalcante
commited on
Commit
•
f0d9ee1
1
Parent(s):
c4a3947
Find Related: Save embeddings in cache
Browse filesThis optimizes initialization on the virtual machine
routers/tool_find_related.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
# find_related.py
|
2 |
|
|
|
|
|
3 |
import re
|
4 |
import torch
|
5 |
import threading
|
@@ -11,7 +13,6 @@ try:
|
|
11 |
from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get
|
12 |
from config import settings
|
13 |
except:
|
14 |
-
import os
|
15 |
import sys
|
16 |
from utils_gitea import gitea_fetch_issues, gitea_json_issue_get
|
17 |
sys.path.append(os.path.abspath(
|
@@ -58,6 +59,7 @@ class EmbeddingContext:
|
|
58 |
TOKEN_LEN_MAX_FOR_EMBEDDING = 512
|
59 |
TOKEN_LEN_MAX_BALCKLIST = 2 * TOKEN_LEN_MAX_FOR_EMBEDDING
|
60 |
issue_attr_filter = {'number', 'title', 'body', 'state', 'updated_at'}
|
|
|
61 |
|
62 |
# Set when creating the object
|
63 |
lock = None
|
@@ -68,8 +70,8 @@ class EmbeddingContext:
|
|
68 |
|
69 |
# Updates constantly
|
70 |
data = {}
|
71 |
-
black_list = {'blender': {
|
72 |
-
|
73 |
|
74 |
def __init__(self):
|
75 |
self.lock = threading.Lock()
|
@@ -165,15 +167,19 @@ class EmbeddingContext:
|
|
165 |
|
166 |
return texts_to_embed
|
167 |
|
168 |
-
def embeddings_generate(self,
|
169 |
-
if
|
170 |
-
self.
|
171 |
-
|
172 |
-
|
|
|
173 |
|
174 |
-
|
|
|
175 |
|
176 |
-
|
|
|
|
|
177 |
issue_attr_filter=self.issue_attr_filter, exclude=black_list)
|
178 |
|
179 |
issues = sorted(issues, key=lambda issue: int(issue['number']))
|
@@ -190,26 +196,21 @@ class EmbeddingContext:
|
|
190 |
'embeddings': embeddings,
|
191 |
}
|
192 |
|
193 |
-
|
194 |
-
self.data[owner] = {repo: {}}
|
195 |
-
elif not repo in self.data[owner]:
|
196 |
-
self.data[owner][repo] = {}
|
197 |
-
|
198 |
-
self.data[owner][repo] = data
|
199 |
|
200 |
-
def embeddings_updated_get(self,
|
201 |
with self.lock:
|
202 |
try:
|
203 |
-
data = self.data[
|
204 |
except:
|
205 |
-
self.embeddings_generate(
|
206 |
-
data = self.data[
|
207 |
|
208 |
-
black_list = self.black_list[
|
209 |
date_old = data['updated_at']
|
210 |
|
211 |
issues = gitea_fetch_issues(
|
212 |
-
|
213 |
|
214 |
# Get the most recent date
|
215 |
date_new = _find_latest_date(issues, date_old)
|
@@ -361,13 +362,13 @@ def text_search(owner, repo, text_to_embed, limit=None):
|
|
361 |
|
362 |
|
363 |
def find_relatedness(gitea_issue, limit=20):
|
364 |
-
|
365 |
repo = gitea_issue['repository']['name']
|
366 |
title = gitea_issue['title']
|
367 |
body = gitea_issue['body']
|
368 |
number = int(gitea_issue['number'])
|
369 |
|
370 |
-
data = EMBEDDING_CTX.embeddings_updated_get(
|
371 |
new_embedding = None
|
372 |
|
373 |
# Check if the embedding already exist.
|
@@ -394,38 +395,40 @@ def find_relatedness(gitea_issue, limit=20):
|
|
394 |
return '\n'.join(duplicates)
|
395 |
|
396 |
|
397 |
-
@router.get("/find_related")
|
398 |
-
def find_related(
|
399 |
-
issue = gitea_json_issue_get(
|
400 |
related = find_relatedness(issue, limit=limit)
|
401 |
return related
|
402 |
|
403 |
|
404 |
if __name__ == "__main__":
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
pickle.dump(
|
416 |
-
|
417 |
else:
|
418 |
-
|
419 |
-
|
|
|
420 |
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
|
425 |
-
|
|
|
426 |
|
427 |
-
if related == '':
|
428 |
-
print("No potential duplicates found.")
|
429 |
-
else:
|
430 |
print("These are the 20 most related issues:")
|
431 |
-
print(
|
|
|
|
|
|
|
|
1 |
# find_related.py
|
2 |
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
import re
|
6 |
import torch
|
7 |
import threading
|
|
|
13 |
from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get
|
14 |
from config import settings
|
15 |
except:
|
|
|
16 |
import sys
|
17 |
from utils_gitea import gitea_fetch_issues, gitea_json_issue_get
|
18 |
sys.path.append(os.path.abspath(
|
|
|
59 |
TOKEN_LEN_MAX_FOR_EMBEDDING = 512
|
60 |
TOKEN_LEN_MAX_BALCKLIST = 2 * TOKEN_LEN_MAX_FOR_EMBEDDING
|
61 |
issue_attr_filter = {'number', 'title', 'body', 'state', 'updated_at'}
|
62 |
+
cache_path = "routers/tool_find_related_cache.pkl"
|
63 |
|
64 |
# Set when creating the object
|
65 |
lock = None
|
|
|
70 |
|
71 |
# Updates constantly
|
72 |
data = {}
|
73 |
+
black_list = {'blender': {109399, 113157, 114706},
|
74 |
+
'blender-addons': set()}
|
75 |
|
76 |
def __init__(self):
|
77 |
self.lock = threading.Lock()
|
|
|
167 |
|
168 |
return texts_to_embed
|
169 |
|
170 |
+
def embeddings_generate(self, repo):
|
171 |
+
if os.path.exists(self.cache_path):
|
172 |
+
with open(self.cache_path, 'rb') as file:
|
173 |
+
self.data = pickle.load(file)
|
174 |
+
if repo in self.data:
|
175 |
+
return
|
176 |
|
177 |
+
if not repo in self.black_list:
|
178 |
+
self.black_list[repo] = {}
|
179 |
|
180 |
+
black_list = self.black_list[repo]
|
181 |
+
|
182 |
+
issues = gitea_fetch_issues('blender', repo, state='open', since=None,
|
183 |
issue_attr_filter=self.issue_attr_filter, exclude=black_list)
|
184 |
|
185 |
issues = sorted(issues, key=lambda issue: int(issue['number']))
|
|
|
196 |
'embeddings': embeddings,
|
197 |
}
|
198 |
|
199 |
+
self.data[repo] = data
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
+
def embeddings_updated_get(self, repo):
|
202 |
with self.lock:
|
203 |
try:
|
204 |
+
data = self.data[repo]
|
205 |
except:
|
206 |
+
self.embeddings_generate(repo)
|
207 |
+
data = self.data[repo]
|
208 |
|
209 |
+
black_list = self.black_list[repo]
|
210 |
date_old = data['updated_at']
|
211 |
|
212 |
issues = gitea_fetch_issues(
|
213 |
+
'blender', repo, since=date_old, issue_attr_filter=self.issue_attr_filter, exclude=black_list)
|
214 |
|
215 |
# Get the most recent date
|
216 |
date_new = _find_latest_date(issues, date_old)
|
|
|
362 |
|
363 |
|
364 |
def find_relatedness(gitea_issue, limit=20):
|
365 |
+
assert gitea_issue['repository']['owner'] == 'blender'
|
366 |
repo = gitea_issue['repository']['name']
|
367 |
title = gitea_issue['title']
|
368 |
body = gitea_issue['body']
|
369 |
number = int(gitea_issue['number'])
|
370 |
|
371 |
+
data = EMBEDDING_CTX.embeddings_updated_get(repo)
|
372 |
new_embedding = None
|
373 |
|
374 |
# Check if the embedding already exist.
|
|
|
395 |
return '\n'.join(duplicates)
|
396 |
|
397 |
|
398 |
+
@router.get("/find_related/{repo}/{number}")
|
399 |
+
def find_related(repo: str = 'blender', number: int = 104399, limit: int = 50):
|
400 |
+
issue = gitea_json_issue_get('blender', repo, number)
|
401 |
related = find_relatedness(issue, limit=limit)
|
402 |
return related
|
403 |
|
404 |
|
405 |
if __name__ == "__main__":
|
406 |
+
update_cache = True
|
407 |
+
if update_cache:
|
408 |
+
EMBEDDING_CTX.embeddings_updated_get('blender')
|
409 |
+
EMBEDDING_CTX.embeddings_updated_get('blender-addons')
|
410 |
+
cache_path = EMBEDDING_CTX.cache_path
|
411 |
+
with open(cache_path, "wb") as file:
|
412 |
+
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
|
413 |
+
for val in EMBEDDING_CTX.data.values():
|
414 |
+
val['embeddings'] = val['embeddings'].to(torch.device('cpu'))
|
415 |
+
|
416 |
+
pickle.dump(EMBEDDING_CTX.data, file,
|
417 |
+
protocol=pickle.HIGHEST_PROTOCOL)
|
418 |
else:
|
419 |
+
# Converting the embeddings to be GPU.
|
420 |
+
for val in EMBEDDING_CTX.data.values():
|
421 |
+
val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
|
422 |
|
423 |
+
# 'blender/blender/111434' must print #96153, #83604 and #79762
|
424 |
+
issue1 = gitea_json_issue_get('blender', 'blender', 111434)
|
425 |
+
issue2 = gitea_json_issue_get('blender', 'blender-addons', 104399)
|
426 |
|
427 |
+
related1 = find_relatedness(issue1, limit=20)
|
428 |
+
related2 = find_relatedness(issue2, limit=20)
|
429 |
|
|
|
|
|
|
|
430 |
print("These are the 20 most related issues:")
|
431 |
+
print(related1)
|
432 |
+
print()
|
433 |
+
print("These are the 20 most related issues:")
|
434 |
+
print(related2)
|
routers/tool_find_related_cache.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4278a800fae11df440b487415a95baa437f1b0651b7b007e44aa951d795049a
|
3 |
+
size 21250520
|