guy-dar commited on
Commit
cea7b8e
·
1 Parent(s): defd911

fix mentioning param twice

Browse files
Files changed (1) hide show
  1. speaking_probes/generate.py +8 -8
speaking_probes/generate.py CHANGED
@@ -82,7 +82,7 @@ def extract_gpt_j_parameters(model, full=False):
82
  hidden_dim = model.config.n_embd
83
  head_size = hidden_dim // num_heads
84
 
85
- K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.fc_in.weight").T
86
  for j in range(num_layers)]).detach()
87
 
88
  K_heads = K.reshape(num_layers, -1, hidden_dim)
@@ -93,20 +93,20 @@ def extract_gpt_j_parameters(model, full=False):
93
 
94
  if full:
95
  raise NotImplementedError
96
- # emb = model.get_output_embeddings().weight.data.T
97
- # V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_proj.weight")
98
- # for j in range(num_layers)]).detach()
99
  # W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
100
  # for j in range(num_layers)]).detach().chunk(3, dim=-1)
101
  # W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
102
  # for j in range(num_layers)]).detach()
103
 
104
- # model_params.V_heads = V.reshape(num_layers, -1, hidden_dim)
105
  # model_params.W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
106
  # model_params.W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
107
  # model_params.W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
108
  # model_params.W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
109
- # model_params.emb = emb
110
  return model_params
111
 
112
 
@@ -177,8 +177,8 @@ class ParamListStructureEnforcer(LogitsProcessor):
177
  # speaking probe
178
  def _preprocess_prompt(model_params, prompt):
179
  K_heads = model_params.K_heads
180
- prompt = re.sub(r'([^ ]|\A)(<neuron>|<param_\d+_\d+>)', lambda m: f'{m.group(1)} {m.group(2)}', prompt)
181
- param_neuron_idxs = [(int(a), int(b)) for a, b in re.findall(r' <param_(\d+)_(\d+)>', prompt)]
182
  param_neuron_tokens = [f' <param_{a}_{b}>' for a, b in param_neuron_idxs]
183
  param_neurons = [deepcopy(K_heads[a, b]) for a, b in param_neuron_idxs]
184
  return prompt, param_neuron_tokens, param_neurons
 
82
  hidden_dim = model.config.n_embd
83
  head_size = hidden_dim // num_heads
84
 
85
+ K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.fc_in.weight")
86
  for j in range(num_layers)]).detach()
87
 
88
  K_heads = K.reshape(num_layers, -1, hidden_dim)
 
93
 
94
  if full:
95
  raise NotImplementedError
96
+ emb = model.get_output_embeddings().weight.data.T
97
+ V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_out.weight").T
98
+ for j in range(num_layers)]).detach()
99
  # W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
100
  # for j in range(num_layers)]).detach().chunk(3, dim=-1)
101
  # W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
102
  # for j in range(num_layers)]).detach()
103
 
104
+ model_params.V_heads = V.reshape(num_layers, -1, hidden_dim)
105
  # model_params.W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
106
  # model_params.W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
107
  # model_params.W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
108
  # model_params.W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
109
+ model_params.emb = emb
110
  return model_params
111
 
112
 
 
177
  # speaking probe
178
  def _preprocess_prompt(model_params, prompt):
179
  K_heads = model_params.K_heads
180
+ prompt = re.sub(r'([^ ]|\A)(<neuron\d*>|<param_\d+_\d+>)', lambda m: f'{m.group(1)} {m.group(2)}', prompt)
181
+ param_neuron_idxs = set([(int(a), int(b)) for a, b in re.findall(r' <param_(\d+)_(\d+)>', prompt)])
182
  param_neuron_tokens = [f' <param_{a}_{b}>' for a, b in param_neuron_idxs]
183
  param_neurons = [deepcopy(K_heads[a, b]) for a, b in param_neuron_idxs]
184
  return prompt, param_neuron_tokens, param_neurons