Spaces:
Running
Running
File size: 11,012 Bytes
ed15883 98caf15 f0d9ee1 a78f82b ed15883 f92bafd ed15883 98caf15 f92bafd ed15883 c7f8eb7 9a6a74b 98caf15 ed15883 5def575 98caf15 ed15883 5def575 98caf15 92238be ed15883 9a6a74b ed15883 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b f92bafd 9a6a74b f92bafd 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b f0d9ee1 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b f92bafd 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 5def575 9a6a74b f92bafd 9a6a74b f92bafd 9a6a74b f92bafd 9a6a74b 6923641 9a6a74b f92bafd 9a6a74b f92bafd 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b f92bafd 9a6a74b 98caf15 9a6a74b 98caf15 9a6a74b ed15883 9a6a74b 98caf15 f0d9ee1 9a6a74b f0d9ee1 ed15883 f0d9ee1 9a6a74b ed15883 9a6a74b ed15883 9a6a74b ed15883 |
|
# routers/find_related.py
import os
import pickle
import torch
import re
from typing import List
from datetime import datetime, timedelta
from enum import Enum
from sentence_transformers import util
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
try:
from .embedding import EMBEDDING_CTX
from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get
except:
from embedding import EMBEDDING_CTX
from utils_gitea import gitea_fetch_issues, gitea_json_issue_get, gitea_issues_body_updated_at_get
router = APIRouter()
issue_attr_filter = {'number', 'title', 'body',
'state', 'updated_at', 'created_at'}
class State(str, Enum):
opened = "opened"
closed = "closed"
all = "all"
class _Data(dict):
cache_path = "routers/embedding/embeddings_issues.pkl"
@staticmethod
def _create_issue_string(title, body):
cleaned_body = body.replace('\r', '')
cleaned_body = cleaned_body.replace('**System Information**\n', '')
cleaned_body = cleaned_body.replace('**Blender Version**\n', '')
cleaned_body = cleaned_body.replace(
'Worked: (newest version of Blender that worked as expected)\n', '')
cleaned_body = cleaned_body.replace(
'**Short description of error**\n', '')
cleaned_body = cleaned_body.replace('**Addon Information**\n', '')
cleaned_body = cleaned_body.replace(
'**Exact steps for others to reproduce the error**\n', '')
cleaned_body = cleaned_body.replace(
'[Please describe the exact steps needed to reproduce the issue]\n', '')
cleaned_body = cleaned_body.replace(
'[Please fill out a short description of the error here]\n', '')
cleaned_body = cleaned_body.replace(
'[Based on the default startup or an attached .blend file (as simple as possible)]\n', '')
cleaned_body = re.sub(
r', branch: .+?, commit date: \d{4}-\d{2}-\d{2} \d{2}:\d{2}, hash: `.+?`', '', cleaned_body)
cleaned_body = re.sub(
r'\/?attachments\/[a-zA-Z0-9\-]+', 'attachment', cleaned_body)
cleaned_body = re.sub(
r'https?:\/\/[^\s/]+(?:\/[^\s/]+)*\/([^\s/]+)', lambda match: match.group(1), cleaned_body)
return title + '\n' + cleaned_body
@staticmethod
def _find_latest_date(issues, default_str=None):
# Handle the case where 'issues' is empty
if not issues:
return default_str
return max((issue['updated_at'] for issue in issues), default=default_str)
@classmethod
def _create_strings_to_embbed(cls, issues):
texts_to_embed = [cls._create_issue_string(
issue['title'], issue['body']) for issue in issues]
return texts_to_embed
def _data_ensure_size(self, repo, size_new):
ARRAY_CHUNK_SIZE = 4096
updated_at_old = None
arrays_size_old = 0
titles_old = []
try:
arrays_size_old = self[repo]['arrays_size']
if size_new <= arrays_size_old:
return
except:
pass
arrays_size_new = ARRAY_CHUNK_SIZE * \
(int(size_new / ARRAY_CHUNK_SIZE) + 1)
data_new = {
'updated_at': updated_at_old,
'arrays_size': arrays_size_new,
'titles': titles_old + [None] * (arrays_size_new - arrays_size_old),
'embeddings': torch.empty((arrays_size_new, *EMBEDDING_CTX.embedding_shape),
dtype=EMBEDDING_CTX.embedding_dtype,
device=EMBEDDING_CTX.embedding_device),
'opened': torch.zeros(arrays_size_new, dtype=torch.bool),
'closed': torch.zeros(arrays_size_new, dtype=torch.bool),
}
try:
data_new['embeddings'][:arrays_size_old] = self[repo]['embeddings']
data_new['opened'][:arrays_size_old] = self[repo]['opened']
data_new['closed'][:arrays_size_old] = self[repo]['closed']
except:
pass
self[repo] = data_new
def _embeddings_generate(self, repo):
if os.path.exists(self.cache_path):
with open(self.cache_path, 'rb') as file:
data = pickle.load(file)
self.update(data)
if repo in self:
return
issues = gitea_fetch_issues('blender', repo, state='all', since=None,
issue_attr_filter=issue_attr_filter)
# issues = sorted(issues, key=lambda issue: int(issue['number']))
print("Embedding Issues...")
texts_to_embed = self._create_strings_to_embbed(issues)
embeddings = EMBEDDING_CTX.encode(texts_to_embed)
self._data_ensure_size(repo, int(issues[0]['number']))
self[repo]['updated_at'] = self._find_latest_date(issues)
titles = self[repo]['titles']
embeddings_new = self[repo]['embeddings']
opened = self[repo]['opened']
closed = self[repo]['closed']
for i, issue in enumerate(issues):
number = int(issue['number'])
titles[number] = issue['title']
embeddings_new[number] = embeddings[i]
if issue['state'] == 'open':
opened[number] = True
if issue['state'] == 'closed':
closed[number] = True
def _embeddings_updated_get(self, repo):
with EMBEDDING_CTX.lock:
try:
data_repo = self[repo]
except:
self._embeddings_generate(repo)
data_repo = self[repo]
date_old = data_repo['updated_at']
issues = gitea_fetch_issues(
'blender', repo, since=date_old, issue_attr_filter=issue_attr_filter)
# Get the most recent date
date_new = self._find_latest_date(issues, date_old)
if date_new == date_old:
# Nothing changed
return data_repo
data_repo['updated_at'] = date_new
# autopep8: off
# Consider that if the time hasn't changed, it's the same issue.
issues = [issue for issue in issues if issue['updated_at'] != date_old]
self._data_ensure_size(repo, int(issues[0]['number']))
updated_at = gitea_issues_body_updated_at_get(issues)
issues_to_embed = []
for i, issue in enumerate(issues):
number = int(issue['number'])
if issue['state'] == 'open':
data_repo['opened'][number] = True
if issue['state'] == 'closed':
data_repo['closed'][number] = True
title_old = data_repo['titles'][number]
if title_old != issue['title']:
data_repo['titles'][number] = issue['title']
issues_to_embed.append(issue)
elif updated_at[i] >= date_old:
issues_to_embed.append(issue)
if issues_to_embed:
print(f"Embedding {len(issues_to_embed)} issue{'s' if len(issues_to_embed) > 1 else ''}")
texts_to_embed = self._create_strings_to_embbed(issues_to_embed)
embeddings = EMBEDDING_CTX.encode(texts_to_embed)
for i, issue in enumerate(issues_to_embed):
number = int(issue['number'])
data_repo['embeddings'][number] = embeddings[i]
# autopep8: on
return data_repo
def _sort_similarity(self,
repo: str,
query_emb: List[torch.Tensor],
limit: int,
state: State = State.opened) -> list:
duplicates = []
data = self[repo]
embeddings = data['embeddings']
mask_opened = data["opened"]
if state == State.all:
mask = mask_opened | data["closed"]
else:
mask = data[state.value]
embeddings = embeddings[mask]
true_indices = mask.nonzero(as_tuple=True)[0]
ret = util.semantic_search(
query_emb, embeddings, top_k=limit, score_function=util.dot_score)
for score in ret[0]:
corpus_id = score['corpus_id']
number = true_indices[corpus_id].item()
closed_char = "" if mask_opened[number] else "~~"
text = f"{closed_char}#{number}{closed_char}: {data['titles'][number]}"
duplicates.append(text)
return duplicates
def find_relatedness(self, repo: str, number: int, limit: int = 20, state: State = State.opened):
data = self._embeddings_updated_get(repo)
# Check if the embedding already exists.
if data['titles'][number] is not None:
new_embedding = data['embeddings'][number]
else:
gitea_issue = gitea_json_issue_get('blender', repo, number)
text_to_embed = self._create_issue_string(
gitea_issue['title'], gitea_issue['body'])
new_embedding = EMBEDDING_CTX.encode([text_to_embed])
duplicates = self._sort_similarity(
repo, new_embedding, limit=limit, state=state)
if not duplicates:
return ''
if match := re.search(r'(~~)?#(\d+)(~~)?:', duplicates[0]):
number_cached = int(match.group(2))
if number_cached == number:
return '\n'.join(duplicates[1:])
return '\n'.join(duplicates)
G_data = _Data()
@router.get("/find_related/{repo}/{number}", response_class=PlainTextResponse)
def find_related(repo: str = 'blender', number: int = 104399, limit: int = 15, state: State = State.opened) -> str:
related = G_data.find_relatedness(repo, number, limit=limit, state=state)
return related
if __name__ == "__main__":
update_cache = True
if update_cache:
G_data._embeddings_updated_get('blender')
G_data._embeddings_updated_get('blender-addons')
with open(G_data.cache_path, "wb") as file:
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
for val in G_data.values():
val['embeddings'] = val['embeddings'].to(torch.device('cpu'))
pickle.dump(dict(G_data), file, protocol=pickle.HIGHEST_PROTOCOL)
# Converting the embeddings to be GPU.
for val in G_data.values():
val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
# 'blender/blender/111434' must print #96153, #83604 and #79762
related1 = G_data.find_relatedness(
'blender', 111434, limit=20, state=State.all)
related2 = G_data.find_relatedness('blender-addons', 104399, limit=20)
print("These are the 20 most related issues:")
print(related1)
print()
print("These are the 20 most related issues:")
print(related2)
|