|
import os |
|
import torch |
|
|
|
import numpy as np |
|
import torch.nn as nn |
|
from huggingface_hub import PyTorchModelHubMixin |
|
from transformers import EncoderDecoderModel, RobertaTokenizerFast, PreTrainedModel |
|
from torch.utils.data import DataLoader, TensorDataset |
|
|
|
class DependencyAnalyzer(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, encoder: PreTrainedModel | None = None, |
|
match_tokenizer: RobertaTokenizerFast | None = None): |
|
super(DependencyAnalyzer, self).__init__() |
|
if not encoder: |
|
encoder: PreTrainedModel = EncoderDecoderModel.from_encoder_decoder_pretrained("microsoft/codebert-base", "microsoft/codebert-base").encoder |
|
if match_tokenizer: |
|
encoder.resize_token_embeddings(len(match_tokenizer)) |
|
encoder.config.decoder_start_token_id = match_tokenizer.cls_token_id |
|
encoder.config.pad_token_id = match_tokenizer.pad_token_id |
|
encoder.config.eos_token_id = match_tokenizer.sep_token_id |
|
encoder.config.vocab_size = match_tokenizer.vocab_size |
|
self.encoder = encoder |
|
self.dense = nn.Linear(768, 2) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
pooler_output = outputs.pooler_output |
|
output_2d = self.dense(pooler_output) |
|
return output_2d |
|
|
|
def load_model_and_tokenizer(model_dir, directly_load = True, model_with_structure_dir = None): |
|
if directly_load: |
|
tokenizer = RobertaTokenizerFast.from_pretrained(model_dir) |
|
if model_with_structure_dir: |
|
model = DependencyAnalyzer.from_pretrained(model_with_structure_dir) |
|
else: |
|
model = DependencyAnalyzer(match_tokenizer=tokenizer) |
|
model.load_state_dict(torch.load(os.path.join(model_dir,'pytorch_model.bin'))) |
|
return model, tokenizer |
|
|
|
model = EncoderDecoderModel.from_pretrained(model_dir) |
|
if not isinstance(model, EncoderDecoderModel): |
|
raise RuntimeError(f"Model read from {model_dir} is not valid") |
|
model = model.encoder |
|
if not isinstance(model, PreTrainedModel): |
|
raise RuntimeError(f"Encoder of original model is not valid") |
|
|
|
tokenizer: RobertaTokenizerFast = RobertaTokenizerFast.from_pretrained("microsoft/codebert-base") |
|
if not isinstance(tokenizer, RobertaTokenizerFast): |
|
raise RuntimeError("Cannot read tokenizer as microsoft/codebert-base") |
|
special_tokens = ['<from>', '<to>'] |
|
|
|
tokenizer.add_tokens(special_tokens, special_tokens = True) |
|
|
|
model = DependencyAnalyzer(model, tokenizer) |
|
|
|
return model, tokenizer |
|
|
|
class DependencyClassifier: |
|
def __init__(self, load_dir, load_with_model_struture=False): |
|
self.model, self.tokenizer = load_model_and_tokenizer(load_dir, model_with_structure_dir=load_dir) \ |
|
if load_with_model_struture \ |
|
else load_model_and_tokenizer(load_dir) |
|
if torch.cuda.is_available(): |
|
self.model.to(torch.device('cuda:1')) |
|
|
|
def construct_pair(self, code_1: str, code_2: str): |
|
return '<from>' + code_1 + '<to>' + code_2 |
|
|
|
def construct_corpus_pair(self, corpus: list[tuple[str, str]]): |
|
return [self.construct_pair(code_1, code_2) for code_1, code_2 in corpus] |
|
|
|
def gen(self, text: str): |
|
sigmoid = nn.Sigmoid() |
|
token_input = self.tokenizer(text, return_tensors='pt') |
|
if torch.cuda.is_available(): |
|
token_input = token_input.to(torch.device('cuda:1')) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model( |
|
input_ids=token_input['input_ids'], |
|
attention_mask=token_input['attention_mask'] |
|
)[0] |
|
outputs = sigmoid(outputs).detach().cpu() |
|
return outputs[1] |
|
|
|
def batch_gen(self, corpus_pair: list[str]): |
|
sigmoid = nn.Sigmoid() |
|
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu') |
|
token_input = self.tokenizer(corpus_pair, return_tensors='pt', padding=True, truncation=True, max_length=512) |
|
dataset = TensorDataset(token_input["input_ids"], token_input["attention_mask"]) |
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=False) |
|
|
|
preds = [] |
|
with torch.no_grad(): |
|
for batch in dataloader: |
|
batch_input, attention_mask = [item.to(device) for item in batch] |
|
outputs = self.model(input_ids=batch_input, attention_mask=attention_mask) |
|
outputs = sigmoid(outputs)[:,1] |
|
preds.append(outputs.detach().cpu()) |
|
preds = torch.cat(preds, dim=0) |
|
return preds.numpy() |