File size: 5,127 Bytes
7bf4b88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
from typing import Any, Union, List, Dict
import re
from models.vss import VSS
from models.model import ModelForSTaRKQA
from stark_qa.tools.api import get_llm_output
def find_floating_number(text: str) -> List[float]:
"""
Extract floating point numbers from the given text.
Args:
text (str): Input text from which to extract numbers.
Returns:
List[float]: List of extracted floating point numbers.
"""
pattern = r'0\.\d+|1\.0'
matches = re.findall(pattern, text)
return [round(float(match), 4) for match in matches if float(match) <= 1.1]
class LLMReranker(ModelForSTaRKQA):
def __init__(self,
kb,
llm_model: str,
emb_model: str,
query_emb_dir: str,
candidates_emb_dir: str,
sim_weight: float = 0.1,
max_cnt: int = 3,
max_k: int = 100,
device: str = 'cuda'):
"""
Initializes the LLMReranker model.
Args:
kb (SemiStruct): Knowledge base.
llm_model (str): Name of the LLM model.
emb_model (str): Embedding model name.
query_emb_dir (str): Directory to query embeddings.
candidates_emb_dir (str): Directory to candidate embeddings.
sim_weight (float): Weight for similarity score.
max_cnt (int): Maximum count for retrying LLM response.
max_k (int): Maximum number of top candidates to consider.
"""
super(LLMReranker, self).__init__(kb)
self.max_k = max_k
self.emb_model = emb_model
self.llm_model = llm_model
self.sim_weight = sim_weight
self.max_cnt = max_cnt
self.query_emb_dir = query_emb_dir
self.candidates_emb_dir = candidates_emb_dir
self.parent_vss = VSS(kb, query_emb_dir, candidates_emb_dir,
emb_model=emb_model, device=device)
def forward(self,
query: Union[str, List[str]],
query_id: Union[int, List[int]] = None,
**kwargs: Any) -> Dict[int, float]:
"""
Forward pass to compute predictions for the given query using LLM reranking.
Args:
query (Union[str, list]): Query string or a list of query strings.
query_id (Union[int, list, None]): Query index (optional).
Returns:
pred_dict (dict): A dictionary of predicted scores or answer ids.
"""
initial_score_dict = self.parent_vss(query, query_id)
node_ids = list(initial_score_dict.keys())
node_scores = list(initial_score_dict.values())
# Get the ids with top k highest scores
top_k_idx = torch.topk(
torch.FloatTensor(node_scores),
min(self.max_k, len(node_scores)),
dim=-1
).indices.view(-1).tolist()
top_k_node_ids = [node_ids[i] for i in top_k_idx]
cand_len = len(top_k_node_ids)
pred_dict = {}
for idx, node_id in enumerate(top_k_node_ids):
node_type = self.skb.get_node_type_by_id(node_id)
prompt = (
f'You are a helpful assistant that examines if a {node_type} '
f'satisfies the requirements in a given query and assign a score from 0.0 to 1.0. '
f'If the {node_type} does not satisfy any requirement in the query, the score should be 0.0. '
f'If there exists explicit and strong evidence supporting that {node_type} '
f'satisfies all aspects mentioned by the query, the score should be 1.0. If partial evidence or weak '
f'evidence exists, the score should be between 0.0 and 1.0.\n'
f'Here is the query:\n\"{query}\"\n'
f'Here is the information about the {node_type}:\n' +
self.skb.get_doc_info(node_id, add_rel=True) + '\n\n' +
f'Please score the {node_type} based on how well it satisfies the query. '
f'ONLY output the floating point score WITHOUT anything else. '
f'Output: The numeric score of this {node_type} is: '
)
success = False
for _ in range(self.max_cnt):
try:
answer = get_llm_output(
prompt,
self.llm_model,
max_tokens=5
)
answer = find_floating_number(answer)
if len(answer) == 1:
answer = answer[0]
success = True
break
except Exception as e:
print(f'Error: {e}, retrying...')
if success:
llm_score = float(answer)
sim_score = (cand_len - idx) / cand_len
score = llm_score + self.sim_weight * sim_score
pred_dict[node_id] = score
else:
return initial_score_dict
return pred_dict
|