Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -36,7 +36,7 @@ pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
|
36 |
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
|
37 |
|
38 |
model.eval()
|
39 |
-
model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(
|
40 |
|
41 |
transform = transforms.Compose([
|
42 |
transforms.ToTensor(),
|
@@ -46,7 +46,8 @@ transform = transforms.Compose([
|
|
46 |
|
47 |
def predict(radio, dict, word_mask, prompt=""):
|
48 |
if(radio == "draw a mask above"):
|
49 |
-
with autocast("cuda"):
|
|
|
50 |
init_image = dict["image"].convert("RGB").resize((512, 512))
|
51 |
mask = dict["mask"].convert("RGB").resize((512, 512))
|
52 |
else:
|
@@ -63,7 +64,8 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
63 |
cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
|
64 |
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
65 |
os.remove(filename)
|
66 |
-
with autocast("cuda"):
|
|
|
67 |
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
68 |
return images[0]
|
69 |
|
|
|
36 |
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
|
37 |
|
38 |
model.eval()
|
39 |
+
model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False)
|
40 |
|
41 |
transform = transforms.Compose([
|
42 |
transforms.ToTensor(),
|
|
|
46 |
|
47 |
def predict(radio, dict, word_mask, prompt=""):
|
48 |
if(radio == "draw a mask above"):
|
49 |
+
#with autocast("cuda"):
|
50 |
+
with autocast(enable=(False if device=='cpu' else True)):
|
51 |
init_image = dict["image"].convert("RGB").resize((512, 512))
|
52 |
mask = dict["mask"].convert("RGB").resize((512, 512))
|
53 |
else:
|
|
|
64 |
cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
|
65 |
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
66 |
os.remove(filename)
|
67 |
+
#with autocast("cuda"):
|
68 |
+
with autocast(enable=(False if device=='cpu' else True)):
|
69 |
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
70 |
return images[0]
|
71 |
|