|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from optimum.intel import OVModelForCausalLM |
|
|
|
|
|
model_name = "DarwinAnim8or/Pythia-Greentext-1.4b" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = OVModelForCausalLM.from_pretrained(model_name, export=True) |
|
|
|
def generate(text, length=100, penalty=3, temperature=0.8, topk=40): |
|
input_text = "Write a greentext from 4chan.org. The story should be like a bullet-point list using > as the start of each line. Most greentexts are humorous or absurd in nature. Most greentexts have a twist near the end.\n" |
|
|
|
if not text.startswith(">"): |
|
input_text += ">" + text + "\n>" |
|
else: |
|
input_text += text + "\n>" |
|
|
|
input_ids = tokenizer.encode(input_text, return_tensors="pt") |
|
input_ids = input_ids[:, :-1] |
|
|
|
length = length + input_ids.size(1) |
|
|
|
output = model.generate( |
|
input_ids, |
|
max_length=length, |
|
temperature=temperature, |
|
top_k=topk, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=penalty, |
|
early_stopping=True, |
|
) |
|
|
|
generated_text = tokenizer.decode(output[:, input_ids.size(1):][0], skip_special_tokens=True) |
|
return generated_text |
|
|
|
examples = [ |
|
["be me"], |
|
["be going to heaven"], |
|
|
|
|
|
|
|
["be a plague doctor"] |
|
] |
|
|
|
demo = gr.Interface( |
|
fn=generate, |
|
inputs=[ |
|
gr.inputs.Textbox(lines=5, label="Input Text"), |
|
gr.inputs.Slider(5, 200, label='Length', default=100, step=5), |
|
gr.inputs.Slider(1, 10, label='no repeat ngram size', default=2, step=1), |
|
gr.inputs.Slider(0.0, 1.0, label='Temperature - control randomness', default=0.2, step=0.1), |
|
gr.inputs.Slider(10, 100, label="top_k", default=40, step=10) |
|
], |
|
outputs=gr.outputs.Textbox(label="Generated Text"), |
|
examples=examples, |
|
title="Pythia-Greentext Playground", |
|
description="Using the 1.4b size model. You may need to run it a few times in order to get something good!" |
|
) |
|
|
|
demo.launch() |