baixintech_zhangyiming_prod commited on
Commit
33ea2c8
1 Parent(s): 33b3901

convnext v2

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. wmdetection/models/__init__.py +6 -1
app.py CHANGED
@@ -6,7 +6,7 @@ import os, glob
6
 
7
 
8
  model, transforms = get_watermarks_detection_model(
9
- 'convnext-wm_1102',
10
  fp16=False,
11
  cache_dir='model_files'
12
  )
 
6
 
7
 
8
  model, transforms = get_watermarks_detection_model(
9
+ 'convnext-wm_1102_v2',
10
  fp16=False,
11
  cache_dir='model_files'
12
  )
wmdetection/models/__init__.py CHANGED
@@ -9,7 +9,7 @@ from wmdetection.utils import FP16Module
9
 
10
 
11
  def get_convnext_model(name):
12
- if name == 'convnext-tiny' or name == 'convnext-wm_1102':
13
  model_ft = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
14
  model_ft.head = nn.Sequential(
15
  nn.Linear(in_features=768, out_features=512),
@@ -78,6 +78,11 @@ MODELS = {
78
  repo_id='Inf009/wm_1102',
79
  filename='convnext_v1_9.pth',
80
  ),
 
 
 
 
 
81
  'resnext101_32x8d-large': dict(
82
  constructor=get_resnext_model,
83
  repo_id='boomb0om/watermark-detectors',
 
9
 
10
 
11
  def get_convnext_model(name):
12
+ if name == 'convnext-tiny' or name == 'convnext-wm_1102' or name == 'convnext-wm_1102_v2':
13
  model_ft = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
14
  model_ft.head = nn.Sequential(
15
  nn.Linear(in_features=768, out_features=512),
 
78
  repo_id='Inf009/wm_1102',
79
  filename='convnext_v1_9.pth',
80
  ),
81
+ 'convnext-wm_1102_v2': dict(
82
+ constructor=get_convnext_model,
83
+ repo_id='Inf009/wm_1102',
84
+ filename='convnext_v2.pth',
85
+ ),
86
  'resnext101_32x8d-large': dict(
87
  constructor=get_resnext_model,
88
  repo_id='boomb0om/watermark-detectors',