File size: 2,167 Bytes
41b65c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3781df4
 
41b65c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import inspect
import os
from typing import List, Optional, Union
import numpy as np
import torch
import PIL
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline
from rembg import remove
import requests
from io import BytesIO
from huggingface_hub import login

token = os.getenv("WRITE_TOKEN")
login(token, True)

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
    

def predict(dict, prompt):
  image = dict['image'].convert("RGB").resize((512, 512))
  mask_image = dict['mask'].convert("RGB").resize((512, 512))
  images = pipe(prompt=prompt, image=image, mask_image=mask_image).images
  return(images[0])

def download_image(url):
    response = requests.get(url)
    return PIL.Image.open(BytesIO(response.content)).convert("RGB")

model_path = "runwayml/stable-diffusion-inpainting"

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    model_path,
    # revision="fp16", 
    # torch_dtype=torch.float16,
    use_auth_token=True
)

img_url = "https://cdn.faire.com/fastly/893b071985d70819da5f0d485f1b1bb97ee4f16a6e14ef1bdd4a086b3588be58.png" # wino
image = download_image(img_url).resize((512, 512))
inverted_mask_image = remove(data = image, only_mask = True)
mask_image = PIL.ImageOps.invert(inverted_mask_image)
prompt = "crazy portal universe"

guidance_scale=7.5
num_samples = 3
generator = torch.Generator(device="cpu").manual_seed(0) # change the seed to get different results
images = pipe(
    prompt=prompt,
    image=image,
    mask_image=mask_image,
    guidance_scale=guidance_scale,
    generator=generator,
    num_images_per_prompt=num_samples,
).images
images.insert(0, image)
image_grid(images, 1, num_samples + 1)

gr.Interface(
    predict,
    title = 'Stable Diffusion In-Painting',
    inputs=[
        gr.Image(source = 'upload', tool = 'sketch', type = 'pil'),
        gr.Textbox(label = 'prompt')
    ],
    outputs = [
        gr.Image()
        ]
).launch(debug=True)