Update modeling_custom.py
Browse files- modeling_custom.py +1 -1
modeling_custom.py
CHANGED
@@ -166,7 +166,7 @@ class LlamaForRewardModelWithGating(LlamaPreTrainedModel):
|
|
166 |
with torch.autocast(device_type=rewards.device.type, dtype=torch.float32):
|
167 |
# [B, num_quantiles, num_objectives]
|
168 |
reward_quantiles = torch.mean(
|
169 |
-
gating_output.unsqueeze(-1).repeat(1, 1, self.
|
170 |
dim=1
|
171 |
)
|
172 |
|
|
|
166 |
with torch.autocast(device_type=rewards.device.type, dtype=torch.float32):
|
167 |
# [B, num_quantiles, num_objectives]
|
168 |
reward_quantiles = torch.mean(
|
169 |
+
gating_output.unsqueeze(-1).repeat(1, 1, self.num_quantiles) * rewards,
|
170 |
dim=1
|
171 |
)
|
172 |
|