Chris4K commited on
Commit
ae60821
·
verified ·
1 Parent(s): 3a7547b

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +18 -8
services/strategy.py CHANGED
@@ -68,14 +68,23 @@ class BestOfN(GenerationStrategy):
68
  # Tokenize the response for scoring with the PRM model
69
  response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
70
 
71
- # Extract the necessary inputs for prm_model
72
- prm_input_ids = response_inputs["input_ids"] # Always present
73
- attention_mask = response_inputs["attention_mask"] # Optional, depending on your model
74
-
75
- # Pass only the required tensors to prm_model
76
- prm_output = generator.prm_model(input_ids=prm_input_ids, attention_mask=attention_mask)
77
-
78
- # Check the expected output structure for prm_model and use it accordingly
 
 
 
 
 
 
 
 
 
79
  score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0
80
 
81
  # Append the response and its score
@@ -86,6 +95,7 @@ class BestOfN(GenerationStrategy):
86
 
87
 
88
 
 
89
  class BeamSearch(GenerationStrategy):
90
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
91
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)
 
68
  # Tokenize the response for scoring with the PRM model
69
  response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
70
 
71
+ # Pass the response to the PRM model based on its input requirements
72
+ try:
73
+ # Example 1: If PRM model accepts BatchEncoding
74
+ prm_output = generator.prm_model(response_inputs)
75
+
76
+ # Example 2: If PRM model expects only input_ids
77
+ # prm_output = generator.prm_model(response_inputs["input_ids"])
78
+
79
+ # Example 3: If PRM model expects raw text
80
+ # prm_output = generator.prm_model(response)
81
+
82
+ except Exception as e:
83
+ print(f"Error with PRM model: {e}")
84
+ score = 0.0
85
+ continue
86
+
87
+ # Calculate the score based on PRM output structure
88
  score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0
89
 
90
  # Append the response and its score
 
95
 
96
 
97
 
98
+
99
  class BeamSearch(GenerationStrategy):
100
  def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str:
101
  input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device)