Spaces:
Running
on
Zero
Running
on
Zero
tori29umai
commited on
Commit
•
20e3524
1
Parent(s):
160851b
Upload 5 files
Browse files- app.py +240 -0
- custom.html +12 -0
- requirements.txt +2 -0
- test_prompt.jinja2 +22 -0
- utils/dl_utils.py +19 -0
app.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from jinja2 import Template
|
3 |
+
from llama_cpp import Llama
|
4 |
+
import os
|
5 |
+
import configparser
|
6 |
+
from utils.dl_utils import dl_guff_model
|
7 |
+
|
8 |
+
# モデルディレクトリが存在しない場合は作成
|
9 |
+
if not os.path.exists("models"):
|
10 |
+
os.makedirs("models")
|
11 |
+
|
12 |
+
# 使用するモデルのファイル名を指定
|
13 |
+
model_filename = "Llama-3.1-70B-EZO-1.1-it-Q4_K_M.gguf"
|
14 |
+
model_path = os.path.join("models", model_filename)
|
15 |
+
|
16 |
+
# モデルファイルが存在しない場合はダウンロード
|
17 |
+
if not os.path.exists(model_path):
|
18 |
+
dl_guff_model("models", f"https://huggingface.co/mmnga/Llama-3.1-70B-EZO-1.1-it-gguf/resolve/main/{model_filename}")
|
19 |
+
|
20 |
+
# 設定をINIファイルに保存する関数
|
21 |
+
def save_settings_to_ini(settings, filename='character_settings.ini'):
|
22 |
+
config = configparser.ConfigParser()
|
23 |
+
config['Settings'] = {
|
24 |
+
'name': settings['name'],
|
25 |
+
'gender': settings['gender'],
|
26 |
+
'situation': '\n'.join(settings['situation']),
|
27 |
+
'orders': '\n'.join(settings['orders']),
|
28 |
+
'dirty_talk_list': '\n'.join(settings['dirty_talk_list']),
|
29 |
+
'example_quotes': '\n'.join(settings['example_quotes'])
|
30 |
+
}
|
31 |
+
with open(filename, 'w', encoding='utf-8') as configfile:
|
32 |
+
config.write(configfile)
|
33 |
+
|
34 |
+
# INIファイルから設定を読み込む関数
|
35 |
+
def load_settings_from_ini(filename='character_settings.ini'):
|
36 |
+
if not os.path.exists(filename):
|
37 |
+
return None
|
38 |
+
|
39 |
+
config = configparser.ConfigParser()
|
40 |
+
config.read(filename, encoding='utf-8')
|
41 |
+
|
42 |
+
if 'Settings' not in config:
|
43 |
+
return None
|
44 |
+
|
45 |
+
try:
|
46 |
+
settings = {
|
47 |
+
'name': config['Settings']['name'],
|
48 |
+
'gender': config['Settings']['gender'],
|
49 |
+
'situation': config['Settings']['situation'].split('\n'),
|
50 |
+
'orders': config['Settings']['orders'].split('\n'),
|
51 |
+
'dirty_talk_list': config['Settings']['dirty_talk_list'].split('\n'),
|
52 |
+
'example_quotes': config['Settings']['example_quotes'].split('\n')
|
53 |
+
}
|
54 |
+
return settings
|
55 |
+
except KeyError:
|
56 |
+
return None
|
57 |
+
|
58 |
+
# LlamaCppのラッパークラス
|
59 |
+
class LlamaCppAdapter:
|
60 |
+
def __init__(self, model_path, n_ctx=4096):
|
61 |
+
print(f"モデルの初期化: {model_path}")
|
62 |
+
self.llama = Llama(model_path=model_path, n_ctx=n_ctx, n_gpu_layers=-1)
|
63 |
+
|
64 |
+
def generate(self, prompt, max_new_tokens=4096, temperature=0.5, top_p=0.7, top_k=80, stop=["<END>"]):
|
65 |
+
return self._generate(prompt, max_new_tokens, temperature, top_p, top_k, stop)
|
66 |
+
|
67 |
+
def _generate(self, prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, stop: list):
|
68 |
+
return self.llama(
|
69 |
+
prompt,
|
70 |
+
temperature=temperature,
|
71 |
+
max_tokens=max_new_tokens,
|
72 |
+
top_p=top_p,
|
73 |
+
top_k=top_k,
|
74 |
+
stop=stop,
|
75 |
+
repeat_penalty=1.2,
|
76 |
+
)
|
77 |
+
|
78 |
+
# キャラクターメーカークラス
|
79 |
+
class CharacterMaker:
|
80 |
+
def __init__(self):
|
81 |
+
self.llama = LlamaCppAdapter(model_path)
|
82 |
+
self.history = []
|
83 |
+
self.settings = load_settings_from_ini()
|
84 |
+
if not self.settings:
|
85 |
+
self.settings = {
|
86 |
+
"name": "ナツ",
|
87 |
+
"gender": "女性",
|
88 |
+
"situation": [
|
89 |
+
"あなたは人工知能アシスタントです。",
|
90 |
+
"ユーザーの日常生活をサポートし、より良い生活を送るお手伝いをします。",
|
91 |
+
"AIアシスタント『ナツ』として、ユーザーの健康と幸福をケアし、様々な質問に答えたり課題解決を手伝ったりします。"
|
92 |
+
],
|
93 |
+
"orders": [
|
94 |
+
"丁寧な言葉遣いを心がけてください。",
|
95 |
+
"ユーザーとの対話を通じてサポートを提供します。",
|
96 |
+
"ユーザーのことは『ユーザー様』と呼んでください。"
|
97 |
+
],
|
98 |
+
"conversation_topics": [
|
99 |
+
"健康管理",
|
100 |
+
"目標設定",
|
101 |
+
"時間管理"
|
102 |
+
],
|
103 |
+
"example_quotes": [
|
104 |
+
"ユーザー様の健康と幸福が何より大切です。どのようなサポートが必要でしょうか?",
|
105 |
+
"私はユーザー様の生活をより良いものにするためのアシスタントです。お手伝いできることがありましたらお申し付けください。",
|
106 |
+
"目標達成に向けて一緒に頑張りましょう。具体的な計画を立てるお手伝いをさせていただきます。",
|
107 |
+
"効率的な時間管理のコツをお教えします。まずは1日のスケジュールを確認してみましょう。",
|
108 |
+
"ストレス解消法についてアドバイスいたします。リラックスするための簡単な呼吸法から始めてみませんか?"
|
109 |
+
]
|
110 |
+
}
|
111 |
+
save_settings_to_ini(self.settings)
|
112 |
+
|
113 |
+
def make(self, input_str: str):
|
114 |
+
prompt = self._generate_aki(input_str)
|
115 |
+
print(prompt)
|
116 |
+
print("-----------------")
|
117 |
+
res = self.llama.generate(prompt, max_new_tokens=1000, stop=["<END>", "\n"])
|
118 |
+
res_text = res["choices"][0]["text"]
|
119 |
+
self.history.append({"user": input_str, "assistant": res_text})
|
120 |
+
return res_text
|
121 |
+
|
122 |
+
def make_prompt(self, name: str, gender: str, situation: list, orders: list, dirty_talk_list: list, example_quotes: list, input_str: str):
|
123 |
+
with open('test_prompt.jinja2', 'r', encoding='utf-8') as f:
|
124 |
+
prompt = f.readlines()
|
125 |
+
fix_example_quotes = [quote+"<END>" for quote in example_quotes]
|
126 |
+
prompt = "".join(prompt)
|
127 |
+
prompt = Template(prompt).render(name=name, gender=gender, situation=situation, orders=orders, dirty_talk_list=dirty_talk_list, example_quotes=fix_example_quotes, histories=self.history, input_str=input_str)
|
128 |
+
return prompt
|
129 |
+
|
130 |
+
def _generate_aki(self, input_str: str):
|
131 |
+
prompt = self.make_prompt(
|
132 |
+
self.settings["name"],
|
133 |
+
self.settings["gender"],
|
134 |
+
self.settings["situation"],
|
135 |
+
self.settings["orders"],
|
136 |
+
self.settings["dirty_talk_list"],
|
137 |
+
self.settings["example_quotes"],
|
138 |
+
input_str
|
139 |
+
)
|
140 |
+
print(prompt)
|
141 |
+
return prompt
|
142 |
+
|
143 |
+
def update_settings(self, new_settings):
|
144 |
+
self.settings.update(new_settings)
|
145 |
+
save_settings_to_ini(self.settings)
|
146 |
+
|
147 |
+
def reset(self):
|
148 |
+
self.history = []
|
149 |
+
self.llama = LlamaCppAdapter(model_path)
|
150 |
+
|
151 |
+
character_maker = CharacterMaker()
|
152 |
+
|
153 |
+
# 設定を更新する関数
|
154 |
+
def update_settings(name, gender, situation, orders, dirty_talk_list, example_quotes):
|
155 |
+
new_settings = {
|
156 |
+
"name": name,
|
157 |
+
"gender": gender,
|
158 |
+
"situation": [s.strip() for s in situation.split('\n') if s.strip()],
|
159 |
+
"orders": [o.strip() for o in orders.split('\n') if o.strip()],
|
160 |
+
"dirty_talk_list": [d.strip() for d in dirty_talk_list.split('\n') if d.strip()],
|
161 |
+
"example_quotes": [e.strip() for e in example_quotes.split('\n') if e.strip()]
|
162 |
+
}
|
163 |
+
character_maker.update_settings(new_settings)
|
164 |
+
return "設定が更新されました。"
|
165 |
+
|
166 |
+
# チャット機能の関数
|
167 |
+
def chat_with_character(message, history):
|
168 |
+
character_maker.history = [{"user": h[0], "assistant": h[1]} for h in history]
|
169 |
+
response = character_maker.make(message)
|
170 |
+
return response
|
171 |
+
|
172 |
+
# チャットをクリアする関数
|
173 |
+
def clear_chat():
|
174 |
+
character_maker.reset()
|
175 |
+
return []
|
176 |
+
|
177 |
+
# カスタムCSS
|
178 |
+
custom_css = """
|
179 |
+
#chatbot {
|
180 |
+
height: 60vh !important;
|
181 |
+
overflow-y: auto;
|
182 |
+
}
|
183 |
+
"""
|
184 |
+
|
185 |
+
# カスタムJavaScript(HTML内に埋め込む)
|
186 |
+
custom_js = """
|
187 |
+
<script>
|
188 |
+
function adjustChatbotHeight() {
|
189 |
+
var chatbot = document.querySelector('#chatbot');
|
190 |
+
if (chatbot) {
|
191 |
+
chatbot.style.height = window.innerHeight * 0.6 + 'px';
|
192 |
+
}
|
193 |
+
}
|
194 |
+
|
195 |
+
// ページ読み込み時と画面サイズ変更時にチャットボットの高さを調整
|
196 |
+
window.addEventListener('load', adjustChatbotHeight);
|
197 |
+
window.addEventListener('resize', adjustChatbotHeight);
|
198 |
+
</script>
|
199 |
+
"""
|
200 |
+
|
201 |
+
# Gradioインターフェースの設定
|
202 |
+
with gr.Blocks(css=custom_css) as iface:
|
203 |
+
chatbot = gr.Chatbot(elem_id="chatbot")
|
204 |
+
|
205 |
+
with gr.Tab("チャット"):
|
206 |
+
gr.ChatInterface(
|
207 |
+
chat_with_character,
|
208 |
+
chatbot=chatbot,
|
209 |
+
textbox=gr.Textbox(placeholder="メッセージを入力してください...", container=False, scale=7),
|
210 |
+
theme="soft",
|
211 |
+
retry_btn="もう一度生成",
|
212 |
+
undo_btn="前のメッセージを取り消す",
|
213 |
+
clear_btn="チャットをクリア",
|
214 |
+
)
|
215 |
+
|
216 |
+
with gr.Tab("設定"):
|
217 |
+
gr.Markdown("## キャラクター設定")
|
218 |
+
name_input = gr.Textbox(label="名前", value=character_maker.settings["name"])
|
219 |
+
gender_input = gr.Textbox(label="性別", value=character_maker.settings["gender"])
|
220 |
+
situation_input = gr.Textbox(label="状況設定", value="\n".join(character_maker.settings["situation"]), lines=5)
|
221 |
+
orders_input = gr.Textbox(label="指示", value="\n".join(character_maker.settings["orders"]), lines=5)
|
222 |
+
dirty_talk_input = gr.Textbox(label="淫語リスト", value="\n".join(character_maker.settings["dirty_talk_list"]), lines=5)
|
223 |
+
example_quotes_input = gr.Textbox(label="例文", value="\n".join(character_maker.settings["example_quotes"]), lines=5)
|
224 |
+
|
225 |
+
update_button = gr.Button("設定を更新")
|
226 |
+
update_output = gr.Textbox(label="更新状態")
|
227 |
+
|
228 |
+
update_button.click(
|
229 |
+
update_settings,
|
230 |
+
inputs=[name_input, gender_input, situation_input, orders_input, dirty_talk_input, example_quotes_input],
|
231 |
+
outputs=[update_output]
|
232 |
+
)
|
233 |
+
|
234 |
+
# Gradioアプリの起動
|
235 |
+
if __name__ == "__main__":
|
236 |
+
iface.launch(
|
237 |
+
share=True,
|
238 |
+
allowed_paths=["models"],
|
239 |
+
favicon_path="custom.html"
|
240 |
+
)
|
custom.html
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script>
|
2 |
+
function adjustChatbotHeight() {
|
3 |
+
var chatbot = document.querySelector('#chatbot');
|
4 |
+
if (chatbot) {
|
5 |
+
chatbot.style.height = window.innerHeight * 0.6 + 'px';
|
6 |
+
}
|
7 |
+
}
|
8 |
+
|
9 |
+
// ページ読み込み時と画面サイズ変更時にチャットボットの高さを調整
|
10 |
+
window.addEventListener('load', adjustChatbotHeight);
|
11 |
+
window.addEventListener('resize', adjustChatbotHeight);
|
12 |
+
</script>
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
https://github.com/abetlen/llama-cpp-python/releases/download/v0.2.81-cu124/llama_cpp_python-0.2.81-cp310-cp310-linux_x86_64.whl
|
test_prompt.jinja2
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
・キャラクター設定
|
2 |
+
名前:{{name}}
|
3 |
+
性別:{{gender}}
|
4 |
+
|
5 |
+
{%for situation in situation %}
|
6 |
+
{{situation}}{%endfor%}
|
7 |
+
|
8 |
+
・今回のユーザーのオーダー
|
9 |
+
{%for order in orders %}
|
10 |
+
{{order}}{%endfor%}
|
11 |
+
|
12 |
+
・使ってほしい淫語表現
|
13 |
+
{%for dirty_talk in dirty_talk_list %}
|
14 |
+
{{dirty_talk}}{%endfor%}
|
15 |
+
・キャラクターの発言例
|
16 |
+
{%for example_quote in example_quotes %}
|
17 |
+
{{example_quote}}{%endfor%}
|
18 |
+
|
19 |
+
{%for history in histories %}user: {{history.user}}
|
20 |
+
{{name}}: {{history.assistant}}{%endfor%}
|
21 |
+
user: {{input_str}}
|
22 |
+
{{name}}:
|
utils/dl_utils.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
def dl_guff_model(model_dir, url):
|
7 |
+
file_name = url.split('/')[-1]
|
8 |
+
folder = model_dir
|
9 |
+
file_path = os.path.join(folder, file_name)
|
10 |
+
if not os.path.exists(file_path):
|
11 |
+
response = requests.get(url, allow_redirects=True)
|
12 |
+
if response.status_code == 200:
|
13 |
+
with open(file_path, 'wb') as f:
|
14 |
+
f.write(response.content)
|
15 |
+
print(f'Downloaded {file_name}')
|
16 |
+
else:
|
17 |
+
print(f'Failed to download {file_name}')
|
18 |
+
else:
|
19 |
+
print(f'{file_name} already exists.')
|