|
import argparse |
|
import json |
|
import logging |
|
import os |
|
|
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def load_model(model_id: str): |
|
return SentenceTransformer(model_id) |
|
|
|
|
|
class EmbeddingWriter: |
|
def __init__(self, output_embedding_filename, output_index_filename, update, embedding_to_issue_index) -> None: |
|
self.output_embedding_filename = output_embedding_filename |
|
self.output_index_filename = output_index_filename |
|
self.embeddings = [] |
|
self.embedding_to_issue_index = embedding_to_issue_index |
|
self.update = update |
|
|
|
def __enter__(self): |
|
return self.embeddings |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
if len(self.embeddings) == 0: |
|
return |
|
|
|
embeddings = np.array(self.embeddings) |
|
|
|
if self.update and os.path.exists(self.output_embedding_filename): |
|
embeddings = np.concatenate([np.load(self.output_embedding_filename), embeddings]) |
|
|
|
logger.info(f"Saving embeddings to {self.output_embedding_filename}") |
|
np.save(self.output_embedding_filename, embeddings) |
|
|
|
logger.info(f"Saving embedding index to {self.output_index_filename}") |
|
with open(self.output_index_filename, "w") as f: |
|
json.dump(self.embedding_to_issue_index, f, indent=4) |
|
|
|
|
|
def embed_issues( |
|
input_filename: str, |
|
model_id: str, |
|
issue_type: str, |
|
n_issues: int = -1, |
|
update: bool = False |
|
): |
|
model = load_model(model_id) |
|
|
|
output_embedding_filename = f"{issue_type}_embeddings.npy" |
|
output_index_filename = f"embedding_index_to_{issue_type}.json" |
|
|
|
with open(input_filename, "r") as f: |
|
issues = json.load(f) |
|
|
|
if update and os.path.exists(output_index_filename): |
|
with open(output_index_filename, "r") as f: |
|
embedding_to_issue_index = json.load(f) |
|
embedding_index = len(embedding_to_issue_index) |
|
else: |
|
embedding_to_issue_index = {} |
|
embedding_index = 0 |
|
|
|
max_issues = n_issues if n_issues > 0 else len(issues) |
|
n_issues = 0 |
|
|
|
with EmbeddingWriter( |
|
output_embedding_filename=output_embedding_filename, |
|
output_index_filename=output_index_filename, |
|
update=update, |
|
embedding_to_issue_index=embedding_to_issue_index |
|
) as embeddings: |
|
for issue_id, issue in issues.items(): |
|
if n_issues >= max_issues: |
|
break |
|
|
|
if issue_id in embedding_to_issue_index.values() and update: |
|
logger.info(f"Skipping issue {issue_id} as it is already embedded") |
|
continue |
|
|
|
if "body" not in issue: |
|
logger.info(f"Skipping issue {issue_id} as it has no body") |
|
continue |
|
|
|
if issue_type == "pull_request" and "pull_request" not in issue: |
|
logger.info(f"Skipping issue {issue_id} as it is not a pull request") |
|
continue |
|
|
|
elif issue_type == "issue" and "pull_request" in issue: |
|
logger.info(f"Skipping issue {issue_id} as it is a pull request") |
|
continue |
|
|
|
title = issue["title"] if issue["title"] is not None else "" |
|
body = issue["body"] if issue["body"] is not None else "" |
|
|
|
logger.info(f"Embedding issue {issue_id}") |
|
embedding = model.encode(title + "\n" + body) |
|
embedding_to_issue_index[embedding_index] = issue_id |
|
embeddings.append(embedding) |
|
embedding_index += 1 |
|
n_issues += 1 |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('issue_type', choices=['issue', 'pull'], default='issue') |
|
parser.add_argument("--input_filename", type=str, default="issues_dict.json") |
|
parser.add_argument("--model_id", type=str, default="all-mpnet-base-v2") |
|
parser.add_argument("--n_issues", type=int, default=-1) |
|
parser.add_argument("--update", action="store_true") |
|
args = parser.parse_args() |
|
embed_issues(**vars(args)) |
|
|