Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import pandas as pd | |
from fpdf import FPDF | |
# Interface utilisateur | |
st.set_page_config( | |
page_title="Traduction d'une phrase en pictogrammes ARASAAC", | |
page_icon="📝", | |
layout="wide" | |
) | |
# Charger le modèle et le tokenizer | |
# checkpoint = "Propicto/t2p-t5-large-orfeo" | |
checkpoint = "Propicto/t2p-nllb-200-distilled-600M-all" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) | |
# Lire le lexique | |
def read_lexicon(lexicon): | |
df = pd.read_csv(lexicon, sep='\t') | |
df['keyword_no_cat'] = df['lemma'].str.split(' #').str[0].str.strip().str.replace(' ', '_') | |
return df | |
lexicon = read_lexicon("lexicon.csv") | |
# Processus de sortie de la traduction | |
def process_output_trad(pred): | |
return pred.split() | |
def get_id_picto_from_predicted_lemma(df_lexicon, lemma): | |
if lemma.endswith("!"): | |
lemma = lemma[:-1] | |
id_picto = df_lexicon.loc[df_lexicon['keyword_no_cat'] == lemma, 'id_picto'].tolist() | |
return (id_picto[0], lemma) if id_picto else (0, lemma) | |
# Génération du contenu HTML pour afficher les pictogrammes | |
def generate_html(ids): | |
html_content = '<html><head><style>' | |
html_content += ''' | |
figure { | |
display: inline-block; | |
text-align: center; | |
font-family: Arial, sans-serif; | |
margin: 0; | |
} | |
figcaption { | |
color: black; | |
background-color: white; | |
border-radius: 5px; | |
} | |
img { | |
background-color: white; | |
margin: 0; | |
padding: 0; | |
border-radius: 6px; | |
} | |
''' | |
html_content += '</style></head><body>' | |
for picto_id, lemma in ids: | |
if picto_id != 0: # ignore invalid IDs | |
img_url = f"https://static.arasaac.org/pictograms/{picto_id}/{picto_id}_500.png" | |
html_content += f''' | |
<figure> | |
<img src="{img_url}" alt="{lemma}" width="100" height="100"/> | |
<figcaption>{lemma}</figcaption> | |
</figure> | |
''' | |
html_content += '</body></html>' | |
return html_content | |
def generate_pdf(ids): | |
pdf = FPDF(orientation='L', unit='mm', format='A4') # 'L' for landscape orientation | |
pdf.add_page() | |
pdf.set_auto_page_break(auto=True, margin=15) | |
# Start positions | |
x_start = 10 | |
y_start = 10 | |
img_width = 50 | |
img_height = 50 | |
spacing = 1 | |
max_width = 297 # A4 landscape width in mm | |
current_x = x_start | |
current_y = y_start | |
for picto_id, lemma in ids: | |
if picto_id != 0: # ignore invalid IDs | |
img_url = f"https://static.arasaac.org/pictograms/{picto_id}/{picto_id}_500.png" | |
pdf.image(img_url, x=current_x, y=current_y, w=img_width, h=img_height) | |
pdf.set_xy(current_x, current_y + img_height + 5) | |
pdf.set_font("Arial", size=12) | |
pdf.cell(img_width, 10, txt=lemma, ln=1, align='C') | |
current_x += img_width + spacing | |
# Move to the next line if exceeds max width | |
if current_x + img_width > max_width: | |
current_x = x_start | |
current_y += img_height + spacing + 10 # Adjust for image height and some spacing | |
pdf_path = "pictograms.pdf" | |
pdf.output(pdf_path) | |
return pdf_path | |
st.title("Traduction d'une phrase en pictogrammes ARASAAC") | |
st.info("Text-to-Pictograms traduit une phrase en français en pictogrammes ARASAAC. Renseignez une phrase, puis validez. Vous pouvez sauvegarder la traduction au format PDF en cliquant sur le bouton en bas de page.", icon='ℹ️') | |
pictogram_ids = [] | |
sentence = st.text_input("Entrez une phrase en français:") | |
if sentence: | |
with st.spinner("Affichage des pictogrammes..."): | |
inputs = tokenizer(sentence, return_tensors="pt").input_ids | |
outputs = model.generate(inputs, max_new_tokens=40, do_sample=True, top_k=30, top_p=0.95) | |
pred = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
sentence_to_map = process_output_trad(pred) | |
pictogram_ids = [get_id_picto_from_predicted_lemma(lexicon, lemma) for lemma in sentence_to_map] | |
html = generate_html(pictogram_ids) | |
st.components.v1.html(html, height=200, scrolling=True) | |
if pictogram_ids: | |
# Container to hold the download button | |
pdf_path = generate_pdf(pictogram_ids) | |
with open(pdf_path, "rb") as pdf_file: | |
st.download_button(label="Télécharger la traduction en PDF", data=pdf_file, file_name="pictograms.pdf", mime="application/pdf") |