rudall-e / app.py
speech-test's picture
Typo
edfe9df
raw
history blame
2.22 kB
import random
import torch
import gradio as gr
from gradio.mix import Series
from rudalle.pipelines import generate_images
from rudalle import get_rudalle_model, get_tokenizer, get_vae
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device)
tokenizer = get_tokenizer()
vae = get_vae().to(device)
def dalle_wrapper(prompt: str):
top_k, top_p = random.choice([
(1024, 0.98),
(512, 0.97),
(384, 0.96),
])
images , _ = generate_images(
prompt,
tokenizer,
dalle,
vae,
top_k=top_k,
images_num=1,
top_p=top_p
)
title = f"<b>{prompt}</b>"
return title, images[0]
translator = gr.Interface.load("huggingface/facebook/wmt19-en-ru",
inputs=[gr.inputs.Textbox(label="What would you like to see?")])
outputs = [
gr.outputs.HTML(label=""),
gr.outputs.Image(label=""),
]
generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs)
description = (
"ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). "
"This demo uses an English-Russian translation model to adapt the prompts. "
"Try pressing [Submit] multiple times to generate new images!"
)
article = (
"<p style='text-align: center'>"
"<a href='https://github.com/sberbank-ai/ru-dalle'>GitHub</a> | "
"<a href='https://habr.com/ru/company/sberbank/blog/586926/'>Article (in Russian)</a>"
"</p>"
)
examples = [["A still life of grapes and a bottle of wine"],
["Город в стиле киберпанк"],
["A colorful photo of a coral reef"],
["A white cat sitting in a cardboard box"]]
series = Series(translator, generator,
title='Kinda-English ruDALL-E',
description=description,
article=article,
layout='horizontal',
theme='huggingface',
examples=examples,
allow_flagging=False,
live=False,
enable_queue=True,
)
series.launch()