Germano Cavalcante commited on
Commit
f0d9ee1
1 Parent(s): c4a3947

Find Related: Save embeddings in cache

Browse files

This optimizes initialization on the virtual machine

routers/tool_find_related.py CHANGED
@@ -1,5 +1,7 @@
1
  # find_related.py
2
 
 
 
3
  import re
4
  import torch
5
  import threading
@@ -11,7 +13,6 @@ try:
11
  from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get
12
  from config import settings
13
  except:
14
- import os
15
  import sys
16
  from utils_gitea import gitea_fetch_issues, gitea_json_issue_get
17
  sys.path.append(os.path.abspath(
@@ -58,6 +59,7 @@ class EmbeddingContext:
58
  TOKEN_LEN_MAX_FOR_EMBEDDING = 512
59
  TOKEN_LEN_MAX_BALCKLIST = 2 * TOKEN_LEN_MAX_FOR_EMBEDDING
60
  issue_attr_filter = {'number', 'title', 'body', 'state', 'updated_at'}
 
61
 
62
  # Set when creating the object
63
  lock = None
@@ -68,8 +70,8 @@ class EmbeddingContext:
68
 
69
  # Updates constantly
70
  data = {}
71
- black_list = {'blender': {'blender': {109399, 113157, 114706},
72
- 'blender-addons': set()}}
73
 
74
  def __init__(self):
75
  self.lock = threading.Lock()
@@ -165,15 +167,19 @@ class EmbeddingContext:
165
 
166
  return texts_to_embed
167
 
168
- def embeddings_generate(self, owner, repo):
169
- if not owner in self.black_list:
170
- self.black_list[owner] = {repo: {}}
171
- elif not repo in self.black_list[owner]:
172
- self.black_list[owner][repo] = {}
 
173
 
174
- black_list = self.black_list[owner][repo]
 
175
 
176
- issues = gitea_fetch_issues(owner, repo, state='open', since=None,
 
 
177
  issue_attr_filter=self.issue_attr_filter, exclude=black_list)
178
 
179
  issues = sorted(issues, key=lambda issue: int(issue['number']))
@@ -190,26 +196,21 @@ class EmbeddingContext:
190
  'embeddings': embeddings,
191
  }
192
 
193
- if not owner in self.data:
194
- self.data[owner] = {repo: {}}
195
- elif not repo in self.data[owner]:
196
- self.data[owner][repo] = {}
197
-
198
- self.data[owner][repo] = data
199
 
200
- def embeddings_updated_get(self, owner, repo):
201
  with self.lock:
202
  try:
203
- data = self.data[owner][repo]
204
  except:
205
- self.embeddings_generate(owner, repo)
206
- data = self.data[owner][repo]
207
 
208
- black_list = self.black_list[owner][repo]
209
  date_old = data['updated_at']
210
 
211
  issues = gitea_fetch_issues(
212
- owner, repo, since=date_old, issue_attr_filter=self.issue_attr_filter, exclude=black_list)
213
 
214
  # Get the most recent date
215
  date_new = _find_latest_date(issues, date_old)
@@ -361,13 +362,13 @@ def text_search(owner, repo, text_to_embed, limit=None):
361
 
362
 
363
  def find_relatedness(gitea_issue, limit=20):
364
- owner = gitea_issue['repository']['owner']
365
  repo = gitea_issue['repository']['name']
366
  title = gitea_issue['title']
367
  body = gitea_issue['body']
368
  number = int(gitea_issue['number'])
369
 
370
- data = EMBEDDING_CTX.embeddings_updated_get(owner, repo)
371
  new_embedding = None
372
 
373
  # Check if the embedding already exist.
@@ -394,38 +395,40 @@ def find_relatedness(gitea_issue, limit=20):
394
  return '\n'.join(duplicates)
395
 
396
 
397
- @router.get("/find_related")
398
- def find_related(owner: str = 'blender', repo: str = 'blender', number: int = 1, limit: int = 50):
399
- issue = gitea_json_issue_get(owner, repo, number)
400
  related = find_relatedness(issue, limit=limit)
401
  return related
402
 
403
 
404
  if __name__ == "__main__":
405
- import os
406
- import pickle
407
- repo = 'blender-addons'
408
- cache_dir = f"routers/cache/{repo}"
409
- file_path = os.path.join(cache_dir, "data.pkl")
410
-
411
- if not os.path.exists(cache_dir):
412
- os.makedirs(cache_dir, exist_ok=True)
413
- with open(file_path, "wb") as file:
414
- EMBEDDING_CTX.embeddings_generate('blender', repo)
415
- pickle.dump(
416
- EMBEDDING_CTX.data['blender'][repo], file, protocol=pickle.HIGHEST_PROTOCOL)
417
  else:
418
- with open(file_path, 'rb') as file:
419
- EMBEDDING_CTX.data['blender'] = {repo: pickle.load(file)}
 
420
 
421
- # 'blender/blender/111434' must print #96153, #83604 and #79762
422
- issue = gitea_json_issue_get('blender', repo, 105027)
423
- print(issue['title'])
424
 
425
- related = find_relatedness(issue, limit=50)
 
426
 
427
- if related == '':
428
- print("No potential duplicates found.")
429
- else:
430
  print("These are the 20 most related issues:")
431
- print(related)
 
 
 
 
1
  # find_related.py
2
 
3
+ import os
4
+ import pickle
5
  import re
6
  import torch
7
  import threading
 
13
  from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get
14
  from config import settings
15
  except:
 
16
  import sys
17
  from utils_gitea import gitea_fetch_issues, gitea_json_issue_get
18
  sys.path.append(os.path.abspath(
 
59
  TOKEN_LEN_MAX_FOR_EMBEDDING = 512
60
  TOKEN_LEN_MAX_BALCKLIST = 2 * TOKEN_LEN_MAX_FOR_EMBEDDING
61
  issue_attr_filter = {'number', 'title', 'body', 'state', 'updated_at'}
62
+ cache_path = "routers/tool_find_related_cache.pkl"
63
 
64
  # Set when creating the object
65
  lock = None
 
70
 
71
  # Updates constantly
72
  data = {}
73
+ black_list = {'blender': {109399, 113157, 114706},
74
+ 'blender-addons': set()}
75
 
76
  def __init__(self):
77
  self.lock = threading.Lock()
 
167
 
168
  return texts_to_embed
169
 
170
+ def embeddings_generate(self, repo):
171
+ if os.path.exists(self.cache_path):
172
+ with open(self.cache_path, 'rb') as file:
173
+ self.data = pickle.load(file)
174
+ if repo in self.data:
175
+ return
176
 
177
+ if not repo in self.black_list:
178
+ self.black_list[repo] = {}
179
 
180
+ black_list = self.black_list[repo]
181
+
182
+ issues = gitea_fetch_issues('blender', repo, state='open', since=None,
183
  issue_attr_filter=self.issue_attr_filter, exclude=black_list)
184
 
185
  issues = sorted(issues, key=lambda issue: int(issue['number']))
 
196
  'embeddings': embeddings,
197
  }
198
 
199
+ self.data[repo] = data
 
 
 
 
 
200
 
201
+ def embeddings_updated_get(self, repo):
202
  with self.lock:
203
  try:
204
+ data = self.data[repo]
205
  except:
206
+ self.embeddings_generate(repo)
207
+ data = self.data[repo]
208
 
209
+ black_list = self.black_list[repo]
210
  date_old = data['updated_at']
211
 
212
  issues = gitea_fetch_issues(
213
+ 'blender', repo, since=date_old, issue_attr_filter=self.issue_attr_filter, exclude=black_list)
214
 
215
  # Get the most recent date
216
  date_new = _find_latest_date(issues, date_old)
 
362
 
363
 
364
  def find_relatedness(gitea_issue, limit=20):
365
+ assert gitea_issue['repository']['owner'] == 'blender'
366
  repo = gitea_issue['repository']['name']
367
  title = gitea_issue['title']
368
  body = gitea_issue['body']
369
  number = int(gitea_issue['number'])
370
 
371
+ data = EMBEDDING_CTX.embeddings_updated_get(repo)
372
  new_embedding = None
373
 
374
  # Check if the embedding already exist.
 
395
  return '\n'.join(duplicates)
396
 
397
 
398
+ @router.get("/find_related/{repo}/{number}")
399
+ def find_related(repo: str = 'blender', number: int = 104399, limit: int = 50):
400
+ issue = gitea_json_issue_get('blender', repo, number)
401
  related = find_relatedness(issue, limit=limit)
402
  return related
403
 
404
 
405
  if __name__ == "__main__":
406
+ update_cache = True
407
+ if update_cache:
408
+ EMBEDDING_CTX.embeddings_updated_get('blender')
409
+ EMBEDDING_CTX.embeddings_updated_get('blender-addons')
410
+ cache_path = EMBEDDING_CTX.cache_path
411
+ with open(cache_path, "wb") as file:
412
+ # Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
413
+ for val in EMBEDDING_CTX.data.values():
414
+ val['embeddings'] = val['embeddings'].to(torch.device('cpu'))
415
+
416
+ pickle.dump(EMBEDDING_CTX.data, file,
417
+ protocol=pickle.HIGHEST_PROTOCOL)
418
  else:
419
+ # Converting the embeddings to be GPU.
420
+ for val in EMBEDDING_CTX.data.values():
421
+ val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
422
 
423
+ # 'blender/blender/111434' must print #96153, #83604 and #79762
424
+ issue1 = gitea_json_issue_get('blender', 'blender', 111434)
425
+ issue2 = gitea_json_issue_get('blender', 'blender-addons', 104399)
426
 
427
+ related1 = find_relatedness(issue1, limit=20)
428
+ related2 = find_relatedness(issue2, limit=20)
429
 
 
 
 
430
  print("These are the 20 most related issues:")
431
+ print(related1)
432
+ print()
433
+ print("These are the 20 most related issues:")
434
+ print(related2)
routers/tool_find_related_cache.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4278a800fae11df440b487415a95baa437f1b0651b7b007e44aa951d795049a
3
+ size 21250520