|
''' |
|
Inspiré de https://huggingface.co./spaces/Tonic/Lucie-7B |
|
''' |
|
|
|
import gradio as gr |
|
import spaces |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
import os |
|
|
|
from src.amodel import AModel |
|
AModel.load_env_variables() |
|
|
|
MODEL_ID = "OpenLLM-France/Lucie-7B-Instruct" |
|
|
|
TOKENIZER = AutoTokenizer.from_pretrained( |
|
MODEL_ID, |
|
token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), |
|
trust_remote_code=True |
|
) |
|
MODEL = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True |
|
) |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
with gr.Blocks(title="Lucie", |
|
fill_height=True, |
|
analytics_enabled=False, |
|
css="footer {visibility: hidden}", |
|
) as demo: |
|
|
|
@spaces.GPU |
|
def send(question): |
|
system_prompt = '''Tu es Lucie, une assistante IA française serviable et amicale. |
|
Tu réponds toujours en français de manière précise et utile. |
|
Tu es honnête et si tu ne sais pas quelque chose, tu le dis simplement. |
|
Tu termines ta réponse par </assistant>.''' |
|
full_prompt = f"""<system> |
|
{system_prompt} |
|
</system> |
|
<user> |
|
{question} |
|
</user> |
|
<assistant>""" |
|
inputs = TOKENIZER(full_prompt, return_tensors="pt").to(DEVICE) |
|
|
|
|
|
outputs = MODEL.generate( |
|
**inputs, |
|
|
|
max_new_tokens=512, |
|
temperature=0.1, |
|
top_p=0.9, |
|
top_k=50, |
|
repetition_penalty=1.2, |
|
do_sample=True, |
|
pad_token_id=TOKENIZER.eos_token_id |
|
) |
|
|
|
response = TOKENIZER.decode(outputs[0], skip_special_tokens=False) |
|
r = response.split("<assistant>")[1].strip() |
|
r = r.split("</")[0].strip() |
|
|
|
return r |
|
|
|
with gr.Row(): |
|
gr.Image("./files/drane.png", show_download_button=False, |
|
show_fullscreen_button=False, show_label=False, show_share_button=False, |
|
interactive=False, container=False) |
|
|
|
gr.Image("./files/lucie.png", show_download_button=False, |
|
show_fullscreen_button=False, show_label=False, show_share_button=False, |
|
interactive=False, container=False) |
|
|
|
with gr.Row(): |
|
gr.Markdown("# Lucie d'OpenLLM") |
|
gr.Markdown("## Discute avec Lucie") |
|
|
|
|
|
with gr.Row(): |
|
question = gr.Textbox( |
|
"", |
|
placeholder="Pose ta question ici", |
|
show_copy_button=False, |
|
show_label=False, |
|
container=False, |
|
lines=2, |
|
autofocus=True, |
|
scale=10 |
|
) |
|
send_btn = gr.Button("Ok", scale=1) |
|
|
|
resp = gr.Textbox("", show_copy_button=False, |
|
show_label=False, |
|
container=False, |
|
max_lines=15) |
|
|
|
|
|
|
|
|
|
|
|
|
|
send_btn.click(send, inputs=[question], outputs=[resp]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch() |