fffiloni's picture
Update app.py
ad26302
raw
history blame
1.65 kB
from diffusers import StableDiffusionInpaintPipeline
import gradio as gr
import numpy as np
import imageio
from PIL import Image
from io import BytesIO
import os
MY_SECRET_TOKEN=os.environ.get('HF_TOKEN_SD')
print("hello sylvain")
YOUR_TOKEN=MY_SECRET_TOKEN
device="cpu"
pipe = StableDiffusionInpaintPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=YOUR_TOKEN)
pipe.to(device)
source_img = gr.Image(source="upload", type="numpy", tool="sketch", elem_id="source_container");
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
def resize(height,img):
baseheight = height
img = Image.open(img)
hpercent = (baseheight/float(img.size[1]))
wsize = int((float(img.size[0])*float(hpercent)))
img = img.resize((wsize,baseheight), Image.Resampling.LANCZOS)
return img
def predict(prompt, source_img):
imageio.imwrite("data.png", source_img["image"])
imageio.imwrite("data_mask.png", source_img["mask"])
src = resize(512, "data.png")
src.save("src.png")
mask = resize(512, "data_mask.png")
mask.save("mask.png")
images_list = img_pipe([prompt] * 1, init_image=src, mask_image=mask, strength=0.75)
images = []
safe_image = Image.open(r"unsafe.png")
for i, image in enumerate(images_list["sample"]):
if(images_list["nsfw_content_detected"][i]):
images.append(safe_image)
else:
images.append(image)
return images
custom_css="style.css"
gr.Interface(fn=predict, inputs=["text", source_img], outputs=gallery, css=custom_css).launch(enable_queue=True)