Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -14,19 +14,13 @@ from matplotlib import pyplot as plt
|
|
14 |
from torchvision import transforms
|
15 |
from diffusers import DiffusionPipeline
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
transform = transforms.Compose([
|
26 |
-
transforms.ToTensor(),
|
27 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
28 |
-
transforms.Resize((512, 512)),
|
29 |
-
])
|
30 |
|
31 |
def read_content(file_path: str) -> str:
|
32 |
"""read the content of target file
|
@@ -36,11 +30,11 @@ def read_content(file_path: str) -> str:
|
|
36 |
|
37 |
return content
|
38 |
|
39 |
-
def predict(dict,
|
40 |
-
init_image = dict["image"].convert("RGB")
|
41 |
-
mask = dict["mask"].convert("RGB")
|
42 |
-
|
43 |
-
return
|
44 |
|
45 |
|
46 |
css = '''
|
@@ -89,9 +83,9 @@ with image_blocks as demo:
|
|
89 |
with gr.Box():
|
90 |
with gr.Row():
|
91 |
with gr.Column():
|
92 |
-
image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload")
|
|
|
93 |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
|
94 |
-
prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', show_label=False, elem_id="input-text")
|
95 |
btn = gr.Button("Inpaint!").style(
|
96 |
margin=False,
|
97 |
rounded=(False, True, True, False),
|
@@ -105,7 +99,7 @@ with image_blocks as demo:
|
|
105 |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
106 |
|
107 |
|
108 |
-
btn.click(fn=predict, inputs=[image,
|
109 |
share_button.click(None, [], [], _js=share_js)
|
110 |
|
111 |
|
|
|
14 |
from torchvision import transforms
|
15 |
from diffusers import DiffusionPipeline
|
16 |
|
17 |
+
pipe = DiffusionPipeline.from_pretrained(
|
18 |
+
"patrickvonplaten/new_inpaint_test",
|
19 |
+
torch_dtype=torch.float16,
|
20 |
+
)
|
21 |
+
pipe = pipe.to("cuda")
|
22 |
|
23 |
+
from share_btn import community_icon_html, loading_icon_html, share_js
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def read_content(file_path: str) -> str:
|
26 |
"""read the content of target file
|
|
|
30 |
|
31 |
return content
|
32 |
|
33 |
+
def predict(dict, example_image):
|
34 |
+
init_image = dict["image"].convert("RGB")
|
35 |
+
mask = dict["mask"].convert("RGB")
|
36 |
+
image = pipe(image=init_image, mask_image=mask, example_image=example_image).images[0]
|
37 |
+
return image, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
38 |
|
39 |
|
40 |
css = '''
|
|
|
83 |
with gr.Box():
|
84 |
with gr.Row():
|
85 |
with gr.Column():
|
86 |
+
image = gr.Image(source='upload', tool='sketch', elem_id="image_upload", type="pil", label="Upload")
|
87 |
+
example = gr.Image(source='upload', elem_id="image_upload", type="pil", label="Upload")
|
88 |
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
|
|
|
89 |
btn = gr.Button("Inpaint!").style(
|
90 |
margin=False,
|
91 |
rounded=(False, True, True, False),
|
|
|
99 |
share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
|
100 |
|
101 |
|
102 |
+
btn.click(fn=predict, inputs=[image, example], outputs=[image_out, community_icon, loading_icon, share_button])
|
103 |
share_button.click(None, [], [], _js=share_js)
|
104 |
|
105 |
|