Spaces:
Running
Running
# find_related.py | |
import os | |
import pickle | |
import re | |
import torch | |
import threading | |
from datetime import datetime, timedelta | |
from sentence_transformers import SentenceTransformer, util | |
from fastapi import APIRouter | |
try: | |
from .utils_gitea import gitea_fetch_issues, gitea_json_issue_get | |
except: | |
from utils_gitea import gitea_fetch_issues, gitea_json_issue_get | |
def _create_issue_string(title, body): | |
cleaned_body = body.replace('\r', '') | |
cleaned_body = cleaned_body.replace('**System Information**\n', '') | |
cleaned_body = cleaned_body.replace('**Blender Version**\n', '') | |
cleaned_body = cleaned_body.replace( | |
'Worked: (newest version of Blender that worked as expected)\n', '') | |
cleaned_body = cleaned_body.replace('**Short description of error**\n', '') | |
cleaned_body = cleaned_body.replace('**Addon Information**\n', '') | |
cleaned_body = cleaned_body.replace( | |
'**Exact steps for others to reproduce the error**\n', '') | |
cleaned_body = cleaned_body.replace( | |
'[Please describe the exact steps needed to reproduce the issue]\n', '') | |
cleaned_body = cleaned_body.replace( | |
'[Please fill out a short description of the error here]\n', '') | |
cleaned_body = cleaned_body.replace( | |
'[Based on the default startup or an attached .blend file (as simple as possible)]\n', '') | |
cleaned_body = re.sub( | |
r', branch: .+?, commit date: \d{4}-\d{2}-\d{2} \d{2}:\d{2}, hash: `.+?`', '', cleaned_body) | |
cleaned_body = re.sub( | |
r'\/?attachments\/[a-zA-Z0-9\-]+', 'attachment', cleaned_body) | |
cleaned_body = re.sub( | |
r'https?:\/\/[^\s/]+(?:\/[^\s/]+)*\/([^\s/]+)', lambda match: match.group(1), cleaned_body) | |
return title + '\n' + cleaned_body | |
def _find_latest_date(issues, default_str=None): | |
# Handle the case where 'issues' is empty | |
if not issues: | |
return default_str | |
return max((issue['updated_at'] for issue in issues), default=default_str) | |
class EmbeddingContext: | |
# These don't change | |
TOKEN_LEN_MAX_FOR_EMBEDDING = 512 | |
TOKEN_LEN_MAX_BALCKLIST = 2 * TOKEN_LEN_MAX_FOR_EMBEDDING | |
issue_attr_filter = {'number', 'title', 'body', 'state', 'updated_at'} | |
cache_path = "routers/tool_find_related_cache.pkl" | |
# Set when creating the object | |
lock = None | |
model = None | |
openai_client = None | |
model_name = '' | |
config_type = '' | |
# Updates constantly | |
data = {} | |
black_list = {'blender': {109399, 113157, 114706}, | |
'blender-addons': set()} | |
def __init__(self): | |
self.lock = threading.Lock() | |
try: | |
from config import settings | |
except: | |
import sys | |
sys.path.append(os.path.abspath( | |
os.path.join(os.path.dirname(__file__), '..'))) | |
from config import settings | |
config_type = settings.embedding_api | |
model_name = settings.embedding_model | |
if config_type == 'sbert': | |
self.model = SentenceTransformer(model_name, use_auth_token=False) | |
self.model.max_seq_length = self.TOKEN_LEN_MAX_FOR_EMBEDDING | |
print("Max Sequence Length:", self.model.max_seq_length) | |
self.encode = self.encode_sbert | |
if torch.cuda.is_available(): | |
self.model = self.model.to('cuda') | |
elif config_type == 'openai': | |
from openai import OpenAI | |
self.openai_client = OpenAI( | |
# base_url = settings.openai_api_base | |
api_key=settings.OPENAI_API_KEY, | |
) | |
self.encode = self.encode_openai | |
self.model_name = model_name | |
self.config_type = config_type | |
def encode(self, texts_to_embed): | |
pass | |
def encode_sbert(self, texts_to_embed): | |
return self.model.encode(texts_to_embed, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True) | |
def encode_openai(self, texts_to_embed): | |
import math | |
import time | |
tokens_count = 0 | |
for text in texts_to_embed: | |
tokens_count += len(self.get_tokens(text)) | |
chunks_num = math.ceil(tokens_count / 500000) | |
chunk_size = math.ceil(len(texts_to_embed) / chunks_num) | |
embeddings = [] | |
for i in range(chunks_num): | |
start = i * chunk_size | |
end = start + chunk_size | |
chunk = texts_to_embed[start:end] | |
embeddings_tmp = self.openai_client.embeddings.create( | |
model=self.model_name, | |
input=chunk, | |
).data | |
if embeddings_tmp is None: | |
break | |
embeddings.extend(embeddings_tmp) | |
if i < chunks_num - 1: | |
time.sleep(60) # Wait 1 minute before the next call | |
return torch.stack([torch.tensor(embedding.embedding, dtype=torch.float32) for embedding in embeddings]) | |
def get_tokens(self, text): | |
if self.model: | |
return self.model.tokenizer.tokenize(text) | |
tokens = [] | |
for token in re.split(r'(\W|\b)', text): | |
if token.strip(): | |
tokens.append(token) | |
return tokens | |
def create_strings_to_embbed(self, issues, black_list): | |
texts_to_embed = [_create_issue_string( | |
issue['title'], issue['body']) for issue in issues] | |
# Create issue blacklist (for keepping track) | |
token_count = 0 | |
for i, text in enumerate(texts_to_embed): | |
tokens = self.get_tokens(text) | |
tokens_len = len(tokens) | |
token_count += tokens_len | |
if tokens_len > self.TOKEN_LEN_MAX_BALCKLIST: | |
# Only use the first TOKEN_LEN_MAX tokens | |
black_list.add(int(issues[i]['number'])) | |
if self.config_type == 'openai': | |
texts_to_embed[i] = ' '.join( | |
tokens[:self.TOKEN_LEN_MAX_BALCKLIST]) | |
return texts_to_embed | |
def embeddings_generate(self, repo): | |
if os.path.exists(self.cache_path): | |
with open(self.cache_path, 'rb') as file: | |
self.data = pickle.load(file) | |
if repo in self.data: | |
return | |
if not repo in self.black_list: | |
self.black_list[repo] = {} | |
black_list = self.black_list[repo] | |
issues = gitea_fetch_issues('blender', repo, state='open', since=None, | |
issue_attr_filter=self.issue_attr_filter, exclude=black_list) | |
issues = sorted(issues, key=lambda issue: int(issue['number'])) | |
print("Embedding Issues...") | |
texts_to_embed = self.create_strings_to_embbed(issues, black_list) | |
embeddings = self.encode(texts_to_embed) | |
data = { | |
# Get the most recent date | |
'updated_at': _find_latest_date(issues), | |
'numbers': [int(issue['number']) for issue in issues], | |
'titles': [issue['title'] for issue in issues], | |
'embeddings': embeddings, | |
} | |
self.data[repo] = data | |
def embeddings_updated_get(self, repo): | |
with self.lock: | |
try: | |
data = self.data[repo] | |
except: | |
self.embeddings_generate(repo) | |
data = self.data[repo] | |
black_list = self.black_list[repo] | |
date_old = data['updated_at'] | |
issues = gitea_fetch_issues( | |
'blender', repo, since=date_old, issue_attr_filter=self.issue_attr_filter, exclude=black_list) | |
# Get the most recent date | |
date_new = _find_latest_date(issues, date_old) | |
if date_new == date_old: | |
# Nothing changed | |
return data | |
data['updated_at'] = date_new | |
# autopep8: off | |
# WORKAROUND: | |
# Consider that if the time hasn't changed, it's the same issue. | |
issues = [issue for issue in issues if issue['updated_at'] != date_old] | |
numbers_old = data['numbers'] | |
titles_old = data['titles'] | |
embeddings_old = data['embeddings'] | |
last_index = len(numbers_old) - 1 | |
issues = sorted(issues, key=lambda issue: int(issue['number'])) | |
issues_clos = [issue for issue in issues if issue['state'] == 'closed'] | |
issues_open = [issue for issue in issues if issue['state'] == 'open'] | |
numbers_clos = [int(issue['number']) for issue in issues_clos] | |
numbers_open = [int(issue['number']) for issue in issues_open] | |
old_closed = [] | |
for number_clos in numbers_clos: | |
for i_old in range(last_index, -1, -1): | |
number_old = numbers_old[i_old] | |
if number_old < number_clos: | |
break | |
if number_old == number_clos: | |
old_closed.append(i_old) | |
break | |
if not old_closed and not issues_open: | |
return data | |
mask_open = torch.ones(len(numbers_open), dtype=torch.bool) | |
need_sort = False | |
change_map = [] | |
for i_open, number_open in enumerate(numbers_open): | |
for i_old in range(last_index, -1, -1): | |
number_old = numbers_old[i_old] | |
if number_old < number_open: | |
need_sort = need_sort or (i_old != last_index) | |
break | |
if number_old == number_open: | |
change_map.append((i_old, i_open)) | |
mask_open[i_open] = False | |
break | |
if issues_open: | |
texts_to_embed = self.create_strings_to_embbed(issues_open, black_list) | |
embeddings = self.encode(texts_to_embed) | |
for i_old, i_open in change_map: | |
titles_old[i_old] = issues_open[i_open]['title'] | |
embeddings_old[i_old] = embeddings[i_open] | |
if old_closed: | |
total = (len(numbers_old) - len(old_closed)) + (len(numbers_open) - len(change_map)) | |
numbers_new = [None] * total | |
titles_new = [None] * total | |
embeddings_new = torch.empty((total, *embeddings_old.shape[1:]), dtype=embeddings_old.dtype, device=embeddings_old.device) | |
i_new = 0 | |
i_old = 0 | |
for i_closed in old_closed + [len(numbers_old)]: | |
while i_old < i_closed: | |
numbers_new[i_new] = numbers_old[i_old] | |
titles_new[i_new] = titles_old[i_old] | |
embeddings_new[i_new] = embeddings_old[i_old] | |
i_new += 1 | |
i_old += 1 | |
i_old += 1 | |
for i_open in range(len(numbers_open)): | |
if not mask_open[i_open]: | |
continue | |
titles_new[i_new] = issues_open[i_open]['title'] | |
numbers_new[i_new] = numbers_open[i_open] | |
embeddings_new[i_new] = embeddings[i_open] | |
i_new += 1 | |
assert i_new == total | |
elif mask_open.any(): | |
titles_new = titles_old + [issue['title'] for i, issue in enumerate(issues_open) if mask_open[i]] | |
numbers_new = numbers_old + [number for i, number in enumerate(numbers_open) if mask_open[i]] | |
embeddings_new = torch.cat([embeddings_old, embeddings[mask_open]]) | |
else: | |
# Only Updated Data changed | |
return data | |
if need_sort: | |
sorted_indices = sorted(range(len(numbers_new)), key=lambda k: numbers_new[k]) | |
titles_new = [titles_new[i] for i in sorted_indices] | |
numbers_new = [numbers_new[i] for i in sorted_indices] | |
embeddings_new = embeddings_new[sorted_indices] | |
data['titles'] = titles_new | |
data['numbers'] = numbers_new | |
data['embeddings'] = embeddings_new | |
# autopep8: on | |
return data | |
router = APIRouter() | |
EMBEDDING_CTX = EmbeddingContext() | |
# EMBEDDING_CTX.embeddings_generate('blender', 'blender') | |
# EMBEDDING_CTX.embeddings_generate('blender', 'blender-addons') | |
def _sort_similarity(data, query_emb, limit): | |
duplicates = [] | |
ret = util.semantic_search( | |
query_emb, data['embeddings'], top_k=limit, score_function=util.dot_score) | |
for score in ret[0]: | |
corpus_id = score['corpus_id'] | |
text = f"#{data['numbers'][corpus_id]}: {data['titles'][corpus_id]}" | |
duplicates.append(text) | |
return duplicates | |
cached_search = {'text': '', 'repo': '', 'issues': []} | |
def text_search(owner, repo, text_to_embed, limit=None): | |
global cached_search | |
global EMBEDDING_CTX | |
if not text_to_embed: | |
return [] | |
if text_to_embed == cached_search['text'] and repo == cached_search['repo']: | |
return cached_search['issues'][:limit] | |
data = EMBEDDING_CTX.embeddings_updated_get(owner, repo) | |
new_embedding = EMBEDDING_CTX.encode([text_to_embed]) | |
result = _sort_similarity(data, new_embedding, 500) | |
cached_search = {'text': text_to_embed, 'repo': repo, 'issues': result} | |
return result[:limit] | |
def find_relatedness(gitea_issue, limit=20): | |
assert gitea_issue['repository']['owner'] == 'blender' | |
repo = gitea_issue['repository']['name'] | |
title = gitea_issue['title'] | |
body = gitea_issue['body'] | |
number = int(gitea_issue['number']) | |
data = EMBEDDING_CTX.embeddings_updated_get(repo) | |
new_embedding = None | |
# Check if the embedding already exist. | |
for i in range(len(data['numbers']) - 1, -1, -1): | |
number_cached = data['numbers'][i] | |
if number_cached < number: | |
break | |
if number_cached == number: | |
new_embedding = data['embeddings'][i] | |
break | |
if new_embedding is None: | |
text_to_embed = _create_issue_string(title, body) | |
new_embedding = EMBEDDING_CTX.encode([text_to_embed]) | |
duplicates = _sort_similarity(data, new_embedding, limit=limit) | |
if not duplicates: | |
return '' | |
number_cached = int(re.search(r'#(\d+):', duplicates[0]).group(1)) | |
if number_cached == number: | |
return '\n'.join(duplicates[1:]) | |
return '\n'.join(duplicates) | |
def find_related(repo: str = 'blender', number: int = 104399, limit: int = 50): | |
issue = gitea_json_issue_get('blender', repo, number) | |
related = find_relatedness(issue, limit=limit) | |
return related | |
if __name__ == "__main__": | |
update_cache = True | |
if update_cache: | |
EMBEDDING_CTX.embeddings_updated_get('blender') | |
EMBEDDING_CTX.embeddings_updated_get('blender-addons') | |
cache_path = EMBEDDING_CTX.cache_path | |
with open(cache_path, "wb") as file: | |
# Converting the embeddings to be CPU compatible, as the virtual machine in use currently only supports the CPU. | |
for val in EMBEDDING_CTX.data.values(): | |
val['embeddings'] = val['embeddings'].to(torch.device('cpu')) | |
pickle.dump(EMBEDDING_CTX.data, file, | |
protocol=pickle.HIGHEST_PROTOCOL) | |
else: | |
# Converting the embeddings to be GPU. | |
for val in EMBEDDING_CTX.data.values(): | |
val['embeddings'] = val['embeddings'].to(torch.device('cuda')) | |
# 'blender/blender/111434' must print #96153, #83604 and #79762 | |
issue1 = gitea_json_issue_get('blender', 'blender', 111434) | |
issue2 = gitea_json_issue_get('blender', 'blender-addons', 104399) | |
related1 = find_relatedness(issue1, limit=20) | |
related2 = find_relatedness(issue2, limit=20) | |
print("These are the 20 most related issues:") | |
print(related1) | |
print() | |
print("These are the 20 most related issues:") | |
print(related2) | |