Spaces:
Build error
Build error
import torch | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import gradio as gr | |
import os | |
model_name = 'eliolio/bart-finetuned-yelpreviews' | |
access_token = os.environ.get('private_token') | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=access_token) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token) | |
def create_prompt(stars, useful, funny, cool): | |
return f"Generate review: stars: {stars}, useful: {useful}, funny: {funny}, cool: {cool}" | |
def generate_reviews(stars, useful, funny, cool): | |
text = create_prompt(stars, useful, funny, cool) | |
inputs = tokenizer(text, return_tensors='pt') | |
out = model.generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
num_beams=5, | |
num_return_sequences=3 | |
) | |
reviews = [] | |
for review in out: | |
reviews.append(tokenizer.decode(review, skip_special_tokens=True)) | |
return reviews[0], reviews[1], reviews[2] | |
css = """ | |
#ctr {text-align: center;} | |
#btn {color: white; background: linear-gradient(90deg, #00d2ff 0%, #3a47d5 100%);} | |
""" | |
md_text = """## Generating Yelp reviews with BART-base ⭐⭐⭐""" | |
demo = gr.Blocks(css=css) | |
with demo: | |
with gr.Row(): | |
gr.Markdown(md_text, elem_id='ctr') | |
with gr.Row(): | |
stars = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="stars") | |
useful = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="useful") | |
funny = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="funny") | |
cool = gr.inputs.Slider(minimum=0, maximum=5, step=1, default=0, label="cool") | |
with gr.Row(): | |
button = gr.Button("Generate reviews !", elem_id='btn') | |
with gr.Row(): | |
output1 = gr.Textbox(label="Review #1") | |
output2 = gr.Textbox(label="Review #2") | |
output3 = gr.Textbox(label="Review #3") | |
button.click( | |
fn=generate_reviews, | |
inputs=[stars, useful, funny, cool], | |
outputs=[output1, output2, output3] | |
) | |
demo.launch() |