Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -51,7 +51,8 @@ transform = transforms.Compose([
|
|
51 |
def predict(radio, dict, word_mask, prompt=""):
|
52 |
if(radio == "draw a mask above"):
|
53 |
#with autocast("cuda"):
|
54 |
-
with autocast(device): #enable=(False if device=='cpu' else True)):
|
|
|
55 |
init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
|
56 |
mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
|
57 |
else:
|
@@ -69,7 +70,8 @@ def predict(radio, dict, word_mask, prompt=""):
|
|
69 |
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
70 |
os.remove(filename)
|
71 |
#with autocast("cuda"):
|
72 |
-
with autocast(device): #enable=(False if device=='cpu' else True)):
|
|
|
73 |
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
74 |
return images[0]
|
75 |
|
|
|
51 |
def predict(radio, dict, word_mask, prompt=""):
|
52 |
if(radio == "draw a mask above"):
|
53 |
#with autocast("cuda"):
|
54 |
+
#with autocast(device): #enable=(False if device=='cpu' else True)):
|
55 |
+
with autocast(enabled=True, dtype=torch.bfloat16, device='cpu'):
|
56 |
init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
|
57 |
mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
|
58 |
else:
|
|
|
70 |
mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
|
71 |
os.remove(filename)
|
72 |
#with autocast("cuda"):
|
73 |
+
#with autocast(device): #enable=(False if device=='cpu' else True)):
|
74 |
+
with autocast(enabled=True, dtype=torch.bfloat16, device='cpu'):
|
75 |
images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
|
76 |
return images[0]
|
77 |
|