guy-dar commited on
Commit
25ed840
·
1 Parent(s): 1aff1a7

first commit

Browse files
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. app.py +28 -0
  3. speaking_probes/generate.py +218 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from speaking_probes.generate import extract_gpt_parameters, speaking_probe
3
+
4
+
5
+ @st.cache
6
+ def load_model(model_name):
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ tokenizer.pad_token = tokenizer.eos_token
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
10
+ model_params = extract_gpt_parameters(model_name)
11
+ return model, model_params, tokenizer
12
+
13
+
14
+ model_name = st.selectbox("Select a model: ", options=['gpt2', 'gpt2-medium', 'gpt2-large'])
15
+ model, model_params, tokenizer = load_model(model_name)
16
+
17
+ neuron_layer = st.text_input("Layer: ")
18
+ neuron_dim = st.text_dim("Dim: ")
19
+ neurons = K_heads[int(neuron_layer), int(neuron_dim)]
20
+ prompt = st.text_area("Prompt: ")
21
+ submitted = st.button("Send!")
22
+
23
+
24
+ if submitted:
25
+ speaking_probe(model, model_params, tokenizer, prompt, *neurons, num_generations=1,
26
+ repetition_penalty=2.,
27
+ num_beams=3, min_length=1, do_sample=True,
28
+ max_new_tokens=100)
speaking_probes/generate.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from copy import deepcopy
3
+ import matplotlib.pyplot as plt
4
+ from torch import nn
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import AutoConfig, AutoTokenizer, AutoModel
8
+ import gc
9
+ import numpy as np
10
+ from copy import deepcopy
11
+ import matplotlib.pyplot as plt
12
+ from torch import nn
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import transformers
16
+ from transformers import AutoConfig, AutoTokenizer, AutoModel
17
+ from transformers import AutoModelForCausalLM
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer, MaxLengthCriteria, StoppingCriteriaList
19
+ from transformers import DataCollatorWithPadding
20
+ from transformers import LogitsProcessor, LogitsProcessorList, LogitsWarper
21
+ from torch.utils.data import DataLoader
22
+ from datasets import load_dataset
23
+ from tqdm.auto import tqdm
24
+ from dataclasses import dataclass
25
+ from argparse import ArgumentParser
26
+
27
+
28
+ @dataclass
29
+ class ModelParameters:
30
+ K_heads: torch.Tensor
31
+ num_layers: int
32
+ d_int: int
33
+
34
+
35
+ def extract_gpt_parameters(model):
36
+ emb = model.get_output_embeddings().weight.data.T
37
+ num_layers = model.config.n_layer
38
+ num_heads = model.config.n_head
39
+ hidden_dim = model.config.n_embd
40
+ head_size = hidden_dim // num_heads
41
+
42
+ K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
43
+ for j in range(num_layers)]).detach()
44
+ V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
45
+ for j in range(num_layers)]).detach()
46
+ W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
47
+ for j in range(num_layers)]).detach().chunk(3, dim=-1)
48
+ W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
49
+ for j in range(num_layers)]).detach()
50
+
51
+ K_heads = K.reshape(num_layers, -1, hidden_dim)
52
+ V_heads = V.reshape(num_layers, -1, hidden_dim)
53
+ d_int = K_heads.shape[1]
54
+
55
+ W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
56
+ W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
57
+ W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
58
+ W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
59
+
60
+ return ModelParameters(K_heads=K_heads, num_layers=num_layers, d_int=d_int)
61
+
62
+
63
+ def encode(token, tokenizer):
64
+ assert (type(token) == str)
65
+ encoded = tokenizer.encode(token)
66
+ assert (len(encoded) == 1)
67
+ return encoded[0]
68
+
69
+
70
+ def read_and_go(path):
71
+ with open(path, 'r') as f:
72
+ return f.read()
73
+
74
+
75
+ def extend_model_and_tokenizer(model, model_params, tokenizer, min_layer=0,
76
+ max_layer=None):
77
+ if max_layer is None:
78
+ max_layer = len(model_params.K_heads)-1
79
+ relevant_neurons = model_params.K_heads[min_layer:max_layer+1]
80
+ num_regular_tokens = len(tokenizer)
81
+ new_tokens = [f" <param_{layer}_{dim}>" for layer in range(min_layer, max_layer+1)
82
+ for dim in range(relevant_neurons.shape[1])]
83
+
84
+ tokenizer_extended = deepcopy(tokenizer)
85
+ model_extended = deepcopy(model)
86
+
87
+ tokenizer_extended.add_tokens(new_tokens)
88
+ model_extended.resize_token_embeddings(len(tokenizer_extended))
89
+ model_extended.transformer.wte.weight.data[-len(new_tokens):] = relevant_neurons.flatten(0, -2)
90
+ return model_extended, tokenizer_extended
91
+
92
+
93
+ # logit processors
94
+ class NeuronTokenBan(LogitsWarper):
95
+ def __init__(self, num_non_neuron_tokens, ban_penalty=-np.inf):
96
+ self.ban_penalty = ban_penalty
97
+ self.num_non_neuron_tokens = num_non_neuron_tokens
98
+
99
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
100
+ scores[:, self.num_non_neuron_tokens:] = self.ban_penalty
101
+ return scores
102
+
103
+
104
+ class ParamListStructureEnforcer(LogitsProcessor):
105
+ def __init__(self, tokenizer, num_regular_tokens):
106
+ self.tokenizer = tokenizer
107
+ self.num_regular_tokens = num_regular_tokens
108
+
109
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
110
+ last_input_id = input_ids[0, -1]
111
+ tokenizer = self.tokenizer
112
+ num_regular_tokens = self.num_regular_tokens
113
+
114
+ comma_id = encode(',', tokenizer)
115
+ eos_score, comma_score = deepcopy(scores[:, tokenizer.eos_token_id]), deepcopy(scores[:, comma_id])
116
+
117
+ if last_input_id >= num_regular_tokens:
118
+ scores[:] = -np.inf
119
+ scores[:, comma_id] = comma_score
120
+ else:
121
+ scores[:, :num_regular_tokens] = -np.inf
122
+
123
+ scores[:, tokenizer.eos_token_id] = eos_score
124
+ return scores
125
+
126
+
127
+ # speaking probe
128
+ def speaking_probe(model, model_params, tokenizer, prompt, *neurons,
129
+ num_generations=1, layer_range=None, bad_words_ids=[], output_neurons=False,
130
+ return_outputs=False, logits_processor=LogitsProcessorList([]), **kwargs):
131
+ num_non_neuron_tokens = len(tokenizer)
132
+ tokenizer_with_neurons = deepcopy(tokenizer)
133
+ has_extra_neurons = len(neurons) > 0
134
+ if has_extra_neurons:
135
+ tokenizer_with_neurons.add_tokens([f" <neuron{i+1 if i > 0 else ''}>" for i in range(len(neurons))])
136
+ model.resize_token_embeddings(len(tokenizer_with_neurons))
137
+ model.transformer.wte.weight.data[-len(neurons):] = torch.stack(neurons, dim=0)
138
+
139
+ logits_processor = deepcopy(logits_processor)
140
+
141
+ if not output_neurons:
142
+ logits_processor.append(NeuronTokenBan(num_non_neuron_tokens))
143
+
144
+ if layer_range is not None:
145
+ num_layers = model_params.num_layers
146
+ min_layer, max_layer = layer_range
147
+ bad_words_ids = deepcopy(bad_words_ids)
148
+ bad_words_ids.extend([[encode(f" <param_{i}_{j}>", tokenizer)]
149
+ for j in range(model_params.d_int)
150
+ for i in [*range(min_layer), *range(max_layer+1, num_layers)]])
151
+ if len(bad_words_ids) == 0:
152
+ bad_words_ids = None
153
+
154
+ print(prompt)
155
+ input_ids = tokenizer_with_neurons.encode(prompt, return_tensors='pt').to(model.device)
156
+ input_ids = torch.cat([deepcopy(input_ids) for _ in range(num_generations)], dim=0)
157
+ outputs = model.generate(input_ids, pad_token_id=model.config.eos_token_id,
158
+ logits_processor=logits_processor,
159
+ bad_words_ids=bad_words_ids,
160
+ return_dict_in_generate=True,
161
+ **kwargs)
162
+
163
+ decoded = tokenizer_with_neurons.batch_decode(outputs.sequences, skip_special_tokens=True)
164
+
165
+ for i in range(len(decoded)):
166
+ print("\n\ngenerate:", decoded[i])
167
+
168
+ if has_extra_neurons:
169
+ model.resize_token_embeddings(num_non_neuron_tokens)
170
+ model.transformer.wte.weight.data = model.transformer.wte.weight.data[:num_non_neuron_tokens]
171
+
172
+ if return_outputs:
173
+ return outputs
174
+
175
+
176
+ # main
177
+ if __name__ == "__main__":
178
+ parser = ArgumentParser()
179
+ parser.add_argument('-p', '--prompt', type=str, default=None)
180
+ parser.add_argument('--model', type=str, default='gpt2-large')
181
+ parser.add_argument('--neuron', type=str, default=None)
182
+ parser.add_argument('--device', type=str, default='cuda')
183
+ parser.add_argument('--prompt_file', type=str, default=None)
184
+ parser.add_argument('--no_sample', action='store_true')
185
+ parser.add_argument('--num_beams', type=int, default=3)
186
+ parser.add_argument('--num_generations', type=int, default=1)
187
+ parser.add_argument('--min_length', type=int, default=20)
188
+ parser.add_argument('--top_p', type=float, default=None)
189
+ parser.add_argument('--top_k', type=int, default=None)
190
+ parser.add_argument('--max_length', type=int, default=100)
191
+ parser.add_argument('--max_new_tokens', type=int, default=None)
192
+ parser.add_argument('--repetition_penalty', type=float, default=2.)
193
+
194
+ args = parser.parse_args()
195
+ # TODO: first make them mutually exclusive
196
+ if args.max_new_tokens is not None:
197
+ args.max_length = None
198
+
199
+
200
+ print("loading model and tokenizer...")
201
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
202
+ tokenizer.pad_token = tokenizer.eos_token
203
+ model = AutoModelForCausalLM.from_pretrained(args.model)
204
+ model_params = extract_gpt_parameters(model)
205
+ prompt = args.prompt or read_and_go(args.prompt_file)
206
+ device = args.device
207
+ model = model.to(device)
208
+
209
+ i1, i2 = map(lambda x: int(x.strip()), args.neuron.split(','))
210
+ neuron = model_params.K_heads[i1, i2]
211
+ neurons = [neuron]
212
+
213
+ speaking_probe(model, model_params, tokenizer, prompt, *neurons,
214
+ num_generations=args.num_generations,
215
+ repetition_penalty=args.repetition_penalty,
216
+ num_beams=args.num_beams, top_p=args.top_p, top_k=args.top_k,
217
+ min_length=args.min_length, do_sample=not args.no_sample,
218
+ max_length=args.max_length, max_new_tokens=args.max_new_tokens)