File size: 3,480 Bytes
b24d496
 
 
 
 
 
67beed8
b24d496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67beed8
 
 
b24d496
67beed8
b24d496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67beed8
b24d496
 
67beed8
 
b24d496
 
67beed8
b24d496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import Dict
from business_transaction_map.common.constants import DEVICE, DO_NORMALIZATION, COLUMN_TYPE_DOC_MAP
from business_transaction_map.components.faiss_vector_database import FaissVectorDatabase
from business_transaction_map.components.embedding_extraction import EmbeddingExtractor
import os
from prompts import BUSINESS_TRANSACTION_PROMPT
from llm.common import LlmApi


db_files_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_DATA_PATH", "transaction_maps_search_data/csv/карта_проводок_new.pkl")

model_path = os.environ.get("GLOBAL_TRANSACTION_MAPS_MODEL_PATH", "")

class TransactionMapsSearch:
    
    def __init__(self,
                 model_name_or_path: str = model_path,
                 device: str = DEVICE):
        
        self.device = device
        self.model = self.load_model(
            model_name_or_path=model_name_or_path, 
            device=device
            )
        self.database = FaissVectorDatabase(str(db_files_path))
        
    @staticmethod
    async def extract_business_transaction_with_llm(question: str, llm_api: LlmApi) -> str:
        prompt = BUSINESS_TRANSACTION_PROMPT.replace('{{ЗАПРОС}}', question)
        res = await llm_api.predict(prompt)
        
        return res
        

    @staticmethod
    def load_model(model_name_or_path: str = None,
                   device: str = None):
        
        model = EmbeddingExtractor(model_name_or_path, device)
        return model
    

    @staticmethod
    def filter_answer(answer: Dict) -> Dict:
        """
        Функция фильтрует ответы.
        Args:
            answer: Словарь с ответом и дополнительной информацией.

        Returns:
            Словарь уникальных ответов.
        """
        list_ = []
        del_key = []
        for key in answer:
            if answer[key]["doc_name"] in list_:
                del_key.append(key)
            else:
                list_.append(answer[key]["doc_name"])
        for i in del_key:
            answer.pop(i)
        return answer


    async def search_transaction_map(self, 
                               query: str = None,
                               find_transaction_maps_by_question: bool = False,
                               k_neighbours: int = 15,
                               llm_api: LlmApi = None):
        
        if find_transaction_maps_by_question:
            query = await self.extract_business_transaction_with_llm(query, llm_api)
        cleaned_text = query.replace("\n", " ")
        # cleaned_text = 'query: ' + cleaned_text  # only for e5
        query_tokens = self.model.query_tokenization(cleaned_text)
        query_embeds = self.model.query_embed_extraction(query_tokens.to(self.device), DO_NORMALIZATION)[0]
        query_embeds = query_embeds[None, :]

        # Предсказывает расстояние и индекс. ИНДЕКС == номерам строк в df
        answer = self.database.search_transaction_map(query_embeds, k_neighbours)
        answer = self.filter_answer(answer)
        final_docs = {}
        answers = []
        for value in list(answer.values()):
            final_docs[value["doc_name"] + '.xlsx'] = value[COLUMN_TYPE_DOC_MAP].upper() if 'sap' in value[
                COLUMN_TYPE_DOC_MAP] else value[COLUMN_TYPE_DOC_MAP]
            answers.append(answer)
        return final_docs, answers