baixintech_zhangyiming_prod commited on
Commit
53a3db7
1 Parent(s): ad6f6d7

output with softmax

Browse files
Files changed (2) hide show
  1. app.py +8 -5
  2. wmdetection/pipelines/predictor.py +6 -0
app.py CHANGED
@@ -12,13 +12,16 @@ model, transforms = get_watermarks_detection_model(
12
  predictor = WatermarksPredictor(model, transforms, 'cpu')
13
 
14
 
15
- def predict(image):
16
- result = predictor.predict_image(image)
17
- return 'watermarked' if result else 'clean' # prints "watermarked"
 
 
18
 
19
 
20
  examples = glob.glob(os.path.join('images', 'clean', '*'))
21
  examples.extend(glob.glob(os.path.join('images', 'watermark', '*')))
22
- iface = gr.Interface(fn=predict, inputs=[gr.inputs.Image(type="pil")],
23
- examples=examples, outputs="text")
 
24
  iface.launch()
 
12
  predictor = WatermarksPredictor(model, transforms, 'cpu')
13
 
14
 
15
+ def predict(image, threshold=0.5):
16
+ result = predictor.predict_image_confidence(image)
17
+ values = result.tolist()
18
+ wm_flag = 1 if values[1] >= threshold else 0
19
+ return 'watermarked' if wm_flag else 'clean', "%.4f" % values[1] # prints "watermarked"
20
 
21
 
22
  examples = glob.glob(os.path.join('images', 'clean', '*'))
23
  examples.extend(glob.glob(os.path.join('images', 'watermark', '*')))
24
+ examples = [[e, 0.5] for e in examples]
25
+ iface = gr.Interface(fn=predict, inputs=[gr.inputs.Image(type="pil"), gr.inputs.Number(label="threshold", default=0.5), ],
26
+ examples=examples, outputs=[gr.outputs.Textbox(label="class"), gr.outputs.Textbox(label="wm_confidence")])
27
  iface.launch()
wmdetection/pipelines/predictor.py CHANGED
@@ -51,6 +51,12 @@ class WatermarksPredictor:
51
  outputs = self.wm_model(input_img.to(self.device))
52
  result = torch.max(outputs, 1)[1].cpu().reshape(-1).tolist()[0]
53
  return result
 
 
 
 
 
 
54
 
55
  def run(self, files, num_workers=8, bs=8, pbar=True):
56
  eval_dataset = ImageDataset(files, self.classifier_transforms)
 
51
  outputs = self.wm_model(input_img.to(self.device))
52
  result = torch.max(outputs, 1)[1].cpu().reshape(-1).tolist()[0]
53
  return result
54
+
55
+ def predict_image_confidence(self, pil_image):
56
+ pil_image = pil_image.convert("RGB")
57
+ input_img = self.classifier_transforms(pil_image).float().unsqueeze(0)
58
+ outputs = self.wm_model(input_img.to(self.device))
59
+ return torch.nn.functional.softmax(outputs, dim=1).cpu().reshape(-1)
60
 
61
  def run(self, files, num_workers=8, bs=8, pbar=True):
62
  eval_dataset = ImageDataset(files, self.classifier_transforms)