import random import torch import gradio as gr from gradio.mix import Series from rudalle.pipelines import generate_images from rudalle import get_rudalle_model, get_tokenizer, get_vae device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dalle = get_rudalle_model("Malevich", pretrained=True, fp16=True, device=device) tokenizer = get_tokenizer() vae = get_vae().to(device) def dalle_wrapper(prompt: str): top_k, top_p = random.choice([ (1024, 0.98), (512, 0.97), (384, 0.96), ]) images , _ = generate_images( prompt, tokenizer, dalle, vae, top_k=top_k, images_num=1, top_p=top_p ) title = f"{prompt}" return title, images[0] translator = gr.Interface.load("huggingface/facebook/wmt19-en-ru", inputs=[gr.inputs.Textbox(label="What would you like to see?")]) outputs = [ gr.outputs.HTML(label=""), gr.outputs.Image(label=""), ] generator = gr.Interface(fn=dalle_wrapper, inputs="text", outputs=outputs) description = ( "ruDALL-E is a 1.3B params text-to-image model by SberAI (links at the bottom). " "This demo uses an English-Russian translation model to adapt the prompts. " "Try pressing [Submit] multiple times to generate new images!" ) article = ( "
" "GitHub | " "Article (in Russian)" "
" ) examples = [["A still life of grapes and a bottle of wine"], ["Город в стиле киберпанк"], ["A colorful photo of a coral reef"], ["A white cat sitting in a cardboard box"]] series = Series(translator, generator, title='Kinda-English ruDALL-E', description=description, article=article, layout='horizontal', theme='huggingface', examples=examples, allow_flagging=False, live=False, enable_queue=True, ) series.launch()