File size: 4,410 Bytes
8324134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os

import requests
from huggingface_hub import InferenceClient
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.llms import CTransformers
from langchain_core.vectorstores import VectorStoreRetriever


class LLMModel:
    base_model = "TheBloke/Llama-2-7B-GGUF"
    specific_model = "llama-2-7b.Q4_K_M.gguf"
    token_model = "meta-llama/Llama-2-7b-hf"
    llm_config = {'context_length': 2048, 'max_new_tokens': 1024, 'temperature': 0.3, 'top_p': 1.0}

    question_answer_system_prompt = """You are a helpful question answer assistant. Given the following context and a question, provide a set of potential questions and answers.
        Keep answers brief and well-structured. Do not give one word answers."""
    final_assistant_system_prompt = """You are a helpful assistant. Given the following list of relevant questions and answers, generate an answer based on this list only.
        Keep answers brief and well-structured. Do not give one word answers.
        If the answer is not found in the list, kindly state "I don't know.". Don't try to make up an answer."""
    template = """<s>[INST] <<SYS>>
        You are a question answer assistant. Given the following context and a question, generate an answer based on this context only.
        Keep answers brief and well-structured. Do not give one word answers.
        If the answer is not found in the context, kindly state "I don't know.". Don't try to make up an answer.
        <</SYS>>
        
        Context: {context}
        
        Question: Give me a step by step explanation of {question}[/INST]
        Answer:"""
    qa_chain_prompt = PromptTemplate.from_template(template)
    retriever = None

    hf_token = os.getenv('HF_TOKEN')
    api_url = os.getenv('API_URL')
    headers = {"Authorization": f"Bearer {hf_token}"}
    client = InferenceClient(api_url)

    # llm = CTransformers(model=base_model, model_file=specific_model, config=llm_config, hf=True)
    llm = None

    def __init__(self, retriever: VectorStoreRetriever):
        self.retriever = retriever

    def create_qa_chain(self):
        return RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.retriever,
            return_source_documents=True,
            chain_type_kwargs={"prompt": self.qa_chain_prompt},
        )

    def format_retrieved_docs(self, docs):
        all_docs = []
        for doc in docs:
            if "source" in doc.metadata:
                all_docs.append(f"""Document: {doc.metadata['source']}\nContent: {doc.page_content}\n\n""")
        return all_docs

    def format_query(self, question, context, system_prompt):
        prompt = f"""[INST] {system_prompt}
        
        Context: {context}
        
        Question: Give me a step by step explanation of {question}[/INST]"""
        return prompt

    def format_question(self, question):
        relevant_docs = self.retriever.get_relevant_documents(question)
        formatted_docs = self.format_retrieved_docs(relevant_docs)
        return self.format_query(question, formatted_docs, self.final_assistant_system_prompt)

    def get_potential_question_answer(self, document_chunk: str):
        prompt = self.format_query("potential questions and answers.", document_chunk, self.question_answer_system_prompt)
        return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4)

    def answer_question_inference_text_gen(self, question):
        prompt = self.format_question(question)
        return self.client.text_generation(prompt, max_new_tokens=512, temperature=0.4)

    def answer_question_inference(self, question):
        relevant_docs = self.retriever.get_relevant_documents(question)
        formatted_docs = "".join(self.format_retrieved_docs(relevant_docs))
        if not formatted_docs:
            return "No uploaded documents. Please try upload a document on the left side."
        else:
            print(formatted_docs)
            return self.client.question_answering(question=question, context=formatted_docs)

    def answer_question_api(self, question):
        formatted_prompt = self.format_question(question)
        resp = requests.post(self.api_url, headers=self.headers, json={"inputs": formatted_prompt}, stream=True)
        for c in resp.iter_content():
            yield c