Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -35,27 +35,31 @@ pipe = StableDiffusionInpaintingPipeline.from_pretrained(
|
|
35 |
#model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
|
36 |
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
|
37 |
|
|
|
38 |
model.eval()
|
39 |
model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False)
|
40 |
|
|
|
|
|
|
|
41 |
transform = transforms.Compose([
|
42 |
transforms.ToTensor(),
|
43 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
44 |
-
transforms.Resize((
|
45 |
])
|
46 |
|
47 |
def predict(radio, dict, word_mask, prompt=""):
|
48 |
if(radio == "draw a mask above"):
|
49 |
#with autocast("cuda"):
|
50 |
-
with autocast(enable=(False if device=='cpu' else True)):
|
51 |
-
init_image = dict["image"].convert("RGB").resize((
|
52 |
-
mask = dict["mask"].convert("RGB").resize((
|
53 |
else:
|
54 |
img = transform(dict["image"]).unsqueeze(0)
|
55 |
word_masks = [word_mask]
|
56 |
with torch.no_grad():
|
57 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
58 |
-
init_image = dict['image'].convert('RGB').resize((
|
59 |
filename = f"{uuid.uuid4()}.png"
|
60 |
plt.imsave(filename,torch.sigmoid(preds[0][0]))
|
61 |
img2 = cv2.imread(filename)
|
@@ -65,7 +69,7 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
65 |
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
66 |
os.remove(filename)
|
67 |
#with autocast("cuda"):
|
68 |
-
with autocast(enable=(False if device=='cpu' else True)):
|
69 |
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
70 |
return images[0]
|
71 |
|
|
|
35 |
#model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
|
36 |
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
|
37 |
|
38 |
+
model = model.to(torch.device(device))
|
39 |
model.eval()
|
40 |
model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False)
|
41 |
|
42 |
+
print ("Torch load(model) : ", model)
|
43 |
+
imgRes = 256 #512
|
44 |
+
|
45 |
transform = transforms.Compose([
|
46 |
transforms.ToTensor(),
|
47 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
48 |
+
transforms.Resize((imgRes, imgRes)),
|
49 |
])
|
50 |
|
51 |
def predict(radio, dict, word_mask, prompt=""):
|
52 |
if(radio == "draw a mask above"):
|
53 |
#with autocast("cuda"):
|
54 |
+
with autocast(device): #enable=(False if device=='cpu' else True)):
|
55 |
+
init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
|
56 |
+
mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
|
57 |
else:
|
58 |
img = transform(dict["image"]).unsqueeze(0)
|
59 |
word_masks = [word_mask]
|
60 |
with torch.no_grad():
|
61 |
preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
62 |
+
init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
|
63 |
filename = f"{uuid.uuid4()}.png"
|
64 |
plt.imsave(filename,torch.sigmoid(preds[0][0]))
|
65 |
img2 = cv2.imread(filename)
|
|
|
69 |
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
70 |
os.remove(filename)
|
71 |
#with autocast("cuda"):
|
72 |
+
with autocast(device): #enable=(False if device=='cpu' else True)):
|
73 |
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
74 |
return images[0]
|
75 |
|