File size: 8,412 Bytes
e0b0a1c
334480b
06b19dd
e0b0a1c
02d955d
e0b0a1c
06b19dd
 
70413db
 
 
06b19dd
e0b0a1c
 
02d955d
 
 
06b19dd
 
 
 
 
 
 
 
708e152
06b19dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7cb7785
06b19dd
2058409
 
 
 
 
 
 
 
06b19dd
2058409
 
 
 
 
 
 
 
06b19dd
2058409
 
06b19dd
2058409
 
 
 
 
 
06b19dd
2058409
06b19dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aef76a
 
06b19dd
4aef76a
06b19dd
4aef76a
 
06b19dd
4aef76a
 
 
 
 
06b19dd
4aef76a
 
 
 
 
 
 
2058409
 
 
 
 
 
 
 
 
 
4aef76a
 
2058409
4aef76a
2058409
06b19dd
2058409
06b19dd
2058409
 
06b19dd
 
2058409
06b19dd
2058409
06b19dd
 
 
 
 
 
 
 
 
 
 
 
4aef76a
2058409
 
06b19dd
4aef76a
06b19dd
 
 
 
2058409
13ff80e
06b19dd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import re
import os
import logging
import gradio as gr

from typing import Set, List, Tuple
from huggingface_hub import InferenceClient
from langchain_openai import AzureChatOpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.chains import SimpleSequentialChain
from langchain.chains import LLMSummarizationCheckerChain


# huggingface_key = os.getenv('HUGGINGFACE_KEY')
# print(huggingface_key)
# login(huggingface_key) # Huggingface api token

# Configure logging
logging.basicConfig(filename='factchecking.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
class FactChecking:

    def __init__(self):
        
        self.llm = AzureChatOpenAI(
          azure_deployment = "GPT-3"
        )

        self.client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

    def format_prompt(self, question: str) -> str:
        """
        Formats the input question into a specific structure for text generation.

        Args:
            question (str): The user's question to be formatted.

        Returns:
            str: The formatted prompt including instructions and the question.
        """
        # Combine the instruction template with the user's question
        prompt = f"[INST] you are the ai assitant your task is answr for the user question[/INST]"
        prompt1 = f"[INST] {question} [/INST]"
        return prompt+prompt1

    def mixtral_response(self,prompt, temperature=0.9, max_new_tokens=5000, top_p=0.95, repetition_penalty=1.0):
        """
        Generates a response to the given prompt using text generation parameters.

        Args:
            prompt (str): The user's question.
            temperature (float): Controls randomness in response generation.
            max_new_tokens (int): The maximum number of tokens to generate.
            top_p (float): Nucleus sampling parameter controlling diversity.
            repetition_penalty (float): Penalty for repeating tokens.

        Returns:
            str: The generated response to the input prompt.
        """

        # Adjust temperature and top_p values within acceptable ranges
        temperature = float(temperature)
        if temperature < 1e-2:
            temperature = 1e-2
        top_p = float(top_p)

        generate_kwargs = dict(
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            seed=42,
        )
        # Simulating a call to a client's text generation API
        formatted_prompt =self.format_prompt(prompt)
        stream =self.client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
        output = ""

        for response in stream:
            output += response.token.text
       
        return output.replace("</s>","")



    def find_different_sentences(self,chain_answer,llm_answer):
        try:
          truth_values = [sentence.strip().split(' (')[1][:-1] for sentence in chain_answer.split('\n\n')]
        except:
          print("single new line presenting")
        try:
          # Extracting the truth values from chain_answer
          truth_values = [sentence.strip().split(' (')[1][:-1] for sentence in chain_answer.split('\n')]
        except:
          print("two new lines presenting")

        tags = []
        for tag in truth_values:
          if "True" in tag:
            tags.append("factual")
          else:
            tags.append("hallucinated")
        # Splitting llm_answer into sentences
        llm_sentences = llm_answer.split('. ')

        # Initializing an empty list to store tagged sentences
        tagged_sentences = []

        # Mapping the truth values to sentences in llm_answer
        for sentence, truth_value in zip(llm_sentences, tags):
            # Extracting the sentence without the truth value
            sentence_text = sentence.split(' (')[0]
            # Appending the sentence with its truth value
            tagged_sentences.append(((sentence_text+"."),(truth_value)))

        return tagged_sentences



    def find_hallucinatted_sentence(self, question: str) -> Tuple[str, List[str]]:
        """
        Finds hallucinated sentences in response to a given question.

        Args:
            question (str): The input question.

        Returns:
            Tuple[str, List[str]]: A tuple containing the original llama_result and a list of hallucinated sentences.
        """
        try:

          # Generate initial response using contract generator
          mixtral_response = self.mixtral_response(question)
            
          template = """Given some text, extract a list of facts from the text.

          Format your output as a bulleted list.

          Text:
          {question}

          Facts:"""
          prompt_template = PromptTemplate(input_variables=["question"], template=template)
          question_chain = LLMChain(llm=self.llm, prompt=prompt_template)

          template = """You are an expert fact checker. You have been hired by a major news organization to fact check a very important story.

          Here is a bullet point list of facts:
          {statement}

          For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output "Undetermined".
          If the fact is false, explain why."""
          prompt_template = PromptTemplate(input_variables=["statement"], template=template)
          assumptions_chain = LLMChain(llm=self.llm, prompt=prompt_template)
          extra_template = f" Original Summary:{mixtral_response} Using these checked assertions to write the original summary with true or false in sentence wised.          For each fact, determine whether it is true or false about the subject. If you are unable to determine whether the fact is true or false, output 'Undetermined'.***format: sentence (True or False) in braces.***"
          template = """Below are some assertions that have been fact checked and are labeled as true of false. If the answer is false, a suggestion is given for a correction.

          Checked Assertions:
          {assertions}
          """
          template += extra_template
          prompt_template = PromptTemplate(input_variables=["assertions"], template=template)
          answer_chain = LLMChain(llm=self.llm, prompt=prompt_template)
          overall_chain = SimpleSequentialChain(chains=[question_chain,assumptions_chain,answer_chain], verbose=True)

          answer = overall_chain.run(mixtral_response)

          # Find different sentences between original result and fact checking result
          prediction_list = self.find_different_sentences(answer,mixtral_response)

          # prediction_list += generated_words
          # Return the original result and list of hallucinated sentences

          return mixtral_response,prediction_list,answer

        except Exception as e:
            print(f"Error occurred in find_hallucinatted_sentence: {e}")
            return "", []
            
    def interface(self):
      css=""".gradio-container {background: rgb(157,228,255);
        background: radial-gradient(circle, rgba(157,228,255,1) 0%, rgba(18,115,106,1) 100%);}"""

      with gr.Blocks(css=css) as demo:
        gr.HTML("""
            <center><h1 style="color:#fff">Detect Hallucination</h1></center>""")
        with gr.Row():
          question = gr.Textbox(label="Question")
        with gr.Row():
          button = gr.Button(value="Submit")
        with gr.Row():
          mixtral_response = gr.Textbox(label="llm answer")
        with gr.Row():
          fact_checking_result = gr.Textbox(label="hallucinated detection result")
        with gr.Row():
          highlighted_prediction = gr.HighlightedText(
                                  label="Sentence Hallucination detection",
                                  combine_adjacent=True,
                                  color_map={"hallucinated": "red", "factual": "green"},
                                  show_legend=True)
        button.click(self.find_hallucinatted_sentence,question,[mixtral_response,highlighted_prediction,fact_checking_result])
      demo.launch()


hallucination_detection = FactChecking()
hallucination_detection.interface()