guy-dar commited on
Commit
93ea3ba
·
1 Parent(s): 1b1fb8d

add param tags instead of neurons

Browse files
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. app.py +4 -4
  3. speaking_probes/generate.py +84 -16
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -17,17 +17,17 @@ def load_model(model_name):
17
  col1, col2, col3, *_ = st.columns(5)
18
  model_name = col1.selectbox("Select a model: ", options=['gpt2', 'gpt2-medium', 'gpt2-large'])
19
  model, model_params, tokenizer = load_model(model_name)
20
- neuron_layer = col2.text_input("Layer: ", value='0')
21
- neuron_dim = col3.text_input("Dim: ", value='0')
 
22
 
23
- neurons = model_params.K_heads[int(neuron_layer), int(neuron_dim)]
24
  prompt = st.text_area("Prompt: ")
25
  submitted = st.button("Send!")
26
 
27
  if submitted:
28
  with st.spinner('Wait for it..'):
29
  model, model_params, tokenizer = map(deepcopy, (model, model_params, tokenizer))
30
- decoded = speaking_probe(model, model_params, tokenizer, prompt, *neurons,
31
  repetition_penalty=2., num_generations=3,
32
  min_length=1, do_sample=True,
33
  max_new_tokens=100)
 
17
  col1, col2, col3, *_ = st.columns(5)
18
  model_name = col1.selectbox("Select a model: ", options=['gpt2', 'gpt2-medium', 'gpt2-large'])
19
  model, model_params, tokenizer = load_model(model_name)
20
+ # neuron_layer = col2.text_input("Layer: ", value='0')
21
+ # neuron_dim = col3.text_input("Dim: ", value='0')
22
+ # neurons = model_params.K_heads[int(neuron_layer), int(neuron_dim)]
23
 
 
24
  prompt = st.text_area("Prompt: ")
25
  submitted = st.button("Send!")
26
 
27
  if submitted:
28
  with st.spinner('Wait for it..'):
29
  model, model_params, tokenizer = map(deepcopy, (model, model_params, tokenizer))
30
+ decoded = speaking_probe(model, model_params, tokenizer, prompt,
31
  repetition_penalty=2., num_generations=3,
32
  min_length=1, do_sample=True,
33
  max_new_tokens=100)
speaking_probes/generate.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  from copy import deepcopy
3
  import matplotlib.pyplot as plt
@@ -29,11 +30,20 @@ from argparse import ArgumentParser
29
  class ModelParameters:
30
  K_heads: torch.Tensor
31
  num_layers: int
32
- d_int: int
33
-
 
 
 
 
 
 
 
 
 
 
34
 
35
- def extract_gpt_parameters(model):
36
- emb = model.get_output_embeddings().weight.data.T
37
  num_layers = model.config.n_layer
38
  num_heads = model.config.n_head
39
  hidden_dim = model.config.n_embd
@@ -41,23 +51,63 @@ def extract_gpt_parameters(model):
41
 
42
  K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
43
  for j in range(num_layers)]).detach()
44
- V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
 
 
 
 
 
 
 
 
 
45
  for j in range(num_layers)]).detach()
46
- W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
47
- for j in range(num_layers)]).detach().chunk(3, dim=-1)
48
- W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  for j in range(num_layers)]).detach()
50
 
51
  K_heads = K.reshape(num_layers, -1, hidden_dim)
52
- V_heads = V.reshape(num_layers, -1, hidden_dim)
53
  d_int = K_heads.shape[1]
 
 
 
54
 
55
- W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
56
- W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
57
- W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
58
- W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
 
 
 
 
 
59
 
60
- return ModelParameters(K_heads=K_heads, num_layers=num_layers, d_int=d_int)
 
 
 
 
 
 
61
 
62
 
63
  def encode(token, tokenizer):
@@ -125,14 +175,29 @@ class ParamListStructureEnforcer(LogitsProcessor):
125
 
126
 
127
  # speaking probe
 
 
 
 
 
 
 
 
 
128
  def speaking_probe(model, model_params, tokenizer, prompt, *neurons,
129
  num_generations=1, layer_range=None, bad_words_ids=[], output_neurons=False,
130
  return_outputs=False, logits_processor=LogitsProcessorList([]), **kwargs):
