Armandoliv's picture
Update app.py
d319358
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("Armandoliv/gpt2-tweetml-generator")
model = AutoModelForCausalLM.from_pretrained("Armandoliv/gpt2-tweetml-generator")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def main_generator(text):
preprocess_text = text.strip().replace("\n"," ").strip()
prompt = f"<|startoftext|> {preprocess_text}"
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
generated = generated.to(device)
sample_outputs = model.generate(
generated,
do_sample=True,
top_k=20,
max_length = 70,
top_p=0.98,
num_return_sequences=10,
temperature=0.95
)
output = ""
for i, sample_output in enumerate(sample_outputs):
output += "{}: {}\n\n".format(i+1, tokenizer.decode(sample_output, skip_special_tokens=True))
return output
inputs = [gr.Textbox(lines=1, placeholder="Text Here...", label="Input")]
outputs = gr.Text( label="10 Tweets Generated")
title="Tweets generation app"
description = "This demo uses AI Models to create tweets.\nIt focus on Data Science and Machine Learning tweets creation."
examples = ['I wonder']
io = gr.Interface(fn=main_generator, inputs=inputs, outputs=outputs, title=title, description = description, examples = examples,
css= """.gr-button-primary { background: -webkit-linear-gradient(
90deg, #355764 0%, #55a8a1 100% ) !important; background: #355764;
background: linear-gradient(
90deg, #355764 0%, #55a8a1 100% ) !important;
background: -moz-linear-gradient( 90deg, #355764 0%, #55a8a1 100% ) !important;
background: -webkit-linear-gradient(
90deg, #355764 0%, #55a8a1 100% ) !important;
color:white !important}"""
)
io.launch()