Spaces:
Runtime error
Runtime error
Armandoliv
commited on
Commit
•
c828564
1
Parent(s):
fb1bff8
Create new file
Browse files
app.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
|
5 |
+
tokenizer = AutoTokenizer.from_pretrained("Armandoliv/gpt2-tweetml-generator")
|
6 |
+
|
7 |
+
model = AutoModelForCausalLM.from_pretrained("Armandoliv/gpt2-tweetml-generator")
|
8 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
9 |
+
model = model.to(device)
|
10 |
+
|
11 |
+
def main_generator(text):
|
12 |
+
max_input_length = 70
|
13 |
+
preprocess_text = text.strip().replace("\n"," ").strip()
|
14 |
+
prompt = f"<|startoftext|> {preprocess_text}"
|
15 |
+
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
16 |
+
generated = generated.to(device)
|
17 |
+
|
18 |
+
sample_outputs = model.generate(
|
19 |
+
generated,
|
20 |
+
do_sample=True,
|
21 |
+
top_k=20,
|
22 |
+
max_length = 70,
|
23 |
+
top_p=0.98,
|
24 |
+
num_return_sequences=10,
|
25 |
+
temperature=0.95
|
26 |
+
|
27 |
+
)
|
28 |
+
output = ""
|
29 |
+
|
30 |
+
for i, sample_output in enumerate(sample_outputs):
|
31 |
+
output += "{}: {}\n\n".format(i+1, tokenizer.decode(sample_output, skip_special_tokens=True))
|
32 |
+
|
33 |
+
|
34 |
+
return output
|
35 |
+
|
36 |
+
inputs = [gr.Textbox(lines=1, placeholder="Text Here...", label="Input")]
|
37 |
+
outputs = gr.Text( label="10 Tweets Generated")
|
38 |
+
title="Tweets generation app"
|
39 |
+
description = "This demo uses AI Models to create tweets.\nIt focus on Data Science and Machine Learning tweets creation."
|
40 |
+
|
41 |
+
io = gr.Interface(fn=main_generator, inputs=inputs, outputs=outputs, title=title, description = description,
|
42 |
+
|
43 |
+
css= """.gr-button-primary { background: -webkit-linear-gradient(
|
44 |
+
90deg, #355764 0%, #55a8a1 100% ) !important; background: #355764;
|
45 |
+
background: linear-gradient(
|
46 |
+
90deg, #355764 0%, #55a8a1 100% ) !important;
|
47 |
+
background: -moz-linear-gradient( 90deg, #355764 0%, #55a8a1 100% ) !important;
|
48 |
+
background: -webkit-linear-gradient(
|
49 |
+
90deg, #355764 0%, #55a8a1 100% ) !important;
|
50 |
+
color:white !important}"""
|
51 |
+
)
|
52 |
+
|
53 |
+
io.launch()
|