Spaces:
Runtime error
Runtime error
fix mentioning param twice
Browse files
speaking_probes/generate.py
CHANGED
@@ -82,7 +82,7 @@ def extract_gpt_j_parameters(model, full=False):
|
|
82 |
hidden_dim = model.config.n_embd
|
83 |
head_size = hidden_dim // num_heads
|
84 |
|
85 |
-
K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.fc_in.weight")
|
86 |
for j in range(num_layers)]).detach()
|
87 |
|
88 |
K_heads = K.reshape(num_layers, -1, hidden_dim)
|
@@ -93,20 +93,20 @@ def extract_gpt_j_parameters(model, full=False):
|
|
93 |
|
94 |
if full:
|
95 |
raise NotImplementedError
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
# W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
|
100 |
# for j in range(num_layers)]).detach().chunk(3, dim=-1)
|
101 |
# W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
|
102 |
# for j in range(num_layers)]).detach()
|
103 |
|
104 |
-
|
105 |
# model_params.W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
|
106 |
# model_params.W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
|
107 |
# model_params.W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
|
108 |
# model_params.W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
|
109 |
-
|
110 |
return model_params
|
111 |
|
112 |
|
@@ -177,8 +177,8 @@ class ParamListStructureEnforcer(LogitsProcessor):
|
|
177 |
# speaking probe
|
178 |
def _preprocess_prompt(model_params, prompt):
|
179 |
K_heads = model_params.K_heads
|
180 |
-
prompt = re.sub(r'([^ ]|\A)(<neuron
|
181 |
-
param_neuron_idxs = [(int(a), int(b)) for a, b in re.findall(r' <param_(\d+)_(\d+)>', prompt)]
|
182 |
param_neuron_tokens = [f' <param_{a}_{b}>' for a, b in param_neuron_idxs]
|
183 |
param_neurons = [deepcopy(K_heads[a, b]) for a, b in param_neuron_idxs]
|
184 |
return prompt, param_neuron_tokens, param_neurons
|
|
|
82 |
hidden_dim = model.config.n_embd
|
83 |
head_size = hidden_dim // num_heads
|
84 |
|
85 |
+
K = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.fc_in.weight")
|
86 |
for j in range(num_layers)]).detach()
|
87 |
|
88 |
K_heads = K.reshape(num_layers, -1, hidden_dim)
|
|
|
93 |
|
94 |
if full:
|
95 |
raise NotImplementedError
|
96 |
+
emb = model.get_output_embeddings().weight.data.T
|
97 |
+
V = torch.cat([model.get_parameter(f"transformer.h.{j}.mlp.c_out.weight").T
|
98 |
+
for j in range(num_layers)]).detach()
|
99 |
# W_Q, W_K, W_V = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_attn.weight")
|
100 |
# for j in range(num_layers)]).detach().chunk(3, dim=-1)
|
101 |
# W_O = torch.cat([model.get_parameter(f"transformer.h.{j}.attn.c_proj.weight")
|
102 |
# for j in range(num_layers)]).detach()
|
103 |
|
104 |
+
model_params.V_heads = V.reshape(num_layers, -1, hidden_dim)
|
105 |
# model_params.W_V_heads = W_V.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
|
106 |
# model_params.W_O_heads = W_O.reshape(num_layers, num_heads, head_size, hidden_dim)
|
107 |
# model_params.W_Q_heads = W_Q.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
|
108 |
# model_params.W_K_heads = W_K.reshape(num_layers, hidden_dim, num_heads, head_size).permute(0, 2, 1, 3)
|
109 |
+
model_params.emb = emb
|
110 |
return model_params
|
111 |
|
112 |
|
|
|
177 |
# speaking probe
|
178 |
def _preprocess_prompt(model_params, prompt):
|
179 |
K_heads = model_params.K_heads
|
180 |
+
prompt = re.sub(r'([^ ]|\A)(<neuron\d*>|<param_\d+_\d+>)', lambda m: f'{m.group(1)} {m.group(2)}', prompt)
|
181 |
+
param_neuron_idxs = set([(int(a), int(b)) for a, b in re.findall(r' <param_(\d+)_(\d+)>', prompt)])
|
182 |
param_neuron_tokens = [f' <param_{a}_{b}>' for a, b in param_neuron_idxs]
|
183 |
param_neurons = [deepcopy(K_heads[a, b]) for a, b in param_neuron_idxs]
|
184 |
return prompt, param_neuron_tokens, param_neurons
|