Chris4K commited on
Commit
f9deaa6
·
verified ·
1 Parent(s): ae60821

Update services/strategy.py

Browse files
Files changed (1) hide show
  1. services/strategy.py +18 -7
services/strategy.py CHANGED
@@ -65,13 +65,24 @@ class BestOfN(GenerationStrategy):
65
  output = generator.model.generate(input_ids, **model_kwargs)
66
  response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
67
 
 
 
 
 
 
 
 
 
 
 
 
68
  # Tokenize the response for scoring with the PRM model
69
- response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
70
 
71
  # Pass the response to the PRM model based on its input requirements
72
- try:
73
  # Example 1: If PRM model accepts BatchEncoding
74
- prm_output = generator.prm_model(response_inputs)
75
 
76
  # Example 2: If PRM model expects only input_ids
77
  # prm_output = generator.prm_model(response_inputs["input_ids"])
@@ -79,10 +90,10 @@ class BestOfN(GenerationStrategy):
79
  # Example 3: If PRM model expects raw text
80
  # prm_output = generator.prm_model(response)
81
 
82
- except Exception as e:
83
- print(f"Error with PRM model: {e}")
84
- score = 0.0
85
- continue
86
 
87
  # Calculate the score based on PRM output structure
88
  score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0
 
65
  output = generator.model.generate(input_ids, **model_kwargs)
66
  response = generator.tokenizer.decode(output[0], skip_special_tokens=True)
67
 
68
+
69
+ # Simple inference example
70
+ prm_output = llm(
71
+ "<|system|>\n{system_message}</s>\n<|user|>\n{response}</s>\n<|assistant|>", # Prompt
72
+ max_tokens=512, # Generate up to 512 tokens
73
+ stop=["</s>"], # Example stop token - not necessarily correct for this specific model! Please check before using.
74
+ echo=True # Whether to echo the prompt
75
+ )
76
+
77
+
78
+
79
  # Tokenize the response for scoring with the PRM model
80
+ #response_inputs = generator.tokenizer(response, return_tensors="pt").to(generator.device)
81
 
82
  # Pass the response to the PRM model based on its input requirements
83
+ #try:
84
  # Example 1: If PRM model accepts BatchEncoding
85
+ # prm_output = generator.prm_model(response_inputs)
86
 
87
  # Example 2: If PRM model expects only input_ids
88
  # prm_output = generator.prm_model(response_inputs["input_ids"])
 
90
  # Example 3: If PRM model expects raw text
91
  # prm_output = generator.prm_model(response)
92
 
93
+ # except Exception as e:
94
+ # print(f"Error with PRM model: {e}")
95
+ # score = 0.0
96
+ # continue
97
 
98
  # Calculate the score based on PRM output structure
99
  score = prm_output.logits.mean().item() if hasattr(prm_output, 'logits') else 0.0