nightfury commited on
Commit
986ef15
1 Parent(s): 02df5d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -36,7 +36,7 @@ pipe = StableDiffusionInpaintingPipeline.from_pretrained(
36
  model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
37
 
38
  model.eval()
39
- model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device('cuda')), strict=False)
40
 
41
  transform = transforms.Compose([
42
  transforms.ToTensor(),
@@ -46,7 +46,8 @@ transform = transforms.Compose([
46
 
47
  def predict(radio, dict, word_mask, prompt=""):
48
  if(radio == "draw a mask above"):
49
- with autocast("cuda"):
 
50
  init_image = dict["image"].convert("RGB").resize((512, 512))
51
  mask = dict["mask"].convert("RGB").resize((512, 512))
52
  else:
@@ -63,7 +64,8 @@ def predict(radio, dict, word_mask, prompt=""):
63
  cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
64
  mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
65
  os.remove(filename)
66
- with autocast("cuda"):
 
67
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
68
  return images[0]
69
 
 
36
  model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
37
 
38
  model.eval()
39
+ model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False)
40
 
41
  transform = transforms.Compose([
42
  transforms.ToTensor(),
 
46
 
47
  def predict(radio, dict, word_mask, prompt=""):
48
  if(radio == "draw a mask above"):
49
+ #with autocast("cuda"):
50
+ with autocast(enable=(False if device=='cpu' else True)):
51
  init_image = dict["image"].convert("RGB").resize((512, 512))
52
  mask = dict["mask"].convert("RGB").resize((512, 512))
53
  else:
 
64
  cv2.cvtColor(bw_image, cv2.COLOR_BGR2RGB)
65
  mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
66
  os.remove(filename)
67
+ #with autocast("cuda"):
68
+ with autocast(enable=(False if device=='cpu' else True)):
69
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
70
  return images[0]
71