Aratako commited on
Commit
c56f80b
·
verified ·
1 Parent(s): 80452dd

add Like/Dislike collection

Browse files
Files changed (1) hide show
  1. app.py +55 -80
app.py CHANGED
@@ -1,16 +1,14 @@
1
  import gradio as gr
 
2
  from openai import OpenAI
3
  import os
4
- import json
5
- from datetime import datetime
6
- from zoneinfo import ZoneInfo
7
- import uuid
8
- from pathlib import Path
9
- from huggingface_hub import CommitScheduler
10
 
11
  openai_api_key = os.getenv('api_key')
12
  openai_api_base = os.getenv('url')
13
- model_name = "weblab-GENIAC/Tanuki-8x8B-dpo-v1.0"
 
 
14
  """
15
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
  """
@@ -20,59 +18,27 @@ client = OpenAI(
20
  base_url=openai_api_base,
21
  )
22
 
23
- # Define the file where to save the data. Use UUID to make sure not to overwrite existing data from a previous run.
24
- feedback_file = Path("user_feedback/") / f"data_{uuid.uuid4()}.json"
25
- feedback_folder = feedback_file.parent
26
-
27
- # Schedule regular uploads. Remote repo and local folder are created if they don't already exist.
28
- scheduler = CommitScheduler(
29
- repo_id="team-hatakeyama-phase2/8x8b-server-original-data", # Replace with your actual repo ID
30
- repo_type="dataset",
31
- folder_path=feedback_folder,
32
- path_in_repo="data",
33
- every=60, # Upload every 1 minutes
34
- )
35
 
36
- def save_or_update_conversation(conversation_id, message, response, message_index, liked=None):
37
- """
38
- Save or update conversation data in a JSON Lines file.
39
- If the entry already exists (same id and message_index), update the 'label' field.
40
- Otherwise, append a new entry.
41
- """
42
- with scheduler.lock:
43
- # Read existing data
44
- data = []
45
- if feedback_file.exists():
46
- with feedback_file.open("r") as f:
47
- data = [json.loads(line) for line in f if line.strip()]
48
-
49
- # Find if an entry with the same id and message_index exists
50
- entry_index = next((i for i, entry in enumerate(data) if entry['id'] == conversation_id and entry['message_index'] == message_index), None)
51
-
52
- if entry_index is not None:
53
- # Update existing entry
54
- data[entry_index]['label'] = liked
55
- else:
56
- # Append new entry
57
- data.append({
58
- "id": conversation_id,
59
- "timestamp": datetime.now(ZoneInfo("Asia/Tokyo")).isoformat(),
60
- "prompt": message,
61
- "completion": response,
62
- "message_index": message_index,
63
- "label": liked
64
- })
65
-
66
- # Write updated data back to file
67
- with feedback_file.open("w") as f:
68
- for entry in data:
69
- f.write(json.dumps(entry) + "\n")
70
 
71
 
