Spaces:
Runtime error
Runtime error
File size: 3,602 Bytes
5c0e14a 21958b2 5c0e14a 88c296d ac9642b 88c296d 5c0e14a c0be996 5c0e14a c0be996 ac9642b c0be996 3fcfbe3 c0be996 5c0e14a c0be996 5c0e14a c0be996 5c0e14a c0be996 5c0e14a a0fcf12 5c0e14a 17fba42 c0be996 5c8d81a c0be996 5c0e14a c0be996 5c0e14a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from huggingface_hub import InferenceClient
import gradio as gr
import os
APIKEY = os.environ["MT_TOKEN"]
EP = os.environ["MT_EP"]
import requests
def translate(source_lang,target_lang,text):
params = {
'auth_key' : APIKEY,
'text' : text,
'source_lang' : source_lang,
"target_lang": target_lang
}
request = requests.post(EP, data=params)
result = request.json()
return result["translations"][0]["text"]
def translate_en_to_ja(text):
if text is None:
return None
return translate("EN","JA",text)
def translate_ja_to_en(text):
if text is None:
return None
return translate("JA","EN",text)
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
message_en = None
real_history = []
def format_prompt(message, history):
global message_en
print(message)
message_en = translate_ja_to_en(message)
print(message_en)
prompt = "<s>"
for user_prompt, bot_response in real_history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message_en} [/INST]"
return prompt
def generate(
prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
output = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=True, return_full_text=False)
last_eng_respose = output.generated_text
real_history.append([message_en, last_eng_respose])
return translate_en_to_ja(output.generated_text)
additional_inputs=[
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=1048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>Mistral 7B + Japanese MT <h1><center>")
gr.HTML("<h3><center><a href='https://huggingface.co./mistralai/Mistral-7B-Instruct-v0.1'>Mistral-7B-Instruct</a> の入出力に機械翻訳をかけたものです。💬<h3><center>")
gr.HTML("<h3><center>モデルの詳細については<a href='https://huggingface.co./docs/transformers/main/model_doc/mistral'>ここから</a>. 📚<h3><center>")
gr.ChatInterface(
generate,
additional_inputs=additional_inputs,
examples=[["人生の秘密は何ですか?"], ["パンケーキのレシピを書いて。"]]
)
demo.queue().launch(debug=True) |