'''
Inspiré de https://huggingface.co./spaces/Tonic/Lucie-7B
'''
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch # avec pip sur windows, ERROR: No matching distribution found for torch
# à utiliser dans un environnement GPU, Colab ou Space
import os
from src.amodel import AModel
AModel.load_env_variables()
MODEL_ID = "OpenLLM-France/Lucie-7B-Instruct"
TOKENIZER = AutoTokenizer.from_pretrained(
MODEL_ID,
token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
trust_remote_code=True
)
MODEL = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
with gr.Blocks(title="Lucie",
fill_height=True,
analytics_enabled=False,
css="footer {visibility: hidden}",
) as demo:
@spaces.GPU
def send(question):
system_prompt = '''Tu es Lucie, une assistante IA française serviable et amicale.
Tu réponds toujours en français de manière précise et utile.
Tu es honnête et si tu ne sais pas quelque chose, tu le dis simplement.
Tu termines ta réponse par .'''
full_prompt = f"""
{system_prompt}
{question}
"""
inputs = TOKENIZER(full_prompt, return_tensors="pt").to(DEVICE)
# inputs = TOKENIZER(question, return_tensors="pt").to(DEVICE)
# Tous les paramètres sont les paramètres par défaut de Tonic/Lucie-7B
outputs = MODEL.generate(
**inputs,
# max_new_tokens=max_new_tokens, # TODO: S'occuper des max_tokens avec tous les modèles
max_new_tokens=512,
temperature=0.1,
top_p=0.9,
top_k=50,
repetition_penalty=1.2,
do_sample=True,
pad_token_id=TOKENIZER.eos_token_id
)
# response = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
response = TOKENIZER.decode(outputs[0], skip_special_tokens=False)
r = response.split("")[1].strip()
r = r.split("")[0].strip()
# r = response
return r
with gr.Row():
gr.Image("./files/drane.png", show_download_button=False,
show_fullscreen_button=False, show_label=False, show_share_button=False,
interactive=False, container=False)
# https://www.svgrepo.com/svg/403600/girl
gr.Image("./files/lucie.png", show_download_button=False,
show_fullscreen_button=False, show_label=False, show_share_button=False,
interactive=False, container=False)
with gr.Row():
gr.Markdown("# Lucie d'OpenLLM")
gr.Markdown("## Discute avec Lucie")
# gr.HTML('''''')
with gr.Row():
question = gr.Textbox(
"",
placeholder="Pose ta question ici",
show_copy_button=False,
show_label=False,
container=False,
lines=2,
autofocus=True,
scale=10
)
send_btn = gr.Button("Ok", scale=1)
resp = gr.Textbox("", show_copy_button=False,
show_label=False,
container=False,
max_lines=15)
# full_resp = gr.Textbox("", show_copy_button=False,
# show_label=False,
# container=False,
# max_lines=15)
send_btn.click(send, inputs=[question], outputs=[resp])
# send_btn.click(send, inputs=[question], outputs=[resp, full_resp])
if __name__ == "__main__":
demo.queue().launch()