File size: 4,500 Bytes
2f43921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
809ed8d
2f43921
 
809ed8d
 
 
2f43921
 
 
 
 
 
 
 
 
 
 
809ed8d
2f43921
809ed8d
2f43921
809ed8d
 
 
2f43921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d27c2b8
2f43921
 
 
 
 
3fef0ca
8850909
a452def
8da63a9
 
 
 
 
 
2f43921
 
 
2430d13
2f43921
 
 
 
 
 
3fef0ca
 
2f43921
 
 
2430d13
2f43921
 
 
 
809ed8d
 
2f43921
d79eded
 
d27c2b8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr

import torch
import open_clip
import mediapy as media
from optim_utils import *

import argparse

# load args
args = argparse.Namespace()
args.__dict__.update(read_json("sample_config.json"))
args.print_step = None

# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretrain, device=device)

args.counter = 0

def inference(target_image, prompt_len, iter):
    args.counter += 1
    print(args.counter)
    
    if prompt_len is not None:
        args.prompt_len = int(prompt_len)
    else:
        args.prompt_len = 8
    
    if iter is not None:
        args.iter = int(iter)
    else:
        args.iter = 1000
        
    learned_prompt = optimize_prompt(model, preprocess, args, device, target_images=[target_image])
    
    return learned_prompt
    
def inference_text(target_prompt, prompt_len, iter):
    args.counter += 1
    print(args.counter)
    
    if prompt_len is not None:
        args.prompt_len = min(int(prompt_len), 75)
    else:
        args.prompt_len = 8
    
    if iter is not None:
        args.iter = min(int(iter), 3000)
    else:
        args.iter = 1000
        
    learned_prompt = optimize_prompt(model, preprocess, args, device, target_prompts=[target_prompt])

    return learned_prompt


gr.Progress(track_tqdm=True)

demo = gr.Blocks().queue(default_concurrency_limit=5)

with demo:
    gr.Markdown("# PEZ Dispenser")
    gr.Markdown("## Hard Prompts Made Easy (PEZ)")
    gr.Markdown("*Want to generate a text prompt for your image that is useful for Stable Diffusion?*")
    gr.Markdown("This space can either generate a text fragment that describes your image, or it can shorten an existing text prompt. This space is using OpenCLIP-ViT/H, the same text encoder used by Stable Diffusion V2. After you generate a prompt, try it out on Stable Diffusion [here](https://huggingface.co./stabilityai/stable-diffusion-2-1-base), [here](https://huggingface.co./spaces/stabilityai/stable-diffusion) or on [Midjourney](https://docs.midjourney.com/). For a quick PEZ demo, try clicking on one of the examples at the bottom of this page.")
    gr.Markdown("For additional details, you can check out the [paper](https://arxiv.org/abs/2302.03668) and the code on [Github](https://github.com/YuxinWenRick/hard-prompts-made-easy).")
    gr.Markdown("Note: Generation with 1000 steps takes ~60 seconds with a T4. Don't want to wait? You can also run on [Google Colab](https://colab.research.google.com/drive/1VSFps4siwASXDwhK_o29dKA9COvTnG8A?usp=sharing). Or, you can reduce the number of steps.")
    gr.HTML("""
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
<br/>
<a href="https://huggingface.co./spaces/tomg-group-umd/pez-dispenser?duplicate=true">
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
<p/>""")
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Image to Prompt")
            input_image = gr.Image(type="pil", label="Target Image")
            image_button = gr.Button("Generate Prompt")

            gr.Markdown("### Long Prompt to Short Prompt")
            input_prompt = gr.Textbox(label="Target Prompt")
            prompt_button = gr.Button("Distill Prompt")

            prompt_len_field = gr.Number(label="Prompt Length (max 75, recommend 8-16)", value=8)
            num_step_field = gr.Number(label="Optimization Steps (max 3000 because of limited resources)", value=1000)

        with gr.Column():
            gr.Markdown("### Learned Prompt")
            output_prompt = gr.Textbox(label="Learned Prompt")

    image_button.click(inference, inputs=[input_image, prompt_len_field, num_step_field], outputs=output_prompt)
    prompt_button.click(inference_text, inputs=[input_prompt, prompt_len_field, num_step_field], outputs=output_prompt)

    gr.Examples([["sample.jpeg", 8, 1000]], inputs=[input_image, prompt_len_field, num_step_field], fn=inference, outputs=output_prompt, cache_examples=True)
    gr.Examples([["digital concept art of old wooden cabin in florida swamp, trending on artstation", 3, 1000]], inputs=[input_prompt, prompt_len_field, num_step_field], fn=inference_text, outputs=output_prompt, cache_examples=True)

    gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=tomg-group-umd_pez-dispenser)")
    
demo.launch()