tools / routers /tool_find_related.py
Germano Cavalcante
Add Utils for generate documentantion
91ad34e
raw
history blame
16 kB
# find_related.py
import os
import pickle
import re
import torch
import threading
from datetime import datetime, timedelta
from sentence_transformers import SentenceTransformer, util
from fastapi import APIRouter
try:
from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get
except:
from utils_gitea import gitea_fetch_issues, gitea_json_issue_get
def _create_issue_string(title, body):
cleaned_body = body.replace('\r', '')
cleaned_body = cleaned_body.replace('**System Information**\n', '')
cleaned_body = cleaned_body.replace('**Blender Version**\n', '')
cleaned_body = cleaned_body.replace(
'Worked: (newest version of Blender that worked as expected)\n', '')
cleaned_body = cleaned_body.replace('**Short description of error**\n', '')
cleaned_body = cleaned_body.replace('**Addon Information**\n', '')
cleaned_body = cleaned_body.replace(
'**Exact steps for others to reproduce the error**\n', '')
cleaned_body = cleaned_body.replace(
'[Please describe the exact steps needed to reproduce the issue]\n', '')
cleaned_body = cleaned_body.replace(
'[Please fill out a short description of the error here]\n', '')
cleaned_body = cleaned_body.replace(
'[Based on the default startup or an attached .blend file (as simple as possible)]\n', '')
cleaned_body = re.sub(
r', branch: .+?, commit date: \d{4}-\d{2}-\d{2} \d{2}:\d{2}, hash: `.+?`', '', cleaned_body)
cleaned_body = re.sub(
r'\/?attachments\/[a-zA-Z0-9\-]+', 'attachment', cleaned_body)
cleaned_body = re.sub(
r'https?:\/\/[^\s/]+(?:\/[^\s/]+)*\/([^\s/]+)', lambda match: match.group(1), cleaned_body)
return title + '\n' + cleaned_body
def _find_latest_date(issues, default_str=None):
# Handle the case where 'issues' is empty
if not issues:
return default_str
return max((issue['updated_at'] for issue in issues), default=default_str)
class EmbeddingContext:
# These don't change
TOKEN_LEN_MAX_FOR_EMBEDDING = 512
TOKEN_LEN_MAX_BALCKLIST = 2 * TOKEN_LEN_MAX_FOR_EMBEDDING
issue_attr_filter = {'number', 'title', 'body', 'state', 'updated_at'}
cache_path = "routers/tool_find_related_cache.pkl"
# Set when creating the object
lock = None
model = None
openai_client = None
model_name = ''
config_type = ''
# Updates constantly
data = {}
black_list = {'blender': {109399, 113157, 114706},
'blender-addons': set()}
def __init__(self):
self.lock = threading.Lock()
try:
from config import settings
except:
import sys
sys.path.append(os.path.abspath(
os.path.join(os.path.dirname(__file__), '..')))
from config import settings
config_type = settings.embedding_api
model_name = settings.embedding_model
if config_type == 'sbert':
self.model = SentenceTransformer(model_name, use_auth_token=False)
self.model.max_seq_length = self.TOKEN_LEN_MAX_FOR_EMBEDDING
print("Max Sequence Length:", self.model.max_seq_length)
self.encode = self.encode_sbert
if torch.cuda.is_available():
self.model = self.model.to('cuda')
elif config_type == 'openai':
from openai import OpenAI
self.openai_client = OpenAI(
# base_url = settings.openai_api_base
api_key=settings.OPENAI_API_KEY,
)
self.encode = self.encode_openai
self.model_name = model_name
self.config_type = config_type
def encode(self, texts_to_embed):
pass
def encode_sbert(self, texts_to_embed):
return self.model.encode(texts_to_embed, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True)
def encode_openai(self, texts_to_embed):
import math
import time
tokens_count = 0
for text in texts_to_embed:
tokens_count += len(self.get_tokens(text))
chunks_num = math.ceil(tokens_count / 500000)
chunk_size = math.ceil(len(texts_to_embed) / chunks_num)
embeddings = []
for i in range(chunks_num):
start = i * chunk_size
end = start + chunk_size
chunk = texts_to_embed[start:end]
embeddings_tmp = self.openai_client.embeddings.create(
model=self.model_name,
input=chunk,
).data
if embeddings_tmp is None:
break
embeddings.extend(embeddings_tmp)
if i < chunks_num - 1:
time.sleep(60) # Wait 1 minute before the next call
return torch.stack([torch.tensor(embedding.embedding, dtype=torch.float32) for embedding in embeddings])
def get_tokens(self, text):
if self.model:
return self.model.tokenizer.tokenize(text)
tokens = []
for token in re.split(r'(\W|\b)', text):
if token.strip():
tokens.append(token)
return tokens
def create_strings_to_embbed(self, issues, black_list):
texts_to_embed = [_create_issue_string(
issue['title'], issue['body']) for issue in issues]
# Create issue blacklist (for keepping track)
token_count = 0
for i, text in enumerate(texts_to_embed):
tokens = self.get_tokens(text)
tokens_len = len(tokens)
token_count += tokens_len
if tokens_len > self.TOKEN_LEN_MAX_BALCKLIST:
# Only use the first TOKEN_LEN_MAX tokens
black_list.add(int(issues[i]['number']))
if self.config_type == 'openai':
texts_to_embed[i] = ' '.join(
tokens[:self.TOKEN_LEN_MAX_BALCKLIST])
return texts_to_embed
def embeddings_generate(self, repo):
if os.path.exists(self.cache_path):
with open(self.cache_path, 'rb') as file:
self.data = pickle.load(file)
if repo in self.data:
return
if not repo in self.black_list:
self.black_list[repo] = {}
black_list = self.black_list[repo]
issues = gitea_fetch_issues('blender', repo, state='open', since=None,
issue_attr_filter=self.issue_attr_filter, exclude=black_list)
issues = sorted(issues, key=lambda issue: int(issue['number']))
print("Embedding Issues...")
texts_to_embed = self.create_strings_to_embbed(issues, black_list)
embeddings = self.encode(texts_to_embed)
data = {
# Get the most recent date
'updated_at': _find_latest_date(issues),
'numbers': [int(issue['number']) for issue in issues],
'titles': [issue['title'] for issue in issues],
'embeddings': embeddings,
}
self.data[repo] = data
def embeddings_updated_get(self, repo):
with self.lock:
try:
data = self.data[repo]
except:
self.embeddings_generate(repo)
data = self.data[repo]
black_list = self.black_list[repo]
date_old = data['updated_at']
issues = gitea_fetch_issues(
'blender', repo, since=date_old, issue_attr_filter=self.issue_attr_filter, exclude=black_list)
# Get the most recent date
date_new = _find_latest_date(issues, date_old)
if date_new == date_old:
# Nothing changed
return data
data['updated_at'] = date_new
# autopep8: off
# WORKAROUND:
# Consider that if the time hasn't changed, it's the same issue.
issues = [issue for issue in issues if issue['updated_at'] != date_old]
numbers_old = data['numbers']
titles_old = data['titles']
embeddings_old = data['embeddings']
last_index = len(numbers_old) - 1
issues = sorted(issues, key=lambda issue: int(issue['number']))
issues_clos = [issue for issue in issues if issue['state'] == 'closed']
issues_open = [issue for issue in issues if issue['state'] == 'open']
numbers_clos = [int(issue['number']) for issue in issues_clos]
numbers_open = [int(issue['number']) for issue in issues_open]
old_closed = []
for number_clos in numbers_clos:
for i_old in range(last_index, -1, -1):
number_old = numbers_old[i_old]
if number_old < number_clos:
break
if number_old == number_clos:
old_closed.append(i_old)
break
if not old_closed and not issues_open:
return data
mask_open = torch.ones(len(numbers_open), dtype=torch.bool)
need_sort = False
change_map = []
for i_open, number_open in enumerate(numbers_open):
for i_old in range(last_index, -1, -1):
number_old = numbers_old[i_old]
if number_old < number_open:
need_sort = need_sort or (i_old != last_index)
break
if number_old == number_open:
change_map.append((i_old, i_open))
mask_open[i_open] = False
break
if issues_open:
texts_to_embed = self.create_strings_to_embbed(issues_open, black_list)
embeddings = self.encode(texts_to_embed)
for i_old, i_open in change_map:
titles_old[i_old] = issues_open[i_open]['title']
embeddings_old[i_old] = embeddings[i_open]
if old_closed:
total = (len(numbers_old) - len(old_closed)) + (len(numbers_open) - len(change_map))
numbers_new = [None] * total
titles_new = [None] * total
embeddings_new = torch.empty((total, *embeddings_old.shape[1:]), dtype=embeddings_old.dtype, device=embeddings_old.device)
i_new = 0
i_old = 0
for i_closed in old_closed + [len(numbers_old)]:
while i_old < i_closed:
numbers_new[i_new] = numbers_old[i_old]
titles_new[i_new] = titles_old[i_old]
embeddings_new[i_new] = embeddings_old[i_old]
i_new += 1
i_old += 1
i_old += 1
for i_open in range(len(numbers_open)):
if not mask_open[i_open]:
continue
titles_new[i_new] = issues_open[i_open]['title']
numbers_new[i_new] = numbers_open[i_open]
embeddings_new[i_new] = embeddings[i_open]
i_new += 1
assert i_new == total
elif mask_open.any():
titles_new = titles_old + [issue['title'] for i, issue in enumerate(issues_open) if mask_open[i]]
numbers_new = numbers_old + [number for i, number in enumerate(numbers_open) if mask_open[i]]
embeddings_new = torch.cat([embeddings_old, embeddings[mask_open]])
else:
# Only Updated Data changed
return data
if need_sort:
sorted_indices = sorted(range(len(numbers_new)), key=lambda k: numbers_new[k])
titles_new = [titles_new[i] for i in sorted_indices]
numbers_new = [numbers_new[i] for i in sorted_indices]
embeddings_new = embeddings_new[sorted_indices]
data['titles'] = titles_new
data['numbers'] = numbers_new
data['embeddings'] = embeddings_new
# autopep8: on
return data
router = APIRouter()
EMBEDDING_CTX = EmbeddingContext()
# EMBEDDING_CTX.embeddings_generate('blender', 'blender')
# EMBEDDING_CTX.embeddings_generate('blender', 'blender-addons')
def _sort_similarity(data, query_emb, limit):
duplicates = []
ret = util.semantic_search(
query_emb, data['embeddings'], top_k=limit, score_function=util.dot_score)
for score in ret[0]:
corpus_id = score['corpus_id']
text = f"#{data['numbers'][corpus_id]}: {data['titles'][corpus_id]}"
duplicates.append(text)
return duplicates
cached_search = {'text': '', 'repo': '', 'issues': []}
def text_search(owner, repo, text_to_embed, limit=None):
global cached_search
global EMBEDDING_CTX
if not text_to_embed:
return []
if text_to_embed == cached_search['text'] and repo == cached_search['repo']:
return cached_search['issues'][:limit]
data = EMBEDDING_CTX.embeddings_updated_get(owner, repo)
new_embedding = EMBEDDING_CTX.encode([text_to_embed])
result = _sort_similarity(data, new_embedding, 500)
cached_search = {'text': text_to_embed, 'repo': repo, 'issues': result}
return result[:limit]
def find_relatedness(gitea_issue, limit=20):
assert gitea_issue['repository']['owner'] == 'blender'
repo = gitea_issue['repository']['name']
title = gitea_issue['title']
body = gitea_issue['body']
number = int(gitea_issue['number'])
data = EMBEDDING_CTX.embeddings_updated_get(repo)
new_embedding = None
# Check if the embedding already exist.
for i in range(len(data['numbers']) - 1, -1, -1):
number_cached = data['numbers'][i]
if number_cached < number:
break
if number_cached == number:
new_embedding = data['embeddings'][i]
break
if new_embedding is None:
text_to_embed = _create_issue_string(title, body)
new_embedding = EMBEDDING_CTX.encode([text_to_embed])
duplicates = _sort_similarity(data, new_embedding, limit=limit)
if not duplicates:
return ''
number_cached = int(re.search(r'#(\d+):', duplicates[0]).group(1))
if number_cached == number:
return '\n'.join(duplicates[1:])
return '\n'.join(duplicates)
@router.get("/find_related/{repo}/{number}")
def find_related(repo: str = 'blender', number: int = 104399, limit: int = 50):
issue = gitea_json_issue_get('blender', repo, number)
related = find_relatedness(issue, limit=limit)
return related
if __name__ == "__main__":
update_cache = True
if update_cache:
EMBEDDING_CTX.embeddings_updated_get('blender')
EMBEDDING_CTX.embeddings_updated_get('blender-addons')
cache_path = EMBEDDING_CTX.cache_path
with open(cache_path, "wb") as file:
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU.
for val in EMBEDDING_CTX.data.values():
val['embeddings'] = val['embeddings'].to(torch.device('cpu'))
pickle.dump(EMBEDDING_CTX.data, file,
protocol=pickle.HIGHEST_PROTOCOL)
else:
# Converting the embeddings to be GPU.
for val in EMBEDDING_CTX.data.values():
val['embeddings'] = val['embeddings'].to(torch.device('cuda'))
# 'blender/blender/111434' must print #96153, #83604 and #79762
issue1 = gitea_json_issue_get('blender', 'blender', 111434)
issue2 = gitea_json_issue_get('blender', 'blender-addons', 104399)
related1 = find_relatedness(issue1, limit=20)
related2 = find_relatedness(issue2, limit=20)
print("These are the 20 most related issues:")
print(related1)
print()
print("These are the 20 most related issues:")
print(related2)