File size: 3,295 Bytes
7dd7207
 
 
 
 
 
 
 
 
 
 
33ea2c8
7dd7207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33b3901
 
 
 
 
33ea2c8
 
 
 
 
7dd7207
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_url, hf_hub_download

from .convnext import ConvNeXt
from wmdetection.utils import FP16Module


def get_convnext_model(name):
    if name == 'convnext-tiny' or name == 'convnext-wm_1102' or name == 'convnext-wm_1102_v2':
        model_ft = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
        model_ft.head = nn.Sequential( 
            nn.Linear(in_features=768, out_features=512),
            nn.GELU(),
            nn.Linear(in_features=512, out_features=256),
            nn.GELU(),
            nn.Linear(in_features=256, out_features=2),
        )
    
    detector_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return model_ft, detector_transforms


def get_resnext_model(name):
    if name == 'resnext50_32x4d-small':
        model_ft = models.resnext50_32x4d(pretrained=False)
    elif name == 'resnext101_32x8d-large':
        model_ft = models.resnext101_32x8d(pretrained=False)
        
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, 2)
    
    detector_transforms = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    return model_ft, detector_transforms


def get_watermarks_detection_model(name, device='cpu', fp16=True, pretrained=True, cache_dir='/tmp/watermark-detection'):
    assert name in MODELS, f"Unknown model name: {name}"
    assert not (fp16 and name.startswith('convnext')), "Can`t use fp16 mode with convnext models"
    config = MODELS[name]
    
    model_ft, detector_transforms = config['constructor'](name)
    
    if pretrained:
        hf_hub_download(repo_id=config['repo_id'], filename=config['filename'], 
                        cache_dir=cache_dir, force_filename=config['filename'])
        weights = torch.load(os.path.join(cache_dir, config['filename']), device)
        model_ft.load_state_dict(weights)
    
    if fp16:
        model_ft = FP16Module(model_ft)
        
    model_ft.eval()
    model_ft = model_ft.to(device)
    
    return model_ft, detector_transforms


MODELS = {
    'convnext-tiny': dict(
        constructor=get_convnext_model,
        repo_id='boomb0om/watermark-detectors',
        filename='convnext-tiny_watermarks_detector.pth',
    ),
    'convnext-wm_1102': dict(
        constructor=get_convnext_model,
        repo_id='Inf009/wm_1102',
        filename='convnext_v1_9.pth',
    ),
    'convnext-wm_1102_v2': dict(
        constructor=get_convnext_model,
        repo_id='Inf009/wm_1102',
        filename='convnext_v2.pth',
    ),
    'resnext101_32x8d-large': dict(
        constructor=get_resnext_model,
        repo_id='boomb0om/watermark-detectors',
        filename='watermark_classifier-resnext101_32x8d-input_size320-4epochs_c097_w082.pth',
    ),
    'resnext50_32x4d-small': dict(
        constructor=get_resnext_model,
        repo_id='boomb0om/watermark-detectors',
        filename='watermark_classifier-resnext50_32x4d-input_size320-4epochs_c082_w078.pth',
    )
}