Kazuki Yoda commited on
Commit
f5b8cbf
·
1 Parent(s): 97f91f5

Implement the logic to get predefined answer

Browse files
Files changed (2) hide show
  1. app.py +12 -0
  2. predefined.py +100 -0
app.py CHANGED
@@ -1,7 +1,11 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
 
 
4
  """
 
 
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
@@ -15,6 +19,14 @@ def respond(
15
  temperature,
16
  top_p,
17
  ):
 
 
 
 
 
 
 
 
18
  messages = [{"role": "system", "content": system_message}]
19
 
20
  for val in history:
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
4
+ from predefined import get_predefined_answer_for_closest_predefined_question
5
+
6
  """
7
+ Copied and modified from HuggingFace Gradio default ChatInterface space
8
+
9
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
  """
11
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
19
  temperature,
20
  top_p,
21
  ):
22
+ ### Modified from here ###
23
+ predefined_answer = get_predefined_answer_for_closest_predefined_question(message)
24
+
25
+ if predefined_answer:
26
+ yield predefined_answer
27
+ return
28
+ ### Modified until here ###
29
+
30
  messages = [{"role": "system", "content": system_message}]
31
 
32
  for val in history:
predefined.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This entire file was solely written by the applicant, Kazuki Yoda."""
2
+
3
+ import json
4
+ from typing import Optional
5
+
6
+ # # For Debugging only
7
+ # from scipy.spatial import distance_matrix
8
+ # from sklearn.metrics.pairwise import cosine_similarity
9
+
10
+ from huggingface_hub import InferenceClient
11
+
12
+ zero_shot_classification_client = InferenceClient("facebook/bart-large-mnli")
13
+
14
+
15
+ def load_predefined_questions_to_answers_as_dict(path="predefined.json"
16
+ ) -> dict[str, str]:
17
+ """Load the predefined question-answer pairs as dict of.
18
+ key: question (str), value: answer (str)"""
19
+
20
+ with open(path) as file:
21
+ data = json.load(file)
22
+
23
+ if "questions" not in data:
24
+ raise ValueError("`questions` key is expected but missing.")
25
+
26
+ question_to_answer = dict()
27
+
28
+ for item in data.get("questions"):
29
+ question = item.get("question")
30
+ answer = item.get("answer")
31
+
32
+ # Skip if either "question" or "answer" key not found
33
+ if question and answer:
34
+ question_to_answer[question] = answer
35
+
36
+ return question_to_answer
37
+
38
+
39
+ def get_embeddings(texts: list[str]):
40
+ client = InferenceClient("efederici/sentence-bert-base")
41
+
42
+ return [client.feature_extraction(text) for text in texts]
43
+
44
+
45
+ def get_predefined_answer_for_closest_predefined_question(
46
+ question: str,
47
+ cutoff=0.5, # Minimum classification score to use the predefined answer
48
+ ) -> Optional[str]:
49
+
50
+ question_to_answer = load_predefined_questions_to_answers_as_dict()
51
+ labels = list(question_to_answer.keys())
52
+
53
+ zero_shot_classification_result = zero_shot_classification_client.zero_shot_classification(
54
+ text=question,
55
+ labels=labels,
56
+ multi_label=True,
57
+ )
58
+ max_score_result = max(zero_shot_classification_result,
59
+ key=lambda x: x.score)
60
+
61
+ if max_score_result.score > cutoff:
62
+ closest_predefined_question = max_score_result.label
63
+ return question_to_answer[closest_predefined_question]
64
+ else:
65
+ # Switch back to the normal LLM response
66
+ return None
67
+
68
+
69
+ if __name__ == "__main__":
70
+ """Run some print debugs. Not executed from the Gradio app."""
71
+
72
+ question_to_answer = load_predefined_questions_to_answers_as_dict()
73
+ print(question_to_answer)
74
+
75
+ additional_questions = [
76
+ "What does EVA do?",
77
+ "How does PHIL work?",
78
+ "Thoughtful AI",
79
+ ### Irrelevant but confusing questions ###
80
+ "Who is the CEO of Thoughtful AI?",
81
+ "How much does Thoughtful AI pay for its ML engineers?",
82
+ "What's Evangelion (EVA)?"
83
+ ]
84
+ predefined_questions = list(question_to_answer.keys())
85
+ questions = predefined_questions + additional_questions
86
+
87
+ embeddings = get_embeddings(questions)
88
+
89
+ for embedding in embeddings:
90
+ print(embedding.shape)
91
+
92
+ # For DEBUG, check the embeddings
93
+ # print(distance_matrix(embeddings, embeddings[:len(predefined_questions)]))
94
+ # print(cosine_similarity(embeddings, embeddings[:len(predefined_questions)]))
95
+
96
+ for question in questions:
97
+ closest_question = get_predefined_answer_for_closest_predefined_question(question)
98
+ print(f"question: {question}")
99
+ print(f"closest_question: {closest_question}")
100
+ print()