Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -59,7 +59,7 @@ model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
|
59 |
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
60 |
model_id_or_path,
|
61 |
revision="fp16",
|
62 |
-
torch_dtype=torch.
|
63 |
use_auth_token=auth_token
|
64 |
)
|
65 |
|
@@ -69,23 +69,25 @@ model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
|
|
69 |
model.eval()
|
70 |
model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False)
|
71 |
|
|
|
|
|
72 |
transform = transforms.Compose([
|
73 |
transforms.ToTensor(),
|
74 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
75 |
-
transforms.Resize((
|
76 |
])
|
77 |
|
78 |
def predict(radio, dict, word_mask, prompt=""):
|
79 |
if(radio == "draw a mask above"):
|
80 |
with autocast(device): #"cuda"
|
81 |
-
init_image = dict["image"].convert("RGB").resize((
|
82 |
-
mask = dict["mask"].convert("RGB").resize((
|
83 |
elif(radio == "type what to keep"):
|
84 |
img = transform(dict["image"]).squeeze(0)
|
85 |
word_masks = [word_mask]
|
86 |
with torch.no_grad():
|
87 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
88 |
-
init_image = dict['image'].convert('RGB').resize((
|
89 |
filename = f"{uuid.uuid4()}.png"
|
90 |
plt.imsave(filename,torch.sigmoid(preds[0][0]))
|
91 |
img2 = cv2.imread(filename)
|
@@ -99,7 +101,7 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
99 |
word_masks = [word_mask]
|
100 |
with torch.no_grad():
|
101 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
102 |
-
init_image = dict['image'].convert('RGB').resize((
|
103 |
filename = f"{uuid.uuid4()}.png"
|
104 |
plt.imsave(filename,torch.sigmoid(preds[0][0]))
|
105 |
img2 = cv2.imread(filename)
|
@@ -127,68 +129,6 @@ css = '''
|
|
127 |
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
|
128 |
#image_upload .touch-none{display: flex}
|
129 |
|
130 |
-
.markdown-body {
|
131 |
-
font-family: -apple-system,BlinkMacSystemFont,"Segoe UI",Helvetica,Arial,sans-serif,"Apple Color Emoji","Segoe UI Emoji";
|
132 |
-
font-size: 16px;
|
133 |
-
line-height: 1.5;
|
134 |
-
word-wrap: break-word;
|
135 |
-
}
|
136 |
-
.container-lg {
|
137 |
-
max-width: 1012px;
|
138 |
-
margin-right: auto;
|
139 |
-
margin-left: auto;
|
140 |
-
}
|
141 |
-
[data-color-mode="auto"][data-light-theme*="light"] {
|
142 |
-
--color-workflow-card-connector: var(--color-scale-gray-3);
|
143 |
-
--color-workflow-card-connector-bg: var(--color-scale-gray-3);
|
144 |
-
--color-workflow-card-connector-inactive: var(--color-border-default);
|
145 |
-
--color-workflow-card-connector-inactive-bg: var(--color-border-default);
|
146 |
-
--color-workflow-card-connector-highlight: var(--color-scale-blue-4);
|
147 |
-
--color-workflow-card-connector-highlight-bg: var(--color-scale-blue-4);
|
148 |
-
--color-workflow-card-bg: var(--color-scale-white);
|
149 |
-
--color-workflow-card-inactive-bg: var(--color-canvas-inset);
|
150 |
-
--color-workflow-card-header-shadow: rgba(0, 0, 0, 0);
|
151 |
-
--color-workflow-card-progress-complete-bg: var(--color-scale-blue-4);
|
152 |
-
--color-workflow-card-progress-incomplete-bg: var(--color-scale-gray-2);
|
153 |
-
--color-discussions-state-answered-icon: var(--color-scale-white);
|
154 |
-
--color-bg-discussions-row-emoji-box: rgba(209, 213, 218, 0.5);
|
155 |
-
--color-notifications-button-text: var(--color-fg-muted);
|
156 |
-
--color-notifications-button-hover-text: var(--color-fg-default);
|
157 |
-
--color-notifications-button-hover-bg: var(--color-scale-gray-2);
|
158 |
-
--color-notifications-row-read-bg: var(--color-canvas-subtle);
|
159 |
-
--color-notifications-row-bg: var(--color-scale-white);
|
160 |
-
--color-icon-directory: var(--color-scale-blue-3);
|
161 |
-
--color-checks-step-error-icon: var(--color-scale-red-4);
|
162 |
-
--color-calendar-halloween-graph-day-L1-bg: #ffee4a;
|
163 |
-
--color-calendar-halloween-graph-day-L2-bg: #ffc501;
|
164 |
-
--color-calendar-halloween-graph-day-L3-bg: #fe9600;
|
165 |
-
--color-calendar-halloween-graph-day-L4-bg: #03001c;
|
166 |
-
--color-calendar-graph-day-bg: #ebedf0;
|
167 |
-
--color-calendar-graph-day-border: rgba(27, 31, 35, 0.06);
|
168 |
-
--color-calendar-graph-day-L1-bg: #9be9a8;
|
169 |
-
--color-calendar-graph-day-L2-bg: #40c463;
|
170 |
-
--color-calendar-graph-day-L3-bg: #30a14e;
|
171 |
-
--color-calendar-graph-day-L4-bg: #216e39;
|
172 |
-
--color-calendar-graph-day-L1-border: rgba(27, 31, 35, 0.06);
|
173 |
-
--color-calendar-graph-day-L2-border: rgba(27, 31, 35, 0.06);
|
174 |
-
--color-calendar-graph-day-L3-border: rgba(27, 31, 35, 0.06);
|
175 |
-
--color-calendar-graph-day-L4-border: rgba(27, 31, 35, 0.06);
|
176 |
-
--color-user-mention-fg: var(--color-fg-default);
|
177 |
-
--color-user-mention-bg: var(--color-attention-subtle);
|
178 |
-
--color-text-white: var(--color-scale-white);
|
179 |
-
}
|
180 |
-
:root {
|
181 |
-
--Layout-pane-width: 220px;
|
182 |
-
--Layout-content-width: 100%;
|
183 |
-
--Layout-template-columns: 1fr var(--Layout-pane-width);
|
184 |
-
--Layout-template-areas: "content pane";
|
185 |
-
--Layout-column-gap: 16px;
|
186 |
-
--Layout-row-gap: 16px;
|
187 |
-
--Layout-outer-spacing-x: 0px;
|
188 |
-
--Layout-outer-spacing-y: 0px;
|
189 |
-
--Layout-inner-spacing-min: 0px;
|
190 |
-
--Layout-inner-spacing-max: 0px;
|
191 |
-
}
|
192 |
'''
|
193 |
def swap_word_mask(radio_option):
|
194 |
if(radio_option == "draw a mask above"):
|
@@ -226,18 +166,18 @@ with image_blocks as demo:
|
|
226 |
<rect x="69" y="69" width="23" height="23" fill="black"></rect>
|
227 |
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
|
228 |
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
|
229 |
-
<rect x="115" y="46" width="23" height="23" fill="
|
230 |
-
<rect x="115" y="115" width="23" height="23" fill="
|
231 |
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
|
232 |
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
|
233 |
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
|
234 |
<rect x="92" y="69" width="23" height="23" fill="white"></rect>
|
235 |
-
<rect x="69" y="46" width="23" height="23" fill="
|
236 |
<rect x="69" y="115" width="23" height="23" fill="white"></rect>
|
237 |
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
|
238 |
-
<rect x="46" y="46" width="23" height="23" fill="
|
239 |
<rect x="46" y="115" width="23" height="23" fill="black"></rect>
|
240 |
-
<rect x="46" y="69" width="23" height="23" fill="
|
241 |
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
|
242 |
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
|
243 |
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
|
@@ -258,7 +198,7 @@ with image_blocks as demo:
|
|
258 |
with gr.Box(elem_id="mask_radio").style(border=False):
|
259 |
radio = gr.Radio(["draw a mask above", "type what to mask below", "type what to keep"], value="draw a mask above", show_label=False, interactive=True).style(container=False)
|
260 |
word_mask = gr.Textbox(label = "What to find in your image", interactive=False, elem_id="word_mask", placeholder="Disabled").style(container=False)
|
261 |
-
img_res = gr.inputs.Dropdown("512*512", "256*256")
|
262 |
prompt = gr.Textbox(label = 'Your prompt (what you want to add in place of what you are removing)')
|
263 |
radio.change(fn=swap_word_mask, inputs=radio, outputs=word_mask,show_progress=False)
|
264 |
radio.change(None, inputs=[], outputs=image_blocks, _js = """
|
|
|
59 |
pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
60 |
model_id_or_path,
|
61 |
revision="fp16",
|
62 |
+
torch_dtype=torch.float16, #float16
|
63 |
use_auth_token=auth_token
|
64 |
)
|
65 |
|
|
|
69 |
model.eval()
|
70 |
model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False)
|
71 |
|
72 |
+
imgRes = 256
|
73 |
+
|
74 |
transform = transforms.Compose([
|
75 |
transforms.ToTensor(),
|
76 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
77 |
+
transforms.Resize((imgRes, imgRes)),
|
78 |
])
|
79 |
|
80 |
def predict(radio, dict, word_mask, prompt=""):
|
81 |
if(radio == "draw a mask above"):
|
82 |
with autocast(device): #"cuda"
|
83 |
+
init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
|
84 |
+
mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
|
85 |
elif(radio == "type what to keep"):
|
86 |
img = transform(dict["image"]).squeeze(0)
|
87 |
word_masks = [word_mask]
|
88 |
with torch.no_grad():
|
89 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
90 |
+
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
91 |
filename = f"{uuid.uuid4()}.png"
|
92 |
plt.imsave(filename,torch.sigmoid(preds[0][0]))
|
93 |
img2 = cv2.imread(filename)
|
|
|
101 |
word_masks = [word_mask]
|
102 |
with torch.no_grad():
|
103 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
104 |
+
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
105 |
filename = f"{uuid.uuid4()}.png"
|
106 |
plt.imsave(filename,torch.sigmoid(preds[0][0]))
|
107 |
img2 = cv2.imread(filename)
|
|
|
129 |
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
|
130 |
#image_upload .touch-none{display: flex}
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
'''
|
133 |
def swap_word_mask(radio_option):
|
134 |
if(radio_option == "draw a mask above"):
|
|
|
166 |
<rect x="69" y="69" width="23" height="23" fill="black"></rect>
|
167 |
<rect x="92" width="23" height="23" fill="#D9D9D9"></rect>
|
168 |
<rect x="92" y="69" width="23" height="23" fill="#AEAEAE"></rect>
|
169 |
+
<rect x="115" y="46" width="23" height="23" fill="black"></rect>
|
170 |
+
<rect x="115" y="115" width="23" height="23" fill="black"></rect>
|
171 |
<rect x="115" y="69" width="23" height="23" fill="#D9D9D9"></rect>
|
172 |
<rect x="92" y="46" width="23" height="23" fill="#AEAEAE"></rect>
|
173 |
<rect x="92" y="115" width="23" height="23" fill="#AEAEAE"></rect>
|
174 |
<rect x="92" y="69" width="23" height="23" fill="white"></rect>
|
175 |
+
<rect x="69" y="46" width="23" height="23" fill="black"></rect>
|
176 |
<rect x="69" y="115" width="23" height="23" fill="white"></rect>
|
177 |
<rect x="69" y="69" width="23" height="23" fill="#D9D9D9"></rect>
|
178 |
+
<rect x="46" y="46" width="23" height="23" fill="white"></rect>
|
179 |
<rect x="46" y="115" width="23" height="23" fill="black"></rect>
|
180 |
+
<rect x="46" y="69" width="23" height="23" fill="white"></rect>
|
181 |
<rect x="23" y="46" width="23" height="23" fill="#D9D9D9"></rect>
|
182 |
<rect x="23" y="115" width="23" height="23" fill="#AEAEAE"></rect>
|
183 |
<rect x="23" y="69" width="23" height="23" fill="black"></rect>
|
|
|
198 |
with gr.Box(elem_id="mask_radio").style(border=False):
|
199 |
radio = gr.Radio(["draw a mask above", "type what to mask below", "type what to keep"], value="draw a mask above", show_label=False, interactive=True).style(container=False)
|
200 |
word_mask = gr.Textbox(label = "What to find in your image", interactive=False, elem_id="word_mask", placeholder="Disabled").style(container=False)
|
201 |
+
img_res = gr.inputs.Dropdown("512*512", "256*256").style(container=True)
|
202 |
prompt = gr.Textbox(label = 'Your prompt (what you want to add in place of what you are removing)')
|
203 |
radio.change(fn=swap_word_mask, inputs=radio, outputs=word_mask,show_progress=False)
|
204 |
radio.change(None, inputs=[], outputs=image_blocks, _js = """
|