|
print('Loading dependencies...') |
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer, LlamaForCausalLM, LlamaTokenizer |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
|
import torch |
|
import re |
|
from typing import List, Tuple |
|
import spacy |
|
import numpy as np |
|
import os |
|
from dataclasses import dataclass |
|
from nltk.tokenize import sent_tokenize, word_tokenize |
|
import time |
|
|
|
|
|
DEVICE = torch.device('cpu') |
|
|
|
|
|
@dataclass |
|
class LexicalUnits: |
|
unit_type: str |
|
text: List[str] |
|
self_info: List[float] = None |
|
|
|
def __add__(self, other): |
|
assert self.unit_type == other.unit_type, 'Cannot add two different unit types' |
|
return LexicalUnits(self.unit_type, self.text + other.text, self.self_info + other.self_info) |
|
|
|
def __radd__(self, other): |
|
if other == 0: |
|
return self |
|
return NotImplementedError() |
|
|
|
def add_to_head(self, token, self_info): |
|
return LexicalUnits(self.unit_type, [token] + self.text, [self_info] + self.self_info) |
|
|
|
def add_to_tail(self, token, self_info): |
|
return LexicalUnits(self.unit_type, self.text + [token], self.self_info + [self_info]) |
|
|
|
class SelectiveContext: |
|
|
|
def __init__(self, model_type = 'gpt2', lang = 'en', device = 'cpu'): |
|
|
|
self.model_type = model_type |
|
self.lang = lang |
|
|
|
global DEVICE |
|
DEVICE = device |
|
|
|
|
|
self.sent_level_self_info = True |
|
|
|
self._prepare_phrase_tokenizer() |
|
self.sent_tokenize_pattern = r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s" |
|
self.phrase_mask_token = '' |
|
self.sent_mask_token = "<...some content omitted.>" |
|
|
|
self._prepare_model() |
|
|
|
|
|
|
|
def _prepare_phrase_tokenizer(self): |
|
|
|
|
|
|
|
lang = self.lang |
|
if lang == "en": |
|
self.nlp = spacy.load("en_core_web_sm", disable=["ner"]) |
|
self.nlp.add_pipe('merge_noun_chunks') |
|
elif lang == "zh": |
|
self.nlp = spacy.load('zh_core_web_sm', disable=["ner"]) |
|
|
|
|
|
|
|
def _prepare_model(self): |
|
|
|
if self.lang == 'zh': |
|
self.tokenizer = BertTokenizer.from_pretrained('uer/gpt2-chinese-cluecorpussmall') |
|
elif self.lang == 'en': |
|
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
|
else: |
|
raise NotImplementedError() |
|
|
|
if self.model_type == 'gpt2': |
|
if self.lang == 'zh': |
|
self.model = GPT2LMHeadModel.from_pretrained('uer/gpt2-chinese-cluecorpussmall') |
|
else: |
|
self.model = GPT2LMHeadModel.from_pretrained('gpt2') |
|
self.model.to(DEVICE) |
|
self.model.eval() |
|
|
|
print('model loaded') |
|
|
|
self.max_token_length = self.model.config.n_positions |
|
self.get_self_information = self._get_self_info_via_gpt2 |
|
|
|
elif self.model_type == 'curie': |
|
global openai |
|
import openai |
|
self.max_token_length = 2048 |
|
|
|
self.get_self_information = self._get_self_info_via_curie |
|
|
|
elif self.model_type == 'llama': |
|
print("Before tokernizer") |
|
self.tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token='LLaMA TOKEN') |
|
print("Before model") |
|
config = AutoConfig.from_pretrained('meta-llama/Llama-2-7b-chat-hf', token='LLaMA TOKEN') |
|
print("After config") |
|
self.model = LlamaForCausalLM.from_pretrained('meta-llama/Llama-2-7b-chat-hf', config=config, token='LLaMA TOKEN') |
|
print("Before DEVICE") |
|
self.model.to(DEVICE) |
|
print("Before eval") |
|
self.model.eval() |
|
|
|
print('model loaded') |
|
|
|
self.max_token_length = self.model.config.max_position_embeddings |
|
self.get_self_information = self._get_self_info_via_llama |
|
|
|
def get_self_information(self, text: str) -> Tuple[List[str], List[float]]: |
|
|
|
raise NotImplementedError |
|
|
|
def _get_self_info_via_gpt2(self, text: str) -> Tuple[List[str], List[float]]: |
|
if self.lang == 'en': |
|
text = f"<|endoftext|>{text}" |
|
elif self.lang == 'zh': |
|
text = f"[CLS]{text}" |
|
with torch.no_grad(): |
|
encoding = self.tokenizer(text, add_special_tokens=False, return_tensors='pt') |
|
encoding = encoding.to(DEVICE) |
|
outputs = self.model(**encoding) |
|
logits = outputs.logits |
|
probs = torch.softmax(logits, dim=-1) |
|
self_info = -torch.log(probs) |
|
|
|
input_ids = encoding['input_ids'] |
|
input_ids_expaned = input_ids[:, 1:].unsqueeze(-1) |
|
|
|
tokens = [self.tokenizer.decode(token_) for token_ in input_ids.squeeze().tolist()[1:]] |
|
return tokens, self_info[:, :-1].gather(-1, input_ids_expaned).squeeze(-1).squeeze(0).tolist() |
|
|
|
def _get_self_info_via_curie(self, text: str) -> Tuple[List[str], List[float]]: |
|
num_retry = 3 |
|
openai.api_key = os.environ["OPENAI_API_KEY"] |
|
|
|
for _ in range(num_retry): |
|
try: |
|
r = openai.Completion.create( |
|
model="curie", |
|
prompt=f"<|endoftext|>{text}", |
|
max_tokens=0, |
|
temperature=0, |
|
echo=True, |
|
logprobs=0, |
|
) |
|
break |
|
except Exception as e: |
|
print(e) |
|
time.sleep(1) |
|
|
|
result = r['choices'][0] |
|
tokens, logprobs = result["logprobs"]["tokens"][1:], result["logprobs"]["token_logprobs"][1:] |
|
|
|
assert len(tokens) == len(logprobs), f"Expected {len(tokens)} logprobs, got {len(logprobs)}" |
|
|
|
self_info = [ -logprob for logprob in logprobs] |
|
return tokens, self_info |
|
|
|
def _get_self_info_via_llama(self, text: str) -> Tuple[List[str], List[float]]: |
|
inputs = self.tokenizer.encode_plus(text, return_tensors="pt") |
|
input_ids = inputs.input_ids.to(DEVICE) |
|
attention_mask = inputs.attention_mask.to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model(input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
self_info = -torch.log(probs) |
|
|
|
input_ids = input_ids.squeeze() |
|
self_info = self_info.squeeze() |
|
|
|
tokens = self.tokenizer.convert_ids_to_tokens(input_ids) |
|
return tokens, self_info.tolist() |
|
|
|
def _lexical_unit(self, sents): |
|
|
|
if self.sent_level_self_info: |
|
sent_self_info = [] |
|
all_noun_phrases = [] |
|
all_noun_phrases_info = [] |
|
all_tokens = [] |
|
all_token_self_info = [] |
|
|
|
for sent in sents: |
|
|
|
tokens, self_info = self.get_self_information(sent) |
|
sent_self_info.append(np.mean(self_info)) |
|
|
|
all_tokens.extend(tokens) |
|
all_token_self_info.extend(self_info) |
|
|
|
noun_phrases, noun_phrases_info = self._calculate_lexical_unit(tokens, self_info) |
|
|
|
|
|
if len(all_noun_phrases) != 0: |
|
noun_phrases[0] = f" {noun_phrases[0]}" |
|
all_noun_phrases.extend(noun_phrases) |
|
all_noun_phrases_info.extend(noun_phrases_info) |
|
|
|
return [ |
|
LexicalUnits('sent', text=sents, self_info=sent_self_info), |
|
LexicalUnits('phrase', text=all_noun_phrases, self_info=all_noun_phrases_info), |
|
LexicalUnits('token', text=all_tokens, self_info=all_token_self_info) |
|
] |
|
|
|
def _calculate_lexical_unit(self, tokens, self_info): |
|
def _unit_info(tokens, self_info, units): |
|
current_unit_idx = 0 |
|
current_position = 0 |
|
unit_self_info = [[] for _ in range(len(units))] |
|
|
|
for idx, (token, info) in enumerate(zip(tokens, self_info)): |
|
current_position += len(token) |
|
if current_position == len(units[current_unit_idx]): |
|
unit_self_info[current_unit_idx].append(info) |
|
current_position = current_position - len(units[current_unit_idx]) |
|
current_unit_idx += 1 |
|
elif current_position > len(units[current_unit_idx]): |
|
counter_ = 1 |
|
current_position = current_position - len(units[current_unit_idx]) |
|
current_unit_idx += 1 |
|
while current_position >= len(units[current_unit_idx]): |
|
counter_ += 1 |
|
current_position = current_position - len(units[current_unit_idx]) |
|
current_unit_idx += 1 |
|
if current_unit_idx >= len(units): |
|
break |
|
partial_info = info/counter_ |
|
for _ in range(counter_): |
|
unit_self_info[(current_unit_idx-1) - _].append(partial_info) |
|
else: |
|
if token == " ": |
|
continue |
|
unit_self_info[current_unit_idx].append(info) |
|
|
|
unit_self_info_ = [np.mean(info) for info in unit_self_info] |
|
return unit_self_info_ |
|
|
|
def _noun_phrases(sent): |
|
noun_phrases = [] |
|
doc = self.nlp(sent) |
|
for index, chunk in enumerate(doc): |
|
if index == 0: |
|
noun_phrases.append(chunk.text) |
|
else: |
|
noun_phrases.append(doc[index-1].whitespace_ + chunk.text) |
|
return noun_phrases |
|
|
|
if self.sent_level_self_info: |
|
|
|
|
|
|
|
sent = ''.join(tokens) |
|
|
|
noun_phrases = _noun_phrases(sent) |
|
|
|
noun_phrases_info = _unit_info(tokens, self_info, noun_phrases) |
|
|
|
return noun_phrases, noun_phrases_info |
|
|
|
def beautify_context(self, context: str) -> str: |
|
context = re.sub(r"\s+", " ", context) |
|
return context |
|
|
|
def self_info_mask(self, sents: List[str], self_info: List[float], mask_level): |
|
|
|
sents_after_mask = [] |
|
masked_sents = [] |
|
|
|
self.ppl_threshold = np.nanpercentile(self_info, self.mask_ratio * 100) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for sent, info in zip(sents, self_info): |
|
if info < self.ppl_threshold: |
|
masked_sents.append(sent) |
|
sents_after_mask.append(self.mask_a_sent(sent, mask_level)) |
|
else: |
|
sents_after_mask.append(sent) |
|
masked_context = " ".join(sents_after_mask) if mask_level == 'sent' else "".join(sents_after_mask) |
|
|
|
return masked_context, masked_sents |
|
|
|
def mask_a_sent(self, sent, level): |
|
if level == 'phrase': |
|
return self.phrase_mask_token |
|
elif level == 'sent': |
|
if self.keep_leading_word: |
|
leading_few_words = " ".join(word_tokenize(sent)[:self.num_lead_words]) + " " |
|
else: |
|
leading_few_words = "" |
|
return leading_few_words + self.mask_token |
|
elif level == 'token': |
|
return '' |
|
|
|
def __call__(self, text: str, reduce_ratio: float = 0.35, reduce_level :str = 'phrase') -> List[str]: |
|
context = self.beautify_context(text) |
|
|
|
self.mask_ratio = reduce_ratio |
|
|
|
sents = re.split(self.sent_tokenize_pattern, context) |
|
sents = [sent.strip() for sent in sents if sent.strip()] |
|
|
|
|
|
assert reduce_level in ['sent', 'phrase', 'token'], f"reduce_level should be one of ['sent', 'phrase', 'token'], got {reduce_level}" |
|
sent_lus, phrase_lus, token_lus = self._lexical_unit(sents) |
|
|
|
lexical_level = { |
|
'sent': sent_lus, |
|
'phrase': phrase_lus, |
|
'token': token_lus |
|
} |
|
|
|
|
|
context, masked_sents = self.self_info_mask(lexical_level[reduce_level].text, lexical_level[reduce_level].self_info, reduce_level) |
|
return context, masked_sents |
|
|
|
def main( |
|
model_type = 'gpt2', |
|
lang = 'en', |
|
file_to_process: str = None, |
|
file_to_save: str = None, |
|
): |
|
|
|
global DEVICE |
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
print(f"Using device: {DEVICE}") |
|
|
|
sc = SelectiveContext(model_type=model_type, lang=lang) |
|
|
|
if file_to_process is None: |
|
while True: |
|
text = input("Please input the text you want to reduce: ") |
|
if text == 'exit': |
|
break |
|
context, masked_sents = sc(text) |
|
print('***********\nThe resultsing context is: \n') |
|
print(context, '\n\n') |
|
|
|
print('***********\nThe content that has been filtered out is: \n') |
|
print(masked_sents, '\n\n') |
|
else: |
|
with open(file_to_process, 'r') as f: |
|
text = f.read() |
|
context, masked_sents = sc(text) |
|
|
|
with open(file_to_save, 'w') as f: |
|
f.write(context) |
|
|
|
if __name__ == "__main__": |
|
main(model_type='gpt2', lang = 'zh') |