Spaces:
Runtime error
Runtime error
from contextlib import nullcontext | |
import gradio as gr | |
import torch | |
from torch import autocast | |
from diffusers import StableDiffusionPipeline | |
import urllib, urllib.request | |
import os | |
from xml.etree import ElementTree | |
import random | |
import re | |
from typing import List | |
pokemon_types = ["Normal", | |
"Water", | |
"Fire", | |
"Ice", | |
"Psychic", | |
"Rock", | |
"Dark", | |
"Electric", | |
"Grass", | |
"Fighting", | |
"Poison", | |
"Ground", | |
"Flying", | |
"Bug", | |
"Ghost", | |
"Dragon", | |
"Steel", | |
"Fairy" | |
] | |
type_choices=["None", "Random"] | |
type_choices.extend(pokemon_types) | |
paper_name = None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
context = autocast if device == "cuda" else nullcontext | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
pipe = StableDiffusionPipeline.from_pretrained("lambdalabs/sd-pokemon-diffusers", torch_dtype=dtype) | |
pipe = pipe.to(device) | |
# Sometimes the nsfw checker is confused by the Pokémon images, you can disable | |
# it at your own risk here | |
disable_safety = True | |
if disable_safety: | |
def null_safety(images, **kwargs): | |
return images, False | |
pipe.safety_checker = null_safety | |
def infer(prompt, n_samples, steps, scale): | |
with context("cuda"): | |
images = pipe(n_samples*[prompt], guidance_scale=scale, num_inference_steps=steps).images | |
return images | |
def get_paper_name(url: str): | |
paper_id = os.path.basename(url) | |
paper_id = paper_id.split(".pdf")[0] | |
query_url = f"http://export.arxiv.org/api/query?id_list={paper_id}" | |
hdr = { "Content-Type" : "application/atom+xml" } | |
req = urllib.request.Request(query_url, headers=hdr) | |
response = urllib.request.urlopen(req) | |
tree = ElementTree.fromstring(response.read().decode("utf-8")) | |
paper_title = tree.find("{http://www.w3.org/2005/Atom}entry").find("{http://www.w3.org/2005/Atom}title").text | |
paper_title = paper_title.replace("\n", "") | |
paper_title = re.sub(' +', ' ', paper_title) | |
return paper_title | |
block = gr.Blocks() | |
examples = [ | |
[ | |
"https://arxiv.org/abs/1706.03762", | |
2, | |
7.5, | |
], | |
[ | |
"https://arxiv.org/abs/1404.5997v2", | |
2, | |
7.5, | |
], | |
[ | |
"https://arxiv.org/abs/2010.11929", | |
2, | |
7.5, | |
], | |
[ | |
"https://arxiv.org/abs/1810.04805v2", | |
2, | |
7.5, | |
] | |
] | |
with block: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 650px; margin: 50px auto;"> | |
<div> | |
<h1 style="font-weight: 900; font-size: 3rem;"> | |
Paper to Pokémon | |
</h1> | |
</div> | |
<p style="margin-bottom: 10px; margin-top: 30px; font-size: 94%"> | |
Generate new Pokémon from an arXiv link. Just paste the link to the overview, the pdf or just give the ID of the paper. | |
It will create a prompt with the paper title, which you can then modify as you like or submit as it is. | |
For general better quality increase the step size. (This will also increase the processing time) | |
</p> | |
</div> | |
""" | |
) | |
with gr.Group(): | |
with gr.Box(): | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
text = gr.Textbox( | |
label="Link or ID for paper", | |
show_label=False, | |
max_lines=1, | |
placeholder="Give arXiv link or ID for the paper", | |
).style( | |
border=(True, False, True, True), | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
btn = gr.Button("Generate image").style( | |
margin=False, | |
rounded=(False, True, True, False), | |
) | |
poke_type = gr.Radio(choices=type_choices, value="None", label="Pokemon Type") | |
prompt_ideas = gr.CheckboxGroup(choices=["as a bird", | |
"with four legs", | |
"with wings", | |
"as a koala", | |
"with a beak", | |
"looking like a llama"], | |
label="Additional prompt ideas") | |
prompt_box = gr.Textbox(placeholder="Your prompt appears here", interactive=True, label="Prompt") | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery" | |
).style(grid=[2], height="auto") | |
with gr.Row(elem_id="advanced-options"): | |
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=2, step=1) | |
steps = gr.Slider(label="Steps", minimum=5, maximum=50, value=25, step=5) | |
scale = gr.Slider( | |
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1 | |
) | |
ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, scale], outputs=gallery, cache_examples=False) | |
ex.dataset.headers = [""] | |
def resolve_poke_type(pok_type: str): | |
if pok_type == "None": | |
return "" | |
elif pok_type == "Random": | |
idx = random.randint(0,len(pokemon_types)-1) | |
return pokemon_types[idx] | |
else: | |
return pok_type | |
def update_prompt_link(new_link: str, pok_type: str, prompt_ideas: List[str]): | |
global paper_name | |
paper_name = get_paper_name(new_link) | |
pok_type = resolve_poke_type(pok_type) | |
prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}" | |
return build_prompt_text(paper_name, pok_type, prompt_ideas) | |
def update_prompt_type(paper_link: str, pok_type: str, prompt_ideas: List[str]): | |
global paper_name | |
if paper_name is None: | |
paper_name = get_paper_name(paper_link) | |
pok_type = resolve_poke_type(pok_type) | |
return build_prompt_text(paper_name, pok_type, prompt_ideas) | |
def build_prompt_text(paper_name, pok_type, add_ideas): | |
prompt_text = f"{paper_name} as {pok_type} type" if pok_type != "" else f"{paper_name}" | |
prompt_text = f"""{prompt_text} {" ".join(add_ideas)}""" | |
return prompt_text | |
text.change(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box) | |
text.submit(update_prompt_link, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box) | |
poke_type.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box) | |
prompt_ideas.change(update_prompt_type, inputs=[text, poke_type, prompt_ideas], outputs=prompt_box) | |
btn.click(infer, inputs=[prompt_box, samples, steps, scale], outputs=gallery) | |
gr.HTML( | |
""" | |
<div class="footer" style="text-align: center; max-width: 650px; margin: 50px auto;"> | |
<p>Inspired by and cloned from the great <a href="https://huggingface.co./spaces/lambdalabs/text-to-pokemon"> | |
Text-to-Pokémon</a> space by Lambda labs</p> | |
<p> Gradio Demo by johko</p> | |
</div> | |
""" | |
) | |
block.launch() |