tyoyo
commited on
Commit
•
782e1e8
1
Parent(s):
7a773a3
fix: 入力が空文字, ハイパラがintのときのバグ修正
Browse files
app.py
CHANGED
@@ -100,7 +100,17 @@ def generate(
|
|
100 |
raise ValueError
|
101 |
|
102 |
history = history_with_input[:-1]
|
103 |
-
generator = run(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
try:
|
105 |
first_response = next(generator)
|
106 |
yield history + [(message, first_response)]
|
@@ -130,7 +140,12 @@ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
|
|
130 |
def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
|
131 |
input_token_length = get_input_token_length(message, chat_history, system_prompt)
|
132 |
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
133 |
-
raise gr.Error(
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
|
136 |
def convert_history_to_str(history: list[tuple[str, str]]) -> str:
|
@@ -360,6 +375,11 @@ graphvizで、AからB、BからC、CからAに有向エッジが生えている
|
|
360 |
api_name=False,
|
361 |
queue=False,
|
362 |
).then(
|
|
|
|
|
|
|
|
|
|
|
363 |
fn=display_input,
|
364 |
inputs=[saved_input, chatbot],
|
365 |
outputs=chatbot,
|
@@ -373,11 +393,6 @@ graphvizで、AからB、BからC、CからAに有向エッジが生えている
|
|
373 |
fn=output_log,
|
374 |
inputs=[chatbot, uuid_list],
|
375 |
).then(
|
376 |
-
fn=check_input_token_length,
|
377 |
-
inputs=[saved_input, chatbot, system_prompt],
|
378 |
-
api_name=False,
|
379 |
-
queue=False,
|
380 |
-
).success(
|
381 |
fn=generate,
|
382 |
inputs=[
|
383 |
saved_input,
|
@@ -412,6 +427,11 @@ graphvizで、AからB、BからC、CからAに有向エッジが生えている
|
|
412 |
api_name=False,
|
413 |
queue=False,
|
414 |
).then(
|
|
|
|
|
|
|
|
|
|
|
415 |
fn=display_input,
|
416 |
inputs=[saved_input, chatbot],
|
417 |
outputs=chatbot,
|
@@ -424,11 +444,6 @@ graphvizで、AからB、BからC、CからAに有向エッジが生えている
|
|
424 |
).then(
|
425 |
fn=output_log,
|
426 |
inputs=[chatbot, uuid_list],
|
427 |
-
).then(
|
428 |
-
fn=check_input_token_length,
|
429 |
-
inputs=[saved_input, chatbot, system_prompt],
|
430 |
-
api_name=False,
|
431 |
-
queue=False,
|
432 |
).success(
|
433 |
fn=generate,
|
434 |
inputs=[
|
@@ -464,6 +479,11 @@ graphvizで、AからB、BからC、CからAに有向エッジが生えている
|
|
464 |
api_name=False,
|
465 |
queue=False,
|
466 |
).then(
|
|
|
|
|
|
|
|
|
|
|
467 |
fn=display_input,
|
468 |
inputs=[saved_input, chatbot],
|
469 |
outputs=chatbot,
|
|
|
100 |
raise ValueError
|
101 |
|
102 |
history = history_with_input[:-1]
|
103 |
+
generator = run(
|
104 |
+
message,
|
105 |
+
history,
|
106 |
+
system_prompt,
|
107 |
+
max_new_tokens,
|
108 |
+
float(temperature),
|
109 |
+
float(top_p),
|
110 |
+
top_k,
|
111 |
+
do_sample,
|
112 |
+
float(repetition_penalty),
|
113 |
+
)
|
114 |
try:
|
115 |
first_response = next(generator)
|
116 |
yield history + [(message, first_response)]
|
|
|
140 |
def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
|
141 |
input_token_length = get_input_token_length(message, chat_history, system_prompt)
|
142 |
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
|
143 |
+
raise gr.Error(
|
144 |
+
f"合計対話長が長すぎます ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH})。入力文章を短くするか、「🗑️ これまでの出力を消す」ボタンを押してから再実行してください。"
|
145 |
+
)
|
146 |
+
|
147 |
+
if len(message) <= 0:
|
148 |
+
raise gr.Error("入力が空です。1文字以上の文字列を入力してください。")
|
149 |
|
150 |
|
151 |
def convert_history_to_str(history: list[tuple[str, str]]) -> str:
|
|
|
375 |
api_name=False,
|
376 |
queue=False,
|
377 |
).then(
|
378 |
+
fn=check_input_token_length,
|
379 |
+
inputs=[saved_input, chatbot, system_prompt],
|
380 |
+
api_name=False,
|
381 |
+
queue=False,
|
382 |
+
).success(
|
383 |
fn=display_input,
|
384 |
inputs=[saved_input, chatbot],
|
385 |
outputs=chatbot,
|
|
|
393 |
fn=output_log,
|
394 |
inputs=[chatbot, uuid_list],
|
395 |
).then(
|
|
|
|
|
|
|
|
|
|
|
396 |
fn=generate,
|
397 |
inputs=[
|
398 |
saved_input,
|
|
|
427 |
api_name=False,
|
428 |
queue=False,
|
429 |
).then(
|
430 |
+
fn=check_input_token_length,
|
431 |
+
inputs=[saved_input, chatbot, system_prompt],
|
432 |
+
api_name=False,
|
433 |
+
queue=False,
|
434 |
+
).success(
|
435 |
fn=display_input,
|
436 |
inputs=[saved_input, chatbot],
|
437 |
outputs=chatbot,
|
|
|
444 |
).then(
|
445 |
fn=output_log,
|
446 |
inputs=[chatbot, uuid_list],
|
|
|
|
|
|
|
|
|
|
|
447 |
).success(
|
448 |
fn=generate,
|
449 |
inputs=[
|
|
|
479 |
api_name=False,
|
480 |
queue=False,
|
481 |
).then(
|
482 |
+
fn=check_input_token_length,
|
483 |
+
inputs=[saved_input, chatbot, system_prompt],
|
484 |
+
api_name=False,
|
485 |
+
queue=False,
|
486 |
+
).success(
|
487 |
fn=display_input,
|
488 |
inputs=[saved_input, chatbot],
|
489 |
outputs=chatbot,
|