Upload 6 files
Browse files- app.py +57 -0
- models/__pycache__/watermark_faster.cpython-39.pyc +0 -0
- models/watermark_faster.py +465 -0
- models/watermark_original.py +368 -0
- options.py +14 -0
- requirements.txt +12 -0
app.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from models.watermark_faster import watermark_model
|
3 |
+
import pdb
|
4 |
+
from options import get_parser_main_model
|
5 |
+
|
6 |
+
opts = get_parser_main_model().parse_args()
|
7 |
+
model = watermark_model(language=opts.language, mode=opts.mode, tau_word=opts.tau_word, lamda=opts.lamda)
|
8 |
+
def watermark_embed_demo(raw):
|
9 |
+
|
10 |
+
watermarked_text = model.embed(raw)
|
11 |
+
return watermarked_text
|
12 |
+
|
13 |
+
def watermark_extract(raw):
|
14 |
+
is_watermark, p_value, n, ones, z_value = model.watermark_detector_fast(raw)
|
15 |
+
confidence = (1 - p_value) * 100
|
16 |
+
|
17 |
+
return f"{confidence:.2f}%"
|
18 |
+
|
19 |
+
def precise_watermark_detect(raw):
|
20 |
+
is_watermark, p_value, n, ones, z_value = model.watermark_detector_precise(raw)
|
21 |
+
confidence = (1 - p_value) * 100
|
22 |
+
|
23 |
+
return f"{confidence:.2f}%"
|
24 |
+
|
25 |
+
|
26 |
+
demo = gr.Blocks()
|
27 |
+
with demo:
|
28 |
+
with gr.Column():
|
29 |
+
gr.Markdown("# Watermarking Text Generated by Black-Box Language Models")
|
30 |
+
|
31 |
+
inputs = gr.TextArea(label="Input text", placeholder="Copy your text here...")
|
32 |
+
output = gr.Textbox(label="Watermarked Text")
|
33 |
+
analysis_button = gr.Button("Inject Watermark")
|
34 |
+
inputs_embed = [inputs]
|
35 |
+
analysis_button.click(fn=watermark_embed_demo, inputs=inputs_embed, outputs=output)
|
36 |
+
|
37 |
+
inputs_w = gr.TextArea(label="Text to Analyze", placeholder="Copy your watermarked text here...")
|
38 |
+
|
39 |
+
mode = gr.Dropdown(
|
40 |
+
label="Detection Mode", choices=["Fast", "Precise"], default="Fast"
|
41 |
+
)
|
42 |
+
output_detect = gr.Textbox(label="Confidence (the likelihood of the text containing a watermark)")
|
43 |
+
detect_button = gr.Button("Detect")
|
44 |
+
|
45 |
+
def detect_watermark(inputs_w, mode):
|
46 |
+
if mode == "Fast":
|
47 |
+
return watermark_extract(inputs_w)
|
48 |
+
else:
|
49 |
+
return precise_watermark_detect(inputs_w)
|
50 |
+
|
51 |
+
detect_button.click(fn=detect_watermark, inputs=[inputs_w, mode], outputs=output_detect)
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
gr.close_all()
|
56 |
+
demo.title = "Watermarking Text Generated by Black-Box Language Models"
|
57 |
+
demo.launch(share = True, server_port=8899)
|
models/__pycache__/watermark_faster.cpython-39.pyc
ADDED
Binary file (15.9 kB). View file
|
|
models/watermark_faster.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
from nltk.corpus import stopwords
|
3 |
+
from nltk import word_tokenize, pos_tag
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
import hashlib
|
8 |
+
from scipy.stats import norm
|
9 |
+
import gensim
|
10 |
+
import pdb
|
11 |
+
from transformers import BertForMaskedLM as WoBertForMaskedLM
|
12 |
+
from wobert import WoBertTokenizer
|
13 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
14 |
+
|
15 |
+
from transformers import BertForMaskedLM, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer
|
16 |
+
import gensim.downloader as api
|
17 |
+
import Levenshtein
|
18 |
+
import string
|
19 |
+
import spacy
|
20 |
+
import paddle
|
21 |
+
from jieba import posseg
|
22 |
+
paddle.enable_static()
|
23 |
+
import re
|
24 |
+
def cut_sent(para):
|
25 |
+
para = re.sub('([。!?\?])([^”’])', r'\1\n\2', para)
|
26 |
+
para = re.sub('([。!?\?][”’])([^,。!?\?\n ])', r'\1\n\2', para)
|
27 |
+
para = re.sub('(\.{6}|\…{2})([^”’\n])', r'\1\n\2', para)
|
28 |
+
para = re.sub('([^。!?\?]*)([::][^。!?\?\n]*)', r'\1\n\2', para)
|
29 |
+
para = re.sub('([。!?\?][”’])$', r'\1\n', para)
|
30 |
+
para = para.rstrip()
|
31 |
+
return para.split("\n")
|
32 |
+
|
33 |
+
def is_subword(token: str):
|
34 |
+
return token.startswith('##')
|
35 |
+
|
36 |
+
def binary_encoding_function(token):
|
37 |
+
hash_value = int(hashlib.sha256(token.encode('utf-8')).hexdigest(), 16)
|
38 |
+
random_bit = hash_value % 2
|
39 |
+
return random_bit
|
40 |
+
|
41 |
+
def is_similar(x, y, threshold=0.5):
|
42 |
+
distance = Levenshtein.distance(x, y)
|
43 |
+
if distance / max(len(x), len(y)) < threshold:
|
44 |
+
return True
|
45 |
+
return False
|
46 |
+
|
47 |
+
class watermark_model:
|
48 |
+
def __init__(self, language, mode, tau_word, lamda):
|
49 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
+
self.language = language
|
51 |
+
self.mode = mode
|
52 |
+
self.tau_word = tau_word
|
53 |
+
self.tau_sent = 0.8
|
54 |
+
self.lamda = lamda
|
55 |
+
self.cn_tag_black_list = set(['','x','u','j','k','zg','y','eng','uv','uj','ud','nr','nrfg','nrt','nw','nz','ns','nt','m','mq','r','w','PER','LOC','ORG'])#set(['','f','u','nr','nw','nz','m','r','p','c','w','PER','LOC','ORG'])
|
56 |
+
self.en_tag_white_list = set(['MD', 'NN', 'NNS', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'RP', 'RB', 'RBR', 'RBS', 'JJ', 'JJR', 'JJS'])
|
57 |
+
if language == 'Chinese':
|
58 |
+
self.relatedness_tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity")
|
59 |
+
self.relatedness_model = AutoModelForSequenceClassification.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity").to(self.device)
|
60 |
+
self.tokenizer = WoBertTokenizer.from_pretrained("junnyu/wobert_chinese_plus_base")
|
61 |
+
self.model = WoBertForMaskedLM.from_pretrained("junnyu/wobert_chinese_plus_base", output_hidden_states=True).to(self.device)
|
62 |
+
self.w2v_model = gensim.models.KeyedVectors.load_word2vec_format('sgns.merge.word.bz2', binary=False, unicode_errors='ignore', limit=50000)
|
63 |
+
elif language == 'English':
|
64 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
65 |
+
self.model = BertForMaskedLM.from_pretrained('bert-base-cased', output_hidden_states=True).to(self.device)
|
66 |
+
self.relatedness_model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli').to(self.device)
|
67 |
+
self.relatedness_tokenizer = RobertaTokenizer.from_pretrained('roberta-large-mnli')
|
68 |
+
self.w2v_model = api.load("glove-wiki-gigaword-100")
|
69 |
+
nltk.download('stopwords')
|
70 |
+
self.stop_words = set(stopwords.words('english'))
|
71 |
+
self.nlp = spacy.load('en_core_web_sm')
|
72 |
+
|
73 |
+
def cut(self,ori_text,text_len):
|
74 |
+
if self.language == 'Chinese':
|
75 |
+
if len(ori_text) > text_len+5:
|
76 |
+
ori_text = ori_text[:text_len+5]
|
77 |
+
if len(ori_text) < text_len-5:
|
78 |
+
return 'Short'
|
79 |
+
return ori_text
|
80 |
+
elif self.language == 'English':
|
81 |
+
tokens = self.tokenizer.tokenize(ori_text)
|
82 |
+
if len(tokens) > text_len+5:
|
83 |
+
ori_text = self.tokenizer.convert_tokens_to_string(tokens[:text_len+5])
|
84 |
+
if len(tokens) < text_len-5:
|
85 |
+
return 'Short'
|
86 |
+
return ori_text
|
87 |
+
else:
|
88 |
+
print(f'Unsupported Language:{self.language}')
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
def sent_tokenize(self,ori_text):
|
92 |
+
if self.language == 'Chinese':
|
93 |
+
return cut_sent(ori_text)
|
94 |
+
elif self.language == 'English':
|
95 |
+
return nltk.sent_tokenize(ori_text)
|
96 |
+
|
97 |
+
def pos_filter(self, tokens, masked_token_index, input_text):
|
98 |
+
if self.language == 'Chinese':
|
99 |
+
pairs = posseg.lcut(input_text)
|
100 |
+
pos_dict = {word: pos for word, pos in pairs}
|
101 |
+
pos_list_input = [pos for _, pos in pairs]
|
102 |
+
pos = pos_dict.get(tokens[masked_token_index], '')
|
103 |
+
if pos in self.cn_tag_black_list:
|
104 |
+
return False
|
105 |
+
else:
|
106 |
+
return True
|
107 |
+
elif self.language == 'English':
|
108 |
+
pos_tags = pos_tag(tokens)
|
109 |
+
pos = pos_tags[masked_token_index][1]
|
110 |
+
if pos not in self.en_tag_white_list:
|
111 |
+
return False
|
112 |
+
if is_subword(tokens[masked_token_index]) or is_subword(tokens[masked_token_index+1]) or (tokens[masked_token_index] in self.stop_words or tokens[masked_token_index] in string.punctuation):
|
113 |
+
return False
|
114 |
+
return True
|
115 |
+
|
116 |
+
def filter_special_candidate(self, top_n_tokens, tokens,masked_token_index,input_text):
|
117 |
+
if self.language == 'English':
|
118 |
+
filtered_tokens = [tok for tok in top_n_tokens if tok not in self.stop_words and tok not in string.punctuation and pos_tag([tok])[0][1] in self.en_tag_white_list and not is_subword(tok)]
|
119 |
+
|
120 |
+
base_word = tokens[masked_token_index]
|
121 |
+
|
122 |
+
processed_tokens = [tok for tok in filtered_tokens if not is_similar(tok,base_word)]
|
123 |
+
return processed_tokens
|
124 |
+
elif self.language == 'Chinese':
|
125 |
+
pairs = posseg.lcut(input_text)
|
126 |
+
pos_dict = {word: pos for word, pos in pairs}
|
127 |
+
pos_list_input = [pos for _, pos in pairs]
|
128 |
+
pos = pos_dict.get(tokens[masked_token_index], '')
|
129 |
+
filtered_tokens = []
|
130 |
+
for tok in top_n_tokens:
|
131 |
+
watermarked_text_segtest = self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [tok] + tokens[masked_token_index+1:-1])
|
132 |
+
watermarked_text_segtest = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text_segtest)
|
133 |
+
pairs_tok = posseg.lcut(watermarked_text_segtest)
|
134 |
+
pos_dict_tok = {word: pos for word, pos in pairs_tok}
|
135 |
+
flag = pos_dict_tok.get(tok, '')
|
136 |
+
if flag not in self.cn_tag_black_list and flag == pos:
|
137 |
+
filtered_tokens.append(tok)
|
138 |
+
processed_tokens = filtered_tokens
|
139 |
+
return processed_tokens
|
140 |
+
|
141 |
+
def global_word_sim(self,word,ori_word):
|
142 |
+
try:
|
143 |
+
global_score = self.w2v_model.similarity(word,ori_word)
|
144 |
+
except KeyError:
|
145 |
+
global_score = 0
|
146 |
+
return global_score
|
147 |
+
|
148 |
+
def context_word_sim(self, init_candidates_list, tokens, index_space, input_text):
|
149 |
+
original_input_tensor = self.tokenizer.encode(input_text, return_tensors='pt').to(self.device)
|
150 |
+
|
151 |
+
all_cos_sims = []
|
152 |
+
|
153 |
+
for init_candidates, masked_token_index in zip(init_candidates_list, index_space):
|
154 |
+
batch_input_ids = [
|
155 |
+
[self.tokenizer.convert_tokens_to_ids(['[CLS]'] + tokens[1:masked_token_index] + [token] + tokens[masked_token_index + 1:-1] + ['[SEP]'])] for token in
|
156 |
+
init_candidates]
|
157 |
+
batch_input_tensors = torch.tensor(batch_input_ids).squeeze(1).to(self.device)
|
158 |
+
|
159 |
+
batch_input_tensors = torch.cat((batch_input_tensors, original_input_tensor), dim=0)
|
160 |
+
|
161 |
+
with torch.no_grad():
|
162 |
+
outputs = self.model(batch_input_tensors)
|
163 |
+
cos_sims = torch.zeros([len(init_candidates)]).to(self.device)
|
164 |
+
num_layers = len(outputs[1])
|
165 |
+
N = 8
|
166 |
+
i = masked_token_index
|
167 |
+
# We want to calculate similarity for the last N layers
|
168 |
+
hidden_states = outputs[1][-N:]
|
169 |
+
|
170 |
+
# Shape of hidden_states: [N, batch_size, sequence_length, hidden_size]
|
171 |
+
hidden_states = torch.stack(hidden_states)
|
172 |
+
|
173 |
+
# Separate the source and candidate hidden states
|
174 |
+
source_hidden_states = hidden_states[:, len(init_candidates):, i, :]
|
175 |
+
candidate_hidden_states = hidden_states[:, :len(init_candidates), i, :]
|
176 |
+
|
177 |
+
# Calculate cosine similarities across all layers and sum
|
178 |
+
cos_sim_sum = F.cosine_similarity(source_hidden_states.unsqueeze(2), candidate_hidden_states.unsqueeze(1), dim=-1).sum(dim=0)
|
179 |
+
|
180 |
+
cos_sim_avg = cos_sim_sum / N
|
181 |
+
cos_sims += cos_sim_avg.squeeze()
|
182 |
+
|
183 |
+
all_cos_sims.append(cos_sims.tolist())
|
184 |
+
|
185 |
+
return all_cos_sims
|
186 |
+
|
187 |
+
|
188 |
+
def sentence_sim(self, init_candidates_list, tokens, index_space, input_text):
|
189 |
+
|
190 |
+
batch_size=128
|
191 |
+
all_batch_sentences = []
|
192 |
+
all_index_lengths = []
|
193 |
+
for init_candidates, masked_token_index in zip(init_candidates_list, index_space):
|
194 |
+
if self.language == 'Chinese':
|
195 |
+
batch_sents = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index + 1:-1]) for token in init_candidates]
|
196 |
+
batch_sentences = [re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', sent) for sent in batch_sents]
|
197 |
+
all_batch_sentences.extend([input_text + '[SEP]' + s for s in batch_sentences])
|
198 |
+
elif self.language == 'English':
|
199 |
+
batch_sentences = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index + 1:-1]) for token in init_candidates]
|
200 |
+
all_batch_sentences.extend([input_text + '</s></s>' + s for s in batch_sentences])
|
201 |
+
|
202 |
+
all_index_lengths.append(len(init_candidates))
|
203 |
+
|
204 |
+
all_relatedness_scores = []
|
205 |
+
start_index = 0
|
206 |
+
for i in range(0, len(all_batch_sentences), batch_size):
|
207 |
+
batch_sentences = all_batch_sentences[i: i + batch_size]
|
208 |
+
encoded_dict = self.relatedness_tokenizer.batch_encode_plus(
|
209 |
+
batch_sentences,
|
210 |
+
padding=True,
|
211 |
+
truncation=True,
|
212 |
+
max_length=512,
|
213 |
+
return_tensors='pt')
|
214 |
+
|
215 |
+
input_ids = encoded_dict['input_ids'].to(self.device)
|
216 |
+
attention_masks = encoded_dict['attention_mask'].to(self.device)
|
217 |
+
|
218 |
+
with torch.no_grad():
|
219 |
+
outputs = self.relatedness_model(input_ids=input_ids, attention_mask=attention_masks)
|
220 |
+
logits = outputs[0]
|
221 |
+
probs = torch.softmax(logits, dim=1)
|
222 |
+
if self.language == 'Chinese':
|
223 |
+
relatedness_scores = probs[:, 1]#.tolist()
|
224 |
+
elif self.language == 'English':
|
225 |
+
relatedness_scores = probs[:, 2]#.tolist()
|
226 |
+
all_relatedness_scores.extend(relatedness_scores)
|
227 |
+
|
228 |
+
all_relatedness_scores_split = []
|
229 |
+
for length in all_index_lengths:
|
230 |
+
all_relatedness_scores_split.append(all_relatedness_scores[start_index:start_index + length])
|
231 |
+
start_index += length
|
232 |
+
|
233 |
+
|
234 |
+
return all_relatedness_scores_split
|
235 |
+
|
236 |
+
|
237 |
+
def candidates_gen(self, tokens, index_space, input_text, topk=64, dropout_prob=0.3):
|
238 |
+
input_ids_bert = self.tokenizer.convert_tokens_to_ids(tokens)
|
239 |
+
new_index_space = []
|
240 |
+
masked_text = self.tokenizer.convert_tokens_to_string(tokens)
|
241 |
+
# Create a tensor of input IDs
|
242 |
+
input_tensor = torch.tensor([input_ids_bert]).to(self.device)
|
243 |
+
|
244 |
+
with torch.no_grad():
|
245 |
+
embeddings = self.model.bert.embeddings(input_tensor.repeat(len(index_space), 1))
|
246 |
+
|
247 |
+
dropout = nn.Dropout2d(p=dropout_prob)
|
248 |
+
|
249 |
+
masked_indices = torch.tensor(index_space).to(self.device)
|
250 |
+
embeddings[torch.arange(len(index_space)), masked_indices] = dropout(embeddings[torch.arange(len(index_space)), masked_indices])
|
251 |
+
|
252 |
+
|
253 |
+
with torch.no_grad():
|
254 |
+
outputs = self.model(inputs_embeds=embeddings)
|
255 |
+
|
256 |
+
all_processed_tokens = []
|
257 |
+
for i, masked_token_index in enumerate(index_space):
|
258 |
+
predicted_logits = outputs[0][i][masked_token_index]
|
259 |
+
# Set the number of top predictions to return
|
260 |
+
n = topk
|
261 |
+
# Get the top n predicted tokens and their probabilities
|
262 |
+
probs = torch.nn.functional.softmax(predicted_logits, dim=-1)
|
263 |
+
top_n_probs, top_n_indices = torch.topk(probs, n)
|
264 |
+
top_n_tokens = self.tokenizer.convert_ids_to_tokens(top_n_indices.tolist())
|
265 |
+
processed_tokens = self.filter_special_candidate(top_n_tokens, tokens, masked_token_index,input_text)
|
266 |
+
|
267 |
+
if tokens[masked_token_index] not in processed_tokens:
|
268 |
+
processed_tokens = [tokens[masked_token_index]] + processed_tokens
|
269 |
+
all_processed_tokens.append(processed_tokens)
|
270 |
+
new_index_space.append(masked_token_index)
|
271 |
+
|
272 |
+
return all_processed_tokens,new_index_space
|
273 |
+
|
274 |
+
|
275 |
+
def filter_candidates(self, init_candidates_list, tokens, index_space, input_text):
|
276 |
+
|
277 |
+
all_context_word_similarity_scores = self.context_word_sim(init_candidates_list, tokens, index_space, input_text)
|
278 |
+
|
279 |
+
all_sentence_similarity_scores = self.sentence_sim(init_candidates_list, tokens, index_space, input_text)
|
280 |
+
|
281 |
+
all_filtered_candidates = []
|
282 |
+
new_index_space = []
|
283 |
+
|
284 |
+
for init_candidates, context_word_similarity_scores, sentence_similarity_scores, masked_token_index in zip(init_candidates_list, all_context_word_similarity_scores, all_sentence_similarity_scores, index_space):
|
285 |
+
filtered_candidates = []
|
286 |
+
for idx, candidate in enumerate(init_candidates):
|
287 |
+
global_word_similarity_score = self.global_word_sim(tokens[masked_token_index], candidate)
|
288 |
+
word_similarity_score = self.lamda*context_word_similarity_scores[idx]+(1-self.lamda)*global_word_similarity_score
|
289 |
+
if word_similarity_score >= self.tau_word and sentence_similarity_scores[idx] >= self.tau_sent:
|
290 |
+
filtered_candidates.append((candidate, word_similarity_score))
|
291 |
+
|
292 |
+
if len(filtered_candidates) >= 1:
|
293 |
+
all_filtered_candidates.append(filtered_candidates)
|
294 |
+
new_index_space.append(masked_token_index)
|
295 |
+
return all_filtered_candidates, new_index_space
|
296 |
+
|
297 |
+
def get_candidate_encodings(self, tokens, enhanced_candidates, index_space):
|
298 |
+
best_candidates = []
|
299 |
+
new_index_space = []
|
300 |
+
|
301 |
+
for init_candidates, masked_token_index in zip(enhanced_candidates, index_space):
|
302 |
+
filtered_candidates = []
|
303 |
+
|
304 |
+
for idx, candidate in enumerate(init_candidates):
|
305 |
+
if masked_token_index-1 in new_index_space:
|
306 |
+
bit = binary_encoding_function(best_candidates[-1]+candidate[0])
|
307 |
+
else:
|
308 |
+
bit = binary_encoding_function(tokens[masked_token_index-1]+candidate[0])
|
309 |
+
|
310 |
+
if bit==1:
|
311 |
+
filtered_candidates.append(candidate)
|
312 |
+
|
313 |
+
# Sort the candidates based on their scores
|
314 |
+
filtered_candidates = sorted(filtered_candidates, key=lambda x: x[1], reverse=True)
|
315 |
+
|
316 |
+
if len(filtered_candidates) >= 1:
|
317 |
+
best_candidates.append(filtered_candidates[0][0])
|
318 |
+
new_index_space.append(masked_token_index)
|
319 |
+
|
320 |
+
return best_candidates, new_index_space
|
321 |
+
|
322 |
+
def watermark_embed(self,text):
|
323 |
+
input_text = text
|
324 |
+
# Tokenize the input text
|
325 |
+
tokens = self.tokenizer.tokenize(input_text)
|
326 |
+
tokens = ['[CLS]'] + tokens + ['[SEP]']
|
327 |
+
masked_tokens=tokens.copy()
|
328 |
+
start_index = 1
|
329 |
+
end_index = len(tokens) - 1
|
330 |
+
|
331 |
+
index_space = []
|
332 |
+
|
333 |
+
for masked_token_index in range(start_index+1, end_index-1):
|
334 |
+
binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tokens[masked_token_index])
|
335 |
+
if binary_encoding == 1 and masked_token_index-1 not in index_space:
|
336 |
+
continue
|
337 |
+
if not self.pos_filter(tokens,masked_token_index,input_text):
|
338 |
+
continue
|
339 |
+
index_space.append(masked_token_index)
|
340 |
+
|
341 |
+
if len(index_space)==0:
|
342 |
+
return text
|
343 |
+
init_candidates, new_index_space = self.candidates_gen(tokens,index_space,input_text, 8, 0)
|
344 |
+
if len(new_index_space)==0:
|
345 |
+
return text
|
346 |
+
enhanced_candidates, new_index_space = self.filter_candidates(init_candidates,tokens,new_index_space,input_text)
|
347 |
+
|
348 |
+
enhanced_candidates, new_index_space = self.get_candidate_encodings(tokens, enhanced_candidates, new_index_space)
|
349 |
+
|
350 |
+
for init_candidate, masked_token_index in zip(enhanced_candidates, new_index_space):
|
351 |
+
tokens[masked_token_index] = init_candidate
|
352 |
+
watermarked_text = self.tokenizer.convert_tokens_to_string(tokens[1:-1])
|
353 |
+
|
354 |
+
if self.language == 'Chinese':
|
355 |
+
watermarked_text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text)
|
356 |
+
return watermarked_text
|
357 |
+
|
358 |
+
def embed(self, ori_text):
|
359 |
+
sents = self.sent_tokenize(ori_text)
|
360 |
+
sents = [s for s in sents if s.strip()]
|
361 |
+
num_sents = len(sents)
|
362 |
+
watermarked_text = ''
|
363 |
+
|
364 |
+
for i in range(0, num_sents, 2):
|
365 |
+
if i+1 < num_sents:
|
366 |
+
sent_pair = sents[i] + sents[i+1]
|
367 |
+
else:
|
368 |
+
sent_pair = sents[i]
|
369 |
+
# keywords = jieba.analyse.extract_tags(sent_pair, topK=5, withWeight=False)
|
370 |
+
if len(watermarked_text) == 0:
|
371 |
+
watermarked_text = self.watermark_embed(sent_pair)
|
372 |
+
else:
|
373 |
+
watermarked_text = watermarked_text + self.watermark_embed(sent_pair)
|
374 |
+
if len(self.get_encodings_fast(ori_text)) == 0:
|
375 |
+
# print(ori_text)
|
376 |
+
return ''
|
377 |
+
return watermarked_text
|
378 |
+
|
379 |
+
def get_encodings_fast(self,text):
|
380 |
+
sents = self.sent_tokenize(text)
|
381 |
+
sents = [s for s in sents if s.strip()]
|
382 |
+
num_sents = len(sents)
|
383 |
+
encodings = []
|
384 |
+
for i in range(0, num_sents, 2):
|
385 |
+
if i+1 < num_sents:
|
386 |
+
sent_pair = sents[i] + sents[i+1]
|
387 |
+
else:
|
388 |
+
sent_pair = sents[i]
|
389 |
+
tokens = self.tokenizer.tokenize(sent_pair)
|
390 |
+
|
391 |
+
for index in range(1,len(tokens)-1):
|
392 |
+
if not self.pos_filter(tokens,index,text):
|
393 |
+
continue
|
394 |
+
bit = binary_encoding_function(tokens[index-1]+tokens[index])
|
395 |
+
encodings.append(bit)
|
396 |
+
return encodings
|
397 |
+
|
398 |
+
def watermark_detector_fast(self, text,alpha=0.05):
|
399 |
+
p = 0.5
|
400 |
+
encodings = self.get_encodings_fast(text)
|
401 |
+
n = len(encodings)
|
402 |
+
ones = sum(encodings)
|
403 |
+
if n == 0:
|
404 |
+
z = 0
|
405 |
+
else:
|
406 |
+
z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
|
407 |
+
threshold = norm.ppf(1 - alpha, loc=0, scale=1)
|
408 |
+
p_value = norm.sf(z)
|
409 |
+
# p_value = norm.sf(abs(z)) * 2
|
410 |
+
is_watermark = z >= threshold
|
411 |
+
return is_watermark, p_value, n, ones, z
|
412 |
+
|
413 |
+
def get_encodings_precise(self, text):
|
414 |
+
# pdb.set_trace()
|
415 |
+
sents = self.sent_tokenize(text)
|
416 |
+
sents = [s for s in sents if s.strip()]
|
417 |
+
num_sents = len(sents)
|
418 |
+
encodings = []
|
419 |
+
for i in range(0, num_sents, 2):
|
420 |
+
if i+1 < num_sents:
|
421 |
+
sent_pair = sents[i] + sents[i+1]
|
422 |
+
else:
|
423 |
+
sent_pair = sents[i]
|
424 |
+
|
425 |
+
tokens = self.tokenizer.tokenize(sent_pair)
|
426 |
+
|
427 |
+
tokens = ['[CLS]'] + tokens + ['[SEP]']
|
428 |
+
|
429 |
+
masked_tokens=tokens.copy()
|
430 |
+
|
431 |
+
start_index = 1
|
432 |
+
end_index = len(tokens) - 1
|
433 |
+
|
434 |
+
index_space = []
|
435 |
+
for masked_token_index in range(start_index+1, end_index-1):
|
436 |
+
if not self.pos_filter(tokens,masked_token_index,sent_pair):
|
437 |
+
continue
|
438 |
+
index_space.append(masked_token_index)
|
439 |
+
if len(index_space)==0:
|
440 |
+
continue
|
441 |
+
|
442 |
+
init_candidates, new_index_space = self.candidates_gen(tokens,index_space,sent_pair, 8, 0)
|
443 |
+
enhanced_candidates, new_index_space = self.filter_candidates(init_candidates,tokens,new_index_space,sent_pair)
|
444 |
+
|
445 |
+
# pdb.set_trace()
|
446 |
+
for j,idx in enumerate(new_index_space):
|
447 |
+
if len(enhanced_candidates[j])>1:
|
448 |
+
bit = binary_encoding_function(tokens[idx-1]+tokens[idx])
|
449 |
+
encodings.append(bit)
|
450 |
+
return encodings
|
451 |
+
|
452 |
+
|
453 |
+
def watermark_detector_precise(self,text,alpha=0.05):
|
454 |
+
p = 0.5
|
455 |
+
encodings = self.get_encodings_precise(text)
|
456 |
+
n = len(encodings)
|
457 |
+
ones = sum(encodings)
|
458 |
+
if n == 0:
|
459 |
+
z = 0
|
460 |
+
else:
|
461 |
+
z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
|
462 |
+
threshold = norm.ppf(1 - alpha, loc=0, scale=1)
|
463 |
+
p_value = norm.sf(z)
|
464 |
+
is_watermark = z >= threshold
|
465 |
+
return is_watermark, p_value, n, ones, z
|
models/watermark_original.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import nltk
|
2 |
+
from nltk.corpus import stopwords
|
3 |
+
from nltk import word_tokenize, pos_tag
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
import hashlib
|
8 |
+
from scipy.stats import norm
|
9 |
+
import gensim
|
10 |
+
import pdb
|
11 |
+
from transformers import BertForMaskedLM as WoBertForMaskedLM
|
12 |
+
from wobert import WoBertTokenizer
|
13 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
14 |
+
|
15 |
+
from transformers import BertForMaskedLM, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer
|
16 |
+
import gensim.downloader as api
|
17 |
+
import Levenshtein
|
18 |
+
import string
|
19 |
+
import spacy
|
20 |
+
import paddle
|
21 |
+
from jieba import posseg
|
22 |
+
|
23 |
+
paddle.enable_static()
|
24 |
+
import re
|
25 |
+
def cut_sent(para):
|
26 |
+
para = re.sub('([。!?\?])([^”’])', r'\1\n\2', para)
|
27 |
+
para = re.sub('([。!?\?][”’])([^,。!?\?\n ])', r'\1\n\2', para)
|
28 |
+
para = re.sub('(\.{6}|\…{2})([^”’\n])', r'\1\n\2', para)
|
29 |
+
para = re.sub('([^。!?\?]*)([::][^。!?\?\n]*)', r'\1\n\2', para)
|
30 |
+
para = re.sub('([。!?\?][”’])$', r'\1\n', para)
|
31 |
+
para = para.rstrip()
|
32 |
+
return para.split("\n")
|
33 |
+
|
34 |
+
def is_subword(token: str):
|
35 |
+
return token.startswith('##')
|
36 |
+
|
37 |
+
def binary_encoding_function(token):
|
38 |
+
hash_value = int(hashlib.sha256(token.encode('utf-8')).hexdigest(), 16)
|
39 |
+
random_bit = hash_value % 2
|
40 |
+
return random_bit
|
41 |
+
|
42 |
+
def is_similar(x, y, threshold=0.5):
|
43 |
+
distance = Levenshtein.distance(x, y)
|
44 |
+
if distance / max(len(x), len(y)) < threshold:
|
45 |
+
return True
|
46 |
+
return False
|
47 |
+
|
48 |
+
class watermark_model:
|
49 |
+
def __init__(self, language, mode, tau_word, lamda):
|
50 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
51 |
+
self.language = language
|
52 |
+
self.mode = mode
|
53 |
+
self.tau_word = tau_word
|
54 |
+
self.tau_sent = 0.8
|
55 |
+
self.lamda = lamda
|
56 |
+
self.cn_tag_black_list = set(['','x','u','j','k','zg','y','eng','uv','uj','ud','nr','nrfg','nrt','nw','nz','ns','nt','m','mq','r','w','PER','LOC','ORG'])#set(['','f','u','nr','nw','nz','m','r','p','c','w','PER','LOC','ORG'])
|
57 |
+
self.en_tag_white_list = set(['MD', 'NN', 'NNS', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'RP', 'RB', 'RBR', 'RBS', 'JJ', 'JJR', 'JJS'])
|
58 |
+
if language == 'Chinese':
|
59 |
+
self.relatedness_tokenizer = AutoTokenizer.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity")
|
60 |
+
self.relatedness_model = AutoModelForSequenceClassification.from_pretrained("IDEA-CCNL/Erlangshen-Roberta-330M-Similarity").to(self.device)
|
61 |
+
self.tokenizer = WoBertTokenizer.from_pretrained("junnyu/wobert_chinese_plus_base")
|
62 |
+
self.model = WoBertForMaskedLM.from_pretrained("junnyu/wobert_chinese_plus_base", output_hidden_states=True).to(self.device)
|
63 |
+
self.w2v_model = gensim.models.KeyedVectors.load_word2vec_format('sgns.merge.word.bz2', binary=False, unicode_errors='ignore', limit=50000)
|
64 |
+
elif language == 'English':
|
65 |
+
self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
66 |
+
self.model = BertForMaskedLM.from_pretrained('bert-base-cased', output_hidden_states=True).to(self.device)
|
67 |
+
self.relatedness_model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli').to(self.device)
|
68 |
+
self.relatedness_tokenizer = RobertaTokenizer.from_pretrained('roberta-large-mnli')
|
69 |
+
self.w2v_model = api.load("glove-wiki-gigaword-100")
|
70 |
+
nltk.download('stopwords')
|
71 |
+
self.stop_words = set(stopwords.words('english'))
|
72 |
+
self.nlp = spacy.load('en_core_web_sm')
|
73 |
+
|
74 |
+
def cut(self,ori_text,text_len):
|
75 |
+
if self.language == 'Chinese':
|
76 |
+
if len(ori_text) > text_len+5:
|
77 |
+
ori_text = ori_text[:text_len+5]
|
78 |
+
if len(ori_text) < text_len-5:
|
79 |
+
return 'Short'
|
80 |
+
elif self.language == 'English':
|
81 |
+
tokens = self.tokenizer.tokenize(ori_text)
|
82 |
+
if len(tokens) > text_len+5:
|
83 |
+
ori_text = self.tokenizer.convert_tokens_to_string(tokens[:text_len+5])
|
84 |
+
if len(tokens) < text_len-5:
|
85 |
+
return 'Short'
|
86 |
+
return ori_text
|
87 |
+
else:
|
88 |
+
print(f'Unsupported Language:{self.language}')
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
def sent_tokenize(self,ori_text):
|
92 |
+
if self.language == 'Chinese':
|
93 |
+
return cut_sent(ori_text)
|
94 |
+
elif self.language == 'English':
|
95 |
+
return nltk.sent_tokenize(ori_text)
|
96 |
+
|
97 |
+
def pos_filter(self, tokens, masked_token_index, input_text):
|
98 |
+
if self.language == 'Chinese':
|
99 |
+
pairs = posseg.lcut(input_text)
|
100 |
+
pos_dict = {word: pos for word, pos in pairs}
|
101 |
+
pos_list_input = [pos for _, pos in pairs]
|
102 |
+
pos = pos_dict.get(tokens[masked_token_index], '')
|
103 |
+
if pos in self.cn_tag_black_list:
|
104 |
+
return False
|
105 |
+
else:
|
106 |
+
return True
|
107 |
+
elif self.language == 'English':
|
108 |
+
pos_tags = pos_tag(tokens)
|
109 |
+
pos = pos_tags[masked_token_index][1]
|
110 |
+
if pos not in self.en_tag_white_list:
|
111 |
+
return False
|
112 |
+
if is_subword(tokens[masked_token_index]) or is_subword(tokens[masked_token_index+1]) or (tokens[masked_token_index] in self.stop_words or tokens[masked_token_index] in string.punctuation):
|
113 |
+
return False
|
114 |
+
return True
|
115 |
+
|
116 |
+
def filter_special_candidate(self, top_n_tokens, tokens,masked_token_index,input_text):
|
117 |
+
if self.language == 'English':
|
118 |
+
filtered_tokens = [tok for tok in top_n_tokens if tok not in self.stop_words and tok not in string.punctuation and pos_tag([tok])[0][1] in self.en_tag_white_list and not is_subword(tok)]
|
119 |
+
|
120 |
+
lemmatized_tokens = []
|
121 |
+
# for token in filtered_tokens:
|
122 |
+
# doc = self.nlp(token)
|
123 |
+
# lemma = doc[0].lemma_ if doc[0].lemma_ != "-PRON-" else token
|
124 |
+
# lemmatized_tokens.append(lemma)
|
125 |
+
|
126 |
+
base_word = tokens[masked_token_index]
|
127 |
+
base_word_lemma = self.nlp(base_word)[0].lemma_
|
128 |
+
processed_tokens = [base_word]+[tok for tok in filtered_tokens if self.nlp(tok)[0].lemma_ != base_word_lemma]
|
129 |
+
return processed_tokens
|
130 |
+
elif self.language == 'Chinese':
|
131 |
+
pairs = posseg.lcut(input_text)
|
132 |
+
pos_dict = {word: pos for word, pos in pairs}
|
133 |
+
pos_list_input = [pos for _, pos in pairs]
|
134 |
+
pos = pos_dict.get(tokens[masked_token_index], '')
|
135 |
+
filtered_tokens = []
|
136 |
+
for tok in top_n_tokens:
|
137 |
+
watermarked_text_segtest = self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [tok] + tokens[masked_token_index+1:-1])
|
138 |
+
watermarked_text_segtest = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text_segtest)
|
139 |
+
pairs_tok = posseg.lcut(watermarked_text_segtest)
|
140 |
+
pos_dict_tok = {word: pos for word, pos in pairs_tok}
|
141 |
+
flag = pos_dict_tok.get(tok, '')
|
142 |
+
if flag not in self.cn_tag_black_list and flag == pos:
|
143 |
+
filtered_tokens.append(tok)
|
144 |
+
processed_tokens = filtered_tokens
|
145 |
+
return processed_tokens
|
146 |
+
|
147 |
+
def global_word_sim(self,word,ori_word):
|
148 |
+
try:
|
149 |
+
global_score = self.w2v_model.similarity(word,ori_word)
|
150 |
+
except KeyError:
|
151 |
+
global_score = 0
|
152 |
+
return global_score
|
153 |
+
|
154 |
+
def context_word_sim(self,init_candidates, tokens, masked_token_index, input_text):
|
155 |
+
original_input_tensor = self.tokenizer.encode(input_text,return_tensors='pt').to(self.device)
|
156 |
+
batch_input_ids = [[self.tokenizer.convert_tokens_to_ids(['[CLS]'] + tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]+ ['[SEP]'])] for token in init_candidates]
|
157 |
+
batch_input_tensors = torch.tensor(batch_input_ids).squeeze().to(self.device)
|
158 |
+
batch_input_tensors = torch.cat((batch_input_tensors,original_input_tensor),dim=0)
|
159 |
+
with torch.no_grad():
|
160 |
+
outputs = self.model(batch_input_tensors)
|
161 |
+
cos_sims = torch.zeros([len(init_candidates)]).to(self.device)
|
162 |
+
num_layers = len(outputs[1])
|
163 |
+
N = 8
|
164 |
+
i = masked_token_index
|
165 |
+
cos_sim_sum = 0
|
166 |
+
for layer in range(num_layers-N,num_layers):
|
167 |
+
ls_hidden_states = outputs[1][layer][0:len(init_candidates), i, :]
|
168 |
+
source_hidden_state = outputs[1][layer][len(init_candidates), i, :]
|
169 |
+
cos_sim_sum += F.cosine_similarity(source_hidden_state, ls_hidden_states, dim=1)
|
170 |
+
cos_sim_avg = cos_sim_sum / N
|
171 |
+
|
172 |
+
cos_sims += cos_sim_avg
|
173 |
+
return cos_sims.tolist()
|
174 |
+
|
175 |
+
def sentence_sim(self,init_candidates, tokens, masked_token_index, input_text):
|
176 |
+
if self.language == 'Chinese':
|
177 |
+
batch_sents = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]) for token in init_candidates]
|
178 |
+
batch_sentences = [re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', sent) for sent in batch_sents]
|
179 |
+
roberta_inputs = [input_text + '[SEP]' + s for s in batch_sentences]
|
180 |
+
elif self.language == 'English':
|
181 |
+
batch_sentences = [self.tokenizer.convert_tokens_to_string(tokens[1:masked_token_index] + [token] + tokens[masked_token_index+1:-1]) for token in init_candidates]
|
182 |
+
roberta_inputs = [input_text + '</s></s>' + s for s in batch_sentences]
|
183 |
+
|
184 |
+
encoded_dict = self.relatedness_tokenizer.batch_encode_plus(
|
185 |
+
roberta_inputs,
|
186 |
+
padding=True,
|
187 |
+
truncation=True,
|
188 |
+
max_length=512,
|
189 |
+
return_tensors='pt')
|
190 |
+
# Extract input_ids and attention_masks
|
191 |
+
input_ids = encoded_dict['input_ids'].to(self.device)
|
192 |
+
attention_masks = encoded_dict['attention_mask'].to(self.device)
|
193 |
+
with torch.no_grad():
|
194 |
+
outputs = self.relatedness_model(input_ids=input_ids, attention_mask=attention_masks)
|
195 |
+
logits = outputs[0]
|
196 |
+
probs = torch.softmax(logits, dim=1)
|
197 |
+
if self.language == 'Chinese':
|
198 |
+
relatedness_scores = probs[:, 1].tolist()
|
199 |
+
elif self.language == 'English':
|
200 |
+
relatedness_scores = probs[:, 2].tolist()
|
201 |
+
|
202 |
+
return relatedness_scores
|
203 |
+
|
204 |
+
def candidates_gen(self,tokens,masked_token_index,input_text,topk=64, dropout_prob=0.3):
|
205 |
+
input_ids_bert = self.tokenizer.convert_tokens_to_ids(tokens)
|
206 |
+
if not self.pos_filter(tokens,masked_token_index,input_text):
|
207 |
+
return []
|
208 |
+
masked_text = self.tokenizer.convert_tokens_to_string(tokens)
|
209 |
+
# Create a tensor of input IDs
|
210 |
+
input_tensor = torch.tensor([input_ids_bert]).to(self.device)
|
211 |
+
|
212 |
+
with torch.no_grad():
|
213 |
+
embeddings = self.model.bert.embeddings(input_tensor)
|
214 |
+
dropout = nn.Dropout2d(p=dropout_prob)
|
215 |
+
# Get the predicted logits
|
216 |
+
embeddings[:, masked_token_index, :] = dropout(embeddings[:, masked_token_index, :])
|
217 |
+
with torch.no_grad():
|
218 |
+
outputs = self.model(inputs_embeds=embeddings)
|
219 |
+
|
220 |
+
predicted_logits = outputs[0][0][masked_token_index]
|
221 |
+
|
222 |
+
# Set the number of top predictions to return
|
223 |
+
n = topk
|
224 |
+
# Get the top n predicted tokens and their probabilities
|
225 |
+
probs = torch.nn.functional.softmax(predicted_logits, dim=-1)
|
226 |
+
top_n_probs, top_n_indices = torch.topk(probs, n)
|
227 |
+
top_n_tokens = self.tokenizer.convert_ids_to_tokens(top_n_indices.tolist())
|
228 |
+
processed_tokens = self.filter_special_candidate(top_n_tokens,tokens,masked_token_index)
|
229 |
+
|
230 |
+
return processed_tokens
|
231 |
+
|
232 |
+
def filter_candidates(self, init_candidates, tokens, masked_token_index, input_text):
|
233 |
+
context_word_similarity_scores = self.context_word_sim(init_candidates, tokens, masked_token_index, input_text)
|
234 |
+
sentence_similarity_scores = self.sentence_sim(init_candidates, tokens, masked_token_index, input_text)
|
235 |
+
filtered_candidates = []
|
236 |
+
for idx, candidate in enumerate(init_candidates):
|
237 |
+
global_word_similarity_score = self.global_word_sim(tokens[masked_token_index], candidate)
|
238 |
+
word_similarity_score = self.lamda*context_word_similarity_scores[idx]+(1-self.lamda)*global_word_similarity_score
|
239 |
+
if word_similarity_score >= self.tau_word and sentence_similarity_scores[idx] >= self.tau_sent:
|
240 |
+
filtered_candidates.append((candidate, word_similarity_score))#, sentence_similarity_scores[idx]))
|
241 |
+
return filtered_candidates
|
242 |
+
|
243 |
+
def watermark_embed(self,text):
|
244 |
+
input_text = text
|
245 |
+
# Tokenize the input text
|
246 |
+
tokens = self.tokenizer.tokenize(input_text)
|
247 |
+
tokens = ['[CLS]'] + tokens + ['[SEP]']
|
248 |
+
masked_tokens=tokens.copy()
|
249 |
+
start_index = 1
|
250 |
+
end_index = len(tokens) - 1
|
251 |
+
for masked_token_index in range(start_index+1, end_index-1):
|
252 |
+
# pdb.set_trace()
|
253 |
+
binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tokens[masked_token_index])
|
254 |
+
if binary_encoding == 1:
|
255 |
+
continue
|
256 |
+
init_candidates = self.candidates_gen(tokens,masked_token_index,input_text, 32, 0.3)
|
257 |
+
if len(init_candidates) <=1:
|
258 |
+
continue
|
259 |
+
enhanced_candidates = self.filter_candidates(init_candidates,tokens,masked_token_index,input_text)
|
260 |
+
hash_top_tokens = enhanced_candidates.copy()
|
261 |
+
for i, tok in enumerate(enhanced_candidates):
|
262 |
+
binary_encoding = binary_encoding_function(tokens[masked_token_index - 1] + tok[0])
|
263 |
+
if binary_encoding != 1 or (is_similar(tok[0], tokens[masked_token_index])) or (tokens[masked_token_index - 1] in tok or tokens[masked_token_index + 1] in tok):
|
264 |
+
hash_top_tokens.remove(tok)
|
265 |
+
hash_top_tokens.sort(key=lambda x: x[1], reverse=True)
|
266 |
+
if len(hash_top_tokens) > 0:
|
267 |
+
selected_token = hash_top_tokens[0][0]
|
268 |
+
else:
|
269 |
+
selected_token = tokens[masked_token_index]
|
270 |
+
|
271 |
+
tokens[masked_token_index] = selected_token
|
272 |
+
watermarked_text = self.tokenizer.convert_tokens_to_string(tokens[1:-1])
|
273 |
+
if self.language == 'Chinese':
|
274 |
+
watermarked_text = re.sub(r'(?<=[\u4e00-\u9fff])\s+(?=[\u4e00-\u9fff,。?!、:])|(?<=[\u4e00-\u9fff,。?!、:])\s+(?=[\u4e00-\u9fff])', '', watermarked_text)
|
275 |
+
|
276 |
+
return watermarked_text
|
277 |
+
|
278 |
+
def embed(self, ori_text):
|
279 |
+
sents = self.sent_tokenize(ori_text)
|
280 |
+
sents = [s for s in sents if s.strip()]
|
281 |
+
num_sents = len(sents)
|
282 |
+
watermarked_text = ''
|
283 |
+
for i in range(0, num_sents, 2):
|
284 |
+
if i+1 < num_sents:
|
285 |
+
sent_pair = sents[i] + sents[i+1]
|
286 |
+
else:
|
287 |
+
sent_pair = sents[i]
|
288 |
+
if len(watermarked_text) == 0:
|
289 |
+
watermarked_text = self.watermark_embed(sent_pair)
|
290 |
+
else:
|
291 |
+
watermarked_text = watermarked_text + self.watermark_embed(sent_pair)
|
292 |
+
if len(self.get_encodings_fast(ori_text)) == 0:
|
293 |
+
return ''
|
294 |
+
return watermarked_text
|
295 |
+
|
296 |
+
def get_encodings_fast(self,text):
|
297 |
+
sents = self.sent_tokenize(text)
|
298 |
+
sents = [s for s in sents if s.strip()]
|
299 |
+
num_sents = len(sents)
|
300 |
+
encodings = []
|
301 |
+
for i in range(0, num_sents, 2):
|
302 |
+
if i+1 < num_sents:
|
303 |
+
sent_pair = sents[i] + sents[i+1]
|
304 |
+
else:
|
305 |
+
sent_pair = sents[i]
|
306 |
+
tokens = self.tokenizer.tokenize(sent_pair)
|
307 |
+
|
308 |
+
for index in range(1,len(tokens)-1):
|
309 |
+
if not self.pos_filter(tokens,index,text):
|
310 |
+
continue
|
311 |
+
bit = binary_encoding_function(tokens[index-1]+tokens[index])
|
312 |
+
encodings.append(bit)
|
313 |
+
return encodings
|
314 |
+
|
315 |
+
def watermark_detector_fast(self, text,alpha=0.05):
|
316 |
+
p = 0.5
|
317 |
+
encodings = self.get_encodings_fast(text)
|
318 |
+
n = len(encodings)
|
319 |
+
ones = sum(encodings)
|
320 |
+
z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
|
321 |
+
threshold = norm.ppf(1 - alpha, loc=0, scale=1)
|
322 |
+
p_value = norm.sf(z)
|
323 |
+
is_watermark = z >= threshold
|
324 |
+
return is_watermark, p_value, n, ones, z
|
325 |
+
|
326 |
+
def get_encodings_precise(self, text):
|
327 |
+
sents = self.sent_tokenize(text)
|
328 |
+
sents = [s for s in sents if s.strip()]
|
329 |
+
num_sents = len(sents)
|
330 |
+
encodings = []
|
331 |
+
for i in range(0, num_sents, 2):
|
332 |
+
if i+1 < num_sents:
|
333 |
+
sent_pair = sents[i] + sents[i+1]
|
334 |
+
else:
|
335 |
+
sent_pair = sents[i]
|
336 |
+
|
337 |
+
tokens = self.tokenizer.tokenize(sent_pair)
|
338 |
+
|
339 |
+
tokens = ['[CLS]'] + tokens + ['[SEP]']
|
340 |
+
|
341 |
+
masked_tokens=tokens.copy()
|
342 |
+
|
343 |
+
start_index = 1
|
344 |
+
end_index = len(tokens) - 1
|
345 |
+
|
346 |
+
for masked_token_index in range(start_index+1, end_index-1):
|
347 |
+
init_candidates = self.candidates_gen(tokens,masked_token_index,sent_pair, 8, 0)
|
348 |
+
if len(init_candidates) <=1:
|
349 |
+
continue
|
350 |
+
enhanced_candidates = self.filter_candidates(init_candidates,tokens,masked_token_index,sent_pair)
|
351 |
+
if len(enhanced_candidates) > 1:
|
352 |
+
bit = binary_encoding_function(tokens[masked_token_index-1]+tokens[masked_token_index])
|
353 |
+
encodings.append(bit)
|
354 |
+
return encodings
|
355 |
+
|
356 |
+
def watermark_detector_precise(self,text,alpha=0.05):
|
357 |
+
p = 0.5
|
358 |
+
encodings = self.get_encodings_precise(text)
|
359 |
+
n = len(encodings)
|
360 |
+
ones = sum(encodings)
|
361 |
+
if n == 0:
|
362 |
+
z = 0
|
363 |
+
else:
|
364 |
+
z = (ones - p * n) / (n * p * (1 - p)) ** 0.5
|
365 |
+
threshold = norm.ppf(1 - alpha, loc=0, scale=1)
|
366 |
+
p_value = norm.sf(z)
|
367 |
+
is_watermark = z >= threshold
|
368 |
+
return is_watermark, p_value, n, ones, z
|
options.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
# TODO: add help for the parameters
|
3 |
+
|
4 |
+
def get_parser_main_model():
|
5 |
+
parser = argparse.ArgumentParser()
|
6 |
+
# TODO: basic parameters training related
|
7 |
+
|
8 |
+
# for embed
|
9 |
+
parser.add_argument('--language', type=str, default='English', help='text language')
|
10 |
+
parser.add_argument('--mode', type=str, choices=['embed', 'fast_detect', 'precise_detect'], default='embed', help='Mode options: embed (default), fast_detect, precise_detect')
|
11 |
+
parser.add_argument('--tau_word', type=float, default=0.8, help='word-level similarity thresh')
|
12 |
+
parser.add_argument('--lamda', type=float, default=0.83, help='word-level similarity weight')
|
13 |
+
|
14 |
+
return parser
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gensim==4.3.0
|
2 |
+
gradio==3.30.0
|
3 |
+
jieba==0.42.1
|
4 |
+
nltk==3.8.1
|
5 |
+
paddle==1.0.2
|
6 |
+
paddlepaddle==2.4.2
|
7 |
+
python_Levenshtein==0.21.0
|
8 |
+
scipy==1.7.3
|
9 |
+
spacy==3.5.0
|
10 |
+
torch==1.11.0
|
11 |
+
transformers==4.26.1
|
12 |
+
wobert==0.0.1
|