File size: 2,557 Bytes
7dd7207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53a3db7
 
 
 
 
 
7dd7207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from tqdm import tqdm
from PIL import Image
import numpy as np

import torch
import torchvision
from torch.utils.data import Dataset
from torch.utils.data import BatchSampler, DataLoader

from wmdetection.utils import read_image_rgb


class ImageDataset(Dataset):
    
    def __init__(self, objects, classifier_transforms):
        self.objects = objects
        self.classifier_transforms = classifier_transforms
        
    def __len__(self):
        return len(self.objects)

    def __getitem__(self, idx):
        obj = self.objects[idx]
        assert isinstance(obj, (str, np.ndarray, Image.Image))
        
        if isinstance(obj, str):
            pil_img = read_image_rgb(obj)
        elif isinstance(obj, np.ndarray):
            pil_img = Image.fromarray(obj)
        elif isinstance(obj, Image.Image):
            pil_img = obj
        
        resnet_img = self.classifier_transforms(pil_img).float()
        
        return resnet_img
    
    
class WatermarksPredictor:
    
    def __init__(self, wm_model, classifier_transforms, device):
        self.wm_model = wm_model
        self.wm_model.eval()
        self.classifier_transforms = classifier_transforms
        
        self.device = device
        
    def predict_image(self, pil_image):
        pil_image = pil_image.convert("RGB")
        input_img = self.classifier_transforms(pil_image).float().unsqueeze(0)
        outputs = self.wm_model(input_img.to(self.device))
        result = torch.max(outputs, 1)[1].cpu().reshape(-1).tolist()[0]
        return result
    
    def predict_image_confidence(self, pil_image):
        pil_image = pil_image.convert("RGB")
        input_img = self.classifier_transforms(pil_image).float().unsqueeze(0)
        outputs = self.wm_model(input_img.to(self.device))
        return torch.nn.functional.softmax(outputs, dim=1).cpu().reshape(-1)
        
    def run(self, files, num_workers=8, bs=8, pbar=True):
        eval_dataset = ImageDataset(files, self.classifier_transforms)
        loader = DataLoader(
            eval_dataset,
            sampler=torch.utils.data.SequentialSampler(eval_dataset),
            batch_size=bs,
            drop_last=False,
            num_workers=num_workers
        )
        if pbar:
            loader = tqdm(loader)
        
        result = []
        for batch in loader:
            with torch.no_grad():
                outputs = self.wm_model(batch.to(self.device))
                result.extend(torch.max(outputs, 1)[1].cpu().reshape(-1).tolist())
        
        return result