131
  num_non_neuron_tokens = len(tokenizer)
132
  tokenizer_with_neurons = deepcopy(tokenizer)
 
 
 
 
 
 
133
  has_extra_neurons = len(neurons) > 0
134
  if has_extra_neurons:
135
- tokenizer_with_neurons.add_tokens([f" <neuron{i+1 if i > 0 else ''}>" for i in range(len(neurons))])
136
  model.resize_token_embeddings(len(tokenizer_with_neurons))
137
  model.transformer.wte.weight.data[-len(neurons):] = torch.stack(neurons, dim=0)
138
 
@@ -160,7 +225,8 @@ def speaking_probe(model, model_params, tokenizer, prompt, *neurons,
160
  **kwargs)
161
 
162
  decoded = tokenizer_with_neurons.batch_decode(outputs.sequences, skip_special_tokens=True)
163
-
 
164
  if has_extra_neurons:
165
  model.resize_token_embeddings(num_non_neuron_tokens)
166
  model.transformer.wte.weight.data = model.transformer.wte.weight.data[:num_non_neuron_tokens]
@@ -188,6 +254,7 @@ if __name__ == "__main__":
188
  parser.add_argument('--max_length', type=int, default=100)
189
  parser.add_argument('--max_new_tokens', type=int, default=None)
190
  parser.add_argument('--repetition_penalty', type=float, default=2.)
 
191
 
192
  args = parser.parse_args()
193
  # TODO: first make them mutually exclusive
@@ -213,6 +280,7 @@ if __name__ == "__main__":
213
  num_generations=args.num_generations,
214
  repetition_penalty=args.repetition_penalty,
215
  num_beams=args.num_beams, top_p=args.top_p, top_k=args.top_k,
 
216
  min_length=args.min_length, do_sample=not args.no_sample,
217
  max_length=args.max_length, max_new_tokens=args.max_new_tokens)
218
  for i in range(len(decoded)):
 
1
+ import re
2
  import numpy as np
3
  from copy import deepcopy
4
  import matplotlib.pyplot as plt
 
30
  class ModelParameters:
31
  K_heads: torch.Tensor
32
  num_layers: int
33
+ d_int: int
34
+ num_heads: int
35
+ hidden_dim: int
36
+ head_size: int
37
+ V_heads: torch.Tensor = None
38
+ W_Q_heads: torch.Tensor = None
39
+ W_K_heads: torch.Tensor = None
40
+ W_V_heads: torch.Tensor = None
41
+ W_O_heads: torch.Tensor = None
42
+ emb: torch.Tensor = None
43
+
44
+
45
 
46
+ def extract_gpt_parameters(model, full=False):
 
47
  num_layers = model.config.n_layer
48
  num_heads = model.config.n_head
49
  hidden_dim = model.config.n_embd
 
51
 
52
  K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_fc.weight").T
53
  for j in range(num_layers)]).detach()
54
+
55
+ K_heads = K.reshape(num_layers, -1, hidden_dim)
56
+ d_int = K_heads.shape[1]
57
+ model_params = ModelParameters(K_heads=K_heads, num_layers=num_layers, d_int=d_int,
58
+ hidden_dim=hidden_dim, head_size=head_size,
59
+ num_heads=num_heads)
60
+
61
+ if full:
62
+ emb = model.get_output_embeddings().weight.data.T
63
+ V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
64
  for j in range(num_layers)]).detach()
65
+ W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
66
+ for j in range(num_layers)]).detach().chunk(3, dim=-1)
67
+ W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
68
+ for j in range(num_layers)]).detach()
69
+
70
+ model_params.V_heads = V.reshape(num_layers, -1, hidden_dim)
71
+ model_params.W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
72
+ model_params.W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
73
+ model_params.W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
74
+ model_params.W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
75
+ model_params.emb = emb
76
+ return model_params
77
+
78
+
79
+ def extract_gpt_j_parameters(model, full=False):
80
+ num_layers = model.config.n_layer
81
+ num_heads = model.config.n_head
82
+ hidden_dim = model.config.n_embd
83
+ head_size = hidden_dim // num_heads
84
+
85
+ K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.fc_in.weight").T
86
  for j in range(num_layers)]).detach()
87
 
