import gradio as gr import requests import io import re import random import os from PIL import Image from datasets import load_dataset from huggingface_hub import login login(token=os.getenv("HF_READ_TOKEN")) API_URL = "https://api-inference.huggingface.co/models/openskyml/open-diffusion-v1" API_TOKEN = os.getenv("HF_READ_TOKEN") # it is free headers = {"Authorization": f"Bearer {API_TOKEN}"} word_list_dataset = load_dataset("openskyml/bad-words-prompt-list", data_files="en.txt", use_auth_token=True) word_list = word_list_dataset["train"]['text'] def query(prompt, is_negative=False, steps=5, cfg_scale=7, seed=None, num_images=4): for filter in word_list: if re.search(rf"\b{filter}\b", prompt): raise gr.Error("Unsafe content found. Please try again with different prompts.") images = [] for _ in range(num_images): payload = { "inputs": prompt + ", 8k", "is_negative": is_negative, "steps": steps, "cfg_scale": cfg_scale, "seed": seed if seed is not None else random.randint(-1, 2147483647) } image_bytes = requests.post(API_URL, headers=headers, json=payload).content image = Image.open(io.BytesIO(image_bytes)) images.append(image) return images css = """ .gradio-container { font-family: 'IBM Plex Sans', sans-serif; } #gallery { min-height: 22rem; margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius: .5rem !important; border-bottom-left-radius: .5rem !important; } #gallery>div>.h-full { min-height: 20rem; } #prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem} #component-16{border-top-width: 1px!important;margin-top: 1em} .image_duplication{position: absolute; width: 100px; left: 50px} """ with gr.Blocks(css=css) as demo: gr.HTML( """

Open Diffusion 1.0 Demo

""" ) with gr.Row(): gallery_output = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2]) with gr.Row(): with gr.Box(): text_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=1) negative_prompt = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=1) text_button = gr.Button("Generate", icon="https://www.gstatic.com/android/keyboard/emojikitchen/20210521/u1fa84/u1fa84_u1fa84.png") text_button.click(query, inputs=[text_prompt, negative_prompt], outputs=gallery_output) demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)