DucHaiten commited on
Commit
0e1330a
1 Parent(s): fc2e84c

Upload 6 files

Browse files
wdv3-timm-main/.gitignore ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python
3
+
4
+ ### Linux ###
5
+ *~
6
+
7
+ # temporary files which can be created if a process still has a handle open of a deleted file
8
+ .fuse_hidden*
9
+
10
+ # KDE directory preferences
11
+ .directory
12
+
13
+ # Linux trash folder which might appear on any partition or disk
14
+ .Trash-*
15
+
16
+ # .nfs files are created when an open file is removed but is still being accessed
17
+ .nfs*
18
+
19
+ ### macOS ###
20
+ # General
21
+ .DS_Store
22
+ .AppleDouble
23
+ .LSOverride
24
+
25
+ # Icon must end with two \r
26
+ Icon
27
+
28
+
29
+ # Thumbnails
30
+ ._*
31
+
32
+ # Files that might appear in the root of a volume
33
+ .DocumentRevisions-V100
34
+ .fseventsd
35
+ .Spotlight-V100
36
+ .TemporaryItems
37
+ .Trashes
38
+ .VolumeIcon.icns
39
+ .com.apple.timemachine.donotpresent
40
+
41
+ # Directories potentially created on remote AFP share
42
+ .AppleDB
43
+ .AppleDesktop
44
+ Network Trash Folder
45
+ Temporary Items
46
+ .apdisk
47
+
48
+ ### Python ###
49
+ # Byte-compiled / optimized / DLL files
50
+ __pycache__/
51
+ *.py[cod]
52
+ *$py.class
53
+
54
+ # C extensions
55
+ *.so
56
+
57
+ # Distribution / packaging
58
+ .Python
59
+ build/
60
+ develop-eggs/
61
+ dist/
62
+ downloads/
63
+ eggs/
64
+ .eggs/
65
+ lib/
66
+ lib64/
67
+ parts/
68
+ sdist/
69
+ var/
70
+ wheels/
71
+ share/python-wheels/
72
+ *.egg-info/
73
+ .installed.cfg
74
+ *.egg
75
+ MANIFEST
76
+
77
+ # PyInstaller
78
+ # Usually these files are written by a python script from a template
79
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
80
+ *.manifest
81
+ *.spec
82
+
83
+ # Installer logs
84
+ pip-log.txt
85
+ pip-delete-this-directory.txt
86
+
87
+ # Unit test / coverage reports
88
+ htmlcov/
89
+ .tox/
90
+ .nox/
91
+ .coverage
92
+ .coverage.*
93
+ .cache
94
+ nosetests.xml
95
+ coverage.xml
96
+ *.cover
97
+ *.py,cover
98
+ .hypothesis/
99
+ .pytest_cache/
100
+ cover/
101
+
102
+ # Translations
103
+ *.mo
104
+ *.pot
105
+
106
+ # Django stuff:
107
+ *.log
108
+ local_settings.py
109
+ db.sqlite3
110
+ db.sqlite3-journal
111
+
112
+ # Flask stuff:
113
+ instance/
114
+ .webassets-cache
115
+
116
+ # Scrapy stuff:
117
+ .scrapy
118
+
119
+ # Sphinx documentation
120
+ docs/_build/
121
+
122
+ # PyBuilder
123
+ .pybuilder/
124
+ target/
125
+
126
+ # Jupyter Notebook
127
+ .ipynb_checkpoints
128
+
129
+ # IPython
130
+ profile_default/
131
+ ipython_config.py
132
+
133
+ # pyenv
134
+ # For a library or package, you might want to ignore these files since the code is
135
+ # intended to run in multiple environments; otherwise, check them in:
136
+ # .python-version
137
+
138
+ # pipenv
139
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
140
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
141
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
142
+ # install all needed dependencies.
143
+ #Pipfile.lock
144
+
145
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
146
+ __pypackages__/
147
+
148
+ # Celery stuff
149
+ celerybeat-schedule
150
+ celerybeat.pid
151
+
152
+ # SageMath parsed files
153
+ *.sage.py
154
+
155
+ # Environments
156
+ .env
157
+ .venv
158
+ env/
159
+ venv/
160
+ ENV/
161
+ env.bak/
162
+ venv.bak/
163
+
164
+ # Spyder project settings
165
+ .spyderproject
166
+ .spyproject
167
+
168
+ # Rope project settings
169
+ .ropeproject
170
+
171
+ # mkdocs documentation
172
+ /site
173
+
174
+ # mypy
175
+ .mypy_cache/
176
+ .dmypy.json
177
+ dmypy.json
178
+
179
+ # Pyre type checker
180
+ .pyre/
181
+
182
+ # pytype static type analyzer
183
+ .pytype/
184
+
185
+ # Cython debug symbols
186
+ cython_debug/
187
+
188
+ ### VisualStudioCode ###
189
+ .vscode/*
190
+ !.vscode/settings.json
191
+ !.vscode/tasks.json
192
+ !.vscode/launch.json
193
+ !.vscode/extensions.json
194
+ *.code-workspace
195
+
196
+ # Local History for Visual Studio Code
197
+ .history/
198
+
199
+ ### VisualStudioCode Patch ###
200
+ # Ignore all local history of files
201
+ .history
202
+ .ionide
203
+
204
+ ### Windows ###
205
+ # Windows thumbnail cache files
206
+ Thumbs.db
207
+ Thumbs.db:encryptable
208
+ ehthumbs.db
209
+ ehthumbs_vista.db
210
+
211
+ # Dump file
212
+ *.stackdump
213
+
214
+ # Folder config file
215
+ [Dd]esktop.ini
216
+
217
+ # Recycle Bin used on file shares
218
+ $RECYCLE.BIN/
219
+
220
+ # Windows Installer files
221
+ *.cab
222
+ *.msi
223
+ *.msix
224
+ *.msm
225
+ *.msp
226
+
227
+ # Windows shortcuts
228
+ *.lnk
229
+
230
+ # End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
231
+
232
+ # temp and misc
233
+ /misc/
234
+ /temp/
235
+
236
+ # direnv
237
+ .envrc
238
+ .envrc.*
239
+
240
+ # dotenv
241
+ .env
242
+ .env.*
243
+
244
+ # temp files
245
+ **/tmp_*.*
246
+ **/*.tmp.*
247
+
248
+ # but keep examples
249
+ !*.example
250
+
251
+ # input images and heatmap outputs
252
+ /images/
253
+ /heatmaps/
wdv3-timm-main/.vscode/settings.json ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.insertSpaces": true,
3
+ "editor.tabSize": 4,
4
+ "files.trimTrailingWhitespace": true,
5
+ "editor.rulers": [100, 120],
6
+
7
+ "files.associations": {
8
+ "*.yaml": "yaml"
9
+ },
10
+ "files.exclude": {
11
+ "**/.git": true,
12
+ "**/.svn": true,
13
+ "**/.hg": true,
14
+ "**/CVS": true,
15
+ "**/.DS_Store": true,
16
+ "**/Thumbs.db": true,
17
+ "**/.ruff_cache": true,
18
+ "**/__pycache__": true,
19
+ "**/*.egg-info": true
20
+ },
21
+
22
+ "[shellscript]": {
23
+ "files.eol": "\n",
24
+ "editor.tabSize": 4,
25
+ "editor.detectIndentation": false
26
+ },
27
+
28
+ "[python]": {
29
+ "editor.wordBasedSuggestions": "off",
30
+ "editor.formatOnSave": true,
31
+ "editor.defaultFormatter": "charliermarsh.ruff",
32
+ "editor.codeActionsOnSave": {
33
+ "source.organizeImports": "always"
34
+ }
35
+ },
36
+ "python.analysis.include": ["./src", "./scripts", "./tests"],
37
+
38
+ "[json]": {
39
+ "editor.defaultFormatter": "esbenp.prettier-vscode",
40
+ "editor.detectIndentation": false,
41
+ "editor.formatOnSaveMode": "file",
42
+ "editor.formatOnSave": true,
43
+ "editor.tabSize": 2
44
+ },
45
+ "[jsonc]": {
46
+ "editor.defaultFormatter": "esbenp.prettier-vscode",
47
+ "editor.detectIndentation": false,
48
+ "editor.formatOnSaveMode": "file",
49
+ "editor.formatOnSave": true,
50
+ "editor.tabSize": 2
51
+ },
52
+
53
+ "[toml]": {
54
+ "editor.tabSize": 2,
55
+ "editor.detectIndentation": false,
56
+ "editor.formatOnSave": true,
57
+ "editor.formatOnSaveMode": "file",
58
+ "editor.defaultFormatter": "tamasfe.even-better-toml",
59
+ "editor.rulers": [80, 100]
60
+ },
61
+ "evenBetterToml.formatter.columnWidth": 88,
62
+
63
+ "[yaml]": {
64
+ "editor.detectIndentation": false,
65
+ "editor.tabSize": 2,
66
+ "editor.formatOnSave": true,
67
+ "editor.formatOnSaveMode": "file",
68
+ "diffEditor.ignoreTrimWhitespace": false,
69
+ "editor.defaultFormatter": "redhat.vscode-yaml"
70
+ },
71
+ "yaml.format.bracketSpacing": true,
72
+ "yaml.format.proseWrap": "preserve",
73
+ "yaml.format.singleQuote": false,
74
+ "yaml.format.printWidth": 110,
75
+
76
+ "[hcl]": {
77
+ "editor.detectIndentation": false,
78
+ "editor.formatOnSave": true,
79
+ "editor.formatOnSaveMode": "file",
80
+ "editor.defaultFormatter": "fredwangwang.vscode-hcl-format"
81
+ },
82
+
83
+ "[markdown]": {
84
+ "files.trimTrailingWhitespace": false
85
+ },
86
+
87
+ "css.lint.validProperties": ["dock", "content-align", "content-justify"],
88
+ "[css]": {
89
+ "editor.formatOnSave": true
90
+ },
91
+
92
+ "remote.autoForwardPorts": false,
93
+ "remote.autoForwardPortsSource": "process"
94
+ }
wdv3-timm-main/README.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # wdv3-timm
2
+
3
+ small example thing showing how to use `timm` to run the WD Tagger V3 models.
4
+
5
+ ## How To Use
6
+
7
+ 1. clone the repository and enter the directory:
8
+ ```sh
9
+ git clone https://github.com/neggles/wdv3-timm.git
10
+ cd wd3-timm
11
+ ```
12
+
13
+ 2. Create a virtual environment and install the Python requirements.
14
+
15
+ If you're using Linux, you can use the provided script:
16
+ ```sh
17
+ bash setup.sh
18
+ ```
19
+
20
+ Or if you're on Windows (or just want to do it manually), you can do the following:
21
+ ```sh
22
+ # Create virtual environment
23
+ python3.10 -m venv .venv
24
+ # Activate it
25
+ source .venv/bin/activate
26
+ # Upgrade pip/setuptools/wheel
27
+ python -m pip install -U pip setuptools wheel
28
+ # At this point, optionally you can install PyTorch manually (e.g. if you are not using an nVidia GPU)
29
+ python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
30
+ # Install requirements
31
+ python -m pip install -r requirements.txt
32
+ ```
33
+
34
+ 3. Run the example script, picking one of the 3 models to use:
35
+ ```sh
36
+ python wdv3_timm.py <swinv2|convnext|vit> path/to/image.png
37
+ ```
38
+
39
+ Example output from `python wdv3_timm.py vit a_picture_of_ganyu.png`:
40
+ ```sh
41
+ Loading model 'vit' from 'SmilingWolf/wd-vit-tagger-v3'...
42
+ Loading tag list...
43
+ Creating data transform...
44
+ Loading image and preprocessing...
45
+ Running inference...
46
+ Processing results...
47
+ --------
48
+ Caption: 1girl, horns, solo, bell, ahoge, colored_skin, blue_skin, neck_bell, looking_at_viewer, purple_eyes, upper_body, blonde_hair, long_hair, goat_horns, blue_hair, off_shoulder, sidelocks, bare_shoulders, alternate_costume, shirt, black_shirt, cowbell, ganyu_(genshin_impact)
49
+ --------
50
+ Tags: 1girl, horns, solo, bell, ahoge, colored skin, blue skin, neck bell, looking at viewer, purple eyes, upper body, blonde hair, long hair, goat horns, blue hair, off shoulder, sidelocks, bare shoulders, alternate costume, shirt, black shirt, cowbell, ganyu \(genshin impact\)
51
+ --------
52
+ Ratings:
53
+ general: 0.827
54
+ sensitive: 0.199
55
+ questionable: 0.001
56
+ explicit: 0.001
57
+ --------
58
+ Character tags (threshold=0.75):
59
+ ganyu_(genshin_impact): 0.991
60
+ --------
61
+ General tags (threshold=0.35):
62
+ 1girl: 0.996
63
+ horns: 0.950
64
+ solo: 0.947
65
+ bell: 0.918
66
+ ahoge: 0.897
67
+ colored_skin: 0.881
68
+ blue_skin: 0.872
69
+ neck_bell: 0.854
70
+ looking_at_viewer: 0.817
71
+ purple_eyes: 0.734
72
+ upper_body: 0.615
73
+ blonde_hair: 0.609
74
+ long_hair: 0.607
75
+ goat_horns: 0.524
76
+ blue_hair: 0.496
77
+ off_shoulder: 0.472
78
+ sidelocks: 0.470
79
+ bare_shoulders: 0.464
80
+ alternate_costume: 0.437
81
+ shirt: 0.427
82
+ black_shirt: 0.417
83
+ cowbell: 0.415
84
+ ```
wdv3-timm-main/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ huggingface-hub
3
+ numpy
4
+ pandas
5
+ pillow >= 9.5.0
6
+ simple-parsing >= 0.1.5
7
+ timm @ git+https://github.com/huggingface/pytorch-image-models@main#egg=timm
8
+ tokenizers
9
+ torch >= 2.0.0
10
+ torchvision
11
+ transformers
wdv3-timm-main/setup.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ # get the folder this script is in and make sure we're in it
5
+ script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd -P)
6
+ cd "${script_dir}"
7
+
8
+ # make venv if not exist
9
+ if [[ ! -d .venv ]]; then
10
+ echo "Creating virtual environment..."
11
+ python3.10 -m venv .venv
12
+ fi
13
+
14
+ # activate the venv
15
+ source .venv/bin/activate
16
+
17
+ # upgrade pip
18
+ python -m pip install -U pip setuptools wheel
19
+
20
+ # install requirements
21
+ python -m pip install -r requirements.txt
22
+
23
+ echo "Setup complete. Run 'source .venv/bin/activate' to enter the virtual environment."
24
+ exit 0
wdv3-timm-main/wdv3_timm.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import timm
8
+ import torch
9
+ from huggingface_hub import hf_hub_download
10
+ from huggingface_hub.utils import HfHubHTTPError
11
+ from PIL import Image
12
+ from simple_parsing import field, parse_known_args
13
+ from timm.data import create_transform, resolve_data_config
14
+ from torch import Tensor, nn
15
+ from torch.nn import functional as F
16
+
17
+ import json
18
+
19
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ MODEL_REPO_MAP = {
21
+ "vit": "SmilingWolf/wd-vit-tagger-v3",
22
+ "swinv2": "SmilingWolf/wd-swinv2-tagger-v3",
23
+ "convnext": "SmilingWolf/wd-convnext-tagger-v3",
24
+ }
25
+
26
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
27
+ # convert to RGB/RGBA if not already (deals with palette images etc.)
28
+ if image.mode not in ["RGB", "RGBA"]:
29
+ image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
30
+ # convert RGBA to RGB with white background
31
+ if image.mode == "RGBA":
32
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
33
+ canvas.alpha_composite(image)
34
+ image = canvas.convert("RGB")
35
+ return image
36
+
37
+ def pil_pad_square(image: Image.Image) -> Image.Image:
38
+ w, h = image.size
39
+ # get the largest dimension so we can pad to a square
40
+ px = max(image.size)
41
+ # pad to square with white background
42
+ canvas = Image.new("RGB", (px, px), (255, 255, 255))
43
+ canvas.paste(image, ((px - w) // 2, (px - h) // 2))
44
+ return canvas
45
+
46
+ @dataclass
47
+ class LabelData:
48
+ names: list[str]
49
+ rating: list[np.int64]
50
+ general: list[np.int64]
51
+ character: list[np.int64]
52
+
53
+ def load_labels_hf(
54
+ repo_id: str,
55
+ revision: Optional[str] = None,
56
+ token: Optional[str] = None,
57
+ ) -> LabelData:
58
+ try:
59
+ csv_path = hf_hub_download(
60
+ repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token
61
+ )
62
+ csv_path = Path(csv_path).resolve()
63
+ except HfHubHTTPError as e:
64
+ raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e
65
+
66
+ df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
67
+ tag_data = LabelData(
68
+ names=df["name"].tolist(),
69
+ rating=list(np.where(df["category"] == 9)[0]),
70
+ general=list(np.where(df["category"] == 0)[0]),
71
+ character=list(np.where(df["category"] == 4)[0]),
72
+ )
73
+
74
+ return tag_data
75
+
76
+ def get_tags(
77
+ probs: Tensor,
78
+ labels: LabelData,
79
+ gen_threshold: float,
80
+ char_threshold: float,
81
+ ):
82
+ # Convert indices+probs to labels
83
+ probs = list(zip(labels.names, probs.numpy()))
84
+
85
+ # First 4 labels are actually ratings
86
+ rating_labels = dict([probs[i] for i in labels.rating])
87
+
88
+ # General labels, pick any where prediction confidence > threshold
89
+ gen_labels = [probs[i] for i in labels.general]
90
+ gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold])
91
+ gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True))
92
+
93
+ # Character labels, pick any where prediction confidence > threshold
94
+ char_labels = [probs[i] for i in labels.character]
95
+ char_labels = dict([x for x in char_labels if x[1] > char_threshold])
96
+ char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True))
97
+
98
+ # Combine general and character labels, sort by confidence
99
+ combined_names = [x for x in gen_labels]
100
+ combined_names.extend([x for x in char_labels])
101
+
102
+ # Convert to a string suitable for use as a training caption
103
+ caption = ", ".join(combined_names)
104
+ taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)")
105
+
106
+ return caption, taglist, rating_labels, char_labels, gen_labels
107
+
108
+ @dataclass
109
+ class ScriptOptions:
110
+ image_file: Path = field(positional=True)
111
+ model: str = field(default="vit")
112
+ gen_threshold: float = field(default=0.35)
113
+ char_threshold: float = field(default=0.75)
114
+
115
+ def main(opts: ScriptOptions):
116
+ repo_id = MODEL_REPO_MAP.get(opts.model)
117
+ image_path = Path(opts.image_file).resolve()
118
+ if not image_path.is_file():
119
+ raise FileNotFoundError(f"Image file not found: {image_path}")
120
+
121
+ print(f"Loading model '{opts.model}' from '{repo_id}'...")
122
+ model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval()
123
+ state_dict = timm.models.load_state_dict_from_hf(repo_id)
124
+ model.load_state_dict(state_dict)
125
+
126
+ print("Loading tag list...")
127
+ labels: LabelData = load_labels_hf(repo_id=repo_id)
128
+
129
+ print("Creating data transform...")
130
+ transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
131
+
132
+ print("Loading image and preprocessing...")
133
+ # get image
134
+ img_input: Image.Image = Image.open(image_path)
135
+ # ensure image is RGB
136
+ img_input = pil_ensure_rgb(img_input)
137
+ # pad to square with white background
138
+ img_input = pil_pad_square(img_input)
139
+ # run the model's input transform to convert to tensor and rescale
140
+ inputs: Tensor = transform(img_input).unsqueeze(0)
141
+ # NCHW image RGB to BGR
142
+ inputs = inputs[:, [2, 1, 0]]
143
+
144
+ print("Running inference...")
145
+ with torch.inference_mode():
146
+ # move model to GPU, if available
147
+ if torch_device.type != "cpu":
148
+ model = model.to(torch_device)
149
+ inputs = inputs.to(torch_device)
150
+ # run the model
151
+ outputs = model.forward(inputs)
152
+ # apply the final activation function (timm doesn't support doing this internally)
153
+ outputs = F.sigmoid(outputs)
154
+ # move inputs, outputs, và model về CPU nếu đang ở trên GPU
155
+ if torch_device.type != "cpu":
156
+ inputs = inputs.to("cpu")
157
+ outputs = outputs.to("cpu")
158
+ model = model.to("cpu")
159
+
160
+ print("Processing results...")
161
+ # Đọc giá trị từ config.json
162
+ with open('config.json', 'r') as config_file:
163
+ config_data = json.load(config_file)
164
+
165
+ gen_threshold = config_data.get('general_threshold', 0.35)
166
+ char_threshold = config_data.get('character_threshold', 0.75)
167
+
168
+ caption, taglist, ratings, character, general = get_tags(
169
+ probs=outputs.squeeze(0),
170
+ labels=labels,
171
+ gen_threshold=gen_threshold,
172
+ char_threshold=char_threshold,
173
+ )
174
+
175
+ print("--------")
176
+ print(f"Caption: {caption}")
177
+ print("--------")
178
+ print(f"Tags: {taglist}")
179
+
180
+ print("--------")
181
+ print("Ratings:")
182
+ for k, v in ratings.items():
183
+ print(f" {k}: {v:.3f}")
184
+
185
+ print("--------")
186
+ print(f"Character tags (threshold={char_threshold}):")
187
+ for k, v in character.items():
188
+ print(f" {k}: {v:.3f}")
189
+
190
+ print("--------")
191
+ print(f"General tags (threshold={gen_threshold}):")
192
+ for k, v in general.items():
193
+ print(f" {k}: {v:.3f}")
194
+
195
+ print("Done!")
196
+
197
+
198
+ if __name__ == "__main__":
199
+ opts, _ = parse_known_args(ScriptOptions)
200
+ if opts.model not in MODEL_REPO_MAP:
201
+ print(f"Available models: {list(MODEL_REPO_MAP.keys())}")
202
+ raise ValueError(f"Unknown model name '{opts.model}'")
203
+ main(opts)