baixintech_zhangyiming_prod
convnext v2
33ea2c8
raw
history blame contribute delete
No virus
3.3 kB
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from huggingface_hub import hf_hub_url, hf_hub_download
from .convnext import ConvNeXt
from wmdetection.utils import FP16Module
def get_convnext_model(name):
if name == 'convnext-tiny' or name == 'convnext-wm_1102' or name == 'convnext-wm_1102_v2':
model_ft = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
model_ft.head = nn.Sequential(
nn.Linear(in_features=768, out_features=512),
nn.GELU(),
nn.Linear(in_features=512, out_features=256),
nn.GELU(),
nn.Linear(in_features=256, out_features=2),
)
detector_transforms = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return model_ft, detector_transforms
def get_resnext_model(name):
if name == 'resnext50_32x4d-small':
model_ft = models.resnext50_32x4d(pretrained=False)
elif name == 'resnext101_32x8d-large':
model_ft = models.resnext101_32x8d(pretrained=False)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
detector_transforms = transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
return model_ft, detector_transforms
def get_watermarks_detection_model(name, device='cpu', fp16=True, pretrained=True, cache_dir='/tmp/watermark-detection'):
assert name in MODELS, f"Unknown model name: {name}"
assert not (fp16 and name.startswith('convnext')), "Can`t use fp16 mode with convnext models"
config = MODELS[name]
model_ft, detector_transforms = config['constructor'](name)
if pretrained:
hf_hub_download(repo_id=config['repo_id'], filename=config['filename'],
cache_dir=cache_dir, force_filename=config['filename'])
weights = torch.load(os.path.join(cache_dir, config['filename']), device)
model_ft.load_state_dict(weights)
if fp16:
model_ft = FP16Module(model_ft)
model_ft.eval()
model_ft = model_ft.to(device)
return model_ft, detector_transforms
MODELS = {
'convnext-tiny': dict(
constructor=get_convnext_model,
repo_id='boomb0om/watermark-detectors',
filename='convnext-tiny_watermarks_detector.pth',
),
'convnext-wm_1102': dict(
constructor=get_convnext_model,
repo_id='Inf009/wm_1102',
filename='convnext_v1_9.pth',
),
'convnext-wm_1102_v2': dict(
constructor=get_convnext_model,
repo_id='Inf009/wm_1102',
filename='convnext_v2.pth',
),
'resnext101_32x8d-large': dict(
constructor=get_resnext_model,
repo_id='boomb0om/watermark-detectors',
filename='watermark_classifier-resnext101_32x8d-input_size320-4epochs_c097_w082.pth',
),
'resnext50_32x4d-small': dict(
constructor=get_resnext_model,
repo_id='boomb0om/watermark-detectors',
filename='watermark_classifier-resnext50_32x4d-input_size320-4epochs_c082_w078.pth',
)
}