72
  def respond(
73
  message,
74
  history: list[tuple[str, str]],
75
- conversation_id,
76
  system_message,
77
  max_tokens,
78
  temperature,
@@ -90,7 +56,8 @@ def respond(
90
  messages.append({"role": "user", "content": message})
91
 
92
  response = ""
93
- for chunk in client.chat.completions.create(
 
94
  model=model_name,
95
  messages=messages,
96
  max_tokens=max_tokens,
@@ -98,32 +65,23 @@ def respond(
98
  temperature=temperature,
99
  top_p=top_p,
100
  ):
101
- if chunk.choices[0].delta.content is not None:
102
- response += chunk.choices[0].delta.content
 
 
103
  yield response
104
 
105
- # Save conversation after the full response is generated
106
- message_index = len(history)
107
- save_or_update_conversation(conversation_id, message, response, message_index)
108
-
109
- def vote(data: gr.LikeData, history, conversation_id):
110
- """
111
- Update user feedback (like/dislike) in the local file.
112
- """
113
- message_index = data.index[0]
114
- liked = data.liked
115
- save_or_update_conversation(conversation_id, None, None, message_index, liked)
116
-
117
- def create_conversation_id():
118
- return str(uuid.uuid4())
119
-
120
  """
121
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
122
  """
123
 
124
  description = """
125
  ### [Tanuki-8x8B-dpo-v1.0](https://huggingface.co/weblab-GENIAC/Tanuki-8x8B-dpo-v1.0)との会話(期間限定での公開)
126
- - 人工知能開発の��め、原則として**このChatBotの入出力データは全て著作権フリー(CC0)で公開する**ため、ご注意ください。著作物、個人情報、機密情報、誹謗中傷などのデータを入力しないでください。
127
  - **上記の条件に同意する場合のみ**、以下のChatbotを利用してください。
128
  """
129
 
@@ -133,8 +91,25 @@ FOOTER = """### 注意
133
  - コンテクスト長が4096までなので、あまり会話が長くなると、エラーで停止します。ページを再読み込みしてください。
134
  - GPUサーバーが不安定なので、応答しないことがあるかもしれません。"""
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def run():
137
- conversation_id = gr.State(create_conversation_id)
138
  chatbot = gr.Chatbot(
139
  elem_id="chatbot",
140
  scale=1,
@@ -144,7 +119,7 @@ def run():
144
  )
145
  with gr.Blocks(fill_height=True) as demo:
146
  gr.Markdown(HEADER)
147
- chat_interface = gr.ChatInterface(
148
  fn=respond,
149
  stop_btn="Stop Generation",
150
  cache_examples=False,
@@ -154,7 +129,6 @@ def run():
154
  label="Parameters", open=False, render=False
155
  ),
156
  additional_inputs=[
157
- conversation_id,
158
  gr.Textbox(value="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。",
159
  label="System message(試験用: 変えると性能が低下する可能性があります。)",
160
  render=False,),
@@ -188,10 +162,11 @@ def run():
188
  ],
189
  analytics_enabled=False,
190
  )
191
- chatbot.like(vote, [chatbot, conversation_id], None)
192
  gr.Markdown(FOOTER)
193
- demo.queue(max_size=256, api_open=True)
194
- demo.launch(share=True, quiet=True)
 
195
 
196
  if __name__ == "__main__":
197
- run()
 
1
  import gradio as gr
2
+ # from huggingface_hub import InferenceClient
3
  from openai import OpenAI
4
  import os
5
+ import requests
 
 
 
 
 
6
 
7
  openai_api_key = os.getenv('api_key')
8
  openai_api_base = os.getenv('url')
9
+ db_url = os.getenv('db_url')
10
+ db_api_key = os.getenv('db_api_key')
11
+ model_name = "gpt-4o"
12
  """
13
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
14
  """
 
18
  base_url=openai_api_base,
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def save_conversation(history, system_message):
23
+ conversation_data = {
24
+ "conversation": history,
25
+ "index": (len(history) - 1, 1), # 最新の応答のインデックス
26
+ "liked": None, # 評価はnull(None)
27
+ "system_message": system_message,
28
+ }
29
+ headers = {
30
+ "X-API-Key": db_api_key
31
+ }
32
+ response = requests.post(db_url, json=conversation_data, headers=headers)
33
+ if response.status_code == 200:
34
+ print("Conversation saved successfully")
35
+ else:
36
+ print(f"Failed to save conversation: {response.status_code}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  def respond(
40
  message,
41
  history: list[tuple[str, str]],
 
42
  system_message,
43
  max_tokens,
44
  temperature,
 
56
  messages.append({"role": "user", "content": message})
57
 
58
  response = ""
59
+
60
+ for new_response in client.chat.completions.create(
61
  model=model_name,
62
  messages=messages,
63
  max_tokens=max_tokens,
 
65
  temperature=temperature,
66
  top_p=top_p,
67
  ):
68
+ token = new_response.choices[0].delta.content
69
+
70
+ if token is not None:
71
+ response += (token)
72
  yield response
73
 
74
+ new_history = history + [(message, response)]
75
+ save_conversation(new_history, system_message)
76
+
77
+
 
 
 
 
 
 
 
 
 
 
 
78
  """
79
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
80
  """
81
 
82
  description = """
83
  ### [Tanuki-8x8B-dpo-v1.0](https://huggingface.co/weblab-GENIAC/Tanuki-8x8B-dpo-v1.0)との会話(期間限定での公開)
84
+ - 人工知能開発のため、原則として**このChatBotの入出力データは全て著作権フリー(CC0)で公開予定です**ので、ご注意ください。著作物、個人情報、機密情報、誹謗中傷などのデータを入力しないでください。
85
  - **上記の条件に同意する場合のみ**、以下のChatbotを利用してください。
86
  """
87
 
 
91
  - コンテクスト長が4096までなので、あまり会話が長くなると、エラーで停止します。ページを再読み込みしてください。
92
  - GPUサーバーが不安定なので、応答しないことがあるかもしれません。"""
93
 
94
+
95
+ def vote(data: gr.LikeData, history):
96
+ vote_data = {
97
+ "conversation": history,
98
+ "index": data.index,
99
+ "liked": data.liked,
100
+ "system_message": None,
101
+ }
102
+ headers = {
103
+ "X-API-Key": db_api_key # APIキーを設定
104
+ }
105
+ response = requests.post(db_url, json=vote_data, headers=headers)
106
+ if response.status_code == 200:
107
+ print("Vote recorded successfully")
108
+ else:
109
+ print(f"Failed to record vote: {response.status_code}")
110
+
111
+
112
  def run():
 
113
  chatbot = gr.Chatbot(
114
  elem_id="chatbot",
115
  scale=1,
 
119
  )
120
  with gr.Blocks(fill_height=True) as demo:
121
  gr.Markdown(HEADER)
122
+ gr.ChatInterface(
123
  fn=respond,
124
  stop_btn="Stop Generation",
125
  cache_examples=False,
 
129
  label="Parameters", open=False, render=False
130
  ),
131
  additional_inputs=[
 
132
  gr.Textbox(value="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。",
133
  label="System message(試験用: 変えると性能が低下する可能性があります。)",
134
  render=False,),
 
162
  ],
163
  analytics_enabled=False,
164
  )
165
+ chatbot.like(vote, chatbot, None)
166
  gr.Markdown(FOOTER)
167
+ demo.queue(max_size=256, api_open=False)
168
+ demo.launch(share=False, quiet=True)
169
+
170
 
171
  if __name__ == "__main__":
172
+ run()