Germano Cavalcante commited on
Commit
6923641
1 Parent(s): f92bafd

Fix: Find Related: State all getting garbage

Browse files
routers/tool_find_related.py CHANGED
@@ -319,21 +319,24 @@ def _sort_similarity(data: dict,
319
  state: State = State.opened) -> list:
320
  duplicates = []
321
  embeddings = data['embeddings']
322
- true_indices = None
323
 
324
- if state != State.all:
 
 
325
  mask = data[state.value]
326
- embeddings = embeddings[mask]
327
- true_indices = mask.nonzero(as_tuple=True)[0]
 
328
 
329
  ret = util.semantic_search(
330
  query_emb, embeddings, top_k=limit, score_function=util.dot_score)
331
 
332
  for score in ret[0]:
333
  corpus_id = score['corpus_id']
334
- number = true_indices[corpus_id].item(
335
- ) if true_indices is not None else corpus_id
336
- text = f"#{number}: {data['titles'][number]}"
337
  duplicates.append(text)
338
 
339
  return duplicates
@@ -390,7 +393,8 @@ if __name__ == "__main__":
390
  val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
391
 
392
  # 'blender/blender/111434' must print #96153, #83604 and #79762
393
- related1 = find_relatedness('blender', 111434, limit=20)
 
394
  related2 = find_relatedness('blender-addons', 104399, limit=20)
395
 
396
  print("These are the 20 most related issues:")
 
319
  state: State = State.opened) -> list:
320
  duplicates = []
321
  embeddings = data['embeddings']
322
+ mask_opened = data["opened"]
323
 
324
+ if state == State.all:
325
+ mask = mask_opened | data["closed"]
326
+ else:
327
  mask = data[state.value]
328
+
329
+ embeddings = embeddings[mask]
330
+ true_indices = mask.nonzero(as_tuple=True)[0]
331
 
332
  ret = util.semantic_search(
333
  query_emb, embeddings, top_k=limit, score_function=util.dot_score)
334
 
335
  for score in ret[0]:
336
  corpus_id = score['corpus_id']
337
+ number = true_indices[corpus_id].item()
338
+ closed_char = "" if mask_opened[number] else "~~"
339
+ text = f"{closed_char}#{number}{closed_char}: {data['titles'][number]}"
340
  duplicates.append(text)
341
 
342
  return duplicates
 
393
  val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
394
 
395
  # 'blender/blender/111434' must print #96153, #83604 and #79762
396
+ related1 = find_relatedness(
397
+ 'blender', 111434, limit=20, state=State.all)
398
  related2 = find_relatedness('blender-addons', 104399, limit=20)
399
 
400
  print("These are the 20 most related issues:")
routers/tool_find_related_cache.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:46c42973a8caaa2f0d4a76ebc6ff16c0b8df927c9b16ba645c3f7155cce84f6a
3
- size 723382452
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:178cabd52e35e69184b0e49a0bdae18478e99d1b5cec6f590840e7d7c65576d8
3
+ size 723396066