File size: 5,127 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
from typing import Any, Union, List, Dict
import re

from models.vss import VSS
from models.model import ModelForSTaRKQA
from stark_qa.tools.api import get_llm_output


def find_floating_number(text: str) -> List[float]:
    """
    Extract floating point numbers from the given text.

    Args:
        text (str): Input text from which to extract numbers.

    Returns:
        List[float]: List of extracted floating point numbers.
    """
    pattern = r'0\.\d+|1\.0'
    matches = re.findall(pattern, text)
    return [round(float(match), 4) for match in matches if float(match) <= 1.1]


class LLMReranker(ModelForSTaRKQA):
    
    def __init__(self,
                 kb,
                 llm_model: str,
                 emb_model: str,
                 query_emb_dir: str,
                 candidates_emb_dir: str,
                 sim_weight: float = 0.1,
                 max_cnt: int = 3,
                 max_k: int = 100,
                 device: str = 'cuda'):
        """
        Initializes the LLMReranker model.

        Args:
            kb (SemiStruct): Knowledge base.
            llm_model (str): Name of the LLM model.
            emb_model (str): Embedding model name.
            query_emb_dir (str): Directory to query embeddings.
            candidates_emb_dir (str): Directory to candidate embeddings.
            sim_weight (float): Weight for similarity score.
            max_cnt (int): Maximum count for retrying LLM response.
            max_k (int): Maximum number of top candidates to consider.
        """
        super(LLMReranker, self).__init__(kb)
        self.max_k = max_k
        self.emb_model = emb_model
        self.llm_model = llm_model
        self.sim_weight = sim_weight
        self.max_cnt = max_cnt

        self.query_emb_dir = query_emb_dir
        self.candidates_emb_dir = candidates_emb_dir
        self.parent_vss = VSS(kb, query_emb_dir, candidates_emb_dir, 
                              emb_model=emb_model, device=device)

    def forward(self, 
                query: Union[str, List[str]],
                query_id: Union[int, List[int]] = None,
                **kwargs: Any) -> Dict[int, float]:
        """
        Forward pass to compute predictions for the given query using LLM reranking.

        Args:
            query (Union[str, list]): Query string or a list of query strings.
            query_id (Union[int, list, None]): Query index (optional).

        Returns:
            pred_dict (dict): A dictionary of predicted scores or answer ids.
        """
        initial_score_dict = self.parent_vss(query, query_id)
        node_ids = list(initial_score_dict.keys())
        node_scores = list(initial_score_dict.values())

        # Get the ids with top k highest scores
        top_k_idx = torch.topk(
            torch.FloatTensor(node_scores), 
            min(self.max_k, len(node_scores)), 
            dim=-1
        ).indices.view(-1).tolist()
        top_k_node_ids = [node_ids[i] for i in top_k_idx]
        cand_len = len(top_k_node_ids)

        pred_dict = {}
        for idx, node_id in enumerate(top_k_node_ids):
            node_type = self.skb.get_node_type_by_id(node_id)
            prompt = (
                f'You are a helpful assistant that examines if a {node_type} '
                f'satisfies the requirements in a given query and assign a score from 0.0 to 1.0. '
                f'If the {node_type} does not satisfy any requirement in the query, the score should be 0.0. '
                f'If there exists explicit and strong evidence supporting that {node_type} '
                f'satisfies all aspects mentioned by the query, the score should be 1.0. If partial evidence or weak '
                f'evidence exists, the score should be between 0.0 and 1.0.\n'
                f'Here is the query:\n\"{query}\"\n'
                f'Here is the information about the {node_type}:\n' +
                self.skb.get_doc_info(node_id, add_rel=True) + '\n\n' +
                f'Please score the {node_type} based on how well it satisfies the query. '
                f'ONLY output the floating point score WITHOUT anything else. '
                f'Output: The numeric score of this {node_type} is: '
            )

            success = False
            for _ in range(self.max_cnt):
                try:
                    answer = get_llm_output(
                        prompt, 
                        self.llm_model, 
                        max_tokens=5
                    )
                    answer = find_floating_number(answer)
                    if len(answer) == 1:
                        answer = answer[0]
                        success = True
                        break
                except Exception as e:
                    print(f'Error: {e}, retrying...')

            if success:
                llm_score = float(answer)
                sim_score = (cand_len - idx) / cand_len
                score = llm_score + self.sim_weight * sim_score
                pred_dict[node_id] = score
            else:
                return initial_score_dict
        return pred_dict