Spaces:
Runtime error
Runtime error
import random | |
import torch | |
import gradio as gr | |
from gradio.mix import Series | |
from rudalle.pipelines import generate_images | |
from rudalle import get_rudalle_model, get_tokenizer, get_vae | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device) | |
tokenizer = get_tokenizer() | |
vae = get_vae().to(device) | |
def dalle_wrapper(prompt: str): | |
top_k, top_p = random.choice([ | |
(1024, 0.98), | |
(512, 0.97), | |
(384, 0.96), | |
]) | |
images , _ = generate_images( | |
prompt, | |
tokenizer, | |
dalle, | |
vae, | |
top_k=top_k, | |
images_num=1, | |
top_p=top_p | |
) | |
title = f"<b>{prompt}</b>" | |
return title, images[0] | |
translator = gr.Interface.load("huggingface/facebook/wmt19-en-ru", | |
inputs=[gr.inputs.Textbox(label="What would you like to see?")]) | |
outputs = [ | |
gr.outputs.HTML(label=""), | |
gr.outputs.Image(label=""), | |
] | |
generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs) | |
description = ( | |
"ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). " | |
"This demo uses an English-Russian translation model to adapt the prompts. " | |
"Try pressing [Submit] multiple times to generate new images!" | |
) | |
article = ( | |
"<p style='text-align: center'>" | |
"<a href='https://github.com/sberbank-ai/ru-dalle'>GitHub</a> | " | |
"<a href='https://habr.com/ru/company/sberbank/blog/586926/'>Article (in Russian)</a>" | |
"</p>" | |
) | |
examples = [["A still life of grapes and a bottle of wine"], | |
["Город в стиле киберпанк"], | |
["A colorful photo of a coral reef"], | |
["A white cat sitting in a cardboard box"]] | |
series = Series(translator, generator, | |
title='Kinda-English ruDALL-E', | |
description=description, | |
article=article, | |
layout='horizontal', | |
theme='huggingface', | |
examples=examples, | |
allow_flagging=False, | |
live=False, | |
enable_queue=True, | |
) | |
series.launch() | |