88
  K_heads = K.reshape(num_layers, -1, hidden_dim)
 
89
  d_int = K_heads.shape[1]
90
+ model_params = ModelParameters(K_heads=K_heads, num_layers=num_layers, d_int=d_int,
91
+ hidden_dim=hidden_dim, head_size=head_size,
92
+ num_heads=num_heads)
93
 
94
+ if full:
95
+ raise NotImplementedError
96
+ # emb = model.get_output_embeddings().weight.data.T
97
+ # V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
98
+ # for j in range(num_layers)]).detach()
99
+ # W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
100
+ # for j in range(num_layers)]).detach().chunk(3, dim=-1)
101
+ # W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
102
+ # for j in range(num_layers)]).detach()
103
 
104
+ # model_params.V_heads = V.reshape(num_layers, -1, hidden_dim)
105
+ # model_params.W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
106
+ # model_params.W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
107
+ # model_params.W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
108
+ # model_params.W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
109
+ # model_params.emb = emb
110
+ return model_params
111
 
112
 
113
  def encode(token, tokenizer):
 
175
 
176
 
177
  # speaking probe
178
+ def _preprocess_prompt(model_params, prompt):
179
+ K_heads = model_params.K_heads
180
+ prompt = re.sub(r'([^ ]|\A)(<neuron>|<param_\d+_\d+>)', lambda m: f'{m.group(1)} {m.group(2)}', prompt)
181
+ param_neuron_idxs = [(int(a), int(b)) for a, b in re.findall(r' <param_(\d+)_(\d+)>', prompt)]
182
+ param_neuron_tokens = [f' <param_{a}_{b}>' for a, b in param_neuron_idxs]
183
+ param_neurons = [deepcopy(K_heads[a, b]) for a, b in param_neuron_idxs]
184
+ return prompt, param_neuron_tokens, param_neurons
185
+
186
+
187
  def speaking_probe(model, model_params, tokenizer, prompt, *neurons,
188
  num_generations=1, layer_range=None, bad_words_ids=[], output_neurons=False,
189
  return_outputs=False, logits_processor=LogitsProcessorList([]), **kwargs):
190
  num_non_neuron_tokens = len(tokenizer)
191
  tokenizer_with_neurons = deepcopy(tokenizer)
192
+
193
+ # adding neurons to the tokenizer
194
+ neuron_tokens = [f" <neuron{i+1 if i > 0 else ''}>" for i in range(len(neurons))]
195
+ prompt, param_neuron_tokens, param_neurons = _preprocess_prompt(model_params, prompt)
196
+ neuron_tokens.extend(param_neuron_tokens)
197
+ neurons = neurons + tuple(param_neurons)
198
  has_extra_neurons = len(neurons) > 0
199
  if has_extra_neurons:
200
+ tokenizer_with_neurons.add_tokens(neuron_tokens)
201
  model.resize_token_embeddings(len(tokenizer_with_neurons))
202
  model.transformer.wte.weight.data[-len(neurons):] = torch.stack(neurons, dim=0)
203
 
 
225
  **kwargs)
226
 
227
  decoded = tokenizer_with_neurons.batch_decode(outputs.sequences, skip_special_tokens=True)
228
+
229
+ # TODO: add `finally` statement
230
  if has_extra_neurons:
231
  model.resize_token_embeddings(num_non_neuron_tokens)
232
  model.transformer.wte.weight.data = model.transformer.wte.weight.data[:num_non_neuron_tokens]
 
254
  parser.add_argument('--max_length', type=int, default=100)
255
  parser.add_argument('--max_new_tokens', type=int, default=None)
256
  parser.add_argument('--repetition_penalty', type=float, default=2.)
257
+ parser.add_argument('--temperature', type=float, default=1.)
258
 
259
  args = parser.parse_args()
260
  # TODO: first make them mutually exclusive
 
280
  num_generations=args.num_generations,
281
  repetition_penalty=args.repetition_penalty,
282
  num_beams=args.num_beams, top_p=args.top_p, top_k=args.top_k,
283
+ temperature=args.temperature,
284
  min_length=args.min_length, do_sample=not args.no_sample,
285
  max_length=args.max_length, max_new_tokens=args.max_new_tokens)
286
  for i in range(len(decoded)):