An-619 zxairdeep commited on
Commit
a2b0f6f
·
verified ·
1 Parent(s): 1b74b5e

Upload 161 files (#5)

Browse files

- Upload 161 files (e5dd7056b4109c4492019ba3f6da55d8d5fb72f7)


Co-authored-by: zhaoxu <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ultralytics/.pre-commit-config.yaml +73 -0
  2. ultralytics/__init__.py +12 -0
  3. ultralytics/assets/bus.jpg +3 -0
  4. ultralytics/assets/zidane.jpg +3 -0
  5. ultralytics/datasets/Argoverse.yaml +73 -0
  6. ultralytics/datasets/GlobalWheat2020.yaml +54 -0
  7. ultralytics/datasets/ImageNet.yaml +2025 -0
  8. ultralytics/datasets/Objects365.yaml +443 -0
  9. ultralytics/datasets/SKU-110K.yaml +58 -0
  10. ultralytics/datasets/VOC.yaml +100 -0
  11. ultralytics/datasets/VisDrone.yaml +73 -0
  12. ultralytics/datasets/coco-pose.yaml +38 -0
  13. ultralytics/datasets/coco.yaml +115 -0
  14. ultralytics/datasets/coco128-seg.yaml +101 -0
  15. ultralytics/datasets/coco128.yaml +101 -0
  16. ultralytics/datasets/coco8-pose.yaml +25 -0
  17. ultralytics/datasets/coco8-seg.yaml +101 -0
  18. ultralytics/datasets/coco8.yaml +101 -0
  19. ultralytics/datasets/xView.yaml +153 -0
  20. ultralytics/hub/__init__.py +117 -0
  21. ultralytics/hub/auth.py +139 -0
  22. ultralytics/hub/session.py +189 -0
  23. ultralytics/hub/utils.py +217 -0
  24. ultralytics/models/README.md +45 -0
  25. ultralytics/models/rt-detr/rtdetr-l.yaml +50 -0
  26. ultralytics/models/rt-detr/rtdetr-x.yaml +54 -0
  27. ultralytics/models/v3/yolov3-spp.yaml +48 -0
  28. ultralytics/models/v3/yolov3-tiny.yaml +39 -0
  29. ultralytics/models/v3/yolov3.yaml +48 -0
  30. ultralytics/models/v5/yolov5-p6.yaml +61 -0
  31. ultralytics/models/v5/yolov5.yaml +50 -0
  32. ultralytics/models/v6/yolov6.yaml +53 -0
  33. ultralytics/models/v8/yolov8-cls.yaml +29 -0
  34. ultralytics/models/v8/yolov8-p2.yaml +54 -0
  35. ultralytics/models/v8/yolov8-p6.yaml +56 -0
  36. ultralytics/models/v8/yolov8-pose-p6.yaml +57 -0
  37. ultralytics/models/v8/yolov8-pose.yaml +47 -0
  38. ultralytics/models/v8/yolov8-rtdetr.yaml +46 -0
  39. ultralytics/models/v8/yolov8-seg.yaml +46 -0
  40. ultralytics/models/v8/yolov8.yaml +46 -0
  41. ultralytics/nn/__init__.py +9 -0
  42. ultralytics/nn/autobackend.py +455 -0
  43. ultralytics/nn/autoshape.py +244 -0
  44. ultralytics/nn/modules/__init__.py +29 -0
  45. ultralytics/nn/modules/block.py +304 -0
  46. ultralytics/nn/modules/conv.py +297 -0
  47. ultralytics/nn/modules/head.py +349 -0
  48. ultralytics/nn/modules/transformer.py +378 -0
  49. ultralytics/nn/modules/utils.py +78 -0
  50. ultralytics/nn/tasks.py +780 -0
ultralytics/.pre-commit-config.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # Pre-commit hooks. For more information see https://github.com/pre-commit/pre-commit-hooks/blob/main/README.md
3
+
4
+ exclude: 'docs/'
5
+ # Define bot property if installed via https://github.com/marketplace/pre-commit-ci
6
+ ci:
7
+ autofix_prs: true
8
+ autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
9
+ autoupdate_schedule: monthly
10
+ # submodules: true
11
+
12
+ repos:
13
+ - repo: https://github.com/pre-commit/pre-commit-hooks
14
+ rev: v4.4.0
15
+ hooks:
16
+ - id: end-of-file-fixer
17
+ - id: trailing-whitespace
18
+ - id: check-case-conflict
19
+ # - id: check-yaml
20
+ - id: check-docstring-first
21
+ - id: double-quote-string-fixer
22
+ - id: detect-private-key
23
+
24
+ - repo: https://github.com/asottile/pyupgrade
25
+ rev: v3.4.0
26
+ hooks:
27
+ - id: pyupgrade
28
+ name: Upgrade code
29
+
30
+ - repo: https://github.com/PyCQA/isort
31
+ rev: 5.12.0
32
+ hooks:
33
+ - id: isort
34
+ name: Sort imports
35
+
36
+ - repo: https://github.com/google/yapf
37
+ rev: v0.33.0
38
+ hooks:
39
+ - id: yapf
40
+ name: YAPF formatting
41
+
42
+ - repo: https://github.com/executablebooks/mdformat
43
+ rev: 0.7.16
44
+ hooks:
45
+ - id: mdformat
46
+ name: MD formatting
47
+ additional_dependencies:
48
+ - mdformat-gfm
49
+ - mdformat-black
50
+ # exclude: "README.md|README.zh-CN.md|CONTRIBUTING.md"
51
+
52
+ - repo: https://github.com/PyCQA/flake8
53
+ rev: 6.0.0
54
+ hooks:
55
+ - id: flake8
56
+ name: PEP8
57
+
58
+ - repo: https://github.com/codespell-project/codespell
59
+ rev: v2.2.4
60
+ hooks:
61
+ - id: codespell
62
+ args:
63
+ - --ignore-words-list=crate,nd,strack,dota
64
+
65
+ # - repo: https://github.com/asottile/yesqa
66
+ # rev: v1.4.0
67
+ # hooks:
68
+ # - id: yesqa
69
+
70
+ # - repo: https://github.com/asottile/dead
71
+ # rev: v1.5.0
72
+ # hooks:
73
+ # - id: dead
ultralytics/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ __version__ = '8.0.120'
4
+
5
+ from ultralytics.hub import start
6
+ from ultralytics.vit.rtdetr import RTDETR
7
+ from ultralytics.vit.sam import SAM
8
+ from ultralytics.yolo.engine.model import YOLO
9
+ from ultralytics.yolo.nas import NAS
10
+ from ultralytics.yolo.utils.checks import check_yolo as checks
11
+
12
+ __all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'RTDETR', 'checks', 'start' # allow simpler import
ultralytics/assets/bus.jpg ADDED

Git LFS Details

  • SHA256: c02019c4979c191eb739ddd944445ef408dad5679acab6fd520ef9d434bfbc63
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB
ultralytics/assets/zidane.jpg ADDED

Git LFS Details

  • SHA256: 16d73869e3267a7d4ed00de8e860833bd1657c1b252e94c0c348277adc7b6edb
  • Pointer size: 130 Bytes
  • Size of remote file: 50.4 kB
ultralytics/datasets/Argoverse.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/ by Argo AI
3
+ # Example usage: yolo train data=Argoverse.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── Argoverse ← downloads here (31.3 GB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/Argoverse # dataset root dir
12
+ train: Argoverse-1.1/images/train/ # train images (relative to 'path') 39384 images
13
+ val: Argoverse-1.1/images/val/ # val images (relative to 'path') 15062 images
14
+ test: Argoverse-1.1/images/test/ # test images (optional) https://eval.ai/web/challenges/challenge-page/800/overview
15
+
16
+ # Classes
17
+ names:
18
+ 0: person
19
+ 1: bicycle
20
+ 2: car
21
+ 3: motorcycle
22
+ 4: bus
23
+ 5: truck
24
+ 6: traffic_light
25
+ 7: stop_sign
26
+
27
+
28
+ # Download script/URL (optional) ---------------------------------------------------------------------------------------
29
+ download: |
30
+ import json
31
+ from tqdm import tqdm
32
+ from ultralytics.yolo.utils.downloads import download
33
+ from pathlib import Path
34
+
35
+ def argoverse2yolo(set):
36
+ labels = {}
37
+ a = json.load(open(set, "rb"))
38
+ for annot in tqdm(a['annotations'], desc=f"Converting {set} to YOLOv5 format..."):
39
+ img_id = annot['image_id']
40
+ img_name = a['images'][img_id]['name']
41
+ img_label_name = f'{img_name[:-3]}txt'
42
+
43
+ cls = annot['category_id'] # instance class id
44
+ x_center, y_center, width, height = annot['bbox']
45
+ x_center = (x_center + width / 2) / 1920.0 # offset and scale
46
+ y_center = (y_center + height / 2) / 1200.0 # offset and scale
47
+ width /= 1920.0 # scale
48
+ height /= 1200.0 # scale
49
+
50
+ img_dir = set.parents[2] / 'Argoverse-1.1' / 'labels' / a['seq_dirs'][a['images'][annot['image_id']]['sid']]
51
+ if not img_dir.exists():
52
+ img_dir.mkdir(parents=True, exist_ok=True)
53
+
54
+ k = str(img_dir / img_label_name)
55
+ if k not in labels:
56
+ labels[k] = []
57
+ labels[k].append(f"{cls} {x_center} {y_center} {width} {height}\n")
58
+
59
+ for k in labels:
60
+ with open(k, "w") as f:
61
+ f.writelines(labels[k])
62
+
63
+
64
+ # Download
65
+ dir = Path(yaml['path']) # dataset root dir
66
+ urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip']
67
+ download(urls, dir=dir)
68
+
69
+ # Convert
70
+ annotations_dir = 'Argoverse-HD/annotations/'
71
+ (dir / 'Argoverse-1.1' / 'tracking').rename(dir / 'Argoverse-1.1' / 'images') # rename 'tracking' to 'images'
72
+ for d in "train.json", "val.json":
73
+ argoverse2yolo(dir / annotations_dir / d) # convert VisDrone annotations to YOLO labels
ultralytics/datasets/GlobalWheat2020.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # Global Wheat 2020 dataset http://www.global-wheat.com/ by University of Saskatchewan
3
+ # Example usage: yolo train data=GlobalWheat2020.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── GlobalWheat2020 ← downloads here (7.0 GB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/GlobalWheat2020 # dataset root dir
12
+ train: # train images (relative to 'path') 3422 images
13
+ - images/arvalis_1
14
+ - images/arvalis_2
15
+ - images/arvalis_3
16
+ - images/ethz_1
17
+ - images/rres_1
18
+ - images/inrae_1
19
+ - images/usask_1
20
+ val: # val images (relative to 'path') 748 images (WARNING: train set contains ethz_1)
21
+ - images/ethz_1
22
+ test: # test images (optional) 1276 images
23
+ - images/utokyo_1
24
+ - images/utokyo_2
25
+ - images/nau_1
26
+ - images/uq_1
27
+
28
+ # Classes
29
+ names:
30
+ 0: wheat_head
31
+
32
+
33
+ # Download script/URL (optional) ---------------------------------------------------------------------------------------
34
+ download: |
35
+ from ultralytics.yolo.utils.downloads import download
36
+ from pathlib import Path
37
+
38
+ # Download
39
+ dir = Path(yaml['path']) # dataset root dir
40
+ urls = ['https://zenodo.org/record/4298502/files/global-wheat-codalab-official.zip',
41
+ 'https://github.com/ultralytics/yolov5/releases/download/v1.0/GlobalWheat2020_labels.zip']
42
+ download(urls, dir=dir)
43
+
44
+ # Make Directories
45
+ for p in 'annotations', 'images', 'labels':
46
+ (dir / p).mkdir(parents=True, exist_ok=True)
47
+
48
+ # Move
49
+ for p in 'arvalis_1', 'arvalis_2', 'arvalis_3', 'ethz_1', 'rres_1', 'inrae_1', 'usask_1', \
50
+ 'utokyo_1', 'utokyo_2', 'nau_1', 'uq_1':
51
+ (dir / 'global-wheat-codalab-official' / p).rename(dir / 'images' / p) # move to /images
52
+ f = (dir / 'global-wheat-codalab-official' / p).with_suffix('.json') # json file
53
+ if f.exists():
54
+ f.rename((dir / 'annotations' / p).with_suffix('.json')) # move to /annotations
ultralytics/datasets/ImageNet.yaml ADDED
@@ -0,0 +1,2025 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # ImageNet-1k dataset https://www.image-net.org/index.php by Stanford University
3
+ # Simplified class names from https://github.com/anishathalye/imagenet-simple-labels
4
+ # Example usage: yolo train task=classify data=imagenet
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── imagenet ← downloads here (144 GB)
9
+
10
+
11
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
12
+ path: ../datasets/imagenet # dataset root dir
13
+ train: train # train images (relative to 'path') 1281167 images
14
+ val: val # val images (relative to 'path') 50000 images
15
+ test: # test images (optional)
16
+
17
+ # Classes
18
+ names:
19
+ 0: tench
20
+ 1: goldfish
21
+ 2: great white shark
22
+ 3: tiger shark
23
+ 4: hammerhead shark
24
+ 5: electric ray
25
+ 6: stingray
26
+ 7: cock
27
+ 8: hen
28
+ 9: ostrich
29
+ 10: brambling
30
+ 11: goldfinch
31
+ 12: house finch
32
+ 13: junco
33
+ 14: indigo bunting
34
+ 15: American robin
35
+ 16: bulbul
36
+ 17: jay
37
+ 18: magpie
38
+ 19: chickadee
39
+ 20: American dipper
40
+ 21: kite
41
+ 22: bald eagle
42
+ 23: vulture
43
+ 24: great grey owl
44
+ 25: fire salamander
45
+ 26: smooth newt
46
+ 27: newt
47
+ 28: spotted salamander
48
+ 29: axolotl
49
+ 30: American bullfrog
50
+ 31: tree frog
51
+ 32: tailed frog
52
+ 33: loggerhead sea turtle
53
+ 34: leatherback sea turtle
54
+ 35: mud turtle
55
+ 36: terrapin
56
+ 37: box turtle
57
+ 38: banded gecko
58
+ 39: green iguana
59
+ 40: Carolina anole
60
+ 41: desert grassland whiptail lizard
61
+ 42: agama
62
+ 43: frilled-necked lizard
63
+ 44: alligator lizard
64
+ 45: Gila monster
65
+ 46: European green lizard
66
+ 47: chameleon
67
+ 48: Komodo dragon
68
+ 49: Nile crocodile
69
+ 50: American alligator
70
+ 51: triceratops
71
+ 52: worm snake
72
+ 53: ring-necked snake
73
+ 54: eastern hog-nosed snake
74
+ 55: smooth green snake
75
+ 56: kingsnake
76
+ 57: garter snake
77
+ 58: water snake
78
+ 59: vine snake
79
+ 60: night snake
80
+ 61: boa constrictor
81
+ 62: African rock python
82
+ 63: Indian cobra
83
+ 64: green mamba
84
+ 65: sea snake
85
+ 66: Saharan horned viper
86
+ 67: eastern diamondback rattlesnake
87
+ 68: sidewinder
88
+ 69: trilobite
89
+ 70: harvestman
90
+ 71: scorpion
91
+ 72: yellow garden spider
92
+ 73: barn spider
93
+ 74: European garden spider
94
+ 75: southern black widow
95
+ 76: tarantula
96
+ 77: wolf spider
97
+ 78: tick
98
+ 79: centipede
99
+ 80: black grouse
100
+ 81: ptarmigan
101
+ 82: ruffed grouse
102
+ 83: prairie grouse
103
+ 84: peacock
104
+ 85: quail
105
+ 86: partridge
106
+ 87: grey parrot
107
+ 88: macaw
108
+ 89: sulphur-crested cockatoo
109
+ 90: lorikeet
110
+ 91: coucal
111
+ 92: bee eater
112
+ 93: hornbill
113
+ 94: hummingbird
114
+ 95: jacamar
115
+ 96: toucan
116
+ 97: duck
117
+ 98: red-breasted merganser
118
+ 99: goose
119
+ 100: black swan
120
+ 101: tusker
121
+ 102: echidna
122
+ 103: platypus
123
+ 104: wallaby
124
+ 105: koala
125
+ 106: wombat
126
+ 107: jellyfish
127
+ 108: sea anemone
128
+ 109: brain coral
129
+ 110: flatworm
130
+ 111: nematode
131
+ 112: conch
132
+ 113: snail
133
+ 114: slug
134
+ 115: sea slug
135
+ 116: chiton
136
+ 117: chambered nautilus
137
+ 118: Dungeness crab
138
+ 119: rock crab
139
+ 120: fiddler crab
140
+ 121: red king crab
141
+ 122: American lobster
142
+ 123: spiny lobster
143
+ 124: crayfish
144
+ 125: hermit crab
145
+ 126: isopod
146
+ 127: white stork
147
+ 128: black stork
148
+ 129: spoonbill
149
+ 130: flamingo
150
+ 131: little blue heron
151
+ 132: great egret
152
+ 133: bittern
153
+ 134: crane (bird)
154
+ 135: limpkin
155
+ 136: common gallinule
156
+ 137: American coot
157
+ 138: bustard
158
+ 139: ruddy turnstone
159
+ 140: dunlin
160
+ 141: common redshank
161
+ 142: dowitcher
162
+ 143: oystercatcher
163
+ 144: pelican
164
+ 145: king penguin
165
+ 146: albatross
166
+ 147: grey whale
167
+ 148: killer whale
168
+ 149: dugong
169
+ 150: sea lion
170
+ 151: Chihuahua
171
+ 152: Japanese Chin
172
+ 153: Maltese
173
+ 154: Pekingese
174
+ 155: Shih Tzu
175
+ 156: King Charles Spaniel
176
+ 157: Papillon
177
+ 158: toy terrier
178
+ 159: Rhodesian Ridgeback
179
+ 160: Afghan Hound
180
+ 161: Basset Hound
181
+ 162: Beagle
182
+ 163: Bloodhound
183
+ 164: Bluetick Coonhound
184
+ 165: Black and Tan Coonhound
185
+ 166: Treeing Walker Coonhound
186
+ 167: English foxhound
187
+ 168: Redbone Coonhound
188
+ 169: borzoi
189
+ 170: Irish Wolfhound
190
+ 171: Italian Greyhound
191
+ 172: Whippet
192
+ 173: Ibizan Hound
193
+ 174: Norwegian Elkhound
194
+ 175: Otterhound
195
+ 176: Saluki
196
+ 177: Scottish Deerhound
197
+ 178: Weimaraner
198
+ 179: Staffordshire Bull Terrier
199
+ 180: American Staffordshire Terrier
200
+ 181: Bedlington Terrier
201
+ 182: Border Terrier
202
+ 183: Kerry Blue Terrier
203
+ 184: Irish Terrier
204
+ 185: Norfolk Terrier
205
+ 186: Norwich Terrier
206
+ 187: Yorkshire Terrier
207
+ 188: Wire Fox Terrier
208
+ 189: Lakeland Terrier
209
+ 190: Sealyham Terrier
210
+ 191: Airedale Terrier
211
+ 192: Cairn Terrier
212
+ 193: Australian Terrier
213
+ 194: Dandie Dinmont Terrier
214
+ 195: Boston Terrier
215
+ 196: Miniature Schnauzer
216
+ 197: Giant Schnauzer
217
+ 198: Standard Schnauzer
218
+ 199: Scottish Terrier
219
+ 200: Tibetan Terrier
220
+ 201: Australian Silky Terrier
221
+ 202: Soft-coated Wheaten Terrier
222
+ 203: West Highland White Terrier
223
+ 204: Lhasa Apso
224
+ 205: Flat-Coated Retriever
225
+ 206: Curly-coated Retriever
226
+ 207: Golden Retriever
227
+ 208: Labrador Retriever
228
+ 209: Chesapeake Bay Retriever
229
+ 210: German Shorthaired Pointer
230
+ 211: Vizsla
231
+ 212: English Setter
232
+ 213: Irish Setter
233
+ 214: Gordon Setter
234
+ 215: Brittany
235
+ 216: Clumber Spaniel
236
+ 217: English Springer Spaniel
237
+ 218: Welsh Springer Spaniel
238
+ 219: Cocker Spaniels
239
+ 220: Sussex Spaniel
240
+ 221: Irish Water Spaniel
241
+ 222: Kuvasz
242
+ 223: Schipperke
243
+ 224: Groenendael
244
+ 225: Malinois
245
+ 226: Briard
246
+ 227: Australian Kelpie
247
+ 228: Komondor
248
+ 229: Old English Sheepdog
249
+ 230: Shetland Sheepdog
250
+ 231: collie
251
+ 232: Border Collie
252
+ 233: Bouvier des Flandres
253
+ 234: Rottweiler
254
+ 235: German Shepherd Dog
255
+ 236: Dobermann
256
+ 237: Miniature Pinscher
257
+ 238: Greater Swiss Mountain Dog
258
+ 239: Bernese Mountain Dog
259
+ 240: Appenzeller Sennenhund
260
+ 241: Entlebucher Sennenhund
261
+ 242: Boxer
262
+ 243: Bullmastiff
263
+ 244: Tibetan Mastiff
264
+ 245: French Bulldog
265
+ 246: Great Dane
266
+ 247: St. Bernard
267
+ 248: husky
268
+ 249: Alaskan Malamute
269
+ 250: Siberian Husky
270
+ 251: Dalmatian
271
+ 252: Affenpinscher
272
+ 253: Basenji
273
+ 254: pug
274
+ 255: Leonberger
275
+ 256: Newfoundland
276
+ 257: Pyrenean Mountain Dog
277
+ 258: Samoyed
278
+ 259: Pomeranian
279
+ 260: Chow Chow
280
+ 261: Keeshond
281
+ 262: Griffon Bruxellois
282
+ 263: Pembroke Welsh Corgi
283
+ 264: Cardigan Welsh Corgi
284
+ 265: Toy Poodle
285
+ 266: Miniature Poodle
286
+ 267: Standard Poodle
287
+ 268: Mexican hairless dog
288
+ 269: grey wolf
289
+ 270: Alaskan tundra wolf
290
+ 271: red wolf
291
+ 272: coyote
292
+ 273: dingo
293
+ 274: dhole
294
+ 275: African wild dog
295
+ 276: hyena
296
+ 277: red fox
297
+ 278: kit fox
298
+ 279: Arctic fox
299
+ 280: grey fox
300
+ 281: tabby cat
301
+ 282: tiger cat
302
+ 283: Persian cat
303
+ 284: Siamese cat
304
+ 285: Egyptian Mau
305
+ 286: cougar
306
+ 287: lynx
307
+ 288: leopard
308
+ 289: snow leopard
309
+ 290: jaguar
310
+ 291: lion
311
+ 292: tiger
312
+ 293: cheetah
313
+ 294: brown bear
314
+ 295: American black bear
315
+ 296: polar bear
316
+ 297: sloth bear
317
+ 298: mongoose
318
+ 299: meerkat
319
+ 300: tiger beetle
320
+ 301: ladybug
321
+ 302: ground beetle
322
+ 303: longhorn beetle
323
+ 304: leaf beetle
324
+ 305: dung beetle
325
+ 306: rhinoceros beetle
326
+ 307: weevil
327
+ 308: fly
328
+ 309: bee
329
+ 310: ant
330
+ 311: grasshopper
331
+ 312: cricket
332
+ 313: stick insect
333
+ 314: cockroach
334
+ 315: mantis
335
+ 316: cicada
336
+ 317: leafhopper
337
+ 318: lacewing
338
+ 319: dragonfly
339
+ 320: damselfly
340
+ 321: red admiral
341
+ 322: ringlet
342
+ 323: monarch butterfly
343
+ 324: small white
344
+ 325: sulphur butterfly
345
+ 326: gossamer-winged butterfly
346
+ 327: starfish
347
+ 328: sea urchin
348
+ 329: sea cucumber
349
+ 330: cottontail rabbit
350
+ 331: hare
351
+ 332: Angora rabbit
352
+ 333: hamster
353
+ 334: porcupine
354
+ 335: fox squirrel
355
+ 336: marmot
356
+ 337: beaver
357
+ 338: guinea pig
358
+ 339: common sorrel
359
+ 340: zebra
360
+ 341: pig
361
+ 342: wild boar
362
+ 343: warthog
363
+ 344: hippopotamus
364
+ 345: ox
365
+ 346: water buffalo
366
+ 347: bison
367
+ 348: ram
368
+ 349: bighorn sheep
369
+ 350: Alpine ibex
370
+ 351: hartebeest
371
+ 352: impala
372
+ 353: gazelle
373
+ 354: dromedary
374
+ 355: llama
375
+ 356: weasel
376
+ 357: mink
377
+ 358: European polecat
378
+ 359: black-footed ferret
379
+ 360: otter
380
+ 361: skunk
381
+ 362: badger
382
+ 363: armadillo
383
+ 364: three-toed sloth
384
+ 365: orangutan
385
+ 366: gorilla
386
+ 367: chimpanzee
387
+ 368: gibbon
388
+ 369: siamang
389
+ 370: guenon
390
+ 371: patas monkey
391
+ 372: baboon
392
+ 373: macaque
393
+ 374: langur
394
+ 375: black-and-white colobus
395
+ 376: proboscis monkey
396
+ 377: marmoset
397
+ 378: white-headed capuchin
398
+ 379: howler monkey
399
+ 380: titi
400
+ 381: Geoffroy's spider monkey
401
+ 382: common squirrel monkey
402
+ 383: ring-tailed lemur
403
+ 384: indri
404
+ 385: Asian elephant
405
+ 386: African bush elephant
406
+ 387: red panda
407
+ 388: giant panda
408
+ 389: snoek
409
+ 390: eel
410
+ 391: coho salmon
411
+ 392: rock beauty
412
+ 393: clownfish
413
+ 394: sturgeon
414
+ 395: garfish
415
+ 396: lionfish
416
+ 397: pufferfish
417
+ 398: abacus
418
+ 399: abaya
419
+ 400: academic gown
420
+ 401: accordion
421
+ 402: acoustic guitar
422
+ 403: aircraft carrier
423
+ 404: airliner
424
+ 405: airship
425
+ 406: altar
426
+ 407: ambulance
427
+ 408: amphibious vehicle
428
+ 409: analog clock
429
+ 410: apiary
430
+ 411: apron
431
+ 412: waste container
432
+ 413: assault rifle
433
+ 414: backpack
434
+ 415: bakery
435
+ 416: balance beam
436
+ 417: balloon
437
+ 418: ballpoint pen
438
+ 419: Band-Aid
439
+ 420: banjo
440
+ 421: baluster
441
+ 422: barbell
442
+ 423: barber chair
443
+ 424: barbershop
444
+ 425: barn
445
+ 426: barometer
446
+ 427: barrel
447
+ 428: wheelbarrow
448
+ 429: baseball
449
+ 430: basketball
450
+ 431: bassinet
451
+ 432: bassoon
452
+ 433: swimming cap
453
+ 434: bath towel
454
+ 435: bathtub
455
+ 436: station wagon
456
+ 437: lighthouse
457
+ 438: beaker
458
+ 439: military cap
459
+ 440: beer bottle
460
+ 441: beer glass
461
+ 442: bell-cot
462
+ 443: bib
463
+ 444: tandem bicycle
464
+ 445: bikini
465
+ 446: ring binder
466
+ 447: binoculars
467
+ 448: birdhouse
468
+ 449: boathouse
469
+ 450: bobsleigh
470
+ 451: bolo tie
471
+ 452: poke bonnet
472
+ 453: bookcase
473
+ 454: bookstore
474
+ 455: bottle cap
475
+ 456: bow
476
+ 457: bow tie
477
+ 458: brass
478
+ 459: bra
479
+ 460: breakwater
480
+ 461: breastplate
481
+ 462: broom
482
+ 463: bucket
483
+ 464: buckle
484
+ 465: bulletproof vest
485
+ 466: high-speed train
486
+ 467: butcher shop
487
+ 468: taxicab
488
+ 469: cauldron
489
+ 470: candle
490
+ 471: cannon
491
+ 472: canoe
492
+ 473: can opener
493
+ 474: cardigan
494
+ 475: car mirror
495
+ 476: carousel
496
+ 477: tool kit
497
+ 478: carton
498
+ 479: car wheel
499
+ 480: automated teller machine
500
+ 481: cassette
501
+ 482: cassette player
502
+ 483: castle
503
+ 484: catamaran
504
+ 485: CD player
505
+ 486: cello
506
+ 487: mobile phone
507
+ 488: chain
508
+ 489: chain-link fence
509
+ 490: chain mail
510
+ 491: chainsaw
511
+ 492: chest
512
+ 493: chiffonier
513
+ 494: chime
514
+ 495: china cabinet
515
+ 496: Christmas stocking
516
+ 497: church
517
+ 498: movie theater
518
+ 499: cleaver
519
+ 500: cliff dwelling
520
+ 501: cloak
521
+ 502: clogs
522
+ 503: cocktail shaker
523
+ 504: coffee mug
524
+ 505: coffeemaker
525
+ 506: coil
526
+ 507: combination lock
527
+ 508: computer keyboard
528
+ 509: confectionery store
529
+ 510: container ship
530
+ 511: convertible
531
+ 512: corkscrew
532
+ 513: cornet
533
+ 514: cowboy boot
534
+ 515: cowboy hat
535
+ 516: cradle
536
+ 517: crane (machine)
537
+ 518: crash helmet
538
+ 519: crate
539
+ 520: infant bed
540
+ 521: Crock Pot
541
+ 522: croquet ball
542
+ 523: crutch
543
+ 524: cuirass
544
+ 525: dam
545
+ 526: desk
546
+ 527: desktop computer
547
+ 528: rotary dial telephone
548
+ 529: diaper
549
+ 530: digital clock
550
+ 531: digital watch
551
+ 532: dining table
552
+ 533: dishcloth
553
+ 534: dishwasher
554
+ 535: disc brake
555
+ 536: dock
556
+ 537: dog sled
557
+ 538: dome
558
+ 539: doormat
559
+ 540: drilling rig
560
+ 541: drum
561
+ 542: drumstick
562
+ 543: dumbbell
563
+ 544: Dutch oven
564
+ 545: electric fan
565
+ 546: electric guitar
566
+ 547: electric locomotive
567
+ 548: entertainment center
568
+ 549: envelope
569
+ 550: espresso machine
570
+ 551: face powder
571
+ 552: feather boa
572
+ 553: filing cabinet
573
+ 554: fireboat
574
+ 555: fire engine
575
+ 556: fire screen sheet
576
+ 557: flagpole
577
+ 558: flute
578
+ 559: folding chair
579
+ 560: football helmet
580
+ 561: forklift
581
+ 562: fountain
582
+ 563: fountain pen
583
+ 564: four-poster bed
584
+ 565: freight car
585
+ 566: French horn
586
+ 567: frying pan
587
+ 568: fur coat
588
+ 569: garbage truck
589
+ 570: gas mask
590
+ 571: gas pump
591
+ 572: goblet
592
+ 573: go-kart
593
+ 574: golf ball
594
+ 575: golf cart
595
+ 576: gondola
596
+ 577: gong
597
+ 578: gown
598
+ 579: grand piano
599
+ 580: greenhouse
600
+ 581: grille
601
+ 582: grocery store
602
+ 583: guillotine
603
+ 584: barrette
604
+ 585: hair spray
605
+ 586: half-track
606
+ 587: hammer
607
+ 588: hamper
608
+ 589: hair dryer
609
+ 590: hand-held computer
610
+ 591: handkerchief
611
+ 592: hard disk drive
612
+ 593: harmonica
613
+ 594: harp
614
+ 595: harvester
615
+ 596: hatchet
616
+ 597: holster
617
+ 598: home theater
618
+ 599: honeycomb
619
+ 600: hook
620
+ 601: hoop skirt
621
+ 602: horizontal bar
622
+ 603: horse-drawn vehicle
623
+ 604: hourglass
624
+ 605: iPod
625
+ 606: clothes iron
626
+ 607: jack-o'-lantern
627
+ 608: jeans
628
+ 609: jeep
629
+ 610: T-shirt
630
+ 611: jigsaw puzzle
631
+ 612: pulled rickshaw
632
+ 613: joystick
633
+ 614: kimono
634
+ 615: knee pad
635
+ 616: knot
636
+ 617: lab coat
637
+ 618: ladle
638
+ 619: lampshade
639
+ 620: laptop computer
640
+ 621: lawn mower
641
+ 622: lens cap
642
+ 623: paper knife
643
+ 624: library
644
+ 625: lifeboat
645
+ 626: lighter
646
+ 627: limousine
647
+ 628: ocean liner
648
+ 629: lipstick
649
+ 630: slip-on shoe
650
+ 631: lotion
651
+ 632: speaker
652
+ 633: loupe
653
+ 634: sawmill
654
+ 635: magnetic compass
655
+ 636: mail bag
656
+ 637: mailbox
657
+ 638: tights
658
+ 639: tank suit
659
+ 640: manhole cover
660
+ 641: maraca
661
+ 642: marimba
662
+ 643: mask
663
+ 644: match
664
+ 645: maypole
665
+ 646: maze
666
+ 647: measuring cup
667
+ 648: medicine chest
668
+ 649: megalith
669
+ 650: microphone
670
+ 651: microwave oven
671
+ 652: military uniform
672
+ 653: milk can
673
+ 654: minibus
674
+ 655: miniskirt
675
+ 656: minivan
676
+ 657: missile
677
+ 658: mitten
678
+ 659: mixing bowl
679
+ 660: mobile home
680
+ 661: Model T
681
+ 662: modem
682
+ 663: monastery
683
+ 664: monitor
684
+ 665: moped
685
+ 666: mortar
686
+ 667: square academic cap
687
+ 668: mosque
688
+ 669: mosquito net
689
+ 670: scooter
690
+ 671: mountain bike
691
+ 672: tent
692
+ 673: computer mouse
693
+ 674: mousetrap
694
+ 675: moving van
695
+ 676: muzzle
696
+ 677: nail
697
+ 678: neck brace
698
+ 679: necklace
699
+ 680: nipple
700
+ 681: notebook computer
701
+ 682: obelisk
702
+ 683: oboe
703
+ 684: ocarina
704
+ 685: odometer
705
+ 686: oil filter
706
+ 687: organ
707
+ 688: oscilloscope
708
+ 689: overskirt
709
+ 690: bullock cart
710
+ 691: oxygen mask
711
+ 692: packet
712
+ 693: paddle
713
+ 694: paddle wheel
714
+ 695: padlock
715
+ 696: paintbrush
716
+ 697: pajamas
717
+ 698: palace
718
+ 699: pan flute
719
+ 700: paper towel
720
+ 701: parachute
721
+ 702: parallel bars
722
+ 703: park bench
723
+ 704: parking meter
724
+ 705: passenger car
725
+ 706: patio
726
+ 707: payphone
727
+ 708: pedestal
728
+ 709: pencil case
729
+ 710: pencil sharpener
730
+ 711: perfume
731
+ 712: Petri dish
732
+ 713: photocopier
733
+ 714: plectrum
734
+ 715: Pickelhaube
735
+ 716: picket fence
736
+ 717: pickup truck
737
+ 718: pier
738
+ 719: piggy bank
739
+ 720: pill bottle
740
+ 721: pillow
741
+ 722: ping-pong ball
742
+ 723: pinwheel
743
+ 724: pirate ship
744
+ 725: pitcher
745
+ 726: hand plane
746
+ 727: planetarium
747
+ 728: plastic bag
748
+ 729: plate rack
749
+ 730: plow
750
+ 731: plunger
751
+ 732: Polaroid camera
752
+ 733: pole
753
+ 734: police van
754
+ 735: poncho
755
+ 736: billiard table
756
+ 737: soda bottle
757
+ 738: pot
758
+ 739: potter's wheel
759
+ 740: power drill
760
+ 741: prayer rug
761
+ 742: printer
762
+ 743: prison
763
+ 744: projectile
764
+ 745: projector
765
+ 746: hockey puck
766
+ 747: punching bag
767
+ 748: purse
768
+ 749: quill
769
+ 750: quilt
770
+ 751: race car
771
+ 752: racket
772
+ 753: radiator
773
+ 754: radio
774
+ 755: radio telescope
775
+ 756: rain barrel
776
+ 757: recreational vehicle
777
+ 758: reel
778
+ 759: reflex camera
779
+ 760: refrigerator
780
+ 761: remote control
781
+ 762: restaurant
782
+ 763: revolver
783
+ 764: rifle
784
+ 765: rocking chair
785
+ 766: rotisserie
786
+ 767: eraser
787
+ 768: rugby ball
788
+ 769: ruler
789
+ 770: running shoe
790
+ 771: safe
791
+ 772: safety pin
792
+ 773: salt shaker
793
+ 774: sandal
794
+ 775: sarong
795
+ 776: saxophone
796
+ 777: scabbard
797
+ 778: weighing scale
798
+ 779: school bus
799
+ 780: schooner
800
+ 781: scoreboard
801
+ 782: CRT screen
802
+ 783: screw
803
+ 784: screwdriver
804
+ 785: seat belt
805
+ 786: sewing machine
806
+ 787: shield
807
+ 788: shoe store
808
+ 789: shoji
809
+ 790: shopping basket
810
+ 791: shopping cart
811
+ 792: shovel
812
+ 793: shower cap
813
+ 794: shower curtain
814
+ 795: ski
815
+ 796: ski mask
816
+ 797: sleeping bag
817
+ 798: slide rule
818
+ 799: sliding door
819
+ 800: slot machine
820
+ 801: snorkel
821
+ 802: snowmobile
822
+ 803: snowplow
823
+ 804: soap dispenser
824
+ 805: soccer ball
825
+ 806: sock
826
+ 807: solar thermal collector
827
+ 808: sombrero
828
+ 809: soup bowl
829
+ 810: space bar
830
+ 811: space heater
831
+ 812: space shuttle
832
+ 813: spatula
833
+ 814: motorboat
834
+ 815: spider web
835
+ 816: spindle
836
+ 817: sports car
837
+ 818: spotlight
838
+ 819: stage
839
+ 820: steam locomotive
840
+ 821: through arch bridge
841
+ 822: steel drum
842
+ 823: stethoscope
843
+ 824: scarf
844
+ 825: stone wall
845
+ 826: stopwatch
846
+ 827: stove
847
+ 828: strainer
848
+ 829: tram
849
+ 830: stretcher
850
+ 831: couch
851
+ 832: stupa
852
+ 833: submarine
853
+ 834: suit
854
+ 835: sundial
855
+ 836: sunglass
856
+ 837: sunglasses
857
+ 838: sunscreen
858
+ 839: suspension bridge
859
+ 840: mop
860
+ 841: sweatshirt
861
+ 842: swimsuit
862
+ 843: swing
863
+ 844: switch
864
+ 845: syringe
865
+ 846: table lamp
866
+ 847: tank
867
+ 848: tape player
868
+ 849: teapot
869
+ 850: teddy bear
870
+ 851: television
871
+ 852: tennis ball
872
+ 853: thatched roof
873
+ 854: front curtain
874
+ 855: thimble
875
+ 856: threshing machine
876
+ 857: throne
877
+ 858: tile roof
878
+ 859: toaster
879
+ 860: tobacco shop
880
+ 861: toilet seat
881
+ 862: torch
882
+ 863: totem pole
883
+ 864: tow truck
884
+ 865: toy store
885
+ 866: tractor
886
+ 867: semi-trailer truck
887
+ 868: tray
888
+ 869: trench coat
889
+ 870: tricycle
890
+ 871: trimaran
891
+ 872: tripod
892
+ 873: triumphal arch
893
+ 874: trolleybus
894
+ 875: trombone
895
+ 876: tub
896
+ 877: turnstile
897
+ 878: typewriter keyboard
898
+ 879: umbrella
899
+ 880: unicycle
900
+ 881: upright piano
901
+ 882: vacuum cleaner
902
+ 883: vase
903
+ 884: vault
904
+ 885: velvet
905
+ 886: vending machine
906
+ 887: vestment
907
+ 888: viaduct
908
+ 889: violin
909
+ 890: volleyball
910
+ 891: waffle iron
911
+ 892: wall clock
912
+ 893: wallet
913
+ 894: wardrobe
914
+ 895: military aircraft
915
+ 896: sink
916
+ 897: washing machine
917
+ 898: water bottle
918
+ 899: water jug
919
+ 900: water tower
920
+ 901: whiskey jug
921
+ 902: whistle
922
+ 903: wig
923
+ 904: window screen
924
+ 905: window shade
925
+ 906: Windsor tie
926
+ 907: wine bottle
927
+ 908: wing
928
+ 909: wok
929
+ 910: wooden spoon
930
+ 911: wool
931
+ 912: split-rail fence
932
+ 913: shipwreck
933
+ 914: yawl
934
+ 915: yurt
935
+ 916: website
936
+ 917: comic book
937
+ 918: crossword
938
+ 919: traffic sign
939
+ 920: traffic light
940
+ 921: dust jacket
941
+ 922: menu
942
+ 923: plate
943
+ 924: guacamole
944
+ 925: consomme
945
+ 926: hot pot
946
+ 927: trifle
947
+ 928: ice cream
948
+ 929: ice pop
949
+ 930: baguette
950
+ 931: bagel
951
+ 932: pretzel
952
+ 933: cheeseburger
953
+ 934: hot dog
954
+ 935: mashed potato
955
+ 936: cabbage
956
+ 937: broccoli
957
+ 938: cauliflower
958
+ 939: zucchini
959
+ 940: spaghetti squash
960
+ 941: acorn squash
961
+ 942: butternut squash
962
+ 943: cucumber
963
+ 944: artichoke
964
+ 945: bell pepper
965
+ 946: cardoon
966
+ 947: mushroom
967
+ 948: Granny Smith
968
+ 949: strawberry
969
+ 950: orange
970
+ 951: lemon
971
+ 952: fig
972
+ 953: pineapple
973
+ 954: banana
974
+ 955: jackfruit
975
+ 956: custard apple
976
+ 957: pomegranate
977
+ 958: hay
978
+ 959: carbonara
979
+ 960: chocolate syrup
980
+ 961: dough
981
+ 962: meatloaf
982
+ 963: pizza
983
+ 964: pot pie
984
+ 965: burrito
985
+ 966: red wine
986
+ 967: espresso
987
+ 968: cup
988
+ 969: eggnog
989
+ 970: alp
990
+ 971: bubble
991
+ 972: cliff
992
+ 973: coral reef
993
+ 974: geyser
994
+ 975: lakeshore
995
+ 976: promontory
996
+ 977: shoal
997
+ 978: seashore
998
+ 979: valley
999
+ 980: volcano
1000
+ 981: baseball player
1001
+ 982: bridegroom
1002
+ 983: scuba diver
1003
+ 984: rapeseed
1004
+ 985: daisy
1005
+ 986: yellow lady's slipper
1006
+ 987: corn
1007
+ 988: acorn
1008
+ 989: rose hip
1009
+ 990: horse chestnut seed
1010
+ 991: coral fungus
1011
+ 992: agaric
1012
+ 993: gyromitra
1013
+ 994: stinkhorn mushroom
1014
+ 995: earth star
1015
+ 996: hen-of-the-woods
1016
+ 997: bolete
1017
+ 998: ear
1018
+ 999: toilet paper
1019
+
1020
+ # Imagenet class codes to human-readable names
1021
+ map:
1022
+ n01440764: tench
1023
+ n01443537: goldfish
1024
+ n01484850: great_white_shark
1025
+ n01491361: tiger_shark
1026
+ n01494475: hammerhead
1027
+ n01496331: electric_ray
1028
+ n01498041: stingray
1029
+ n01514668: cock
1030
+ n01514859: hen
1031
+ n01518878: ostrich
1032
+ n01530575: brambling
1033
+ n01531178: goldfinch
1034
+ n01532829: house_finch
1035
+ n01534433: junco
1036
+ n01537544: indigo_bunting
1037
+ n01558993: robin
1038
+ n01560419: bulbul
1039
+ n01580077: jay
1040
+ n01582220: magpie
1041
+ n01592084: chickadee
1042
+ n01601694: water_ouzel
1043
+ n01608432: kite
1044
+ n01614925: bald_eagle
1045
+ n01616318: vulture
1046
+ n01622779: great_grey_owl
1047
+ n01629819: European_fire_salamander
1048
+ n01630670: common_newt
1049
+ n01631663: eft
1050
+ n01632458: spotted_salamander
1051
+ n01632777: axolotl
1052
+ n01641577: bullfrog
1053
+ n01644373: tree_frog
1054
+ n01644900: tailed_frog
1055
+ n01664065: loggerhead
1056
+ n01665541: leatherback_turtle
1057
+ n01667114: mud_turtle
1058
+ n01667778: terrapin
1059
+ n01669191: box_turtle
1060
+ n01675722: banded_gecko
1061
+ n01677366: common_iguana
1062
+ n01682714: American_chameleon
1063
+ n01685808: whiptail
1064
+ n01687978: agama
1065
+ n01688243: frilled_lizard
1066
+ n01689811: alligator_lizard
1067
+ n01692333: Gila_monster
1068
+ n01693334: green_lizard
1069
+ n01694178: African_chameleon
1070
+ n01695060: Komodo_dragon
1071
+ n01697457: African_crocodile
1072
+ n01698640: American_alligator
1073
+ n01704323: triceratops
1074
+ n01728572: thunder_snake
1075
+ n01728920: ringneck_snake
1076
+ n01729322: hognose_snake
1077
+ n01729977: green_snake
1078
+ n01734418: king_snake
1079
+ n01735189: garter_snake
1080
+ n01737021: water_snake
1081
+ n01739381: vine_snake
1082
+ n01740131: night_snake
1083
+ n01742172: boa_constrictor
1084
+ n01744401: rock_python
1085
+ n01748264: Indian_cobra
1086
+ n01749939: green_mamba
1087
+ n01751748: sea_snake
1088
+ n01753488: horned_viper
1089
+ n01755581: diamondback
1090
+ n01756291: sidewinder
1091
+ n01768244: trilobite
1092
+ n01770081: harvestman
1093
+ n01770393: scorpion
1094
+ n01773157: black_and_gold_garden_spider
1095
+ n01773549: barn_spider
1096
+ n01773797: garden_spider
1097
+ n01774384: black_widow
1098
+ n01774750: tarantula
1099
+ n01775062: wolf_spider
1100
+ n01776313: tick
1101
+ n01784675: centipede
1102
+ n01795545: black_grouse
1103
+ n01796340: ptarmigan
1104
+ n01797886: ruffed_grouse
1105
+ n01798484: prairie_chicken
1106
+ n01806143: peacock
1107
+ n01806567: quail
1108
+ n01807496: partridge
1109
+ n01817953: African_grey
1110
+ n01818515: macaw
1111
+ n01819313: sulphur-crested_cockatoo
1112
+ n01820546: lorikeet
1113
+ n01824575: coucal
1114
+ n01828970: bee_eater
1115
+ n01829413: hornbill
1116
+ n01833805: hummingbird
1117
+ n01843065: jacamar
1118
+ n01843383: toucan
1119
+ n01847000: drake
1120
+ n01855032: red-breasted_merganser
1121
+ n01855672: goose
1122
+ n01860187: black_swan
1123
+ n01871265: tusker
1124
+ n01872401: echidna
1125
+ n01873310: platypus
1126
+ n01877812: wallaby
1127
+ n01882714: koala
1128
+ n01883070: wombat
1129
+ n01910747: jellyfish
1130
+ n01914609: sea_anemone
1131
+ n01917289: brain_coral
1132
+ n01924916: flatworm
1133
+ n01930112: nematode
1134
+ n01943899: conch
1135
+ n01944390: snail
1136
+ n01945685: slug
1137
+ n01950731: sea_slug
1138
+ n01955084: chiton
1139
+ n01968897: chambered_nautilus
1140
+ n01978287: Dungeness_crab
1141
+ n01978455: rock_crab
1142
+ n01980166: fiddler_crab
1143
+ n01981276: king_crab
1144
+ n01983481: American_lobster
1145
+ n01984695: spiny_lobster
1146
+ n01985128: crayfish
1147
+ n01986214: hermit_crab
1148
+ n01990800: isopod
1149
+ n02002556: white_stork
1150
+ n02002724: black_stork
1151
+ n02006656: spoonbill
1152
+ n02007558: flamingo
1153
+ n02009229: little_blue_heron
1154
+ n02009912: American_egret
1155
+ n02011460: bittern
1156
+ n02012849: crane_(bird)
1157
+ n02013706: limpkin
1158
+ n02017213: European_gallinule
1159
+ n02018207: American_coot
1160
+ n02018795: bustard
1161
+ n02025239: ruddy_turnstone
1162
+ n02027492: red-backed_sandpiper
1163
+ n02028035: redshank
1164
+ n02033041: dowitcher
1165
+ n02037110: oystercatcher
1166
+ n02051845: pelican
1167
+ n02056570: king_penguin
1168
+ n02058221: albatross
1169
+ n02066245: grey_whale
1170
+ n02071294: killer_whale
1171
+ n02074367: dugong
1172
+ n02077923: sea_lion
1173
+ n02085620: Chihuahua
1174
+ n02085782: Japanese_spaniel
1175
+ n02085936: Maltese_dog
1176
+ n02086079: Pekinese
1177
+ n02086240: Shih-Tzu
1178
+ n02086646: Blenheim_spaniel
1179
+ n02086910: papillon
1180
+ n02087046: toy_terrier
1181
+ n02087394: Rhodesian_ridgeback
1182
+ n02088094: Afghan_hound
1183
+ n02088238: basset
1184
+ n02088364: beagle
1185
+ n02088466: bloodhound
1186
+ n02088632: bluetick
1187
+ n02089078: black-and-tan_coonhound
1188
+ n02089867: Walker_hound
1189
+ n02089973: English_foxhound
1190
+ n02090379: redbone
1191
+ n02090622: borzoi
1192
+ n02090721: Irish_wolfhound
1193
+ n02091032: Italian_greyhound
1194
+ n02091134: whippet
1195
+ n02091244: Ibizan_hound
1196
+ n02091467: Norwegian_elkhound
1197
+ n02091635: otterhound
1198
+ n02091831: Saluki
1199
+ n02092002: Scottish_deerhound
1200
+ n02092339: Weimaraner
1201
+ n02093256: Staffordshire_bullterrier
1202
+ n02093428: American_Staffordshire_terrier
1203
+ n02093647: Bedlington_terrier
1204
+ n02093754: Border_terrier
1205
+ n02093859: Kerry_blue_terrier
1206
+ n02093991: Irish_terrier
1207
+ n02094114: Norfolk_terrier
1208
+ n02094258: Norwich_terrier
1209
+ n02094433: Yorkshire_terrier
1210
+ n02095314: wire-haired_fox_terrier
1211
+ n02095570: Lakeland_terrier
1212
+ n02095889: Sealyham_terrier
1213
+ n02096051: Airedale
1214
+ n02096177: cairn
1215
+ n02096294: Australian_terrier
1216
+ n02096437: Dandie_Dinmont
1217
+ n02096585: Boston_bull
1218
+ n02097047: miniature_schnauzer
1219
+ n02097130: giant_schnauzer
1220
+ n02097209: standard_schnauzer
1221
+ n02097298: Scotch_terrier
1222
+ n02097474: Tibetan_terrier
1223
+ n02097658: silky_terrier
1224
+ n02098105: soft-coated_wheaten_terrier
1225
+ n02098286: West_Highland_white_terrier
1226
+ n02098413: Lhasa
1227
+ n02099267: flat-coated_retriever
1228
+ n02099429: curly-coated_retriever
1229
+ n02099601: golden_retriever
1230
+ n02099712: Labrador_retriever
1231
+ n02099849: Chesapeake_Bay_retriever
1232
+ n02100236: German_short-haired_pointer
1233
+ n02100583: vizsla
1234
+ n02100735: English_setter
1235
+ n02100877: Irish_setter
1236
+ n02101006: Gordon_setter
1237
+ n02101388: Brittany_spaniel
1238
+ n02101556: clumber
1239
+ n02102040: English_springer
1240
+ n02102177: Welsh_springer_spaniel
1241
+ n02102318: cocker_spaniel
1242
+ n02102480: Sussex_spaniel
1243
+ n02102973: Irish_water_spaniel
1244
+ n02104029: kuvasz
1245
+ n02104365: schipperke
1246
+ n02105056: groenendael
1247
+ n02105162: malinois
1248
+ n02105251: briard
1249
+ n02105412: kelpie
1250
+ n02105505: komondor
1251
+ n02105641: Old_English_sheepdog
1252
+ n02105855: Shetland_sheepdog
1253
+ n02106030: collie
1254
+ n02106166: Border_collie
1255
+ n02106382: Bouvier_des_Flandres
1256
+ n02106550: Rottweiler
1257
+ n02106662: German_shepherd
1258
+ n02107142: Doberman
1259
+ n02107312: miniature_pinscher
1260
+ n02107574: Greater_Swiss_Mountain_dog
1261
+ n02107683: Bernese_mountain_dog
1262
+ n02107908: Appenzeller
1263
+ n02108000: EntleBucher
1264
+ n02108089: boxer
1265
+ n02108422: bull_mastiff
1266
+ n02108551: Tibetan_mastiff
1267
+ n02108915: French_bulldog
1268
+ n02109047: Great_Dane
1269
+ n02109525: Saint_Bernard
1270
+ n02109961: Eskimo_dog
1271
+ n02110063: malamute
1272
+ n02110185: Siberian_husky
1273
+ n02110341: dalmatian
1274
+ n02110627: affenpinscher
1275
+ n02110806: basenji
1276
+ n02110958: pug
1277
+ n02111129: Leonberg
1278
+ n02111277: Newfoundland
1279
+ n02111500: Great_Pyrenees
1280
+ n02111889: Samoyed
1281
+ n02112018: Pomeranian
1282
+ n02112137: chow
1283
+ n02112350: keeshond
1284
+ n02112706: Brabancon_griffon
1285
+ n02113023: Pembroke
1286
+ n02113186: Cardigan
1287
+ n02113624: toy_poodle
1288
+ n02113712: miniature_poodle
1289
+ n02113799: standard_poodle
1290
+ n02113978: Mexican_hairless
1291
+ n02114367: timber_wolf
1292
+ n02114548: white_wolf
1293
+ n02114712: red_wolf
1294
+ n02114855: coyote
1295
+ n02115641: dingo
1296
+ n02115913: dhole
1297
+ n02116738: African_hunting_dog
1298
+ n02117135: hyena
1299
+ n02119022: red_fox
1300
+ n02119789: kit_fox
1301
+ n02120079: Arctic_fox
1302
+ n02120505: grey_fox
1303
+ n02123045: tabby
1304
+ n02123159: tiger_cat
1305
+ n02123394: Persian_cat
1306
+ n02123597: Siamese_cat
1307
+ n02124075: Egyptian_cat
1308
+ n02125311: cougar
1309
+ n02127052: lynx
1310
+ n02128385: leopard
1311
+ n02128757: snow_leopard
1312
+ n02128925: jaguar
1313
+ n02129165: lion
1314
+ n02129604: tiger
1315
+ n02130308: cheetah
1316
+ n02132136: brown_bear
1317
+ n02133161: American_black_bear
1318
+ n02134084: ice_bear
1319
+ n02134418: sloth_bear
1320
+ n02137549: mongoose
1321
+ n02138441: meerkat
1322
+ n02165105: tiger_beetle
1323
+ n02165456: ladybug
1324
+ n02167151: ground_beetle
1325
+ n02168699: long-horned_beetle
1326
+ n02169497: leaf_beetle
1327
+ n02172182: dung_beetle
1328
+ n02174001: rhinoceros_beetle
1329
+ n02177972: weevil
1330
+ n02190166: fly
1331
+ n02206856: bee
1332
+ n02219486: ant
1333
+ n02226429: grasshopper
1334
+ n02229544: cricket
1335
+ n02231487: walking_stick
1336
+ n02233338: cockroach
1337
+ n02236044: mantis
1338
+ n02256656: cicada
1339
+ n02259212: leafhopper
1340
+ n02264363: lacewing
1341
+ n02268443: dragonfly
1342
+ n02268853: damselfly
1343
+ n02276258: admiral
1344
+ n02277742: ringlet
1345
+ n02279972: monarch
1346
+ n02280649: cabbage_butterfly
1347
+ n02281406: sulphur_butterfly
1348
+ n02281787: lycaenid
1349
+ n02317335: starfish
1350
+ n02319095: sea_urchin
1351
+ n02321529: sea_cucumber
1352
+ n02325366: wood_rabbit
1353
+ n02326432: hare
1354
+ n02328150: Angora
1355
+ n02342885: hamster
1356
+ n02346627: porcupine
1357
+ n02356798: fox_squirrel
1358
+ n02361337: marmot
1359
+ n02363005: beaver
1360
+ n02364673: guinea_pig
1361
+ n02389026: sorrel
1362
+ n02391049: zebra
1363
+ n02395406: hog
1364
+ n02396427: wild_boar
1365
+ n02397096: warthog
1366
+ n02398521: hippopotamus
1367
+ n02403003: ox
1368
+ n02408429: water_buffalo
1369
+ n02410509: bison
1370
+ n02412080: ram
1371
+ n02415577: bighorn
1372
+ n02417914: ibex
1373
+ n02422106: hartebeest
1374
+ n02422699: impala
1375
+ n02423022: gazelle
1376
+ n02437312: Arabian_camel
1377
+ n02437616: llama
1378
+ n02441942: weasel
1379
+ n02442845: mink
1380
+ n02443114: polecat
1381
+ n02443484: black-footed_ferret
1382
+ n02444819: otter
1383
+ n02445715: skunk
1384
+ n02447366: badger
1385
+ n02454379: armadillo
1386
+ n02457408: three-toed_sloth
1387
+ n02480495: orangutan
1388
+ n02480855: gorilla
1389
+ n02481823: chimpanzee
1390
+ n02483362: gibbon
1391
+ n02483708: siamang
1392
+ n02484975: guenon
1393
+ n02486261: patas
1394
+ n02486410: baboon
1395
+ n02487347: macaque
1396
+ n02488291: langur
1397
+ n02488702: colobus
1398
+ n02489166: proboscis_monkey
1399
+ n02490219: marmoset
1400
+ n02492035: capuchin
1401
+ n02492660: howler_monkey
1402
+ n02493509: titi
1403
+ n02493793: spider_monkey
1404
+ n02494079: squirrel_monkey
1405
+ n02497673: Madagascar_cat
1406
+ n02500267: indri
1407
+ n02504013: Indian_elephant
1408
+ n02504458: African_elephant
1409
+ n02509815: lesser_panda
1410
+ n02510455: giant_panda
1411
+ n02514041: barracouta
1412
+ n02526121: eel
1413
+ n02536864: coho
1414
+ n02606052: rock_beauty
1415
+ n02607072: anemone_fish
1416
+ n02640242: sturgeon
1417
+ n02641379: gar
1418
+ n02643566: lionfish
1419
+ n02655020: puffer
1420
+ n02666196: abacus
1421
+ n02667093: abaya
1422
+ n02669723: academic_gown
1423
+ n02672831: accordion
1424
+ n02676566: acoustic_guitar
1425
+ n02687172: aircraft_carrier
1426
+ n02690373: airliner
1427
+ n02692877: airship
1428
+ n02699494: altar
1429
+ n02701002: ambulance
1430
+ n02704792: amphibian
1431
+ n02708093: analog_clock
1432
+ n02727426: apiary
1433
+ n02730930: apron
1434
+ n02747177: ashcan
1435
+ n02749479: assault_rifle
1436
+ n02769748: backpack
1437
+ n02776631: bakery
1438
+ n02777292: balance_beam
1439
+ n02782093: balloon
1440
+ n02783161: ballpoint
1441
+ n02786058: Band_Aid
1442
+ n02787622: banjo
1443
+ n02788148: bannister
1444
+ n02790996: barbell
1445
+ n02791124: barber_chair
1446
+ n02791270: barbershop
1447
+ n02793495: barn
1448
+ n02794156: barometer
1449
+ n02795169: barrel
1450
+ n02797295: barrow
1451
+ n02799071: baseball
1452
+ n02802426: basketball
1453
+ n02804414: bassinet
1454
+ n02804610: bassoon
1455
+ n02807133: bathing_cap
1456
+ n02808304: bath_towel
1457
+ n02808440: bathtub
1458
+ n02814533: beach_wagon
1459
+ n02814860: beacon
1460
+ n02815834: beaker
1461
+ n02817516: bearskin
1462
+ n02823428: beer_bottle
1463
+ n02823750: beer_glass
1464
+ n02825657: bell_cote
1465
+ n02834397: bib
1466
+ n02835271: bicycle-built-for-two
1467
+ n02837789: bikini
1468
+ n02840245: binder
1469
+ n02841315: binoculars
1470
+ n02843684: birdhouse
1471
+ n02859443: boathouse
1472
+ n02860847: bobsled
1473
+ n02865351: bolo_tie
1474
+ n02869837: bonnet
1475
+ n02870880: bookcase
1476
+ n02871525: bookshop
1477
+ n02877765: bottlecap
1478
+ n02879718: bow
1479
+ n02883205: bow_tie
1480
+ n02892201: brass
1481
+ n02892767: brassiere
1482
+ n02894605: breakwater
1483
+ n02895154: breastplate
1484
+ n02906734: broom
1485
+ n02909870: bucket
1486
+ n02910353: buckle
1487
+ n02916936: bulletproof_vest
1488
+ n02917067: bullet_train
1489
+ n02927161: butcher_shop
1490
+ n02930766: cab
1491
+ n02939185: caldron
1492
+ n02948072: candle
1493
+ n02950826: cannon
1494
+ n02951358: canoe
1495
+ n02951585: can_opener
1496
+ n02963159: cardigan
1497
+ n02965783: car_mirror
1498
+ n02966193: carousel
1499
+ n02966687: carpenter's_kit
1500
+ n02971356: carton
1501
+ n02974003: car_wheel
1502
+ n02977058: cash_machine
1503
+ n02978881: cassette
1504
+ n02979186: cassette_player
1505
+ n02980441: castle
1506
+ n02981792: catamaran
1507
+ n02988304: CD_player
1508
+ n02992211: cello
1509
+ n02992529: cellular_telephone
1510
+ n02999410: chain
1511
+ n03000134: chainlink_fence
1512
+ n03000247: chain_mail
1513
+ n03000684: chain_saw
1514
+ n03014705: chest
1515
+ n03016953: chiffonier
1516
+ n03017168: chime
1517
+ n03018349: china_cabinet
1518
+ n03026506: Christmas_stocking
1519
+ n03028079: church
1520
+ n03032252: cinema
1521
+ n03041632: cleaver
1522
+ n03042490: cliff_dwelling
1523
+ n03045698: cloak
1524
+ n03047690: clog
1525
+ n03062245: cocktail_shaker
1526
+ n03063599: coffee_mug
1527
+ n03063689: coffeepot
1528
+ n03065424: coil
1529
+ n03075370: combination_lock
1530
+ n03085013: computer_keyboard
1531
+ n03089624: confectionery
1532
+ n03095699: container_ship
1533
+ n03100240: convertible
1534
+ n03109150: corkscrew
1535
+ n03110669: cornet
1536
+ n03124043: cowboy_boot
1537
+ n03124170: cowboy_hat
1538
+ n03125729: cradle
1539
+ n03126707: crane_(machine)
1540
+ n03127747: crash_helmet
1541
+ n03127925: crate
1542
+ n03131574: crib
1543
+ n03133878: Crock_Pot
1544
+ n03134739: croquet_ball
1545
+ n03141823: crutch
1546
+ n03146219: cuirass
1547
+ n03160309: dam
1548
+ n03179701: desk
1549
+ n03180011: desktop_computer
1550
+ n03187595: dial_telephone
1551
+ n03188531: diaper
1552
+ n03196217: digital_clock
1553
+ n03197337: digital_watch
1554
+ n03201208: dining_table
1555
+ n03207743: dishrag
1556
+ n03207941: dishwasher
1557
+ n03208938: disk_brake
1558
+ n03216828: dock
1559
+ n03218198: dogsled
1560
+ n03220513: dome
1561
+ n03223299: doormat
1562
+ n03240683: drilling_platform
1563
+ n03249569: drum
1564
+ n03250847: drumstick
1565
+ n03255030: dumbbell
1566
+ n03259280: Dutch_oven
1567
+ n03271574: electric_fan
1568
+ n03272010: electric_guitar
1569
+ n03272562: electric_locomotive
1570
+ n03290653: entertainment_center
1571
+ n03291819: envelope
1572
+ n03297495: espresso_maker
1573
+ n03314780: face_powder
1574
+ n03325584: feather_boa
1575
+ n03337140: file
1576
+ n03344393: fireboat
1577
+ n03345487: fire_engine
1578
+ n03347037: fire_screen
1579
+ n03355925: flagpole
1580
+ n03372029: flute
1581
+ n03376595: folding_chair
1582
+ n03379051: football_helmet
1583
+ n03384352: forklift
1584
+ n03388043: fountain
1585
+ n03388183: fountain_pen
1586
+ n03388549: four-poster
1587
+ n03393912: freight_car
1588
+ n03394916: French_horn
1589
+ n03400231: frying_pan
1590
+ n03404251: fur_coat
1591
+ n03417042: garbage_truck
1592
+ n03424325: gasmask
1593
+ n03425413: gas_pump
1594
+ n03443371: goblet
1595
+ n03444034: go-kart
1596
+ n03445777: golf_ball
1597
+ n03445924: golfcart
1598
+ n03447447: gondola
1599
+ n03447721: gong
1600
+ n03450230: gown
1601
+ n03452741: grand_piano
1602
+ n03457902: greenhouse
1603
+ n03459775: grille
1604
+ n03461385: grocery_store
1605
+ n03467068: guillotine
1606
+ n03476684: hair_slide
1607
+ n03476991: hair_spray
1608
+ n03478589: half_track
1609
+ n03481172: hammer
1610
+ n03482405: hamper
1611
+ n03483316: hand_blower
1612
+ n03485407: hand-held_computer
1613
+ n03485794: handkerchief
1614
+ n03492542: hard_disc
1615
+ n03494278: harmonica
1616
+ n03495258: harp
1617
+ n03496892: harvester
1618
+ n03498962: hatchet
1619
+ n03527444: holster
1620
+ n03529860: home_theater
1621
+ n03530642: honeycomb
1622
+ n03532672: hook
1623
+ n03534580: hoopskirt
1624
+ n03535780: horizontal_bar
1625
+ n03538406: horse_cart
1626
+ n03544143: hourglass
1627
+ n03584254: iPod
1628
+ n03584829: iron
1629
+ n03590841: jack-o'-lantern
1630
+ n03594734: jean
1631
+ n03594945: jeep
1632
+ n03595614: jersey
1633
+ n03598930: jigsaw_puzzle
1634
+ n03599486: jinrikisha
1635
+ n03602883: joystick
1636
+ n03617480: kimono
1637
+ n03623198: knee_pad
1638
+ n03627232: knot
1639
+ n03630383: lab_coat
1640
+ n03633091: ladle
1641
+ n03637318: lampshade
1642
+ n03642806: laptop
1643
+ n03649909: lawn_mower
1644
+ n03657121: lens_cap
1645
+ n03658185: letter_opener
1646
+ n03661043: library
1647
+ n03662601: lifeboat
1648
+ n03666591: lighter
1649
+ n03670208: limousine
1650
+ n03673027: liner
1651
+ n03676483: lipstick
1652
+ n03680355: Loafer
1653
+ n03690938: lotion
1654
+ n03691459: loudspeaker
1655
+ n03692522: loupe
1656
+ n03697007: lumbermill
1657
+ n03706229: magnetic_compass
1658
+ n03709823: mailbag
1659
+ n03710193: mailbox
1660
+ n03710637: maillot_(tights)
1661
+ n03710721: maillot_(tank_suit)
1662
+ n03717622: manhole_cover
1663
+ n03720891: maraca
1664
+ n03721384: marimba
1665
+ n03724870: mask
1666
+ n03729826: matchstick
1667
+ n03733131: maypole
1668
+ n03733281: maze
1669
+ n03733805: measuring_cup
1670
+ n03742115: medicine_chest
1671
+ n03743016: megalith
1672
+ n03759954: microphone
1673
+ n03761084: microwave
1674
+ n03763968: military_uniform
1675
+ n03764736: milk_can
1676
+ n03769881: minibus
1677
+ n03770439: miniskirt
1678
+ n03770679: minivan
1679
+ n03773504: missile
1680
+ n03775071: mitten
1681
+ n03775546: mixing_bowl
1682
+ n03776460: mobile_home
1683
+ n03777568: Model_T
1684
+ n03777754: modem
1685
+ n03781244: monastery
1686
+ n03782006: monitor
1687
+ n03785016: moped
1688
+ n03786901: mortar
1689
+ n03787032: mortarboard
1690
+ n03788195: mosque
1691
+ n03788365: mosquito_net
1692
+ n03791053: motor_scooter
1693
+ n03792782: mountain_bike
1694
+ n03792972: mountain_tent
1695
+ n03793489: mouse
1696
+ n03794056: mousetrap
1697
+ n03796401: moving_van
1698
+ n03803284: muzzle
1699
+ n03804744: nail
1700
+ n03814639: neck_brace
1701
+ n03814906: necklace
1702
+ n03825788: nipple
1703
+ n03832673: notebook
1704
+ n03837869: obelisk
1705
+ n03838899: oboe
1706
+ n03840681: ocarina
1707
+ n03841143: odometer
1708
+ n03843555: oil_filter
1709
+ n03854065: organ
1710
+ n03857828: oscilloscope
1711
+ n03866082: overskirt
1712
+ n03868242: oxcart
1713
+ n03868863: oxygen_mask
1714
+ n03871628: packet
1715
+ n03873416: paddle
1716
+ n03874293: paddlewheel
1717
+ n03874599: padlock
1718
+ n03876231: paintbrush
1719
+ n03877472: pajama
1720
+ n03877845: palace
1721
+ n03884397: panpipe
1722
+ n03887697: paper_towel
1723
+ n03888257: parachute
1724
+ n03888605: parallel_bars
1725
+ n03891251: park_bench
1726
+ n03891332: parking_meter
1727
+ n03895866: passenger_car
1728
+ n03899768: patio
1729
+ n03902125: pay-phone
1730
+ n03903868: pedestal
1731
+ n03908618: pencil_box
1732
+ n03908714: pencil_sharpener
1733
+ n03916031: perfume
1734
+ n03920288: Petri_dish
1735
+ n03924679: photocopier
1736
+ n03929660: pick
1737
+ n03929855: pickelhaube
1738
+ n03930313: picket_fence
1739
+ n03930630: pickup
1740
+ n03933933: pier
1741
+ n03935335: piggy_bank
1742
+ n03937543: pill_bottle
1743
+ n03938244: pillow
1744
+ n03942813: ping-pong_ball
1745
+ n03944341: pinwheel
1746
+ n03947888: pirate
1747
+ n03950228: pitcher
1748
+ n03954731: plane
1749
+ n03956157: planetarium
1750
+ n03958227: plastic_bag
1751
+ n03961711: plate_rack
1752
+ n03967562: plow
1753
+ n03970156: plunger
1754
+ n03976467: Polaroid_camera
1755
+ n03976657: pole
1756
+ n03977966: police_van
1757
+ n03980874: poncho
1758
+ n03982430: pool_table
1759
+ n03983396: pop_bottle
1760
+ n03991062: pot
1761
+ n03992509: potter's_wheel
1762
+ n03995372: power_drill
1763
+ n03998194: prayer_rug
1764
+ n04004767: printer
1765
+ n04005630: prison
1766
+ n04008634: projectile
1767
+ n04009552: projector
1768
+ n04019541: puck
1769
+ n04023962: punching_bag
1770
+ n04026417: purse
1771
+ n04033901: quill
1772
+ n04033995: quilt
1773
+ n04037443: racer
1774
+ n04039381: racket
1775
+ n04040759: radiator
1776
+ n04041544: radio
1777
+ n04044716: radio_telescope
1778
+ n04049303: rain_barrel
1779
+ n04065272: recreational_vehicle
1780
+ n04067472: reel
1781
+ n04069434: reflex_camera
1782
+ n04070727: refrigerator
1783
+ n04074963: remote_control
1784
+ n04081281: restaurant
1785
+ n04086273: revolver
1786
+ n04090263: rifle
1787
+ n04099969: rocking_chair
1788
+ n04111531: rotisserie
1789
+ n04116512: rubber_eraser
1790
+ n04118538: rugby_ball
1791
+ n04118776: rule
1792
+ n04120489: running_shoe
1793
+ n04125021: safe
1794
+ n04127249: safety_pin
1795
+ n04131690: saltshaker
1796
+ n04133789: sandal
1797
+ n04136333: sarong
1798
+ n04141076: sax
1799
+ n04141327: scabbard
1800
+ n04141975: scale
1801
+ n04146614: school_bus
1802
+ n04147183: schooner
1803
+ n04149813: scoreboard
1804
+ n04152593: screen
1805
+ n04153751: screw
1806
+ n04154565: screwdriver
1807
+ n04162706: seat_belt
1808
+ n04179913: sewing_machine
1809
+ n04192698: shield
1810
+ n04200800: shoe_shop
1811
+ n04201297: shoji
1812
+ n04204238: shopping_basket
1813
+ n04204347: shopping_cart
1814
+ n04208210: shovel
1815
+ n04209133: shower_cap
1816
+ n04209239: shower_curtain
1817
+ n04228054: ski
1818
+ n04229816: ski_mask
1819
+ n04235860: sleeping_bag
1820
+ n04238763: slide_rule
1821
+ n04239074: sliding_door
1822
+ n04243546: slot
1823
+ n04251144: snorkel
1824
+ n04252077: snowmobile
1825
+ n04252225: snowplow
1826
+ n04254120: soap_dispenser
1827
+ n04254680: soccer_ball
1828
+ n04254777: sock
1829
+ n04258138: solar_dish
1830
+ n04259630: sombrero
1831
+ n04263257: soup_bowl
1832
+ n04264628: space_bar
1833
+ n04265275: space_heater
1834
+ n04266014: space_shuttle
1835
+ n04270147: spatula
1836
+ n04273569: speedboat
1837
+ n04275548: spider_web
1838
+ n04277352: spindle
1839
+ n04285008: sports_car
1840
+ n04286575: spotlight
1841
+ n04296562: stage
1842
+ n04310018: steam_locomotive
1843
+ n04311004: steel_arch_bridge
1844
+ n04311174: steel_drum
1845
+ n04317175: stethoscope
1846
+ n04325704: stole
1847
+ n04326547: stone_wall
1848
+ n04328186: stopwatch
1849
+ n04330267: stove
1850
+ n04332243: strainer
1851
+ n04335435: streetcar
1852
+ n04336792: stretcher
1853
+ n04344873: studio_couch
1854
+ n04346328: stupa
1855
+ n04347754: submarine
1856
+ n04350905: suit
1857
+ n04355338: sundial
1858
+ n04355933: sunglass
1859
+ n04356056: sunglasses
1860
+ n04357314: sunscreen
1861
+ n04366367: suspension_bridge
1862
+ n04367480: swab
1863
+ n04370456: sweatshirt
1864
+ n04371430: swimming_trunks
1865
+ n04371774: swing
1866
+ n04372370: switch
1867
+ n04376876: syringe
1868
+ n04380533: table_lamp
1869
+ n04389033: tank
1870
+ n04392985: tape_player
1871
+ n04398044: teapot
1872
+ n04399382: teddy
1873
+ n04404412: television
1874
+ n04409515: tennis_ball
1875
+ n04417672: thatch
1876
+ n04418357: theater_curtain
1877
+ n04423845: thimble
1878
+ n04428191: thresher
1879
+ n04429376: throne
1880
+ n04435653: tile_roof
1881
+ n04442312: toaster
1882
+ n04443257: tobacco_shop
1883
+ n04447861: toilet_seat
1884
+ n04456115: torch
1885
+ n04458633: totem_pole
1886
+ n04461696: tow_truck
1887
+ n04462240: toyshop
1888
+ n04465501: tractor
1889
+ n04467665: trailer_truck
1890
+ n04476259: tray
1891
+ n04479046: trench_coat
1892
+ n04482393: tricycle
1893
+ n04483307: trimaran
1894
+ n04485082: tripod
1895
+ n04486054: triumphal_arch
1896
+ n04487081: trolleybus
1897
+ n04487394: trombone
1898
+ n04493381: tub
1899
+ n04501370: turnstile
1900
+ n04505470: typewriter_keyboard
1901
+ n04507155: umbrella
1902
+ n04509417: unicycle
1903
+ n04515003: upright
1904
+ n04517823: vacuum
1905
+ n04522168: vase
1906
+ n04523525: vault
1907
+ n04525038: velvet
1908
+ n04525305: vending_machine
1909
+ n04532106: vestment
1910
+ n04532670: viaduct
1911
+ n04536866: violin
1912
+ n04540053: volleyball
1913
+ n04542943: waffle_iron
1914
+ n04548280: wall_clock
1915
+ n04548362: wallet
1916
+ n04550184: wardrobe
1917
+ n04552348: warplane
1918
+ n04553703: washbasin
1919
+ n04554684: washer
1920
+ n04557648: water_bottle
1921
+ n04560804: water_jug
1922
+ n04562935: water_tower
1923
+ n04579145: whiskey_jug
1924
+ n04579432: whistle
1925
+ n04584207: wig
1926
+ n04589890: window_screen
1927
+ n04590129: window_shade
1928
+ n04591157: Windsor_tie
1929
+ n04591713: wine_bottle
1930
+ n04592741: wing
1931
+ n04596742: wok
1932
+ n04597913: wooden_spoon
1933
+ n04599235: wool
1934
+ n04604644: worm_fence
1935
+ n04606251: wreck
1936
+ n04612504: yawl
1937
+ n04613696: yurt
1938
+ n06359193: web_site
1939
+ n06596364: comic_book
1940
+ n06785654: crossword_puzzle
1941
+ n06794110: street_sign
1942
+ n06874185: traffic_light
1943
+ n07248320: book_jacket
1944
+ n07565083: menu
1945
+ n07579787: plate
1946
+ n07583066: guacamole
1947
+ n07584110: consomme
1948
+ n07590611: hot_pot
1949
+ n07613480: trifle
1950
+ n07614500: ice_cream
1951
+ n07615774: ice_lolly
1952
+ n07684084: French_loaf
1953
+ n07693725: bagel
1954
+ n07695742: pretzel
1955
+ n07697313: cheeseburger
1956
+ n07697537: hotdog
1957
+ n07711569: mashed_potato
1958
+ n07714571: head_cabbage
1959
+ n07714990: broccoli
1960
+ n07715103: cauliflower
1961
+ n07716358: zucchini
1962
+ n07716906: spaghetti_squash
1963
+ n07717410: acorn_squash
1964
+ n07717556: butternut_squash
1965
+ n07718472: cucumber
1966
+ n07718747: artichoke
1967
+ n07720875: bell_pepper
1968
+ n07730033: cardoon
1969
+ n07734744: mushroom
1970
+ n07742313: Granny_Smith
1971
+ n07745940: strawberry
1972
+ n07747607: orange
1973
+ n07749582: lemon
1974
+ n07753113: fig
1975
+ n07753275: pineapple
1976
+ n07753592: banana
1977
+ n07754684: jackfruit
1978
+ n07760859: custard_apple
1979
+ n07768694: pomegranate
1980
+ n07802026: hay
1981
+ n07831146: carbonara
1982
+ n07836838: chocolate_sauce
1983
+ n07860988: dough
1984
+ n07871810: meat_loaf
1985
+ n07873807: pizza
1986
+ n07875152: potpie
1987
+ n07880968: burrito
1988
+ n07892512: red_wine
1989
+ n07920052: espresso
1990
+ n07930864: cup
1991
+ n07932039: eggnog
1992
+ n09193705: alp
1993
+ n09229709: bubble
1994
+ n09246464: cliff
1995
+ n09256479: coral_reef
1996
+ n09288635: geyser
1997
+ n09332890: lakeside
1998
+ n09399592: promontory
1999
+ n09421951: sandbar
2000
+ n09428293: seashore
2001
+ n09468604: valley
2002
+ n09472597: volcano
2003
+ n09835506: ballplayer
2004
+ n10148035: groom
2005
+ n10565667: scuba_diver
2006
+ n11879895: rapeseed
2007
+ n11939491: daisy
2008
+ n12057211: yellow_lady's_slipper
2009
+ n12144580: corn
2010
+ n12267677: acorn
2011
+ n12620546: hip
2012
+ n12768682: buckeye
2013
+ n12985857: coral_fungus
2014
+ n12998815: agaric
2015
+ n13037406: gyromitra
2016
+ n13040303: stinkhorn
2017
+ n13044778: earthstar
2018
+ n13052670: hen-of-the-woods
2019
+ n13054560: bolete
2020
+ n13133613: ear
2021
+ n15075141: toilet_tissue
2022
+
2023
+
2024
+ # Download script/URL (optional)
2025
+ download: yolo/data/scripts/get_imagenet.sh
ultralytics/datasets/Objects365.yaml ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # Objects365 dataset https://www.objects365.org/ by Megvii
3
+ # Example usage: yolo train data=Objects365.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── Objects365 ← downloads here (712 GB = 367G data + 345G zips)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/Objects365 # dataset root dir
12
+ train: images/train # train images (relative to 'path') 1742289 images
13
+ val: images/val # val images (relative to 'path') 80000 images
14
+ test: # test images (optional)
15
+
16
+ # Classes
17
+ names:
18
+ 0: Person
19
+ 1: Sneakers
20
+ 2: Chair
21
+ 3: Other Shoes
22
+ 4: Hat
23
+ 5: Car
24
+ 6: Lamp
25
+ 7: Glasses
26
+ 8: Bottle
27
+ 9: Desk
28
+ 10: Cup
29
+ 11: Street Lights
30
+ 12: Cabinet/shelf
31
+ 13: Handbag/Satchel
32
+ 14: Bracelet
33
+ 15: Plate
34
+ 16: Picture/Frame
35
+ 17: Helmet
36
+ 18: Book
37
+ 19: Gloves
38
+ 20: Storage box
39
+ 21: Boat
40
+ 22: Leather Shoes
41
+ 23: Flower
42
+ 24: Bench
43
+ 25: Potted Plant
44
+ 26: Bowl/Basin
45
+ 27: Flag
46
+ 28: Pillow
47
+ 29: Boots
48
+ 30: Vase
49
+ 31: Microphone
50
+ 32: Necklace
51
+ 33: Ring
52
+ 34: SUV
53
+ 35: Wine Glass
54
+ 36: Belt
55
+ 37: Monitor/TV
56
+ 38: Backpack
57
+ 39: Umbrella
58
+ 40: Traffic Light
59
+ 41: Speaker
60
+ 42: Watch
61
+ 43: Tie
62
+ 44: Trash bin Can
63
+ 45: Slippers
64
+ 46: Bicycle
65
+ 47: Stool
66
+ 48: Barrel/bucket
67
+ 49: Van
68
+ 50: Couch
69
+ 51: Sandals
70
+ 52: Basket
71
+ 53: Drum
72
+ 54: Pen/Pencil
73
+ 55: Bus
74
+ 56: Wild Bird
75
+ 57: High Heels
76
+ 58: Motorcycle
77
+ 59: Guitar
78
+ 60: Carpet
79
+ 61: Cell Phone
80
+ 62: Bread
81
+ 63: Camera
82
+ 64: Canned
83
+ 65: Truck
84
+ 66: Traffic cone
85
+ 67: Cymbal
86
+ 68: Lifesaver
87
+ 69: Towel
88
+ 70: Stuffed Toy
89
+ 71: Candle
90
+ 72: Sailboat
91
+ 73: Laptop
92
+ 74: Awning
93
+ 75: Bed
94
+ 76: Faucet
95
+ 77: Tent
96
+ 78: Horse
97
+ 79: Mirror
98
+ 80: Power outlet
99
+ 81: Sink
100
+ 82: Apple
101
+ 83: Air Conditioner
102
+ 84: Knife
103
+ 85: Hockey Stick
104
+ 86: Paddle
105
+ 87: Pickup Truck
106
+ 88: Fork
107
+ 89: Traffic Sign
108
+ 90: Balloon
109
+ 91: Tripod
110
+ 92: Dog
111
+ 93: Spoon
112
+ 94: Clock
113
+ 95: Pot
114
+ 96: Cow
115
+ 97: Cake
116
+ 98: Dinning Table
117
+ 99: Sheep
118
+ 100: Hanger
119
+ 101: Blackboard/Whiteboard
120
+ 102: Napkin
121
+ 103: Other Fish
122
+ 104: Orange/Tangerine
123
+ 105: Toiletry
124
+ 106: Keyboard
125
+ 107: Tomato
126
+ 108: Lantern
127
+ 109: Machinery Vehicle
128
+ 110: Fan
129
+ 111: Green Vegetables
130
+ 112: Banana
131
+ 113: Baseball Glove
132
+ 114: Airplane
133
+ 115: Mouse
134
+ 116: Train
135
+ 117: Pumpkin
136
+ 118: Soccer
137
+ 119: Skiboard
138
+ 120: Luggage
139
+ 121: Nightstand
140
+ 122: Tea pot
141
+ 123: Telephone
142
+ 124: Trolley
143
+ 125: Head Phone
144
+ 126: Sports Car
145
+ 127: Stop Sign
146
+ 128: Dessert
147
+ 129: Scooter
148
+ 130: Stroller
149
+ 131: Crane
150
+ 132: Remote
151
+ 133: Refrigerator
152
+ 134: Oven
153
+ 135: Lemon
154
+ 136: Duck
155
+ 137: Baseball Bat
156
+ 138: Surveillance Camera
157
+ 139: Cat
158
+ 140: Jug
159
+ 141: Broccoli
160
+ 142: Piano
161
+ 143: Pizza
162
+ 144: Elephant
163
+ 145: Skateboard
164
+ 146: Surfboard
165
+ 147: Gun
166
+ 148: Skating and Skiing shoes
167
+ 149: Gas stove
168
+ 150: Donut
169
+ 151: Bow Tie
170
+ 152: Carrot
171
+ 153: Toilet
172
+ 154: Kite
173
+ 155: Strawberry
174
+ 156: Other Balls
175
+ 157: Shovel
176
+ 158: Pepper
177
+ 159: Computer Box
178
+ 160: Toilet Paper
179
+ 161: Cleaning Products
180
+ 162: Chopsticks
181
+ 163: Microwave
182
+ 164: Pigeon
183
+ 165: Baseball
184
+ 166: Cutting/chopping Board
185
+ 167: Coffee Table
186
+ 168: Side Table
187
+ 169: Scissors
188
+ 170: Marker
189
+ 171: Pie
190
+ 172: Ladder
191
+ 173: Snowboard
192
+ 174: Cookies
193
+ 175: Radiator
194
+ 176: Fire Hydrant
195
+ 177: Basketball
196
+ 178: Zebra
197
+ 179: Grape
198
+ 180: Giraffe
199
+ 181: Potato
200
+ 182: Sausage
201
+ 183: Tricycle
202
+ 184: Violin
203
+ 185: Egg
204
+ 186: Fire Extinguisher
205
+ 187: Candy
206
+ 188: Fire Truck
207
+ 189: Billiards
208
+ 190: Converter
209
+ 191: Bathtub
210
+ 192: Wheelchair
211
+ 193: Golf Club
212
+ 194: Briefcase
213
+ 195: Cucumber
214
+ 196: Cigar/Cigarette
215
+ 197: Paint Brush
216
+ 198: Pear
217
+ 199: Heavy Truck
218
+ 200: Hamburger
219
+ 201: Extractor
220
+ 202: Extension Cord
221
+ 203: Tong
222
+ 204: Tennis Racket
223
+ 205: Folder
224
+ 206: American Football
225
+ 207: earphone
226
+ 208: Mask
227
+ 209: Kettle
228
+ 210: Tennis
229
+ 211: Ship
230
+ 212: Swing
231
+ 213: Coffee Machine
232
+ 214: Slide
233
+ 215: Carriage
234
+ 216: Onion
235
+ 217: Green beans
236
+ 218: Projector
237
+ 219: Frisbee
238
+ 220: Washing Machine/Drying Machine
239
+ 221: Chicken
240
+ 222: Printer
241
+ 223: Watermelon
242
+ 224: Saxophone
243
+ 225: Tissue
244
+ 226: Toothbrush
245
+ 227: Ice cream
246
+ 228: Hot-air balloon
247
+ 229: Cello
248
+ 230: French Fries
249
+ 231: Scale
250
+ 232: Trophy
251
+ 233: Cabbage
252
+ 234: Hot dog
253
+ 235: Blender
254
+ 236: Peach
255
+ 237: Rice
256
+ 238: Wallet/Purse
257
+ 239: Volleyball
258
+ 240: Deer
259
+ 241: Goose
260
+ 242: Tape
261
+ 243: Tablet
262
+ 244: Cosmetics
263
+ 245: Trumpet
264
+ 246: Pineapple
265
+ 247: Golf Ball
266
+ 248: Ambulance
267
+ 249: Parking meter
268
+ 250: Mango
269
+ 251: Key
270
+ 252: Hurdle
271
+ 253: Fishing Rod
272
+ 254: Medal
273
+ 255: Flute
274
+ 256: Brush
275
+ 257: Penguin
276
+ 258: Megaphone
277
+ 259: Corn
278
+ 260: Lettuce
279
+ 261: Garlic
280
+ 262: Swan
281
+ 263: Helicopter
282
+ 264: Green Onion
283
+ 265: Sandwich
284
+ 266: Nuts
285
+ 267: Speed Limit Sign
286
+ 268: Induction Cooker
287
+ 269: Broom
288
+ 270: Trombone
289
+ 271: Plum
290
+ 272: Rickshaw
291
+ 273: Goldfish
292
+ 274: Kiwi fruit
293
+ 275: Router/modem
294
+ 276: Poker Card
295
+ 277: Toaster
296
+ 278: Shrimp
297
+ 279: Sushi
298
+ 280: Cheese
299
+ 281: Notepaper
300
+ 282: Cherry
301
+ 283: Pliers
302
+ 284: CD
303
+ 285: Pasta
304
+ 286: Hammer
305
+ 287: Cue
306
+ 288: Avocado
307
+ 289: Hamimelon
308
+ 290: Flask
309
+ 291: Mushroom
310
+ 292: Screwdriver
311
+ 293: Soap
312
+ 294: Recorder
313
+ 295: Bear
314
+ 296: Eggplant
315
+ 297: Board Eraser
316
+ 298: Coconut
317
+ 299: Tape Measure/Ruler
318
+ 300: Pig
319
+ 301: Showerhead
320
+ 302: Globe
321
+ 303: Chips
322
+ 304: Steak
323
+ 305: Crosswalk Sign
324
+ 306: Stapler
325
+ 307: Camel
326
+ 308: Formula 1
327
+ 309: Pomegranate
328
+ 310: Dishwasher
329
+ 311: Crab
330
+ 312: Hoverboard
331
+ 313: Meat ball
332
+ 314: Rice Cooker
333
+ 315: Tuba
334
+ 316: Calculator
335
+ 317: Papaya
336
+ 318: Antelope
337
+ 319: Parrot
338
+ 320: Seal
339
+ 321: Butterfly
340
+ 322: Dumbbell
341
+ 323: Donkey
342
+ 324: Lion
343
+ 325: Urinal
344
+ 326: Dolphin
345
+ 327: Electric Drill
346
+ 328: Hair Dryer
347
+ 329: Egg tart
348
+ 330: Jellyfish
349
+ 331: Treadmill
350
+ 332: Lighter
351
+ 333: Grapefruit
352
+ 334: Game board
353
+ 335: Mop
354
+ 336: Radish
355
+ 337: Baozi
356
+ 338: Target
357
+ 339: French
358
+ 340: Spring Rolls
359
+ 341: Monkey
360
+ 342: Rabbit
361
+ 343: Pencil Case
362
+ 344: Yak
363
+ 345: Red Cabbage
364
+ 346: Binoculars
365
+ 347: Asparagus
366
+ 348: Barbell
367
+ 349: Scallop
368
+ 350: Noddles
369
+ 351: Comb
370
+ 352: Dumpling
371
+ 353: Oyster
372
+ 354: Table Tennis paddle
373
+ 355: Cosmetics Brush/Eyeliner Pencil
374
+ 356: Chainsaw
375
+ 357: Eraser
376
+ 358: Lobster
377
+ 359: Durian
378
+ 360: Okra
379
+ 361: Lipstick
380
+ 362: Cosmetics Mirror
381
+ 363: Curling
382
+ 364: Table Tennis
383
+
384
+
385
+ # Download script/URL (optional) ---------------------------------------------------------------------------------------
386
+ download: |
387
+ from tqdm import tqdm
388
+
389
+ from ultralytics.yolo.utils.checks import check_requirements
390
+ from ultralytics.yolo.utils.downloads import download
391
+ from ultralytics.yolo.utils.ops import xyxy2xywhn
392
+
393
+ import numpy as np
394
+ from pathlib import Path
395
+
396
+ check_requirements(('pycocotools>=2.0',))
397
+ from pycocotools.coco import COCO
398
+
399
+ # Make Directories
400
+ dir = Path(yaml['path']) # dataset root dir
401
+ for p in 'images', 'labels':
402
+ (dir / p).mkdir(parents=True, exist_ok=True)
403
+ for q in 'train', 'val':
404
+ (dir / p / q).mkdir(parents=True, exist_ok=True)
405
+
406
+ # Train, Val Splits
407
+ for split, patches in [('train', 50 + 1), ('val', 43 + 1)]:
408
+ print(f"Processing {split} in {patches} patches ...")
409
+ images, labels = dir / 'images' / split, dir / 'labels' / split
410
+
411
+ # Download
412
+ url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/"
413
+ if split == 'train':
414
+ download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir) # annotations json
415
+ download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, threads=8)
416
+ elif split == 'val':
417
+ download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir) # annotations json
418
+ download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, threads=8)
419
+ download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, threads=8)
420
+
421
+ # Move
422
+ for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'):
423
+ f.rename(images / f.name) # move to /images/{split}
424
+
425
+ # Labels
426
+ coco = COCO(dir / f'zhiyuan_objv2_{split}.json')
427
+ names = [x["name"] for x in coco.loadCats(coco.getCatIds())]
428
+ for cid, cat in enumerate(names):
429
+ catIds = coco.getCatIds(catNms=[cat])
430
+ imgIds = coco.getImgIds(catIds=catIds)
431
+ for im in tqdm(coco.loadImgs(imgIds), desc=f'Class {cid + 1}/{len(names)} {cat}'):
432
+ width, height = im["width"], im["height"]
433
+ path = Path(im["file_name"]) # image filename
434
+ try:
435
+ with open(labels / path.with_suffix('.txt').name, 'a') as file:
436
+ annIds = coco.getAnnIds(imgIds=im["id"], catIds=catIds, iscrowd=None)
437
+ for a in coco.loadAnns(annIds):
438
+ x, y, w, h = a['bbox'] # bounding box in xywh (xy top-left corner)
439
+ xyxy = np.array([x, y, x + w, y + h])[None] # pixels(1,4)
440
+ x, y, w, h = xyxy2xywhn(xyxy, w=width, h=height, clip=True)[0] # normalized and clipped
441
+ file.write(f"{cid} {x:.5f} {y:.5f} {w:.5f} {h:.5f}\n")
442
+ except Exception as e:
443
+ print(e)
ultralytics/datasets/SKU-110K.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19 by Trax Retail
3
+ # Example usage: yolo train data=SKU-110K.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── SKU-110K ← downloads here (13.6 GB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/SKU-110K # dataset root dir
12
+ train: train.txt # train images (relative to 'path') 8219 images
13
+ val: val.txt # val images (relative to 'path') 588 images
14
+ test: test.txt # test images (optional) 2936 images
15
+
16
+ # Classes
17
+ names:
18
+ 0: object
19
+
20
+
21
+ # Download script/URL (optional) ---------------------------------------------------------------------------------------
22
+ download: |
23
+ import shutil
24
+ from pathlib import Path
25
+
26
+ import numpy as np
27
+ import pandas as pd
28
+ from tqdm import tqdm
29
+
30
+ from ultralytics.yolo.utils.downloads import download
31
+ from ultralytics.yolo.utils.ops import xyxy2xywh
32
+
33
+ # Download
34
+ dir = Path(yaml['path']) # dataset root dir
35
+ parent = Path(dir.parent) # download dir
36
+ urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
37
+ download(urls, dir=parent)
38
+
39
+ # Rename directories
40
+ if dir.exists():
41
+ shutil.rmtree(dir)
42
+ (parent / 'SKU110K_fixed').rename(dir) # rename dir
43
+ (dir / 'labels').mkdir(parents=True, exist_ok=True) # create labels dir
44
+
45
+ # Convert labels
46
+ names = 'image', 'x1', 'y1', 'x2', 'y2', 'class', 'image_width', 'image_height' # column names
47
+ for d in 'annotations_train.csv', 'annotations_val.csv', 'annotations_test.csv':
48
+ x = pd.read_csv(dir / 'annotations' / d, names=names).values # annotations
49
+ images, unique_images = x[:, 0], np.unique(x[:, 0])
50
+ with open((dir / d).with_suffix('.txt').__str__().replace('annotations_', ''), 'w') as f:
51
+ f.writelines(f'./images/{s}\n' for s in unique_images)
52
+ for im in tqdm(unique_images, desc=f'Converting {dir / d}'):
53
+ cls = 0 # single-class dataset
54
+ with open((dir / 'labels' / im).with_suffix('.txt'), 'a') as f:
55
+ for r in x[images == im]:
56
+ w, h = r[6], r[7] # image width, height
57
+ xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance
58
+ f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label
ultralytics/datasets/VOC.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC by University of Oxford
3
+ # Example usage: yolo train data=VOC.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── VOC ← downloads here (2.8 GB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/VOC
12
+ train: # train images (relative to 'path') 16551 images
13
+ - images/train2012
14
+ - images/train2007
15
+ - images/val2012
16
+ - images/val2007
17
+ val: # val images (relative to 'path') 4952 images
18
+ - images/test2007
19
+ test: # test images (optional)
20
+ - images/test2007
21
+
22
+ # Classes
23
+ names:
24
+ 0: aeroplane
25
+ 1: bicycle
26
+ 2: bird
27
+ 3: boat
28
+ 4: bottle
29
+ 5: bus
30
+ 6: car
31
+ 7: cat
32
+ 8: chair
33
+ 9: cow
34
+ 10: diningtable
35
+ 11: dog
36
+ 12: horse
37
+ 13: motorbike
38
+ 14: person
39
+ 15: pottedplant
40
+ 16: sheep
41
+ 17: sofa
42
+ 18: train
43
+ 19: tvmonitor
44
+
45
+
46
+ # Download script/URL (optional) ---------------------------------------------------------------------------------------
47
+ download: |
48
+ import xml.etree.ElementTree as ET
49
+
50
+ from tqdm import tqdm
51
+ from ultralytics.yolo.utils.downloads import download
52
+ from pathlib import Path
53
+
54
+ def convert_label(path, lb_path, year, image_id):
55
+ def convert_box(size, box):
56
+ dw, dh = 1. / size[0], 1. / size[1]
57
+ x, y, w, h = (box[0] + box[1]) / 2.0 - 1, (box[2] + box[3]) / 2.0 - 1, box[1] - box[0], box[3] - box[2]
58
+ return x * dw, y * dh, w * dw, h * dh
59
+
60
+ in_file = open(path / f'VOC{year}/Annotations/{image_id}.xml')
61
+ out_file = open(lb_path, 'w')
62
+ tree = ET.parse(in_file)
63
+ root = tree.getroot()
64
+ size = root.find('size')
65
+ w = int(size.find('width').text)
66
+ h = int(size.find('height').text)
67
+
68
+ names = list(yaml['names'].values()) # names list
69
+ for obj in root.iter('object'):
70
+ cls = obj.find('name').text
71
+ if cls in names and int(obj.find('difficult').text) != 1:
72
+ xmlbox = obj.find('bndbox')
73
+ bb = convert_box((w, h), [float(xmlbox.find(x).text) for x in ('xmin', 'xmax', 'ymin', 'ymax')])
74
+ cls_id = names.index(cls) # class id
75
+ out_file.write(" ".join([str(a) for a in (cls_id, *bb)]) + '\n')
76
+
77
+
78
+ # Download
79
+ dir = Path(yaml['path']) # dataset root dir
80
+ url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
81
+ urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images
82
+ f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images
83
+ f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images
84
+ download(urls, dir=dir / 'images', curl=True, threads=3)
85
+
86
+ # Convert
87
+ path = dir / 'images/VOCdevkit'
88
+ for year, image_set in ('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test'):
89
+ imgs_path = dir / 'images' / f'{image_set}{year}'
90
+ lbs_path = dir / 'labels' / f'{image_set}{year}'
91
+ imgs_path.mkdir(exist_ok=True, parents=True)
92
+ lbs_path.mkdir(exist_ok=True, parents=True)
93
+
94
+ with open(path / f'VOC{year}/ImageSets/Main/{image_set}.txt') as f:
95
+ image_ids = f.read().strip().split()
96
+ for id in tqdm(image_ids, desc=f'{image_set}{year}'):
97
+ f = path / f'VOC{year}/JPEGImages/{id}.jpg' # old img path
98
+ lb_path = (lbs_path / f.name).with_suffix('.txt') # new label path
99
+ f.rename(imgs_path / f.name) # move image
100
+ convert_label(path, lb_path, year, id) # convert labels to YOLO format
ultralytics/datasets/VisDrone.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset by Tianjin University
3
+ # Example usage: yolo train data=VisDrone.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── VisDrone ← downloads here (2.3 GB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/VisDrone # dataset root dir
12
+ train: VisDrone2019-DET-train/images # train images (relative to 'path') 6471 images
13
+ val: VisDrone2019-DET-val/images # val images (relative to 'path') 548 images
14
+ test: VisDrone2019-DET-test-dev/images # test images (optional) 1610 images
15
+
16
+ # Classes
17
+ names:
18
+ 0: pedestrian
19
+ 1: people
20
+ 2: bicycle
21
+ 3: car
22
+ 4: van
23
+ 5: truck
24
+ 6: tricycle
25
+ 7: awning-tricycle
26
+ 8: bus
27
+ 9: motor
28
+
29
+
30
+ # Download script/URL (optional) ---------------------------------------------------------------------------------------
31
+ download: |
32
+ import os
33
+ from pathlib import Path
34
+
35
+ from ultralytics.yolo.utils.downloads import download
36
+
37
+ def visdrone2yolo(dir):
38
+ from PIL import Image
39
+ from tqdm import tqdm
40
+
41
+ def convert_box(size, box):
42
+ # Convert VisDrone box to YOLO xywh box
43
+ dw = 1. / size[0]
44
+ dh = 1. / size[1]
45
+ return (box[0] + box[2] / 2) * dw, (box[1] + box[3] / 2) * dh, box[2] * dw, box[3] * dh
46
+
47
+ (dir / 'labels').mkdir(parents=True, exist_ok=True) # make labels directory
48
+ pbar = tqdm((dir / 'annotations').glob('*.txt'), desc=f'Converting {dir}')
49
+ for f in pbar:
50
+ img_size = Image.open((dir / 'images' / f.name).with_suffix('.jpg')).size
51
+ lines = []
52
+ with open(f, 'r') as file: # read annotation.txt
53
+ for row in [x.split(',') for x in file.read().strip().splitlines()]:
54
+ if row[4] == '0': # VisDrone 'ignored regions' class 0
55
+ continue
56
+ cls = int(row[5]) - 1
57
+ box = convert_box(img_size, tuple(map(int, row[:4])))
58
+ lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n")
59
+ with open(str(f).replace(f'{os.sep}annotations{os.sep}', f'{os.sep}labels{os.sep}'), 'w') as fl:
60
+ fl.writelines(lines) # write label.txt
61
+
62
+
63
+ # Download
64
+ dir = Path(yaml['path']) # dataset root dir
65
+ urls = ['https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-train.zip',
66
+ 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-val.zip',
67
+ 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-dev.zip',
68
+ 'https://github.com/ultralytics/yolov5/releases/download/v1.0/VisDrone2019-DET-test-challenge.zip']
69
+ download(urls, dir=dir, curl=True, threads=4)
70
+
71
+ # Convert
72
+ for d in 'VisDrone2019-DET-train', 'VisDrone2019-DET-val', 'VisDrone2019-DET-test-dev':
73
+ visdrone2yolo(dir / d) # convert VisDrone annotations to YOLO labels
ultralytics/datasets/coco-pose.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # COCO 2017 dataset http://cocodataset.org by Microsoft
3
+ # Example usage: yolo train data=coco-pose.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── coco-pose ← downloads here (20.1 GB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/coco-pose # dataset root dir
12
+ train: train2017.txt # train images (relative to 'path') 118287 images
13
+ val: val2017.txt # val images (relative to 'path') 5000 images
14
+ test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
15
+
16
+ # Keypoints
17
+ kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
18
+ flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
19
+
20
+ # Classes
21
+ names:
22
+ 0: person
23
+
24
+ # Download script/URL (optional)
25
+ download: |
26
+ from ultralytics.yolo.utils.downloads import download
27
+ from pathlib import Path
28
+
29
+ # Download labels
30
+ dir = Path(yaml['path']) # dataset root dir
31
+ url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
32
+ urls = [url + 'coco2017labels-pose.zip'] # labels
33
+ download(urls, dir=dir.parent)
34
+ # Download data
35
+ urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
36
+ 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
37
+ 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
38
+ download(urls, dir=dir / 'images', threads=3)
ultralytics/datasets/coco.yaml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # COCO 2017 dataset http://cocodataset.org by Microsoft
3
+ # Example usage: yolo train data=coco.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── coco ← downloads here (20.1 GB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/coco # dataset root dir
12
+ train: train2017.txt # train images (relative to 'path') 118287 images
13
+ val: val2017.txt # val images (relative to 'path') 5000 images
14
+ test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
15
+
16
+ # Classes
17
+ names:
18
+ 0: person
19
+ 1: bicycle
20
+ 2: car
21
+ 3: motorcycle
22
+ 4: airplane
23
+ 5: bus
24
+ 6: train
25
+ 7: truck
26
+ 8: boat
27
+ 9: traffic light
28
+ 10: fire hydrant
29
+ 11: stop sign
30
+ 12: parking meter
31
+ 13: bench
32
+ 14: bird
33
+ 15: cat
34
+ 16: dog
35
+ 17: horse
36
+ 18: sheep
37
+ 19: cow
38
+ 20: elephant
39
+ 21: bear
40
+ 22: zebra
41
+ 23: giraffe
42
+ 24: backpack
43
+ 25: umbrella
44
+ 26: handbag
45
+ 27: tie
46
+ 28: suitcase
47
+ 29: frisbee
48
+ 30: skis
49
+ 31: snowboard
50
+ 32: sports ball
51
+ 33: kite
52
+ 34: baseball bat
53
+ 35: baseball glove
54
+ 36: skateboard
55
+ 37: surfboard
56
+ 38: tennis racket
57
+ 39: bottle
58
+ 40: wine glass
59
+ 41: cup
60
+ 42: fork
61
+ 43: knife
62
+ 44: spoon
63
+ 45: bowl
64
+ 46: banana
65
+ 47: apple
66
+ 48: sandwich
67
+ 49: orange
68
+ 50: broccoli
69
+ 51: carrot
70
+ 52: hot dog
71
+ 53: pizza
72
+ 54: donut
73
+ 55: cake
74
+ 56: chair
75
+ 57: couch
76
+ 58: potted plant
77
+ 59: bed
78
+ 60: dining table
79
+ 61: toilet
80
+ 62: tv
81
+ 63: laptop
82
+ 64: mouse
83
+ 65: remote
84
+ 66: keyboard
85
+ 67: cell phone
86
+ 68: microwave
87
+ 69: oven
88
+ 70: toaster
89
+ 71: sink
90
+ 72: refrigerator
91
+ 73: book
92
+ 74: clock
93
+ 75: vase
94
+ 76: scissors
95
+ 77: teddy bear
96
+ 78: hair drier
97
+ 79: toothbrush
98
+
99
+
100
+ # Download script/URL (optional)
101
+ download: |
102
+ from ultralytics.yolo.utils.downloads import download
103
+ from pathlib import Path
104
+
105
+ # Download labels
106
+ segments = True # segment or box labels
107
+ dir = Path(yaml['path']) # dataset root dir
108
+ url = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/'
109
+ urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels
110
+ download(urls, dir=dir.parent)
111
+ # Download data
112
+ urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
113
+ 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
114
+ 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
115
+ download(urls, dir=dir / 'images', threads=3)
ultralytics/datasets/coco128-seg.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # COCO128-seg dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
3
+ # Example usage: yolo train data=coco128.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── coco128-seg ← downloads here (7 MB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/coco128-seg # dataset root dir
12
+ train: images/train2017 # train images (relative to 'path') 128 images
13
+ val: images/train2017 # val images (relative to 'path') 128 images
14
+ test: # test images (optional)
15
+
16
+ # Classes
17
+ names:
18
+ 0: person
19
+ 1: bicycle
20
+ 2: car
21
+ 3: motorcycle
22
+ 4: airplane
23
+ 5: bus
24
+ 6: train
25
+ 7: truck
26
+ 8: boat
27
+ 9: traffic light
28
+ 10: fire hydrant
29
+ 11: stop sign
30
+ 12: parking meter
31
+ 13: bench
32
+ 14: bird
33
+ 15: cat
34
+ 16: dog
35
+ 17: horse
36
+ 18: sheep
37
+ 19: cow
38
+ 20: elephant
39
+ 21: bear
40
+ 22: zebra
41
+ 23: giraffe
42
+ 24: backpack
43
+ 25: umbrella
44
+ 26: handbag
45
+ 27: tie
46
+ 28: suitcase
47
+ 29: frisbee
48
+ 30: skis
49
+ 31: snowboard
50
+ 32: sports ball
51
+ 33: kite
52
+ 34: baseball bat
53
+ 35: baseball glove
54
+ 36: skateboard
55
+ 37: surfboard
56
+ 38: tennis racket
57
+ 39: bottle
58
+ 40: wine glass
59
+ 41: cup
60
+ 42: fork
61
+ 43: knife
62
+ 44: spoon
63
+ 45: bowl
64
+ 46: banana
65
+ 47: apple
66
+ 48: sandwich
67
+ 49: orange
68
+ 50: broccoli
69
+ 51: carrot
70
+ 52: hot dog
71
+ 53: pizza
72
+ 54: donut
73
+ 55: cake
74
+ 56: chair
75
+ 57: couch
76
+ 58: potted plant
77
+ 59: bed
78
+ 60: dining table
79
+ 61: toilet
80
+ 62: tv
81
+ 63: laptop
82
+ 64: mouse
83
+ 65: remote
84
+ 66: keyboard
85
+ 67: cell phone
86
+ 68: microwave
87
+ 69: oven
88
+ 70: toaster
89
+ 71: sink
90
+ 72: refrigerator
91
+ 73: book
92
+ 74: clock
93
+ 75: vase
94
+ 76: scissors
95
+ 77: teddy bear
96
+ 78: hair drier
97
+ 79: toothbrush
98
+
99
+
100
+ # Download script/URL (optional)
101
+ download: https://ultralytics.com/assets/coco128-seg.zip
ultralytics/datasets/coco128.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics
3
+ # Example usage: yolo train data=coco128.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── coco128 ← downloads here (7 MB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/coco128 # dataset root dir
12
+ train: images/train2017 # train images (relative to 'path') 128 images
13
+ val: images/train2017 # val images (relative to 'path') 128 images
14
+ test: # test images (optional)
15
+
16
+ # Classes
17
+ names:
18
+ 0: person
19
+ 1: bicycle
20
+ 2: car
21
+ 3: motorcycle
22
+ 4: airplane
23
+ 5: bus
24
+ 6: train
25
+ 7: truck
26
+ 8: boat
27
+ 9: traffic light
28
+ 10: fire hydrant
29
+ 11: stop sign
30
+ 12: parking meter
31
+ 13: bench
32
+ 14: bird
33
+ 15: cat
34
+ 16: dog
35
+ 17: horse
36
+ 18: sheep
37
+ 19: cow
38
+ 20: elephant
39
+ 21: bear
40
+ 22: zebra
41
+ 23: giraffe
42
+ 24: backpack
43
+ 25: umbrella
44
+ 26: handbag
45
+ 27: tie
46
+ 28: suitcase
47
+ 29: frisbee
48
+ 30: skis
49
+ 31: snowboard
50
+ 32: sports ball
51
+ 33: kite
52
+ 34: baseball bat
53
+ 35: baseball glove
54
+ 36: skateboard
55
+ 37: surfboard
56
+ 38: tennis racket
57
+ 39: bottle
58
+ 40: wine glass
59
+ 41: cup
60
+ 42: fork
61
+ 43: knife
62
+ 44: spoon
63
+ 45: bowl
64
+ 46: banana
65
+ 47: apple
66
+ 48: sandwich
67
+ 49: orange
68
+ 50: broccoli
69
+ 51: carrot
70
+ 52: hot dog
71
+ 53: pizza
72
+ 54: donut
73
+ 55: cake
74
+ 56: chair
75
+ 57: couch
76
+ 58: potted plant
77
+ 59: bed
78
+ 60: dining table
79
+ 61: toilet
80
+ 62: tv
81
+ 63: laptop
82
+ 64: mouse
83
+ 65: remote
84
+ 66: keyboard
85
+ 67: cell phone
86
+ 68: microwave
87
+ 69: oven
88
+ 70: toaster
89
+ 71: sink
90
+ 72: refrigerator
91
+ 73: book
92
+ 74: clock
93
+ 75: vase
94
+ 76: scissors
95
+ 77: teddy bear
96
+ 78: hair drier
97
+ 79: toothbrush
98
+
99
+
100
+ # Download script/URL (optional)
101
+ download: https://ultralytics.com/assets/coco128.zip
ultralytics/datasets/coco8-pose.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # COCO8-pose dataset (first 8 images from COCO train2017) by Ultralytics
3
+ # Example usage: yolo train data=coco8-pose.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── coco8-pose ← downloads here (1 MB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/coco8-pose # dataset root dir
12
+ train: images/train # train images (relative to 'path') 4 images
13
+ val: images/val # val images (relative to 'path') 4 images
14
+ test: # test images (optional)
15
+
16
+ # Keypoints
17
+ kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
18
+ flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
19
+
20
+ # Classes
21
+ names:
22
+ 0: person
23
+
24
+ # Download script/URL (optional)
25
+ download: https://ultralytics.com/assets/coco8-pose.zip
ultralytics/datasets/coco8-seg.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # COCO8-seg dataset (first 8 images from COCO train2017) by Ultralytics
3
+ # Example usage: yolo train data=coco8-seg.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── coco8-seg ← downloads here (1 MB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/coco8-seg # dataset root dir
12
+ train: images/train # train images (relative to 'path') 4 images
13
+ val: images/val # val images (relative to 'path') 4 images
14
+ test: # test images (optional)
15
+
16
+ # Classes
17
+ names:
18
+ 0: person
19
+ 1: bicycle
20
+ 2: car
21
+ 3: motorcycle
22
+ 4: airplane
23
+ 5: bus
24
+ 6: train
25
+ 7: truck
26
+ 8: boat
27
+ 9: traffic light
28
+ 10: fire hydrant
29
+ 11: stop sign
30
+ 12: parking meter
31
+ 13: bench
32
+ 14: bird
33
+ 15: cat
34
+ 16: dog
35
+ 17: horse
36
+ 18: sheep
37
+ 19: cow
38
+ 20: elephant
39
+ 21: bear
40
+ 22: zebra
41
+ 23: giraffe
42
+ 24: backpack
43
+ 25: umbrella
44
+ 26: handbag
45
+ 27: tie
46
+ 28: suitcase
47
+ 29: frisbee
48
+ 30: skis
49
+ 31: snowboard
50
+ 32: sports ball
51
+ 33: kite
52
+ 34: baseball bat
53
+ 35: baseball glove
54
+ 36: skateboard
55
+ 37: surfboard
56
+ 38: tennis racket
57
+ 39: bottle
58
+ 40: wine glass
59
+ 41: cup
60
+ 42: fork
61
+ 43: knife
62
+ 44: spoon
63
+ 45: bowl
64
+ 46: banana
65
+ 47: apple
66
+ 48: sandwich
67
+ 49: orange
68
+ 50: broccoli
69
+ 51: carrot
70
+ 52: hot dog
71
+ 53: pizza
72
+ 54: donut
73
+ 55: cake
74
+ 56: chair
75
+ 57: couch
76
+ 58: potted plant
77
+ 59: bed
78
+ 60: dining table
79
+ 61: toilet
80
+ 62: tv
81
+ 63: laptop
82
+ 64: mouse
83
+ 65: remote
84
+ 66: keyboard
85
+ 67: cell phone
86
+ 68: microwave
87
+ 69: oven
88
+ 70: toaster
89
+ 71: sink
90
+ 72: refrigerator
91
+ 73: book
92
+ 74: clock
93
+ 75: vase
94
+ 76: scissors
95
+ 77: teddy bear
96
+ 78: hair drier
97
+ 79: toothbrush
98
+
99
+
100
+ # Download script/URL (optional)
101
+ download: https://ultralytics.com/assets/coco8-seg.zip
ultralytics/datasets/coco8.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # COCO8 dataset (first 8 images from COCO train2017) by Ultralytics
3
+ # Example usage: yolo train data=coco8.yaml
4
+ # parent
5
+ # ├── ultralytics
6
+ # └── datasets
7
+ # └── coco8 ← downloads here (1 MB)
8
+
9
+
10
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11
+ path: ../datasets/coco8 # dataset root dir
12
+ train: images/train # train images (relative to 'path') 4 images
13
+ val: images/val # val images (relative to 'path') 4 images
14
+ test: # test images (optional)
15
+
16
+ # Classes
17
+ names:
18
+ 0: person
19
+ 1: bicycle
20
+ 2: car
21
+ 3: motorcycle
22
+ 4: airplane
23
+ 5: bus
24
+ 6: train
25
+ 7: truck
26
+ 8: boat
27
+ 9: traffic light
28
+ 10: fire hydrant
29
+ 11: stop sign
30
+ 12: parking meter
31
+ 13: bench
32
+ 14: bird
33
+ 15: cat
34
+ 16: dog
35
+ 17: horse
36
+ 18: sheep
37
+ 19: cow
38
+ 20: elephant
39
+ 21: bear
40
+ 22: zebra
41
+ 23: giraffe
42
+ 24: backpack
43
+ 25: umbrella
44
+ 26: handbag
45
+ 27: tie
46
+ 28: suitcase
47
+ 29: frisbee
48
+ 30: skis
49
+ 31: snowboard
50
+ 32: sports ball
51
+ 33: kite
52
+ 34: baseball bat
53
+ 35: baseball glove
54
+ 36: skateboard
55
+ 37: surfboard
56
+ 38: tennis racket
57
+ 39: bottle
58
+ 40: wine glass
59
+ 41: cup
60
+ 42: fork
61
+ 43: knife
62
+ 44: spoon
63
+ 45: bowl
64
+ 46: banana
65
+ 47: apple
66
+ 48: sandwich
67
+ 49: orange
68
+ 50: broccoli
69
+ 51: carrot
70
+ 52: hot dog
71
+ 53: pizza
72
+ 54: donut
73
+ 55: cake
74
+ 56: chair
75
+ 57: couch
76
+ 58: potted plant
77
+ 59: bed
78
+ 60: dining table
79
+ 61: toilet
80
+ 62: tv
81
+ 63: laptop
82
+ 64: mouse
83
+ 65: remote
84
+ 66: keyboard
85
+ 67: cell phone
86
+ 68: microwave
87
+ 69: oven
88
+ 70: toaster
89
+ 71: sink
90
+ 72: refrigerator
91
+ 73: book
92
+ 74: clock
93
+ 75: vase
94
+ 76: scissors
95
+ 77: teddy bear
96
+ 78: hair drier
97
+ 79: toothbrush
98
+
99
+
100
+ # Download script/URL (optional)
101
+ download: https://ultralytics.com/assets/coco8.zip
ultralytics/datasets/xView.yaml ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # DIUx xView 2018 Challenge https://challenge.xviewdataset.org by U.S. National Geospatial-Intelligence Agency (NGA)
3
+ # -------- DOWNLOAD DATA MANUALLY and jar xf val_images.zip to 'datasets/xView' before running train command! --------
4
+ # Example usage: yolo train data=xView.yaml
5
+ # parent
6
+ # ├── ultralytics
7
+ # └── datasets
8
+ # └── xView ← downloads here (20.7 GB)
9
+
10
+
11
+ # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
12
+ path: ../datasets/xView # dataset root dir
13
+ train: images/autosplit_train.txt # train images (relative to 'path') 90% of 847 train images
14
+ val: images/autosplit_val.txt # train images (relative to 'path') 10% of 847 train images
15
+
16
+ # Classes
17
+ names:
18
+ 0: Fixed-wing Aircraft
19
+ 1: Small Aircraft
20
+ 2: Cargo Plane
21
+ 3: Helicopter
22
+ 4: Passenger Vehicle
23
+ 5: Small Car
24
+ 6: Bus
25
+ 7: Pickup Truck
26
+ 8: Utility Truck
27
+ 9: Truck
28
+ 10: Cargo Truck
29
+ 11: Truck w/Box
30
+ 12: Truck Tractor
31
+ 13: Trailer
32
+ 14: Truck w/Flatbed
33
+ 15: Truck w/Liquid
34
+ 16: Crane Truck
35
+ 17: Railway Vehicle
36
+ 18: Passenger Car
37
+ 19: Cargo Car
38
+ 20: Flat Car
39
+ 21: Tank car
40
+ 22: Locomotive
41
+ 23: Maritime Vessel
42
+ 24: Motorboat
43
+ 25: Sailboat
44
+ 26: Tugboat
45
+ 27: Barge
46
+ 28: Fishing Vessel
47
+ 29: Ferry
48
+ 30: Yacht
49
+ 31: Container Ship
50
+ 32: Oil Tanker
51
+ 33: Engineering Vehicle
52
+ 34: Tower crane
53
+ 35: Container Crane
54
+ 36: Reach Stacker
55
+ 37: Straddle Carrier
56
+ 38: Mobile Crane
57
+ 39: Dump Truck
58
+ 40: Haul Truck
59
+ 41: Scraper/Tractor
60
+ 42: Front loader/Bulldozer
61
+ 43: Excavator
62
+ 44: Cement Mixer
63
+ 45: Ground Grader
64
+ 46: Hut/Tent
65
+ 47: Shed
66
+ 48: Building
67
+ 49: Aircraft Hangar
68
+ 50: Damaged Building
69
+ 51: Facility
70
+ 52: Construction Site
71
+ 53: Vehicle Lot
72
+ 54: Helipad
73
+ 55: Storage Tank
74
+ 56: Shipping container lot
75
+ 57: Shipping Container
76
+ 58: Pylon
77
+ 59: Tower
78
+
79
+
80
+ # Download script/URL (optional) ---------------------------------------------------------------------------------------
81
+ download: |
82
+ import json
83
+ import os
84
+ from pathlib import Path
85
+
86
+ import numpy as np
87
+ from PIL import Image
88
+ from tqdm import tqdm
89
+
90
+ from ultralytics.yolo.data.dataloaders.v5loader import autosplit
91
+ from ultralytics.yolo.utils.ops import xyxy2xywhn
92
+
93
+
94
+ def convert_labels(fname=Path('xView/xView_train.geojson')):
95
+ # Convert xView geoJSON labels to YOLO format
96
+ path = fname.parent
97
+ with open(fname) as f:
98
+ print(f'Loading {fname}...')
99
+ data = json.load(f)
100
+
101
+ # Make dirs
102
+ labels = Path(path / 'labels' / 'train')
103
+ os.system(f'rm -rf {labels}')
104
+ labels.mkdir(parents=True, exist_ok=True)
105
+
106
+ # xView classes 11-94 to 0-59
107
+ xview_class2index = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, -1, 3, -1, 4, 5, 6, 7, 8, -1, 9, 10, 11,
108
+ 12, 13, 14, 15, -1, -1, 16, 17, 18, 19, 20, 21, 22, -1, 23, 24, 25, -1, 26, 27, -1, 28, -1,
109
+ 29, 30, 31, 32, 33, 34, 35, 36, 37, -1, 38, 39, 40, 41, 42, 43, 44, 45, -1, -1, -1, -1, 46,
110
+ 47, 48, 49, -1, 50, 51, -1, 52, -1, -1, -1, 53, 54, -1, 55, -1, -1, 56, -1, 57, -1, 58, 59]
111
+
112
+ shapes = {}
113
+ for feature in tqdm(data['features'], desc=f'Converting {fname}'):
114
+ p = feature['properties']
115
+ if p['bounds_imcoords']:
116
+ id = p['image_id']
117
+ file = path / 'train_images' / id
118
+ if file.exists(): # 1395.tif missing
119
+ try:
120
+ box = np.array([int(num) for num in p['bounds_imcoords'].split(",")])
121
+ assert box.shape[0] == 4, f'incorrect box shape {box.shape[0]}'
122
+ cls = p['type_id']
123
+ cls = xview_class2index[int(cls)] # xView class to 0-60
124
+ assert 59 >= cls >= 0, f'incorrect class index {cls}'
125
+
126
+ # Write YOLO label
127
+ if id not in shapes:
128
+ shapes[id] = Image.open(file).size
129
+ box = xyxy2xywhn(box[None].astype(np.float), w=shapes[id][0], h=shapes[id][1], clip=True)
130
+ with open((labels / id).with_suffix('.txt'), 'a') as f:
131
+ f.write(f"{cls} {' '.join(f'{x:.6f}' for x in box[0])}\n") # write label.txt
132
+ except Exception as e:
133
+ print(f'WARNING: skipping one label for {file}: {e}')
134
+
135
+
136
+ # Download manually from https://challenge.xviewdataset.org
137
+ dir = Path(yaml['path']) # dataset root dir
138
+ # urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
139
+ # 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
140
+ # 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
141
+ # download(urls, dir=dir)
142
+
143
+ # Convert labels
144
+ convert_labels(dir / 'xView_train.geojson')
145
+
146
+ # Move images
147
+ images = Path(dir / 'images')
148
+ images.mkdir(parents=True, exist_ok=True)
149
+ Path(dir / 'train_images').rename(dir / 'images' / 'train')
150
+ Path(dir / 'val_images').rename(dir / 'images' / 'val')
151
+
152
+ # Split
153
+ autosplit(dir / 'images' / 'train')
ultralytics/hub/__init__.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import requests
4
+
5
+ from ultralytics.hub.auth import Auth
6
+ from ultralytics.hub.utils import PREFIX
7
+ from ultralytics.yolo.data.utils import HUBDatasetStats
8
+ from ultralytics.yolo.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
9
+
10
+
11
+ def login(api_key=''):
12
+ """
13
+ Log in to the Ultralytics HUB API using the provided API key.
14
+
15
+ Args:
16
+ api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
17
+
18
+ Example:
19
+ from ultralytics import hub
20
+ hub.login('API_KEY')
21
+ """
22
+ Auth(api_key, verbose=True)
23
+
24
+
25
+ def logout():
26
+ """
27
+ Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo hub login'.
28
+
29
+ Example:
30
+ from ultralytics import hub
31
+ hub.logout()
32
+ """
33
+ SETTINGS['api_key'] = ''
34
+ yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS)
35
+ LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
36
+
37
+
38
+ def start(key=''):
39
+ """
40
+ Start training models with Ultralytics HUB (DEPRECATED).
41
+
42
+ Args:
43
+ key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
44
+ or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
45
+ """
46
+ api_key, model_id = key.split('_')
47
+ LOGGER.warning(f"""
48
+ WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is:
49
+
50
+ from ultralytics import YOLO, hub
51
+
52
+ hub.login('{api_key}')
53
+ model = YOLO('https://hub.ultralytics.com/models/{model_id}')
54
+ model.train()""")
55
+
56
+
57
+ def reset_model(model_id=''):
58
+ """Reset a trained model to an untrained state."""
59
+ r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
60
+ if r.status_code == 200:
61
+ LOGGER.info(f'{PREFIX}Model reset successfully')
62
+ return
63
+ LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
64
+
65
+
66
+ def export_fmts_hub():
67
+ """Returns a list of HUB-supported export formats."""
68
+ from ultralytics.yolo.engine.exporter import export_formats
69
+ return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
70
+
71
+
72
+ def export_model(model_id='', format='torchscript'):
73
+ """Export a model to all formats."""
74
+ assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
75
+ r = requests.post(f'https://api.ultralytics.com/v1/models/{model_id}/export',
76
+ json={'format': format},
77
+ headers={'x-api-key': Auth().api_key})
78
+ assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
79
+ LOGGER.info(f'{PREFIX}{format} export started ✅')
80
+
81
+
82
+ def get_export(model_id='', format='torchscript'):
83
+ """Get an exported model dictionary with download URL."""
84
+ assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
85
+ r = requests.post('https://api.ultralytics.com/get-export',
86
+ json={
87
+ 'apiKey': Auth().api_key,
88
+ 'modelId': model_id,
89
+ 'format': format})
90
+ assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
91
+ return r.json()
92
+
93
+
94
+ def check_dataset(path='', task='detect'):
95
+ """
96
+ Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is
97
+ uploaded to the HUB. Usage examples are given below.
98
+
99
+ Args:
100
+ path (str, optional): Path to data.zip (with data.yaml inside data.zip). Defaults to ''.
101
+ task (str, optional): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Defaults to 'detect'.
102
+
103
+ Example:
104
+ ```python
105
+ from ultralytics.hub import check_dataset
106
+
107
+ check_dataset('path/to/coco8.zip', task='detect') # detect dataset
108
+ check_dataset('path/to/coco8-seg.zip', task='segment') # segment dataset
109
+ check_dataset('path/to/coco8-pose.zip', task='pose') # pose dataset
110
+ ```
111
+ """
112
+ HUBDatasetStats(path=path, task=task).get_json()
113
+ LOGGER.info('Checks completed correctly ✅. Upload this dataset to https://hub.ultralytics.com/datasets/.')
114
+
115
+
116
+ if __name__ == '__main__':
117
+ start()
ultralytics/hub/auth.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import requests
4
+
5
+ from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, request_with_credentials
6
+ from ultralytics.yolo.utils import LOGGER, SETTINGS, emojis, is_colab, set_settings
7
+
8
+ API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys'
9
+
10
+
11
+ class Auth:
12
+ id_token = api_key = model_key = False
13
+
14
+ def __init__(self, api_key='', verbose=False):
15
+ """
16
+ Initialize the Auth class with an optional API key.
17
+
18
+ Args:
19
+ api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
20
+ """
21
+ # Split the input API key in case it contains a combined key_model and keep only the API key part
22
+ api_key = api_key.split('_')[0]
23
+
24
+ # Set API key attribute as value passed or SETTINGS API key if none passed
25
+ self.api_key = api_key or SETTINGS.get('api_key', '')
26
+
27
+ # If an API key is provided
28
+ if self.api_key:
29
+ # If the provided API key matches the API key in the SETTINGS
30
+ if self.api_key == SETTINGS.get('api_key'):
31
+ # Log that the user is already logged in
32
+ if verbose:
33
+ LOGGER.info(f'{PREFIX}Authenticated ✅')
34
+ return
35
+ else:
36
+ # Attempt to authenticate with the provided API key
37
+ success = self.authenticate()
38
+ # If the API key is not provided and the environment is a Google Colab notebook
39
+ elif is_colab():
40
+ # Attempt to authenticate using browser cookies
41
+ success = self.auth_with_cookies()
42
+ else:
43
+ # Request an API key
44
+ success = self.request_api_key()
45
+
46
+ # Update SETTINGS with the new API key after successful authentication
47
+ if success:
48
+ set_settings({'api_key': self.api_key})
49
+ # Log that the new login was successful
50
+ if verbose:
51
+ LOGGER.info(f'{PREFIX}New authentication successful ✅')
52
+ elif verbose:
53
+ LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
54
+
55
+ def request_api_key(self, max_attempts=3):
56
+ """
57
+ Prompt the user to input their API key. Returns the model ID.
58
+ """
59
+ import getpass
60
+ for attempts in range(max_attempts):
61
+ LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
62
+ input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
63
+ self.api_key = input_key.split('_')[0] # remove model id if present
64
+ if self.authenticate():
65
+ return True
66
+ raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
67
+
68
+ def authenticate(self) -> bool:
69
+ """
70
+ Attempt to authenticate with the server using either id_token or API key.
71
+
72
+ Returns:
73
+ bool: True if authentication is successful, False otherwise.
74
+ """
75
+ try:
76
+ header = self.get_auth_header()
77
+ if header:
78
+ r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
79
+ if not r.json().get('success', False):
80
+ raise ConnectionError('Unable to authenticate.')
81
+ return True
82
+ raise ConnectionError('User has not authenticated locally.')
83
+ except ConnectionError:
84
+ self.id_token = self.api_key = False # reset invalid
85
+ LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
86
+ return False
87
+
88
+ def auth_with_cookies(self) -> bool:
89
+ """
90
+ Attempt to fetch authentication via cookies and set id_token.
91
+ User must be logged in to HUB and running in a supported browser.
92
+
93
+ Returns:
94
+ bool: True if authentication is successful, False otherwise.
95
+ """
96
+ if not is_colab():
97
+ return False # Currently only works with Colab
98
+ try:
99
+ authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
100
+ if authn.get('success', False):
101
+ self.id_token = authn.get('data', {}).get('idToken', None)
102
+ self.authenticate()
103
+ return True
104
+ raise ConnectionError('Unable to fetch browser authentication details.')
105
+ except ConnectionError:
106
+ self.id_token = False # reset invalid
107
+ return False
108
+
109
+ def get_auth_header(self):
110
+ """
111
+ Get the authentication header for making API requests.
112
+
113
+ Returns:
114
+ (dict): The authentication header if id_token or API key is set, None otherwise.
115
+ """
116
+ if self.id_token:
117
+ return {'authorization': f'Bearer {self.id_token}'}
118
+ elif self.api_key:
119
+ return {'x-api-key': self.api_key}
120
+ else:
121
+ return None
122
+
123
+ def get_state(self) -> bool:
124
+ """
125
+ Get the authentication state.
126
+
127
+ Returns:
128
+ bool: True if either id_token or API key is set, False otherwise.
129
+ """
130
+ return self.id_token or self.api_key
131
+
132
+ def set_api_key(self, key: str):
133
+ """
134
+ Set the API key for authentication.
135
+
136
+ Args:
137
+ key (str): The API key string.
138
+ """
139
+ self.api_key = key
ultralytics/hub/session.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ import signal
3
+ import sys
4
+ from pathlib import Path
5
+ from time import sleep
6
+
7
+ import requests
8
+
9
+ from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, smart_request
10
+ from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
11
+ from ultralytics.yolo.utils.errors import HUBModelError
12
+
13
+ AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
14
+
15
+
16
+ class HUBTrainingSession:
17
+ """
18
+ HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
19
+
20
+ Args:
21
+ url (str): Model identifier used to initialize the HUB training session.
22
+
23
+ Attributes:
24
+ agent_id (str): Identifier for the instance communicating with the server.
25
+ model_id (str): Identifier for the YOLOv5 model being trained.
26
+ model_url (str): URL for the model in Ultralytics HUB.
27
+ api_url (str): API URL for the model in Ultralytics HUB.
28
+ auth_header (Dict): Authentication header for the Ultralytics HUB API requests.
29
+ rate_limits (Dict): Rate limits for different API calls (in seconds).
30
+ timers (Dict): Timers for rate limiting.
31
+ metrics_queue (Dict): Queue for the model's metrics.
32
+ model (Dict): Model data fetched from Ultralytics HUB.
33
+ alive (bool): Indicates if the heartbeat loop is active.
34
+ """
35
+
36
+ def __init__(self, url):
37
+ """
38
+ Initialize the HUBTrainingSession with the provided model identifier.
39
+
40
+ Args:
41
+ url (str): Model identifier used to initialize the HUB training session.
42
+ It can be a URL string or a model key with specific format.
43
+
44
+ Raises:
45
+ ValueError: If the provided model identifier is invalid.
46
+ ConnectionError: If connecting with global API key is not supported.
47
+ """
48
+
49
+ from ultralytics.hub.auth import Auth
50
+
51
+ # Parse input
52
+ if url.startswith('https://hub.ultralytics.com/models/'):
53
+ url = url.split('https://hub.ultralytics.com/models/')[-1]
54
+ if [len(x) for x in url.split('_')] == [42, 20]:
55
+ key, model_id = url.split('_')
56
+ elif len(url) == 20:
57
+ key, model_id = '', url
58
+ else:
59
+ raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
60
+ f"model='https://hub.ultralytics.com/models/MODEL_ID' and try again.")
61
+
62
+ # Authorize
63
+ auth = Auth(key)
64
+ self.agent_id = None # identifies which instance is communicating with server
65
+ self.model_id = model_id
66
+ self.model_url = f'https://hub.ultralytics.com/models/{model_id}'
67
+ self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
68
+ self.auth_header = auth.get_auth_header()
69
+ self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
70
+ self.timers = {} # rate limit timers (seconds)
71
+ self.metrics_queue = {} # metrics queue
72
+ self.model = self._get_model()
73
+ self.alive = True
74
+ self._start_heartbeat() # start heartbeats
75
+ self._register_signal_handlers()
76
+ LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
77
+
78
+ def _register_signal_handlers(self):
79
+ """Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
80
+ signal.signal(signal.SIGTERM, self._handle_signal)
81
+ signal.signal(signal.SIGINT, self._handle_signal)
82
+
83
+ def _handle_signal(self, signum, frame):
84
+ """
85
+ Handle kill signals and prevent heartbeats from being sent on Colab after termination.
86
+ This method does not use frame, it is included as it is passed by signal.
87
+ """
88
+ if self.alive is True:
89
+ LOGGER.info(f'{PREFIX}Kill signal received! ❌')
90
+ self._stop_heartbeat()
91
+ sys.exit(signum)
92
+
93
+ def _stop_heartbeat(self):
94
+ """Terminate the heartbeat loop."""
95
+ self.alive = False
96
+
97
+ def upload_metrics(self):
98
+ """Upload model metrics to Ultralytics HUB."""
99
+ payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
100
+ smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
101
+
102
+ def _get_model(self):
103
+ """Fetch and return model data from Ultralytics HUB."""
104
+ api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
105
+
106
+ try:
107
+ response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
108
+ data = response.json().get('data', None)
109
+
110
+ if data.get('status', None) == 'trained':
111
+ raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
112
+
113
+ if not data.get('data', None):
114
+ raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
115
+ self.model_id = data['id']
116
+
117
+ if data['status'] == 'new': # new model to start training
118
+ self.train_args = {
119
+ # TODO: deprecate 'batch_size' key for 'batch' in 3Q23
120
+ 'batch': data['batch' if ('batch' in data) else 'batch_size'],
121
+ 'epochs': data['epochs'],
122
+ 'imgsz': data['imgsz'],
123
+ 'patience': data['patience'],
124
+ 'device': data['device'],
125
+ 'cache': data['cache'],
126
+ 'data': data['data']}
127
+ self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False
128
+ self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
129
+ elif data['status'] == 'training': # existing model to resume training
130
+ self.train_args = {'data': data['data'], 'resume': True}
131
+ self.model_file = data['resume']
132
+
133
+ return data
134
+ except requests.exceptions.ConnectionError as e:
135
+ raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
136
+ except Exception:
137
+ raise
138
+
139
+ def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
140
+ """
141
+ Upload a model checkpoint to Ultralytics HUB.
142
+
143
+ Args:
144
+ epoch (int): The current training epoch.
145
+ weights (str): Path to the model weights file.
146
+ is_best (bool): Indicates if the current model is the best one so far.
147
+ map (float): Mean average precision of the model.
148
+ final (bool): Indicates if the model is the final model after training.
149
+ """
150
+ if Path(weights).is_file():
151
+ with open(weights, 'rb') as f:
152
+ file = f.read()
153
+ else:
154
+ LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
155
+ file = None
156
+ url = f'{self.api_url}/upload'
157
+ # url = 'http://httpbin.org/post' # for debug
158
+ data = {'epoch': epoch}
159
+ if final:
160
+ data.update({'type': 'final', 'map': map})
161
+ smart_request('post',
162
+ url,
163
+ data=data,
164
+ files={'best.pt': file},
165
+ headers=self.auth_header,
166
+ retry=10,
167
+ timeout=3600,
168
+ thread=False,
169
+ progress=True,
170
+ code=4)
171
+ else:
172
+ data.update({'type': 'epoch', 'isBest': bool(is_best)})
173
+ smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
174
+
175
+ @threaded
176
+ def _start_heartbeat(self):
177
+ """Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
178
+ while self.alive:
179
+ r = smart_request('post',
180
+ f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
181
+ json={
182
+ 'agent': AGENT_NAME,
183
+ 'agentId': self.agent_id},
184
+ headers=self.auth_header,
185
+ retry=0,
186
+ code=5,
187
+ thread=False) # already in a thread
188
+ self.agent_id = r.json().get('data', {}).get('agentId', None)
189
+ sleep(self.rate_limits['heartbeat'])
ultralytics/hub/utils.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import os
4
+ import platform
5
+ import random
6
+ import sys
7
+ import threading
8
+ import time
9
+ from pathlib import Path
10
+
11
+ import requests
12
+ from tqdm import tqdm
13
+
14
+ from ultralytics.yolo.utils import (ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING, TQDM_BAR_FORMAT,
15
+ TryExcept, __version__, colorstr, get_git_origin_url, is_colab, is_git_dir,
16
+ is_pip_package)
17
+
18
+ PREFIX = colorstr('Ultralytics HUB: ')
19
+ HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
20
+ HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
21
+
22
+
23
+ def request_with_credentials(url: str) -> any:
24
+ """
25
+ Make an AJAX request with cookies attached in a Google Colab environment.
26
+
27
+ Args:
28
+ url (str): The URL to make the request to.
29
+
30
+ Returns:
31
+ (any): The response data from the AJAX request.
32
+
33
+ Raises:
34
+ OSError: If the function is not run in a Google Colab environment.
35
+ """
36
+ if not is_colab():
37
+ raise OSError('request_with_credentials() must run in a Colab environment')
38
+ from google.colab import output # noqa
39
+ from IPython import display # noqa
40
+ display.display(
41
+ display.Javascript("""
42
+ window._hub_tmp = new Promise((resolve, reject) => {
43
+ const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000)
44
+ fetch("%s", {
45
+ method: 'POST',
46
+ credentials: 'include'
47
+ })
48
+ .then((response) => resolve(response.json()))
49
+ .then((json) => {
50
+ clearTimeout(timeout);
51
+ }).catch((err) => {
52
+ clearTimeout(timeout);
53
+ reject(err);
54
+ });
55
+ });
56
+ """ % url))
57
+ return output.eval_js('_hub_tmp')
58
+
59
+
60
+ def requests_with_progress(method, url, **kwargs):
61
+ """
62
+ Make an HTTP request using the specified method and URL, with an optional progress bar.
63
+
64
+ Args:
65
+ method (str): The HTTP method to use (e.g. 'GET', 'POST').
66
+ url (str): The URL to send the request to.
67
+ **kwargs (dict): Additional keyword arguments to pass to the underlying `requests.request` function.
68
+
69
+ Returns:
70
+ (requests.Response): The response object from the HTTP request.
71
+
72
+ Note:
73
+ If 'progress' is set to True, the progress bar will display the download progress
74
+ for responses with a known content length.
75
+ """
76
+ progress = kwargs.pop('progress', False)
77
+ if not progress:
78
+ return requests.request(method, url, **kwargs)
79
+ response = requests.request(method, url, stream=True, **kwargs)
80
+ total = int(response.headers.get('content-length', 0)) # total size
81
+ pbar = tqdm(total=total, unit='B', unit_scale=True, unit_divisor=1024, bar_format=TQDM_BAR_FORMAT)
82
+ for data in response.iter_content(chunk_size=1024):
83
+ pbar.update(len(data))
84
+ pbar.close()
85
+ return response
86
+
87
+
88
+ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs):
89
+ """
90
+ Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
91
+
92
+ Args:
93
+ method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
94
+ url (str): The URL to make the request to.
95
+ retry (int, optional): Number of retries to attempt before giving up. Default is 3.
96
+ timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
97
+ thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
98
+ code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
99
+ verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
100
+ progress (bool, optional): Whether to show a progress bar during the request. Default is False.
101
+ **kwargs (dict): Keyword arguments to be passed to the requests function specified in method.
102
+
103
+ Returns:
104
+ (requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None.
105
+ """
106
+ retry_codes = (408, 500) # retry only these codes
107
+
108
+ @TryExcept(verbose=verbose)
109
+ def func(func_method, func_url, **func_kwargs):
110
+ """Make HTTP requests with retries and timeouts, with optional progress tracking."""
111
+ r = None # response
112
+ t0 = time.time() # initial time for timer
113
+ for i in range(retry + 1):
114
+ if (time.time() - t0) > timeout:
115
+ break
116
+ r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files)
117
+ if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful"
118
+ break
119
+ try:
120
+ m = r.json().get('message', 'No JSON message.')
121
+ except AttributeError:
122
+ m = 'Unable to read JSON.'
123
+ if i == 0:
124
+ if r.status_code in retry_codes:
125
+ m += f' Retrying {retry}x for {timeout}s.' if retry else ''
126
+ elif r.status_code == 429: # rate limit
127
+ h = r.headers # response headers
128
+ m = f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " \
129
+ f"Please retry after {h['Retry-After']}s."
130
+ if verbose:
131
+ LOGGER.warning(f'{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})')
132
+ if r.status_code not in retry_codes:
133
+ return r
134
+ time.sleep(2 ** i) # exponential standoff
135
+ return r
136
+
137
+ args = method, url
138
+ kwargs['progress'] = progress
139
+ if thread:
140
+ threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
141
+ else:
142
+ return func(*args, **kwargs)
143
+
144
+
145
+ class Events:
146
+ """
147
+ A class for collecting anonymous event analytics. Event analytics are enabled when sync=True in settings and
148
+ disabled when sync=False. Run 'yolo settings' to see and update settings YAML file.
149
+
150
+ Attributes:
151
+ url (str): The URL to send anonymous events.
152
+ rate_limit (float): The rate limit in seconds for sending events.
153
+ metadata (dict): A dictionary containing metadata about the environment.
154
+ enabled (bool): A flag to enable or disable Events based on certain conditions.
155
+ """
156
+
157
+ url = 'https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw'
158
+
159
+ def __init__(self):
160
+ """
161
+ Initializes the Events object with default values for events, rate_limit, and metadata.
162
+ """
163
+ self.events = [] # events list
164
+ self.rate_limit = 60.0 # rate limit (seconds)
165
+ self.t = 0.0 # rate limit timer (seconds)
166
+ self.metadata = {
167
+ 'cli': Path(sys.argv[0]).name == 'yolo',
168
+ 'install': 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
169
+ 'python': '.'.join(platform.python_version_tuple()[:2]), # i.e. 3.10
170
+ 'version': __version__,
171
+ 'env': ENVIRONMENT,
172
+ 'session_id': round(random.random() * 1E15),
173
+ 'engagement_time_msec': 1000}
174
+ self.enabled = \
175
+ SETTINGS['sync'] and \
176
+ RANK in (-1, 0) and \
177
+ not TESTS_RUNNING and \
178
+ ONLINE and \
179
+ (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
180
+
181
+ def __call__(self, cfg):
182
+ """
183
+ Attempts to add a new event to the events list and send events if the rate limit is reached.
184
+
185
+ Args:
186
+ cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
187
+ """
188
+ if not self.enabled:
189
+ # Events disabled, do nothing
190
+ return
191
+
192
+ # Attempt to add to events
193
+ if len(self.events) < 25: # Events list limited to 25 events (drop any events past this)
194
+ params = {**self.metadata, **{'task': cfg.task}}
195
+ if cfg.mode == 'export':
196
+ params['format'] = cfg.format
197
+ self.events.append({'name': cfg.mode, 'params': params})
198
+
199
+ # Check rate limit
200
+ t = time.time()
201
+ if (t - self.t) < self.rate_limit:
202
+ # Time is under rate limiter, wait to send
203
+ return
204
+
205
+ # Time is over rate limiter, send now
206
+ data = {'client_id': SETTINGS['uuid'], 'events': self.events} # SHA-256 anonymized UUID hash and events list
207
+
208
+ # POST equivalent to requests.post(self.url, json=data)
209
+ smart_request('post', self.url, json=data, retry=0, verbose=False)
210
+
211
+ # Reset events and rate limit timer
212
+ self.events = []
213
+ self.t = t
214
+
215
+
216
+ # Run below code on hub/utils init -------------------------------------------------------------------------------------
217
+ events = Events()
ultralytics/models/README.md ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Models
2
+
3
+ Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration
4
+ files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted
5
+ and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image
6
+ segmentation tasks.
7
+
8
+ These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like
9
+ instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms,
10
+ from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this
11
+ directory provides a great starting point for your custom model development needs.
12
+
13
+ To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've
14
+ selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full
15
+ details at the Ultralytics [Docs](https://docs.ultralytics.com/models), and if you need help or have any questions, feel free
16
+ to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now!
17
+
18
+ ### Usage
19
+
20
+ Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command:
21
+
22
+ ```bash
23
+ yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
24
+ ```
25
+
26
+ They may also be used directly in a Python environment, and accepts the same
27
+ [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
28
+
29
+ ```python
30
+ from ultralytics import YOLO
31
+
32
+ model = YOLO("model.yaml") # build a YOLOv8n model from scratch
33
+ # YOLO("model.pt") use pre-trained model if available
34
+ model.info() # display model information
35
+ model.train(data="coco128.yaml", epochs=100) # train the model
36
+ ```
37
+
38
+ ## Pre-trained Model Architectures
39
+
40
+ Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information
41
+ and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
42
+
43
+ ## Contributing New Models
44
+
45
+ If you've developed a new model architecture or have improvements for existing models that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing).
ultralytics/models/rt-detr/rtdetr-l.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ l: [1.00, 1.00, 1024]
9
+
10
+ backbone:
11
+ # [from, repeats, module, args]
12
+ - [-1, 1, HGStem, [32, 48]] # 0-P2/4
13
+ - [-1, 6, HGBlock, [48, 128, 3]] # stage 1
14
+
15
+ - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
16
+ - [-1, 6, HGBlock, [96, 512, 3]] # stage 2
17
+
18
+ - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
19
+ - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut
20
+ - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
21
+ - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3
22
+
23
+ - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
24
+ - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4
25
+
26
+ head:
27
+ - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
28
+ - [-1, 1, AIFI, [1024, 8]]
29
+ - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0
30
+
31
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
32
+ - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
33
+ - [[-2, -1], 1, Concat, [1]]
34
+ - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
35
+ - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1
36
+
37
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
38
+ - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
39
+ - [[-2, -1], 1, Concat, [1]] # cat backbone P4
40
+ - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1
41
+
42
+ - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
43
+ - [[-1, 17], 1, Concat, [1]] # cat Y4
44
+ - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0
45
+
46
+ - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
47
+ - [[-1, 12], 1, Concat, [1]] # cat Y5
48
+ - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1
49
+
50
+ - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
ultralytics/models/rt-detr/rtdetr-x.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ x: [1.00, 1.00, 2048]
9
+
10
+ backbone:
11
+ # [from, repeats, module, args]
12
+ - [-1, 1, HGStem, [32, 64]] # 0-P2/4
13
+ - [-1, 6, HGBlock, [64, 128, 3]] # stage 1
14
+
15
+ - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
16
+ - [-1, 6, HGBlock, [128, 512, 3]]
17
+ - [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2
18
+
19
+ - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16
20
+ - [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut
21
+ - [-1, 6, HGBlock, [256, 1024, 5, True, True]]
22
+ - [-1, 6, HGBlock, [256, 1024, 5, True, True]]
23
+ - [-1, 6, HGBlock, [256, 1024, 5, True, True]]
24
+ - [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3
25
+
26
+ - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32
27
+ - [-1, 6, HGBlock, [512, 2048, 5, True, False]]
28
+ - [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4
29
+
30
+ head:
31
+ - [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2
32
+ - [-1, 1, AIFI, [2048, 8]]
33
+ - [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0
34
+
35
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36
+ - [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1
37
+ - [[-2, -1], 1, Concat, [1]]
38
+ - [-1, 3, RepC3, [384]] # 20, fpn_blocks.0
39
+ - [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1
40
+
41
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
42
+ - [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0
43
+ - [[-2, -1], 1, Concat, [1]] # cat backbone P4
44
+ - [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1
45
+
46
+ - [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0
47
+ - [[-1, 21], 1, Concat, [1]] # cat Y4
48
+ - [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0
49
+
50
+ - [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1
51
+ - [[-1, 16], 1, Concat, [1]] # cat Y5
52
+ - [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1
53
+
54
+ - [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
ultralytics/models/v3/yolov3-spp.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ depth_multiple: 1.0 # model depth multiple
7
+ width_multiple: 1.0 # layer channel multiple
8
+
9
+ # darknet53 backbone
10
+ backbone:
11
+ # [from, number, module, args]
12
+ [[-1, 1, Conv, [32, 3, 1]], # 0
13
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
14
+ [-1, 1, Bottleneck, [64]],
15
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
16
+ [-1, 2, Bottleneck, [128]],
17
+ [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
18
+ [-1, 8, Bottleneck, [256]],
19
+ [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
20
+ [-1, 8, Bottleneck, [512]],
21
+ [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
22
+ [-1, 4, Bottleneck, [1024]], # 10
23
+ ]
24
+
25
+ # YOLOv3-SPP head
26
+ head:
27
+ [[-1, 1, Bottleneck, [1024, False]],
28
+ [-1, 1, SPP, [512, [5, 9, 13]]],
29
+ [-1, 1, Conv, [1024, 3, 1]],
30
+ [-1, 1, Conv, [512, 1, 1]],
31
+ [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
32
+
33
+ [-2, 1, Conv, [256, 1, 1]],
34
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
36
+ [-1, 1, Bottleneck, [512, False]],
37
+ [-1, 1, Bottleneck, [512, False]],
38
+ [-1, 1, Conv, [256, 1, 1]],
39
+ [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
40
+
41
+ [-2, 1, Conv, [128, 1, 1]],
42
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
43
+ [[-1, 6], 1, Concat, [1]], # cat backbone P3
44
+ [-1, 1, Bottleneck, [256, False]],
45
+ [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
46
+
47
+ [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
48
+ ]
ultralytics/models/v3/yolov3-tiny.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ depth_multiple: 1.0 # model depth multiple
7
+ width_multiple: 1.0 # layer channel multiple
8
+
9
+ # YOLOv3-tiny backbone
10
+ backbone:
11
+ # [from, number, module, args]
12
+ [[-1, 1, Conv, [16, 3, 1]], # 0
13
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2
14
+ [-1, 1, Conv, [32, 3, 1]],
15
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4
16
+ [-1, 1, Conv, [64, 3, 1]],
17
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8
18
+ [-1, 1, Conv, [128, 3, 1]],
19
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16
20
+ [-1, 1, Conv, [256, 3, 1]],
21
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32
22
+ [-1, 1, Conv, [512, 3, 1]],
23
+ [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11
24
+ [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12
25
+ ]
26
+
27
+ # YOLOv3-tiny head
28
+ head:
29
+ [[-1, 1, Conv, [1024, 3, 1]],
30
+ [-1, 1, Conv, [256, 1, 1]],
31
+ [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large)
32
+
33
+ [-2, 1, Conv, [128, 1, 1]],
34
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
36
+ [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium)
37
+
38
+ [[19, 15], 1, Detect, [nc]], # Detect(P4, P5)
39
+ ]
ultralytics/models/v3/yolov3.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ depth_multiple: 1.0 # model depth multiple
7
+ width_multiple: 1.0 # layer channel multiple
8
+
9
+ # darknet53 backbone
10
+ backbone:
11
+ # [from, number, module, args]
12
+ [[-1, 1, Conv, [32, 3, 1]], # 0
13
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
14
+ [-1, 1, Bottleneck, [64]],
15
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
16
+ [-1, 2, Bottleneck, [128]],
17
+ [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
18
+ [-1, 8, Bottleneck, [256]],
19
+ [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
20
+ [-1, 8, Bottleneck, [512]],
21
+ [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
22
+ [-1, 4, Bottleneck, [1024]], # 10
23
+ ]
24
+
25
+ # YOLOv3 head
26
+ head:
27
+ [[-1, 1, Bottleneck, [1024, False]],
28
+ [-1, 1, Conv, [512, 1, 1]],
29
+ [-1, 1, Conv, [1024, 3, 1]],
30
+ [-1, 1, Conv, [512, 1, 1]],
31
+ [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
32
+
33
+ [-2, 1, Conv, [256, 1, 1]],
34
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
36
+ [-1, 1, Bottleneck, [512, False]],
37
+ [-1, 1, Bottleneck, [512, False]],
38
+ [-1, 1, Conv, [256, 1, 1]],
39
+ [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
40
+
41
+ [-2, 1, Conv, [128, 1, 1]],
42
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
43
+ [[-1, 6], 1, Concat, [1]], # cat backbone P3
44
+ [-1, 1, Bottleneck, [256, False]],
45
+ [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
46
+
47
+ [[27, 22, 15], 1, Detect, [nc]], # Detect(P3, P4, P5)
48
+ ]
ultralytics/models/v5/yolov5-p6.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024]
9
+ s: [0.33, 0.50, 1024]
10
+ m: [0.67, 0.75, 1024]
11
+ l: [1.00, 1.00, 1024]
12
+ x: [1.33, 1.25, 1024]
13
+
14
+ # YOLOv5 v6.0 backbone
15
+ backbone:
16
+ # [from, number, module, args]
17
+ [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
18
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
19
+ [-1, 3, C3, [128]],
20
+ [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
21
+ [-1, 6, C3, [256]],
22
+ [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
23
+ [-1, 9, C3, [512]],
24
+ [-1, 1, Conv, [768, 3, 2]], # 7-P5/32
25
+ [-1, 3, C3, [768]],
26
+ [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64
27
+ [-1, 3, C3, [1024]],
28
+ [-1, 1, SPPF, [1024, 5]], # 11
29
+ ]
30
+
31
+ # YOLOv5 v6.0 head
32
+ head:
33
+ [[-1, 1, Conv, [768, 1, 1]],
34
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
35
+ [[-1, 8], 1, Concat, [1]], # cat backbone P5
36
+ [-1, 3, C3, [768, False]], # 15
37
+
38
+ [-1, 1, Conv, [512, 1, 1]],
39
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
40
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
41
+ [-1, 3, C3, [512, False]], # 19
42
+
43
+ [-1, 1, Conv, [256, 1, 1]],
44
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
45
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
46
+ [-1, 3, C3, [256, False]], # 23 (P3/8-small)
47
+
48
+ [-1, 1, Conv, [256, 3, 2]],
49
+ [[-1, 20], 1, Concat, [1]], # cat head P4
50
+ [-1, 3, C3, [512, False]], # 26 (P4/16-medium)
51
+
52
+ [-1, 1, Conv, [512, 3, 2]],
53
+ [[-1, 16], 1, Concat, [1]], # cat head P5
54
+ [-1, 3, C3, [768, False]], # 29 (P5/32-large)
55
+
56
+ [-1, 1, Conv, [768, 3, 2]],
57
+ [[-1, 12], 1, Concat, [1]], # cat head P6
58
+ [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge)
59
+
60
+ [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6)
61
+ ]
ultralytics/models/v5/yolov5.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024]
9
+ s: [0.33, 0.50, 1024]
10
+ m: [0.67, 0.75, 1024]
11
+ l: [1.00, 1.00, 1024]
12
+ x: [1.33, 1.25, 1024]
13
+
14
+ # YOLOv5 v6.0 backbone
15
+ backbone:
16
+ # [from, number, module, args]
17
+ [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
18
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
19
+ [-1, 3, C3, [128]],
20
+ [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
21
+ [-1, 6, C3, [256]],
22
+ [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
23
+ [-1, 9, C3, [512]],
24
+ [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
25
+ [-1, 3, C3, [1024]],
26
+ [-1, 1, SPPF, [1024, 5]], # 9
27
+ ]
28
+
29
+ # YOLOv5 v6.0 head
30
+ head:
31
+ [[-1, 1, Conv, [512, 1, 1]],
32
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
33
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
34
+ [-1, 3, C3, [512, False]], # 13
35
+
36
+ [-1, 1, Conv, [256, 1, 1]],
37
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
38
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
39
+ [-1, 3, C3, [256, False]], # 17 (P3/8-small)
40
+
41
+ [-1, 1, Conv, [256, 3, 2]],
42
+ [[-1, 14], 1, Concat, [1]], # cat head P4
43
+ [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
44
+
45
+ [-1, 1, Conv, [512, 3, 2]],
46
+ [[-1, 10], 1, Concat, [1]], # cat head P5
47
+ [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
48
+
49
+ [[17, 20, 23], 1, Detect, [nc]], # Detect(P3, P4, P5)
50
+ ]
ultralytics/models/v6/yolov6.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ activation: nn.ReLU() # (optional) model default activation function
7
+ scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n'
8
+ # [depth, width, max_channels]
9
+ n: [0.33, 0.25, 1024]
10
+ s: [0.33, 0.50, 1024]
11
+ m: [0.67, 0.75, 768]
12
+ l: [1.00, 1.00, 512]
13
+ x: [1.00, 1.25, 512]
14
+
15
+ # YOLOv6-3.0s backbone
16
+ backbone:
17
+ # [from, repeats, module, args]
18
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20
+ - [-1, 6, Conv, [128, 3, 1]]
21
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22
+ - [-1, 12, Conv, [256, 3, 1]]
23
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24
+ - [-1, 18, Conv, [512, 3, 1]]
25
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
26
+ - [-1, 6, Conv, [1024, 3, 1]]
27
+ - [-1, 1, SPPF, [1024, 5]] # 9
28
+
29
+ # YOLOv6-3.0s head
30
+ head:
31
+ - [-1, 1, Conv, [256, 1, 1]]
32
+ - [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]]
33
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
34
+ - [-1, 1, Conv, [256, 3, 1]]
35
+ - [-1, 9, Conv, [256, 3, 1]] # 14
36
+
37
+ - [-1, 1, Conv, [128, 1, 1]]
38
+ - [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]]
39
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
40
+ - [-1, 1, Conv, [128, 3, 1]]
41
+ - [-1, 9, Conv, [128, 3, 1]] # 19
42
+
43
+ - [-1, 1, Conv, [128, 3, 2]]
44
+ - [[-1, 15], 1, Concat, [1]] # cat head P4
45
+ - [-1, 1, Conv, [256, 3, 1]]
46
+ - [-1, 9, Conv, [256, 3, 1]] # 23
47
+
48
+ - [-1, 1, Conv, [256, 3, 2]]
49
+ - [[-1, 10], 1, Concat, [1]] # cat head P5
50
+ - [-1, 1, Conv, [512, 3, 1]]
51
+ - [-1, 9, Conv, [512, 3, 1]] # 27
52
+
53
+ - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/models/v8/yolov8-cls.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify
3
+
4
+ # Parameters
5
+ nc: 1000 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024]
9
+ s: [0.33, 0.50, 1024]
10
+ m: [0.67, 0.75, 1024]
11
+ l: [1.00, 1.00, 1024]
12
+ x: [1.00, 1.25, 1024]
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+
27
+ # YOLOv8.0n head
28
+ head:
29
+ - [-1, 1, Classify, [nc]] # Classify
ultralytics/models/v8/yolov8-p2.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P2-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024]
9
+ s: [0.33, 0.50, 1024]
10
+ m: [0.67, 0.75, 768]
11
+ l: [1.00, 1.00, 512]
12
+ x: [1.00, 1.25, 512]
13
+
14
+ # YOLOv8.0 backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0-p2 head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+
34
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37
+
38
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
39
+ - [[-1, 2], 1, Concat, [1]] # cat backbone P2
40
+ - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall)
41
+
42
+ - [-1, 1, Conv, [128, 3, 2]]
43
+ - [[-1, 15], 1, Concat, [1]] # cat head P3
44
+ - [-1, 3, C2f, [256]] # 21 (P3/8-small)
45
+
46
+ - [-1, 1, Conv, [256, 3, 2]]
47
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
48
+ - [-1, 3, C2f, [512]] # 24 (P4/16-medium)
49
+
50
+ - [-1, 1, Conv, [512, 3, 2]]
51
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
52
+ - [-1, 3, C2f, [1024]] # 27 (P5/32-large)
53
+
54
+ - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5)
ultralytics/models/v8/yolov8-p6.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024]
9
+ s: [0.33, 0.50, 1024]
10
+ m: [0.67, 0.75, 768]
11
+ l: [1.00, 1.00, 512]
12
+ x: [1.00, 1.25, 512]
13
+
14
+ # YOLOv8.0x6 backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [768, True]]
26
+ - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
27
+ - [-1, 3, C2f, [1024, True]]
28
+ - [-1, 1, SPPF, [1024, 5]] # 11
29
+
30
+ # YOLOv8.0x6 head
31
+ head:
32
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
33
+ - [[-1, 8], 1, Concat, [1]] # cat backbone P5
34
+ - [-1, 3, C2, [768, False]] # 14
35
+
36
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
37
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
38
+ - [-1, 3, C2, [512, False]] # 17
39
+
40
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
41
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
42
+ - [-1, 3, C2, [256, False]] # 20 (P3/8-small)
43
+
44
+ - [-1, 1, Conv, [256, 3, 2]]
45
+ - [[-1, 17], 1, Concat, [1]] # cat head P4
46
+ - [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
47
+
48
+ - [-1, 1, Conv, [512, 3, 2]]
49
+ - [[-1, 14], 1, Concat, [1]] # cat head P5
50
+ - [-1, 3, C2, [768, False]] # 26 (P5/32-large)
51
+
52
+ - [-1, 1, Conv, [768, 3, 2]]
53
+ - [[-1, 11], 1, Concat, [1]] # cat head P6
54
+ - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
55
+
56
+ - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6)
ultralytics/models/v8/yolov8-pose-p6.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
3
+
4
+ # Parameters
5
+ nc: 1 # number of classes
6
+ kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
7
+ scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n'
8
+ # [depth, width, max_channels]
9
+ n: [0.33, 0.25, 1024]
10
+ s: [0.33, 0.50, 1024]
11
+ m: [0.67, 0.75, 768]
12
+ l: [1.00, 1.00, 512]
13
+ x: [1.00, 1.25, 512]
14
+
15
+ # YOLOv8.0x6 backbone
16
+ backbone:
17
+ # [from, repeats, module, args]
18
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20
+ - [-1, 3, C2f, [128, True]]
21
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22
+ - [-1, 6, C2f, [256, True]]
23
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24
+ - [-1, 6, C2f, [512, True]]
25
+ - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32
26
+ - [-1, 3, C2f, [768, True]]
27
+ - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64
28
+ - [-1, 3, C2f, [1024, True]]
29
+ - [-1, 1, SPPF, [1024, 5]] # 11
30
+
31
+ # YOLOv8.0x6 head
32
+ head:
33
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
34
+ - [[-1, 8], 1, Concat, [1]] # cat backbone P5
35
+ - [-1, 3, C2, [768, False]] # 14
36
+
37
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
38
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
39
+ - [-1, 3, C2, [512, False]] # 17
40
+
41
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
42
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
43
+ - [-1, 3, C2, [256, False]] # 20 (P3/8-small)
44
+
45
+ - [-1, 1, Conv, [256, 3, 2]]
46
+ - [[-1, 17], 1, Concat, [1]] # cat head P4
47
+ - [-1, 3, C2, [512, False]] # 23 (P4/16-medium)
48
+
49
+ - [-1, 1, Conv, [512, 3, 2]]
50
+ - [[-1, 14], 1, Concat, [1]] # cat head P5
51
+ - [-1, 3, C2, [768, False]] # 26 (P5/32-large)
52
+
53
+ - [-1, 1, Conv, [768, 3, 2]]
54
+ - [[-1, 11], 1, Concat, [1]] # cat head P6
55
+ - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge)
56
+
57
+ - [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6)
ultralytics/models/v8/yolov8-pose.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
3
+
4
+ # Parameters
5
+ nc: 1 # number of classes
6
+ kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
7
+ scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n'
8
+ # [depth, width, max_channels]
9
+ n: [0.33, 0.25, 1024]
10
+ s: [0.33, 0.50, 1024]
11
+ m: [0.67, 0.75, 768]
12
+ l: [1.00, 1.00, 512]
13
+ x: [1.00, 1.25, 512]
14
+
15
+ # YOLOv8.0n backbone
16
+ backbone:
17
+ # [from, repeats, module, args]
18
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
19
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
20
+ - [-1, 3, C2f, [128, True]]
21
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
22
+ - [-1, 6, C2f, [256, True]]
23
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
24
+ - [-1, 6, C2f, [512, True]]
25
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
26
+ - [-1, 3, C2f, [1024, True]]
27
+ - [-1, 1, SPPF, [1024, 5]] # 9
28
+
29
+ # YOLOv8.0n head
30
+ head:
31
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
32
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
33
+ - [-1, 3, C2f, [512]] # 12
34
+
35
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
36
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
37
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
38
+
39
+ - [-1, 1, Conv, [256, 3, 2]]
40
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
41
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
42
+
43
+ - [-1, 1, Conv, [512, 3, 2]]
44
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
45
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
46
+
47
+ - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5)
ultralytics/models/v8/yolov8-rtdetr.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+
34
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37
+
38
+ - [-1, 1, Conv, [256, 3, 2]]
39
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
40
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41
+
42
+ - [-1, 1, Conv, [512, 3, 2]]
43
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
44
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45
+
46
+ - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
ultralytics/models/v8/yolov8-seg.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024]
9
+ s: [0.33, 0.50, 1024]
10
+ m: [0.67, 0.75, 768]
11
+ l: [1.00, 1.00, 512]
12
+ x: [1.00, 1.25, 512]
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+
34
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37
+
38
+ - [-1, 1, Conv, [256, 3, 2]]
39
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
40
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41
+
42
+ - [-1, 1, Conv, [512, 3, 2]]
43
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
44
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45
+
46
+ - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)
ultralytics/models/v8/yolov8.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ # YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
3
+
4
+ # Parameters
5
+ nc: 80 # number of classes
6
+ scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
7
+ # [depth, width, max_channels]
8
+ n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
9
+ s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
10
+ m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
11
+ l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
12
+ x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
13
+
14
+ # YOLOv8.0n backbone
15
+ backbone:
16
+ # [from, repeats, module, args]
17
+ - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
18
+ - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
19
+ - [-1, 3, C2f, [128, True]]
20
+ - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
21
+ - [-1, 6, C2f, [256, True]]
22
+ - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
23
+ - [-1, 6, C2f, [512, True]]
24
+ - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
25
+ - [-1, 3, C2f, [1024, True]]
26
+ - [-1, 1, SPPF, [1024, 5]] # 9
27
+
28
+ # YOLOv8.0n head
29
+ head:
30
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
31
+ - [[-1, 6], 1, Concat, [1]] # cat backbone P4
32
+ - [-1, 3, C2f, [512]] # 12
33
+
34
+ - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
35
+ - [[-1, 4], 1, Concat, [1]] # cat backbone P3
36
+ - [-1, 3, C2f, [256]] # 15 (P3/8-small)
37
+
38
+ - [-1, 1, Conv, [256, 3, 2]]
39
+ - [[-1, 12], 1, Concat, [1]] # cat head P4
40
+ - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
41
+
42
+ - [-1, 1, Conv, [512, 3, 2]]
43
+ - [[-1, 9], 1, Concat, [1]] # cat head P5
44
+ - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
45
+
46
+ - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)
ultralytics/nn/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ from .tasks import (BaseModel, ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
4
+ attempt_load_weights, guess_model_scale, guess_model_task, parse_model, torch_safe_load,
5
+ yaml_model_load)
6
+
7
+ __all__ = ('attempt_load_one_weight', 'attempt_load_weights', 'parse_model', 'yaml_model_load', 'guess_model_task',
8
+ 'guess_model_scale', 'torch_safe_load', 'DetectionModel', 'SegmentationModel', 'ClassificationModel',
9
+ 'BaseModel')
ultralytics/nn/autobackend.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import ast
4
+ import contextlib
5
+ import json
6
+ import platform
7
+ import zipfile
8
+ from collections import OrderedDict, namedtuple
9
+ from pathlib import Path
10
+ from urllib.parse import urlparse
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn as nn
16
+ from PIL import Image
17
+
18
+ from ultralytics.yolo.utils import LINUX, LOGGER, ROOT, yaml_load
19
+ from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml
20
+ from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
21
+ from ultralytics.yolo.utils.ops import xywh2xyxy
22
+
23
+
24
+ def check_class_names(names):
25
+ """Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts."""
26
+ if isinstance(names, list): # names is a list
27
+ names = dict(enumerate(names)) # convert to dict
28
+ if isinstance(names, dict):
29
+ # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True'
30
+ names = {int(k): str(v) for k, v in names.items()}
31
+ n = len(names)
32
+ if max(names.keys()) >= n:
33
+ raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
34
+ f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
35
+ if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
36
+ map = yaml_load(ROOT / 'datasets/ImageNet.yaml')['map'] # human-readable names
37
+ names = {k: map[v] for k, v in names.items()}
38
+ return names
39
+
40
+
41
+ class AutoBackend(nn.Module):
42
+
43
+ def __init__(self,
44
+ weights='yolov8n.pt',
45
+ device=torch.device('cpu'),
46
+ dnn=False,
47
+ data=None,
48
+ fp16=False,
49
+ fuse=True,
50
+ verbose=True):
51
+ """
52
+ MultiBackend class for python inference on various platforms using Ultralytics YOLO.
53
+
54
+ Args:
55
+ weights (str): The path to the weights file. Default: 'yolov8n.pt'
56
+ device (torch.device): The device to run the model on.
57
+ dnn (bool): Use OpenCV DNN module for inference if True, defaults to False.
58
+ data (str | Path | optional): Additional data.yaml file for class names.
59
+ fp16 (bool): If True, use half precision. Default: False
60
+ fuse (bool): Whether to fuse the model or not. Default: True
61
+ verbose (bool): Whether to run in verbose mode or not. Default: True
62
+
63
+ Supported formats and their naming conventions:
64
+ | Format | Suffix |
65
+ |-----------------------|------------------|
66
+ | PyTorch | *.pt |
67
+ | TorchScript | *.torchscript |
68
+ | ONNX Runtime | *.onnx |
69
+ | ONNX OpenCV DNN | *.onnx dnn=True |
70
+ | OpenVINO | *.xml |
71
+ | CoreML | *.mlmodel |
72
+ | TensorRT | *.engine |
73
+ | TensorFlow SavedModel | *_saved_model |
74
+ | TensorFlow GraphDef | *.pb |
75
+ | TensorFlow Lite | *.tflite |
76
+ | TensorFlow Edge TPU | *_edgetpu.tflite |
77
+ | PaddlePaddle | *_paddle_model |
78
+ """
79
+ super().__init__()
80
+ w = str(weights[0] if isinstance(weights, list) else weights)
81
+ nn_module = isinstance(weights, torch.nn.Module)
82
+ pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
83
+ fp16 &= pt or jit or onnx or engine or nn_module or triton # FP16
84
+ nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
85
+ stride = 32 # default stride
86
+ model, metadata = None, None
87
+ cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
88
+ if not (pt or triton or nn_module):
89
+ w = attempt_download_asset(w) # download if not local
90
+
91
+ # NOTE: special case: in-memory pytorch model
92
+ if nn_module:
93
+ model = weights.to(device)
94
+ model = model.fuse(verbose=verbose) if fuse else model
95
+ if hasattr(model, 'kpt_shape'):
96
+ kpt_shape = model.kpt_shape # pose-only
97
+ stride = max(int(model.stride.max()), 32) # model stride
98
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
99
+ model.half() if fp16 else model.float()
100
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
101
+ pt = True
102
+ elif pt: # PyTorch
103
+ from ultralytics.nn.tasks import attempt_load_weights
104
+ model = attempt_load_weights(weights if isinstance(weights, list) else w,
105
+ device=device,
106
+ inplace=True,
107
+ fuse=fuse)
108
+ if hasattr(model, 'kpt_shape'):
109
+ kpt_shape = model.kpt_shape # pose-only
110
+ stride = max(int(model.stride.max()), 32) # model stride
111
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
112
+ model.half() if fp16 else model.float()
113
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
114
+ elif jit: # TorchScript
115
+ LOGGER.info(f'Loading {w} for TorchScript inference...')
116
+ extra_files = {'config.txt': ''} # model metadata
117
+ model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
118
+ model.half() if fp16 else model.float()
119
+ if extra_files['config.txt']: # load metadata dict
120
+ metadata = json.loads(extra_files['config.txt'], object_hook=lambda x: dict(x.items()))
121
+ elif dnn: # ONNX OpenCV DNN
122
+ LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
123
+ check_requirements('opencv-python>=4.5.4')
124
+ net = cv2.dnn.readNetFromONNX(w)
125
+ elif onnx: # ONNX Runtime
126
+ LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
127
+ check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
128
+ import onnxruntime
129
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
130
+ session = onnxruntime.InferenceSession(w, providers=providers)
131
+ output_names = [x.name for x in session.get_outputs()]
132
+ metadata = session.get_modelmeta().custom_metadata_map # metadata
133
+ elif xml: # OpenVINO
134
+ LOGGER.info(f'Loading {w} for OpenVINO inference...')
135
+ check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
136
+ from openvino.runtime import Core, Layout, get_batch # noqa
137
+ ie = Core()
138
+ w = Path(w)
139
+ if not w.is_file(): # if not *.xml
140
+ w = next(w.glob('*.xml')) # get *.xml file from *_openvino_model dir
141
+ network = ie.read_model(model=str(w), weights=w.with_suffix('.bin'))
142
+ if network.get_parameters()[0].get_layout().empty:
143
+ network.get_parameters()[0].set_layout(Layout('NCHW'))
144
+ batch_dim = get_batch(network)
145
+ if batch_dim.is_static:
146
+ batch_size = batch_dim.get_length()
147
+ executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for NCS2
148
+ metadata = w.parent / 'metadata.yaml'
149
+ elif engine: # TensorRT
150
+ LOGGER.info(f'Loading {w} for TensorRT inference...')
151
+ try:
152
+ import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download
153
+ except ImportError:
154
+ if LINUX:
155
+ check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
156
+ import tensorrt as trt # noqa
157
+ check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
158
+ if device.type == 'cpu':
159
+ device = torch.device('cuda:0')
160
+ Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
161
+ logger = trt.Logger(trt.Logger.INFO)
162
+ # Read file
163
+ with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
164
+ meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length
165
+ metadata = json.loads(f.read(meta_len).decode('utf-8')) # read metadata
166
+ model = runtime.deserialize_cuda_engine(f.read()) # read engine
167
+ context = model.create_execution_context()
168
+ bindings = OrderedDict()
169
+ output_names = []
170
+ fp16 = False # default updated below
171
+ dynamic = False
172
+ for i in range(model.num_bindings):
173
+ name = model.get_binding_name(i)
174
+ dtype = trt.nptype(model.get_binding_dtype(i))
175
+ if model.binding_is_input(i):
176
+ if -1 in tuple(model.get_binding_shape(i)): # dynamic
177
+ dynamic = True
178
+ context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
179
+ if dtype == np.float16:
180
+ fp16 = True
181
+ else: # output
182
+ output_names.append(name)
183
+ shape = tuple(context.get_binding_shape(i))
184
+ im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
185
+ bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
186
+ binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
187
+ batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
188
+ elif coreml: # CoreML
189
+ LOGGER.info(f'Loading {w} for CoreML inference...')
190
+ import coremltools as ct
191
+ model = ct.models.MLModel(w)
192
+ metadata = dict(model.user_defined_metadata)
193
+ elif saved_model: # TF SavedModel
194
+ LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
195
+ import tensorflow as tf
196
+ keras = False # assume TF1 saved_model
197
+ model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
198
+ metadata = Path(w) / 'metadata.yaml'
199
+ elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
200
+ LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
201
+ import tensorflow as tf
202
+
203
+ from ultralytics.yolo.engine.exporter import gd_outputs
204
+
205
+ def wrap_frozen_graph(gd, inputs, outputs):
206
+ """Wrap frozen graphs for deployment."""
207
+ x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
208
+ ge = x.graph.as_graph_element
209
+ return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
210
+
211
+ gd = tf.Graph().as_graph_def() # TF GraphDef
212
+ with open(w, 'rb') as f:
213
+ gd.ParseFromString(f.read())
214
+ frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
215
+ elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
216
+ try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
217
+ from tflite_runtime.interpreter import Interpreter, load_delegate
218
+ except ImportError:
219
+ import tensorflow as tf
220
+ Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
221
+ if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
222
+ LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
223
+ delegate = {
224
+ 'Linux': 'libedgetpu.so.1',
225
+ 'Darwin': 'libedgetpu.1.dylib',
226
+ 'Windows': 'edgetpu.dll'}[platform.system()]
227
+ interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
228
+ else: # TFLite
229
+ LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
230
+ interpreter = Interpreter(model_path=w) # load TFLite model
231
+ interpreter.allocate_tensors() # allocate
232
+ input_details = interpreter.get_input_details() # inputs
233
+ output_details = interpreter.get_output_details() # outputs
234
+ # Load metadata
235
+ with contextlib.suppress(zipfile.BadZipFile):
236
+ with zipfile.ZipFile(w, 'r') as model:
237
+ meta_file = model.namelist()[0]
238
+ metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
239
+ elif tfjs: # TF.js
240
+ raise NotImplementedError('YOLOv8 TF.js inference is not supported')
241
+ elif paddle: # PaddlePaddle
242
+ LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
243
+ check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
244
+ import paddle.inference as pdi # noqa
245
+ w = Path(w)
246
+ if not w.is_file(): # if not *.pdmodel
247
+ w = next(w.rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
248
+ config = pdi.Config(str(w), str(w.with_suffix('.pdiparams')))
249
+ if cuda:
250
+ config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
251
+ predictor = pdi.create_predictor(config)
252
+ input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
253
+ output_names = predictor.get_output_names()
254
+ metadata = w.parents[1] / 'metadata.yaml'
255
+ elif triton: # NVIDIA Triton Inference Server
256
+ LOGGER.info('Triton Inference Server not supported...')
257
+ '''
258
+ TODO:
259
+ check_requirements('tritonclient[all]')
260
+ from utils.triton import TritonRemoteModel
261
+ model = TritonRemoteModel(url=w)
262
+ nhwc = model.runtime.startswith("tensorflow")
263
+ '''
264
+ else:
265
+ from ultralytics.yolo.engine.exporter import export_formats
266
+ raise TypeError(f"model='{w}' is not a supported model format. "
267
+ 'See https://docs.ultralytics.com/modes/predict for help.'
268
+ f'\n\n{export_formats()}')
269
+
270
+ # Load external metadata YAML
271
+ if isinstance(metadata, (str, Path)) and Path(metadata).exists():
272
+ metadata = yaml_load(metadata)
273
+ if metadata:
274
+ for k, v in metadata.items():
275
+ if k in ('stride', 'batch'):
276
+ metadata[k] = int(v)
277
+ elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str):
278
+ metadata[k] = eval(v)
279
+ stride = metadata['stride']
280
+ task = metadata['task']
281
+ batch = metadata['batch']
282
+ imgsz = metadata['imgsz']
283
+ names = metadata['names']
284
+ kpt_shape = metadata.get('kpt_shape')
285
+ elif not (pt or triton or nn_module):
286
+ LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
287
+
288
+ # Check names
289
+ if 'names' not in locals(): # names missing
290
+ names = self._apply_default_class_names(data)
291
+ names = check_class_names(names)
292
+
293
+ self.__dict__.update(locals()) # assign all variables to self
294
+
295
+ def forward(self, im, augment=False, visualize=False):
296
+ """
297
+ Runs inference on the YOLOv8 MultiBackend model.
298
+
299
+ Args:
300
+ im (torch.Tensor): The image tensor to perform inference on.
301
+ augment (bool): whether to perform data augmentation during inference, defaults to False
302
+ visualize (bool): whether to visualize the output predictions, defaults to False
303
+
304
+ Returns:
305
+ (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
306
+ """
307
+ b, ch, h, w = im.shape # batch, channel, height, width
308
+ if self.fp16 and im.dtype != torch.float16:
309
+ im = im.half() # to FP16
310
+ if self.nhwc:
311
+ im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
312
+
313
+ if self.pt or self.nn_module: # PyTorch
314
+ y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
315
+ elif self.jit: # TorchScript
316
+ y = self.model(im)
317
+ elif self.dnn: # ONNX OpenCV DNN
318
+ im = im.cpu().numpy() # torch to numpy
319
+ self.net.setInput(im)
320
+ y = self.net.forward()
321
+ elif self.onnx: # ONNX Runtime
322
+ im = im.cpu().numpy() # torch to numpy
323
+ y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
324
+ elif self.xml: # OpenVINO
325
+ im = im.cpu().numpy() # FP32
326
+ y = list(self.executable_network([im]).values())
327
+ elif self.engine: # TensorRT
328
+ if self.dynamic and im.shape != self.bindings['images'].shape:
329
+ i = self.model.get_binding_index('images')
330
+ self.context.set_binding_shape(i, im.shape) # reshape if dynamic
331
+ self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
332
+ for name in self.output_names:
333
+ i = self.model.get_binding_index(name)
334
+ self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
335
+ s = self.bindings['images'].shape
336
+ assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
337
+ self.binding_addrs['images'] = int(im.data_ptr())
338
+ self.context.execute_v2(list(self.binding_addrs.values()))
339
+ y = [self.bindings[x].data for x in sorted(self.output_names)]
340
+ elif self.coreml: # CoreML
341
+ im = im[0].cpu().numpy()
342
+ im_pil = Image.fromarray((im * 255).astype('uint8'))
343
+ # im = im.resize((192, 320), Image.ANTIALIAS)
344
+ y = self.model.predict({'image': im_pil}) # coordinates are xywh normalized
345
+ if 'confidence' in y:
346
+ box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
347
+ conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
348
+ y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
349
+ elif len(y) == 1: # classification model
350
+ y = list(y.values())
351
+ elif len(y) == 2: # segmentation model
352
+ y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
353
+ elif self.paddle: # PaddlePaddle
354
+ im = im.cpu().numpy().astype(np.float32)
355
+ self.input_handle.copy_from_cpu(im)
356
+ self.predictor.run()
357
+ y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
358
+ elif self.triton: # NVIDIA Triton Inference Server
359
+ y = self.model(im)
360
+ else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
361
+ im = im.cpu().numpy()
362
+ if self.saved_model: # SavedModel
363
+ y = self.model(im, training=False) if self.keras else self.model(im)
364
+ if not isinstance(y, list):
365
+ y = [y]
366
+ elif self.pb: # GraphDef
367
+ y = self.frozen_func(x=self.tf.constant(im))
368
+ if len(y) == 2 and len(self.names) == 999: # segments and names not defined
369
+ ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
370
+ nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
371
+ self.names = {i: f'class{i}' for i in range(nc)}
372
+ else: # Lite or Edge TPU
373
+ input = self.input_details[0]
374
+ int8 = input['dtype'] == np.int8 # is TFLite quantized int8 model
375
+ if int8:
376
+ scale, zero_point = input['quantization']
377
+ im = (im / scale + zero_point).astype(np.int8) # de-scale
378
+ self.interpreter.set_tensor(input['index'], im)
379
+ self.interpreter.invoke()
380
+ y = []
381
+ for output in self.output_details:
382
+ x = self.interpreter.get_tensor(output['index'])
383
+ if int8:
384
+ scale, zero_point = output['quantization']
385
+ x = (x.astype(np.float32) - zero_point) * scale # re-scale
386
+ y.append(x)
387
+ # TF segment fixes: export is reversed vs ONNX export and protos are transposed
388
+ if len(y) == 2: # segment with (det, proto) output order reversed
389
+ if len(y[1].shape) != 4:
390
+ y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
391
+ y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
392
+ y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
393
+ # y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
394
+
395
+ # for x in y:
396
+ # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
397
+ if isinstance(y, (list, tuple)):
398
+ return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
399
+ else:
400
+ return self.from_numpy(y)
401
+
402
+ def from_numpy(self, x):
403
+ """
404
+ Convert a numpy array to a tensor.
405
+
406
+ Args:
407
+ x (np.ndarray): The array to be converted.
408
+
409
+ Returns:
410
+ (torch.Tensor): The converted tensor
411
+ """
412
+ return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
413
+
414
+ def warmup(self, imgsz=(1, 3, 640, 640)):
415
+ """
416
+ Warm up the model by running one forward pass with a dummy input.
417
+
418
+ Args:
419
+ imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
420
+
421
+ Returns:
422
+ (None): This method runs the forward pass and don't return any value
423
+ """
424
+ warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
425
+ if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
426
+ im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
427
+ for _ in range(2 if self.jit else 1): #
428
+ self.forward(im) # warmup
429
+
430
+ @staticmethod
431
+ def _apply_default_class_names(data):
432
+ """Applies default class names to an input YAML file or returns numerical class names."""
433
+ with contextlib.suppress(Exception):
434
+ return yaml_load(check_yaml(data))['names']
435
+ return {i: f'class{i}' for i in range(999)} # return default if above errors
436
+
437
+ @staticmethod
438
+ def _model_type(p='path/to/model.pt'):
439
+ """
440
+ This function takes a path to a model file and returns the model type
441
+
442
+ Args:
443
+ p: path to the model file. Defaults to path/to/model.pt
444
+ """
445
+ # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
446
+ # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
447
+ from ultralytics.yolo.engine.exporter import export_formats
448
+ sf = list(export_formats().Suffix) # export suffixes
449
+ if not is_url(p, check=False) and not isinstance(p, str):
450
+ check_suffix(p, sf) # checks
451
+ url = urlparse(p) # if url may be Triton inference server
452
+ types = [s in Path(p).name for s in sf]
453
+ types[8] &= not types[9] # tflite &= not edgetpu
454
+ triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
455
+ return types + [triton]
ultralytics/nn/autoshape.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Common modules
4
+ """
5
+
6
+ from copy import copy
7
+ from pathlib import Path
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import requests
12
+ import torch
13
+ import torch.nn as nn
14
+ from PIL import Image, ImageOps
15
+ from torch.cuda import amp
16
+
17
+ from ultralytics.nn.autobackend import AutoBackend
18
+ from ultralytics.yolo.data.augment import LetterBox
19
+ from ultralytics.yolo.utils import LOGGER, colorstr
20
+ from ultralytics.yolo.utils.files import increment_path
21
+ from ultralytics.yolo.utils.ops import Profile, make_divisible, non_max_suppression, scale_boxes, xyxy2xywh
22
+ from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
23
+ from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode
24
+
25
+
26
+ class AutoShape(nn.Module):
27
+ """YOLOv8 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS."""
28
+ conf = 0.25 # NMS confidence threshold
29
+ iou = 0.45 # NMS IoU threshold
30
+ agnostic = False # NMS class-agnostic
31
+ multi_label = False # NMS multiple labels per box
32
+ classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
33
+ max_det = 1000 # maximum number of detections per image
34
+ amp = False # Automatic Mixed Precision (AMP) inference
35
+
36
+ def __init__(self, model, verbose=True):
37
+ """Initializes object and copies attributes from model object."""
38
+ super().__init__()
39
+ if verbose:
40
+ LOGGER.info('Adding AutoShape... ')
41
+ copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
42
+ self.dmb = isinstance(model, AutoBackend) # DetectMultiBackend() instance
43
+ self.pt = not self.dmb or model.pt # PyTorch model
44
+ self.model = model.eval()
45
+ if self.pt:
46
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
47
+ m.inplace = False # Detect.inplace=False for safe multithread inference
48
+ m.export = True # do not output loss values
49
+
50
+ def _apply(self, fn):
51
+ """Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers."""
52
+ self = super()._apply(fn)
53
+ if self.pt:
54
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
55
+ m.stride = fn(m.stride)
56
+ m.grid = list(map(fn, m.grid))
57
+ if isinstance(m.anchor_grid, list):
58
+ m.anchor_grid = list(map(fn, m.anchor_grid))
59
+ return self
60
+
61
+ @smart_inference_mode()
62
+ def forward(self, ims, size=640, augment=False, profile=False):
63
+ """Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:."""
64
+ # file: ims = 'data/images/zidane.jpg' # str or PosixPath
65
+ # URI: = 'https://ultralytics.com/images/zidane.jpg'
66
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
67
+ # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
68
+ # numpy: = np.zeros((640,1280,3)) # HWC
69
+ # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
70
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
71
+
72
+ dt = (Profile(), Profile(), Profile())
73
+ with dt[0]:
74
+ if isinstance(size, int): # expand
75
+ size = (size, size)
76
+ p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
77
+ autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
78
+ if isinstance(ims, torch.Tensor): # torch
79
+ with amp.autocast(autocast):
80
+ return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
81
+
82
+ # Preprocess
83
+ n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
84
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
85
+ for i, im in enumerate(ims):
86
+ f = f'image{i}' # filename
87
+ if isinstance(im, (str, Path)): # filename or uri
88
+ im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
89
+ im = np.asarray(ImageOps.exif_transpose(im))
90
+ elif isinstance(im, Image.Image): # PIL Image
91
+ im, f = np.asarray(ImageOps.exif_transpose(im)), getattr(im, 'filename', f) or f
92
+ files.append(Path(f).with_suffix('.jpg').name)
93
+ if im.shape[0] < 5: # image in CHW
94
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
95
+ im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
96
+ s = im.shape[:2] # HWC
97
+ shape0.append(s) # image shape
98
+ g = max(size) / max(s) # gain
99
+ shape1.append([y * g for y in s])
100
+ ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
101
+ shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
102
+ x = [LetterBox(shape1, auto=False)(image=im)['img'] for im in ims] # pad
103
+ x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
104
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
105
+
106
+ with amp.autocast(autocast):
107
+ # Inference
108
+ with dt[1]:
109
+ y = self.model(x, augment=augment) # forward
110
+
111
+ # Postprocess
112
+ with dt[2]:
113
+ y = non_max_suppression(y if self.dmb else y[0],
114
+ self.conf,
115
+ self.iou,
116
+ self.classes,
117
+ self.agnostic,
118
+ self.multi_label,
119
+ max_det=self.max_det) # NMS
120
+ for i in range(n):
121
+ scale_boxes(shape1, y[i][:, :4], shape0[i])
122
+
123
+ return Detections(ims, y, files, dt, self.names, x.shape)
124
+
125
+
126
+ class Detections:
127
+ """ YOLOv8 detections class for inference results"""
128
+
129
+ def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
130
+ """Initialize object attributes for YOLO detection results."""
131
+ super().__init__()
132
+ d = pred[0].device # device
133
+ gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
134
+ self.ims = ims # list of images as numpy arrays
135
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
136
+ self.names = names # class names
137
+ self.files = files # image filenames
138
+ self.times = times # profiling times
139
+ self.xyxy = pred # xyxy pixels
140
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
141
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
142
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
143
+ self.n = len(self.pred) # number of images (batch size)
144
+ self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
145
+ self.s = tuple(shape) # inference BCHW shape
146
+
147
+ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
148
+ """Return performance metrics and optionally cropped/save images or results."""
149
+ s, crops = '', []
150
+ for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
151
+ s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
152
+ if pred.shape[0]:
153
+ for c in pred[:, -1].unique():
154
+ n = (pred[:, -1] == c).sum() # detections per class
155
+ s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
156
+ s = s.rstrip(', ')
157
+ if show or save or render or crop:
158
+ annotator = Annotator(im, example=str(self.names))
159
+ for *box, conf, cls in reversed(pred): # xyxy, confidence, class
160
+ label = f'{self.names[int(cls)]} {conf:.2f}'
161
+ if crop:
162
+ file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
163
+ crops.append({
164
+ 'box': box,
165
+ 'conf': conf,
166
+ 'cls': cls,
167
+ 'label': label,
168
+ 'im': save_one_box(box, im, file=file, save=save)})
169
+ else: # all others
170
+ annotator.box_label(box, label if labels else '', color=colors(cls))
171
+ im = annotator.im
172
+ else:
173
+ s += '(no detections)'
174
+
175
+ im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
176
+ if show:
177
+ im.show(self.files[i]) # show
178
+ if save:
179
+ f = self.files[i]
180
+ im.save(save_dir / f) # save
181
+ if i == self.n - 1:
182
+ LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
183
+ if render:
184
+ self.ims[i] = np.asarray(im)
185
+ if pprint:
186
+ s = s.lstrip('\n')
187
+ return f'{s}\nSpeed: %.1fms preprocess, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
188
+ if crop:
189
+ if save:
190
+ LOGGER.info(f'Saved results to {save_dir}\n')
191
+ return crops
192
+
193
+ def show(self, labels=True):
194
+ """Displays YOLO results with detected bounding boxes."""
195
+ self._run(show=True, labels=labels) # show results
196
+
197
+ def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
198
+ """Save detection results with optional labels to specified directory."""
199
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
200
+ self._run(save=True, labels=labels, save_dir=save_dir) # save results
201
+
202
+ def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
203
+ """Crops images into detections and saves them if 'save' is True."""
204
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
205
+ return self._run(crop=True, save=save, save_dir=save_dir) # crop results
206
+
207
+ def render(self, labels=True):
208
+ """Renders detected objects and returns images."""
209
+ self._run(render=True, labels=labels) # render results
210
+ return self.ims
211
+
212
+ def pandas(self):
213
+ """Return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])."""
214
+ import pandas
215
+ new = copy(self) # return copy
216
+ ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
217
+ cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
218
+ for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
219
+ a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
220
+ setattr(new, k, [pandas.DataFrame(x, columns=c) for x in a])
221
+ return new
222
+
223
+ def tolist(self):
224
+ """Return a list of Detections objects, i.e. 'for result in results.tolist():'."""
225
+ r = range(self.n) # iterable
226
+ x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
227
+ # for d in x:
228
+ # for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
229
+ # setattr(d, k, getattr(d, k)[0]) # pop out of list
230
+ return x
231
+
232
+ def print(self):
233
+ """Print the results of the `self._run()` function."""
234
+ LOGGER.info(self.__str__())
235
+
236
+ def __len__(self): # override len(results)
237
+ return self.n
238
+
239
+ def __str__(self): # override print(results)
240
+ return self._run(pprint=True) # print results
241
+
242
+ def __repr__(self):
243
+ """Returns a printable representation of the object."""
244
+ return f'YOLOv8 {self.__class__} instance\n' + self.__str__()
ultralytics/nn/modules/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Ultralytics modules. Visualize with:
4
+
5
+ from ultralytics.nn.modules import *
6
+ import torch
7
+ import os
8
+
9
+ x = torch.ones(1, 128, 40, 40)
10
+ m = Conv(128, 128)
11
+ f = f'{m._get_name()}.onnx'
12
+ torch.onnx.export(m, x, f)
13
+ os.system(f'onnxsim {f} {f} && open {f}')
14
+ """
15
+
16
+ from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
17
+ HGBlock, HGStem, Proto, RepC3)
18
+ from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
19
+ GhostConv, LightConv, RepConv, SpatialAttention)
20
+ from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
21
+ from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
22
+ MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
23
+
24
+ __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus',
25
+ 'GhostConv', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer',
26
+ 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
27
+ 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
28
+ 'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
29
+ 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
ultralytics/nn/modules/block.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Block modules
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .conv import Conv, DWConv, GhostConv, LightConv, RepConv
11
+ from .transformer import TransformerBlock
12
+
13
+ __all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
14
+ 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3')
15
+
16
+
17
+ class DFL(nn.Module):
18
+ """
19
+ Integral module of Distribution Focal Loss (DFL).
20
+ Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
21
+ """
22
+
23
+ def __init__(self, c1=16):
24
+ """Initialize a convolutional layer with a given number of input channels."""
25
+ super().__init__()
26
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
27
+ x = torch.arange(c1, dtype=torch.float)
28
+ self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
29
+ self.c1 = c1
30
+
31
+ def forward(self, x):
32
+ """Applies a transformer layer on input tensor 'x' and returns a tensor."""
33
+ b, c, a = x.shape # batch, channels, anchors
34
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
35
+ # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
36
+
37
+
38
+ class Proto(nn.Module):
39
+ """YOLOv8 mask Proto module for segmentation models."""
40
+
41
+ def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
42
+ super().__init__()
43
+ self.cv1 = Conv(c1, c_, k=3)
44
+ self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
45
+ self.cv2 = Conv(c_, c_, k=3)
46
+ self.cv3 = Conv(c_, c2)
47
+
48
+ def forward(self, x):
49
+ """Performs a forward pass through layers using an upsampled input image."""
50
+ return self.cv3(self.cv2(self.upsample(self.cv1(x))))
51
+
52
+
53
+ class HGStem(nn.Module):
54
+ """StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d.
55
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
56
+ """
57
+
58
+ def __init__(self, c1, cm, c2):
59
+ super().__init__()
60
+ self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU())
61
+ self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU())
62
+ self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU())
63
+ self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU())
64
+ self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU())
65
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True)
66
+
67
+ def forward(self, x):
68
+ """Forward pass of a PPHGNetV2 backbone layer."""
69
+ x = self.stem1(x)
70
+ x = F.pad(x, [0, 1, 0, 1])
71
+ x2 = self.stem2a(x)
72
+ x2 = F.pad(x2, [0, 1, 0, 1])
73
+ x2 = self.stem2b(x2)
74
+ x1 = self.pool(x)
75
+ x = torch.cat([x1, x2], dim=1)
76
+ x = self.stem3(x)
77
+ x = self.stem4(x)
78
+ return x
79
+
80
+
81
+ class HGBlock(nn.Module):
82
+ """HG_Block of PPHGNetV2 with 2 convolutions and LightConv.
83
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
84
+ """
85
+
86
+ def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()):
87
+ super().__init__()
88
+ block = LightConv if lightconv else Conv
89
+ self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n))
90
+ self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv
91
+ self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv
92
+ self.add = shortcut and c1 == c2
93
+
94
+ def forward(self, x):
95
+ """Forward pass of a PPHGNetV2 backbone layer."""
96
+ y = [x]
97
+ y.extend(m(y[-1]) for m in self.m)
98
+ y = self.ec(self.sc(torch.cat(y, 1)))
99
+ return y + x if self.add else y
100
+
101
+
102
+ class SPP(nn.Module):
103
+ """Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
104
+
105
+ def __init__(self, c1, c2, k=(5, 9, 13)):
106
+ """Initialize the SPP layer with input/output channels and pooling kernel sizes."""
107
+ super().__init__()
108
+ c_ = c1 // 2 # hidden channels
109
+ self.cv1 = Conv(c1, c_, 1, 1)
110
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
111
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
112
+
113
+ def forward(self, x):
114
+ """Forward pass of the SPP layer, performing spatial pyramid pooling."""
115
+ x = self.cv1(x)
116
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
117
+
118
+
119
+ class SPPF(nn.Module):
120
+ """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher."""
121
+
122
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
123
+ super().__init__()
124
+ c_ = c1 // 2 # hidden channels
125
+ self.cv1 = Conv(c1, c_, 1, 1)
126
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
127
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
128
+
129
+ def forward(self, x):
130
+ """Forward pass through Ghost Convolution block."""
131
+ x = self.cv1(x)
132
+ y1 = self.m(x)
133
+ y2 = self.m(y1)
134
+ return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
135
+
136
+
137
+ class C1(nn.Module):
138
+ """CSP Bottleneck with 1 convolution."""
139
+
140
+ def __init__(self, c1, c2, n=1): # ch_in, ch_out, number
141
+ super().__init__()
142
+ self.cv1 = Conv(c1, c2, 1, 1)
143
+ self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
144
+
145
+ def forward(self, x):
146
+ """Applies cross-convolutions to input in the C3 module."""
147
+ y = self.cv1(x)
148
+ return self.m(y) + y
149
+
150
+
151
+ class C2(nn.Module):
152
+ """CSP Bottleneck with 2 convolutions."""
153
+
154
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
155
+ super().__init__()
156
+ self.c = int(c2 * e) # hidden channels
157
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
158
+ self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)
159
+ # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()
160
+ self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
161
+
162
+ def forward(self, x):
163
+ """Forward pass through the CSP bottleneck with 2 convolutions."""
164
+ a, b = self.cv1(x).chunk(2, 1)
165
+ return self.cv2(torch.cat((self.m(a), b), 1))
166
+
167
+
168
+ class C2f(nn.Module):
169
+ """CSP Bottleneck with 2 convolutions."""
170
+
171
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
172
+ super().__init__()
173
+ self.c = int(c2 * e) # hidden channels
174
+ self.cv1 = Conv(c1, 2 * self.c, 1, 1)
175
+ self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
176
+ self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
177
+
178
+ def forward(self, x):
179
+ """Forward pass through C2f layer."""
180
+ y = list(self.cv1(x).chunk(2, 1))
181
+ y.extend(m(y[-1]) for m in self.m)
182
+ return self.cv2(torch.cat(y, 1))
183
+
184
+ def forward_split(self, x):
185
+ """Forward pass using split() instead of chunk()."""
186
+ y = list(self.cv1(x).split((self.c, self.c), 1))
187
+ y.extend(m(y[-1]) for m in self.m)
188
+ return self.cv2(torch.cat(y, 1))
189
+
190
+
191
+ class C3(nn.Module):
192
+ """CSP Bottleneck with 3 convolutions."""
193
+
194
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
195
+ super().__init__()
196
+ c_ = int(c2 * e) # hidden channels
197
+ self.cv1 = Conv(c1, c_, 1, 1)
198
+ self.cv2 = Conv(c1, c_, 1, 1)
199
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
200
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
201
+
202
+ def forward(self, x):
203
+ """Forward pass through the CSP bottleneck with 2 convolutions."""
204
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
205
+
206
+
207
+ class C3x(C3):
208
+ """C3 module with cross-convolutions."""
209
+
210
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
211
+ """Initialize C3TR instance and set default parameters."""
212
+ super().__init__(c1, c2, n, shortcut, g, e)
213
+ self.c_ = int(c2 * e)
214
+ self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
215
+
216
+
217
+ class RepC3(nn.Module):
218
+ """Rep C3."""
219
+
220
+ def __init__(self, c1, c2, n=3, e=1.0):
221
+ super().__init__()
222
+ c_ = int(c2 * e) # hidden channels
223
+ self.cv1 = Conv(c1, c2, 1, 1)
224
+ self.cv2 = Conv(c1, c2, 1, 1)
225
+ self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)])
226
+ self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity()
227
+
228
+ def forward(self, x):
229
+ """Forward pass of RT-DETR neck layer."""
230
+ return self.cv3(self.m(self.cv1(x)) + self.cv2(x))
231
+
232
+
233
+ class C3TR(C3):
234
+ """C3 module with TransformerBlock()."""
235
+
236
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
237
+ """Initialize C3Ghost module with GhostBottleneck()."""
238
+ super().__init__(c1, c2, n, shortcut, g, e)
239
+ c_ = int(c2 * e)
240
+ self.m = TransformerBlock(c_, c_, 4, n)
241
+
242
+
243
+ class C3Ghost(C3):
244
+ """C3 module with GhostBottleneck()."""
245
+
246
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
247
+ """Initialize 'SPP' module with various pooling sizes for spatial pyramid pooling."""
248
+ super().__init__(c1, c2, n, shortcut, g, e)
249
+ c_ = int(c2 * e) # hidden channels
250
+ self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
251
+
252
+
253
+ class GhostBottleneck(nn.Module):
254
+ """Ghost Bottleneck https://github.com/huawei-noah/ghostnet."""
255
+
256
+ def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
257
+ super().__init__()
258
+ c_ = c2 // 2
259
+ self.conv = nn.Sequential(
260
+ GhostConv(c1, c_, 1, 1), # pw
261
+ DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
262
+ GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
263
+ self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
264
+ act=False)) if s == 2 else nn.Identity()
265
+
266
+ def forward(self, x):
267
+ """Applies skip connection and concatenation to input tensor."""
268
+ return self.conv(x) + self.shortcut(x)
269
+
270
+
271
+ class Bottleneck(nn.Module):
272
+ """Standard bottleneck."""
273
+
274
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, groups, kernels, expand
275
+ super().__init__()
276
+ c_ = int(c2 * e) # hidden channels
277
+ self.cv1 = Conv(c1, c_, k[0], 1)
278
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
279
+ self.add = shortcut and c1 == c2
280
+
281
+ def forward(self, x):
282
+ """'forward()' applies the YOLOv5 FPN to input data."""
283
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
284
+
285
+
286
+ class BottleneckCSP(nn.Module):
287
+ """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks."""
288
+
289
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
290
+ super().__init__()
291
+ c_ = int(c2 * e) # hidden channels
292
+ self.cv1 = Conv(c1, c_, 1, 1)
293
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
294
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
295
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
296
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
297
+ self.act = nn.SiLU()
298
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
299
+
300
+ def forward(self, x):
301
+ """Applies a CSP bottleneck with 3 convolutions."""
302
+ y1 = self.cv3(self.m(self.cv1(x)))
303
+ y2 = self.cv2(x)
304
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
ultralytics/nn/modules/conv.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Convolution modules
4
+ """
5
+
6
+ import math
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ __all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
13
+ 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
14
+
15
+
16
+ def autopad(k, p=None, d=1): # kernel, padding, dilation
17
+ """Pad to 'same' shape outputs."""
18
+ if d > 1:
19
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
20
+ if p is None:
21
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
22
+ return p
23
+
24
+
25
+ class Conv(nn.Module):
26
+ """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
27
+ default_act = nn.SiLU() # default activation
28
+
29
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
30
+ """Initialize Conv layer with given arguments including activation."""
31
+ super().__init__()
32
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
33
+ self.bn = nn.BatchNorm2d(c2)
34
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
35
+
36
+ def forward(self, x):
37
+ """Apply convolution, batch normalization and activation to input tensor."""
38
+ return self.act(self.bn(self.conv(x)))
39
+
40
+ def forward_fuse(self, x):
41
+ """Perform transposed convolution of 2D data."""
42
+ return self.act(self.conv(x))
43
+
44
+
45
+ class Conv2(Conv):
46
+ """Simplified RepConv module with Conv fusing."""
47
+
48
+ def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
49
+ """Initialize Conv layer with given arguments including activation."""
50
+ super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
51
+ self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
52
+
53
+ def forward(self, x):
54
+ """Apply convolution, batch normalization and activation to input tensor."""
55
+ return self.act(self.bn(self.conv(x) + self.cv2(x)))
56
+
57
+ def fuse_convs(self):
58
+ """Fuse parallel convolutions."""
59
+ w = torch.zeros_like(self.conv.weight.data)
60
+ i = [x // 2 for x in w.shape[2:]]
61
+ w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
62
+ self.conv.weight.data += w
63
+ self.__delattr__('cv2')
64
+
65
+
66
+ class LightConv(nn.Module):
67
+ """Light convolution with args(ch_in, ch_out, kernel).
68
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py
69
+ """
70
+
71
+ def __init__(self, c1, c2, k=1, act=nn.ReLU()):
72
+ """Initialize Conv layer with given arguments including activation."""
73
+ super().__init__()
74
+ self.conv1 = Conv(c1, c2, 1, act=False)
75
+ self.conv2 = DWConv(c2, c2, k, act=act)
76
+
77
+ def forward(self, x):
78
+ """Apply 2 convolutions to input tensor."""
79
+ return self.conv2(self.conv1(x))
80
+
81
+
82
+ class DWConv(Conv):
83
+ """Depth-wise convolution."""
84
+
85
+ def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
86
+ super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
87
+
88
+
89
+ class DWConvTranspose2d(nn.ConvTranspose2d):
90
+ """Depth-wise transpose convolution."""
91
+
92
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
93
+ super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
94
+
95
+
96
+ class ConvTranspose(nn.Module):
97
+ """Convolution transpose 2d layer."""
98
+ default_act = nn.SiLU() # default activation
99
+
100
+ def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
101
+ """Initialize ConvTranspose2d layer with batch normalization and activation function."""
102
+ super().__init__()
103
+ self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
104
+ self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
105
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
106
+
107
+ def forward(self, x):
108
+ """Applies transposed convolutions, batch normalization and activation to input."""
109
+ return self.act(self.bn(self.conv_transpose(x)))
110
+
111
+ def forward_fuse(self, x):
112
+ """Applies activation and convolution transpose operation to input."""
113
+ return self.act(self.conv_transpose(x))
114
+
115
+
116
+ class Focus(nn.Module):
117
+ """Focus wh information into c-space."""
118
+
119
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
120
+ super().__init__()
121
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
122
+ # self.contract = Contract(gain=2)
123
+
124
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
125
+ return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
126
+ # return self.conv(self.contract(x))
127
+
128
+
129
+ class GhostConv(nn.Module):
130
+ """Ghost Convolution https://github.com/huawei-noah/ghostnet."""
131
+
132
+ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
133
+ super().__init__()
134
+ c_ = c2 // 2 # hidden channels
135
+ self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
136
+ self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
137
+
138
+ def forward(self, x):
139
+ """Forward propagation through a Ghost Bottleneck layer with skip connection."""
140
+ y = self.cv1(x)
141
+ return torch.cat((y, self.cv2(y)), 1)
142
+
143
+
144
+ class RepConv(nn.Module):
145
+ """RepConv is a basic rep-style block, including training and deploy status
146
+ This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
147
+ """
148
+ default_act = nn.SiLU() # default activation
149
+
150
+ def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
151
+ super().__init__()
152
+ assert k == 3 and p == 1
153
+ self.g = g
154
+ self.c1 = c1
155
+ self.c2 = c2
156
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
157
+
158
+ self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None
159
+ self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
160
+ self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
161
+
162
+ def forward_fuse(self, x):
163
+ """Forward process"""
164
+ return self.act(self.conv(x))
165
+
166
+ def forward(self, x):
167
+ """Forward process"""
168
+ id_out = 0 if self.bn is None else self.bn(x)
169
+ return self.act(self.conv1(x) + self.conv2(x) + id_out)
170
+
171
+ def get_equivalent_kernel_bias(self):
172
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
173
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
174
+ kernelid, biasid = self._fuse_bn_tensor(self.bn)
175
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
176
+
177
+ def _avg_to_3x3_tensor(self, avgp):
178
+ channels = self.c1
179
+ groups = self.g
180
+ kernel_size = avgp.kernel_size
181
+ input_dim = channels // groups
182
+ k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
183
+ k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
184
+ return k
185
+
186
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
187
+ if kernel1x1 is None:
188
+ return 0
189
+ else:
190
+ return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
191
+
192
+ def _fuse_bn_tensor(self, branch):
193
+ if branch is None:
194
+ return 0, 0
195
+ if isinstance(branch, Conv):
196
+ kernel = branch.conv.weight
197
+ running_mean = branch.bn.running_mean
198
+ running_var = branch.bn.running_var
199
+ gamma = branch.bn.weight
200
+ beta = branch.bn.bias
201
+ eps = branch.bn.eps
202
+ elif isinstance(branch, nn.BatchNorm2d):
203
+ if not hasattr(self, 'id_tensor'):
204
+ input_dim = self.c1 // self.g
205
+ kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
206
+ for i in range(self.c1):
207
+ kernel_value[i, i % input_dim, 1, 1] = 1
208
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
209
+ kernel = self.id_tensor
210
+ running_mean = branch.running_mean
211
+ running_var = branch.running_var
212
+ gamma = branch.weight
213
+ beta = branch.bias
214
+ eps = branch.eps
215
+ std = (running_var + eps).sqrt()
216
+ t = (gamma / std).reshape(-1, 1, 1, 1)
217
+ return kernel * t, beta - running_mean * gamma / std
218
+
219
+ def fuse_convs(self):
220
+ if hasattr(self, 'conv'):
221
+ return
222
+ kernel, bias = self.get_equivalent_kernel_bias()
223
+ self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
224
+ out_channels=self.conv1.conv.out_channels,
225
+ kernel_size=self.conv1.conv.kernel_size,
226
+ stride=self.conv1.conv.stride,
227
+ padding=self.conv1.conv.padding,
228
+ dilation=self.conv1.conv.dilation,
229
+ groups=self.conv1.conv.groups,
230
+ bias=True).requires_grad_(False)
231
+ self.conv.weight.data = kernel
232
+ self.conv.bias.data = bias
233
+ for para in self.parameters():
234
+ para.detach_()
235
+ self.__delattr__('conv1')
236
+ self.__delattr__('conv2')
237
+ if hasattr(self, 'nm'):
238
+ self.__delattr__('nm')
239
+ if hasattr(self, 'bn'):
240
+ self.__delattr__('bn')
241
+ if hasattr(self, 'id_tensor'):
242
+ self.__delattr__('id_tensor')
243
+
244
+
245
+ class ChannelAttention(nn.Module):
246
+ """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""
247
+
248
+ def __init__(self, channels: int) -> None:
249
+ super().__init__()
250
+ self.pool = nn.AdaptiveAvgPool2d(1)
251
+ self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
252
+ self.act = nn.Sigmoid()
253
+
254
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
255
+ return x * self.act(self.fc(self.pool(x)))
256
+
257
+
258
+ class SpatialAttention(nn.Module):
259
+ """Spatial-attention module."""
260
+
261
+ def __init__(self, kernel_size=7):
262
+ """Initialize Spatial-attention module with kernel size argument."""
263
+ super().__init__()
264
+ assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
265
+ padding = 3 if kernel_size == 7 else 1
266
+ self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
267
+ self.act = nn.Sigmoid()
268
+
269
+ def forward(self, x):
270
+ """Apply channel and spatial attention on input for feature recalibration."""
271
+ return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
272
+
273
+
274
+ class CBAM(nn.Module):
275
+ """Convolutional Block Attention Module."""
276
+
277
+ def __init__(self, c1, kernel_size=7): # ch_in, kernels
278
+ super().__init__()
279
+ self.channel_attention = ChannelAttention(c1)
280
+ self.spatial_attention = SpatialAttention(kernel_size)
281
+
282
+ def forward(self, x):
283
+ """Applies the forward pass through C1 module."""
284
+ return self.spatial_attention(self.channel_attention(x))
285
+
286
+
287
+ class Concat(nn.Module):
288
+ """Concatenate a list of tensors along dimension."""
289
+
290
+ def __init__(self, dimension=1):
291
+ """Concatenates a list of tensors along a specified dimension."""
292
+ super().__init__()
293
+ self.d = dimension
294
+
295
+ def forward(self, x):
296
+ """Forward pass for the YOLOv8 mask Proto module."""
297
+ return torch.cat(x, self.d)
ultralytics/nn/modules/head.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Model head modules
4
+ """
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn.init import constant_, xavier_uniform_
11
+
12
+ from ultralytics.yolo.utils.tal import dist2bbox, make_anchors
13
+
14
+ from .block import DFL, Proto
15
+ from .conv import Conv
16
+ from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
17
+ from .utils import bias_init_with_prob, linear_init_
18
+
19
+ __all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'RTDETRDecoder'
20
+
21
+
22
+ class Detect(nn.Module):
23
+ """YOLOv8 Detect head for detection models."""
24
+ dynamic = False # force grid reconstruction
25
+ export = False # export mode
26
+ shape = None
27
+ anchors = torch.empty(0) # init
28
+ strides = torch.empty(0) # init
29
+
30
+ def __init__(self, nc=80, ch=()): # detection layer
31
+ super().__init__()
32
+ self.nc = nc # number of classes
33
+ self.nl = len(ch) # number of detection layers
34
+ self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
35
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
36
+ self.stride = torch.zeros(self.nl) # strides computed during build
37
+ c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels
38
+ self.cv2 = nn.ModuleList(
39
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
40
+ self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
41
+ self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
42
+
43
+ def forward(self, x):
44
+ """Concatenates and returns predicted bounding boxes and class probabilities."""
45
+ shape = x[0].shape # BCHW
46
+ for i in range(self.nl):
47
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
48
+ if self.training:
49
+ return x
50
+ elif self.dynamic or self.shape != shape:
51
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
52
+ self.shape = shape
53
+
54
+ x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
55
+ if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
56
+ box = x_cat[:, :self.reg_max * 4]
57
+ cls = x_cat[:, self.reg_max * 4:]
58
+ else:
59
+ box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
60
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
61
+ y = torch.cat((dbox, cls.sigmoid()), 1)
62
+ return y if self.export else (y, x)
63
+
64
+ def bias_init(self):
65
+ """Initialize Detect() biases, WARNING: requires stride availability."""
66
+ m = self # self.model[-1] # Detect() module
67
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
68
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
69
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
70
+ a[-1].bias.data[:] = 1.0 # box
71
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
72
+
73
+
74
+ class Segment(Detect):
75
+ """YOLOv8 Segment head for segmentation models."""
76
+
77
+ def __init__(self, nc=80, nm=32, npr=256, ch=()):
78
+ """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
79
+ super().__init__(nc, ch)
80
+ self.nm = nm # number of masks
81
+ self.npr = npr # number of protos
82
+ self.proto = Proto(ch[0], self.npr, self.nm) # protos
83
+ self.detect = Detect.forward
84
+
85
+ c4 = max(ch[0] // 4, self.nm)
86
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
87
+
88
+ def forward(self, x):
89
+ """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
90
+ p = self.proto(x[0]) # mask protos
91
+ bs = p.shape[0] # batch size
92
+
93
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
94
+ x = self.detect(self, x)
95
+ if self.training:
96
+ return x, mc, p
97
+ return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
98
+
99
+
100
+ class Pose(Detect):
101
+ """YOLOv8 Pose head for keypoints models."""
102
+
103
+ def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
104
+ """Initialize YOLO network with default parameters and Convolutional Layers."""
105
+ super().__init__(nc, ch)
106
+ self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
107
+ self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
108
+ self.detect = Detect.forward
109
+
110
+ c4 = max(ch[0] // 4, self.nk)
111
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
112
+
113
+ def forward(self, x):
114
+ """Perform forward pass through YOLO model and return predictions."""
115
+ bs = x[0].shape[0] # batch size
116
+ kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
117
+ x = self.detect(self, x)
118
+ if self.training:
119
+ return x, kpt
120
+ pred_kpt = self.kpts_decode(bs, kpt)
121
+ return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
122
+
123
+ def kpts_decode(self, bs, kpts):
124
+ """Decodes keypoints."""
125
+ ndim = self.kpt_shape[1]
126
+ if self.export: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug
127
+ y = kpts.view(bs, *self.kpt_shape, -1)
128
+ a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
129
+ if ndim == 3:
130
+ a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
131
+ return a.view(bs, self.nk, -1)
132
+ else:
133
+ y = kpts.clone()
134
+ if ndim == 3:
135
+ y[:, 2::3].sigmoid_() # inplace sigmoid
136
+ y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
137
+ y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
138
+ return y
139
+
140
+
141
+ class Classify(nn.Module):
142
+ """YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)."""
143
+
144
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
145
+ super().__init__()
146
+ c_ = 1280 # efficientnet_b0 size
147
+ self.conv = Conv(c1, c_, k, s, p, g)
148
+ self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
149
+ self.drop = nn.Dropout(p=0.0, inplace=True)
150
+ self.linear = nn.Linear(c_, c2) # to x(b,c2)
151
+
152
+ def forward(self, x):
153
+ """Performs a forward pass of the YOLO model on input image data."""
154
+ if isinstance(x, list):
155
+ x = torch.cat(x, 1)
156
+ x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
157
+ return x if self.training else x.softmax(1)
158
+
159
+
160
+ class RTDETRDecoder(nn.Module):
161
+
162
+ def __init__(
163
+ self,
164
+ nc=80,
165
+ ch=(512, 1024, 2048),
166
+ hd=256, # hidden dim
167
+ nq=300, # num queries
168
+ ndp=4, # num decoder points
169
+ nh=8, # num head
170
+ ndl=6, # num decoder layers
171
+ d_ffn=1024, # dim of feedforward
172
+ dropout=0.,
173
+ act=nn.ReLU(),
174
+ eval_idx=-1,
175
+ # training args
176
+ nd=100, # num denoising
177
+ label_noise_ratio=0.5,
178
+ box_noise_scale=1.0,
179
+ learnt_init_query=False):
180
+ super().__init__()
181
+ self.hidden_dim = hd
182
+ self.nhead = nh
183
+ self.nl = len(ch) # num level
184
+ self.nc = nc
185
+ self.num_queries = nq
186
+ self.num_decoder_layers = ndl
187
+
188
+ # backbone feature projection
189
+ self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
190
+ # NOTE: simplified version but it's not consistent with .pt weights.
191
+ # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
192
+
193
+ # Transformer module
194
+ decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
195
+ self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
196
+
197
+ # denoising part
198
+ self.denoising_class_embed = nn.Embedding(nc, hd)
199
+ self.num_denoising = nd
200
+ self.label_noise_ratio = label_noise_ratio
201
+ self.box_noise_scale = box_noise_scale
202
+
203
+ # decoder embedding
204
+ self.learnt_init_query = learnt_init_query
205
+ if learnt_init_query:
206
+ self.tgt_embed = nn.Embedding(nq, hd)
207
+ self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
208
+
209
+ # encoder head
210
+ self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
211
+ self.enc_score_head = nn.Linear(hd, nc)
212
+ self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
213
+
214
+ # decoder head
215
+ self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
216
+ self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
217
+
218
+ self._reset_parameters()
219
+
220
+ def forward(self, x, batch=None):
221
+ from ultralytics.vit.utils.ops import get_cdn_group
222
+
223
+ # input projection and embedding
224
+ feats, shapes = self._get_encoder_input(x)
225
+
226
+ # prepare denoising training
227
+ dn_embed, dn_bbox, attn_mask, dn_meta = \
228
+ get_cdn_group(batch,
229
+ self.nc,
230
+ self.num_queries,
231
+ self.denoising_class_embed.weight,
232
+ self.num_denoising,
233
+ self.label_noise_ratio,
234
+ self.box_noise_scale,
235
+ self.training)
236
+
237
+ embed, refer_bbox, enc_bboxes, enc_scores = \
238
+ self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
239
+
240
+ # decoder
241
+ dec_bboxes, dec_scores = self.decoder(embed,
242
+ refer_bbox,
243
+ feats,
244
+ shapes,
245
+ self.dec_bbox_head,
246
+ self.dec_score_head,
247
+ self.query_pos_head,
248
+ attn_mask=attn_mask)
249
+ if not self.training:
250
+ dec_scores = dec_scores.sigmoid_()
251
+ return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
252
+
253
+ def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
254
+ anchors = []
255
+ for i, (h, w) in enumerate(shapes):
256
+ grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=dtype, device=device),
257
+ torch.arange(end=w, dtype=dtype, device=device),
258
+ indexing='ij')
259
+ grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
260
+
261
+ valid_WH = torch.tensor([h, w], dtype=dtype, device=device)
262
+ grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
263
+ wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
264
+ anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
265
+
266
+ anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
267
+ valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
268
+ anchors = torch.log(anchors / (1 - anchors))
269
+ anchors = torch.where(valid_mask, anchors, torch.inf)
270
+ return anchors, valid_mask
271
+
272
+ def _get_encoder_input(self, x):
273
+ # get projection features
274
+ x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
275
+ # get encoder inputs
276
+ feats = []
277
+ shapes = []
278
+ for feat in x:
279
+ h, w = feat.shape[2:]
280
+ # [b, c, h, w] -> [b, h*w, c]
281
+ feats.append(feat.flatten(2).permute(0, 2, 1))
282
+ # [nl, 2]
283
+ shapes.append([h, w])
284
+
285
+ # [b, h*w, c]
286
+ feats = torch.cat(feats, 1)
287
+ return feats, shapes
288
+
289
+ def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
290
+ bs = len(feats)
291
+ # prepare input for decoder
292
+ anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
293
+ features = self.enc_output(torch.where(valid_mask, feats, 0)) # bs, h*w, 256
294
+
295
+ enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
296
+ # dynamic anchors + static content
297
+ enc_outputs_bboxes = self.enc_bbox_head(features) + anchors # (bs, h*w, 4)
298
+
299
+ # query selection
300
+ # (bs, num_queries)
301
+ topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
302
+ # (bs, num_queries)
303
+ batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
304
+
305
+ # Unsigmoided
306
+ refer_bbox = enc_outputs_bboxes[batch_ind, topk_ind].view(bs, self.num_queries, -1)
307
+ # refer_bbox = torch.gather(enc_outputs_bboxes, 1, topk_ind.reshape(bs, self.num_queries).unsqueeze(-1).repeat(1, 1, 4))
308
+
309
+ enc_bboxes = refer_bbox.sigmoid()
310
+ if dn_bbox is not None:
311
+ refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
312
+ if self.training:
313
+ refer_bbox = refer_bbox.detach()
314
+ enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
315
+
316
+ if self.learnt_init_query:
317
+ embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
318
+ else:
319
+ embeddings = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
320
+ if self.training:
321
+ embeddings = embeddings.detach()
322
+ if dn_embed is not None:
323
+ embeddings = torch.cat([dn_embed, embeddings], 1)
324
+
325
+ return embeddings, refer_bbox, enc_bboxes, enc_scores
326
+
327
+ # TODO
328
+ def _reset_parameters(self):
329
+ # class and bbox head init
330
+ bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
331
+ # NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
332
+ # linear_init_(self.enc_score_head)
333
+ constant_(self.enc_score_head.bias, bias_cls)
334
+ constant_(self.enc_bbox_head.layers[-1].weight, 0.)
335
+ constant_(self.enc_bbox_head.layers[-1].bias, 0.)
336
+ for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
337
+ # linear_init_(cls_)
338
+ constant_(cls_.bias, bias_cls)
339
+ constant_(reg_.layers[-1].weight, 0.)
340
+ constant_(reg_.layers[-1].bias, 0.)
341
+
342
+ linear_init_(self.enc_output[0])
343
+ xavier_uniform_(self.enc_output[0].weight)
344
+ if self.learnt_init_query:
345
+ xavier_uniform_(self.tgt_embed.weight)
346
+ xavier_uniform_(self.query_pos_head.layers[0].weight)
347
+ xavier_uniform_(self.query_pos_head.layers[1].weight)
348
+ for layer in self.input_proj:
349
+ xavier_uniform_(layer[0].weight)
ultralytics/nn/modules/transformer.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Transformer modules
4
+ """
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn.init import constant_, xavier_uniform_
12
+
13
+ from .conv import Conv
14
+ from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch
15
+
16
+ __all__ = ('TransformerEncoderLayer', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'AIFI',
17
+ 'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP')
18
+
19
+
20
+ class TransformerEncoderLayer(nn.Module):
21
+ """Transformer Encoder."""
22
+
23
+ def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False):
24
+ super().__init__()
25
+ self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
26
+ # Implementation of Feedforward model
27
+ self.fc1 = nn.Linear(c1, cm)
28
+ self.fc2 = nn.Linear(cm, c1)
29
+
30
+ self.norm1 = nn.LayerNorm(c1)
31
+ self.norm2 = nn.LayerNorm(c1)
32
+ self.dropout = nn.Dropout(dropout)
33
+ self.dropout1 = nn.Dropout(dropout)
34
+ self.dropout2 = nn.Dropout(dropout)
35
+
36
+ self.act = act
37
+ self.normalize_before = normalize_before
38
+
39
+ def with_pos_embed(self, tensor, pos=None):
40
+ """Add position embeddings if given."""
41
+ return tensor if pos is None else tensor + pos
42
+
43
+ def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
44
+ q = k = self.with_pos_embed(src, pos)
45
+ src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
46
+ src = src + self.dropout1(src2)
47
+ src = self.norm1(src)
48
+ src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
49
+ src = src + self.dropout2(src2)
50
+ src = self.norm2(src)
51
+ return src
52
+
53
+ def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
54
+ src2 = self.norm1(src)
55
+ q = k = self.with_pos_embed(src2, pos)
56
+ src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
57
+ src = src + self.dropout1(src2)
58
+ src2 = self.norm2(src)
59
+ src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
60
+ src = src + self.dropout2(src2)
61
+ return src
62
+
63
+ def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
64
+ """Forward propagates the input through the encoder module."""
65
+ if self.normalize_before:
66
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
67
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
68
+
69
+
70
+ class AIFI(TransformerEncoderLayer):
71
+
72
+ def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False):
73
+ super().__init__(c1, cm, num_heads, dropout, act, normalize_before)
74
+
75
+ def forward(self, x):
76
+ c, h, w = x.shape[1:]
77
+ pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
78
+ # flatten [B, C, H, W] to [B, HxW, C]
79
+ x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
80
+ return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()
81
+
82
+ @staticmethod
83
+ def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.):
84
+ grid_w = torch.arange(int(w), dtype=torch.float32)
85
+ grid_h = torch.arange(int(h), dtype=torch.float32)
86
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
87
+ assert embed_dim % 4 == 0, \
88
+ 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
89
+ pos_dim = embed_dim // 4
90
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
91
+ omega = 1. / (temperature ** omega)
92
+
93
+ out_w = grid_w.flatten()[..., None] @ omega[None]
94
+ out_h = grid_h.flatten()[..., None] @ omega[None]
95
+
96
+ return torch.concat([torch.sin(out_w), torch.cos(out_w),
97
+ torch.sin(out_h), torch.cos(out_h)], axis=1)[None, :, :]
98
+
99
+
100
+ class TransformerLayer(nn.Module):
101
+ """Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
102
+
103
+ def __init__(self, c, num_heads):
104
+ """Initializes a self-attention mechanism using linear transformations and multi-head attention."""
105
+ super().__init__()
106
+ self.q = nn.Linear(c, c, bias=False)
107
+ self.k = nn.Linear(c, c, bias=False)
108
+ self.v = nn.Linear(c, c, bias=False)
109
+ self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
110
+ self.fc1 = nn.Linear(c, c, bias=False)
111
+ self.fc2 = nn.Linear(c, c, bias=False)
112
+
113
+ def forward(self, x):
114
+ """Apply a transformer block to the input x and return the output."""
115
+ x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
116
+ x = self.fc2(self.fc1(x)) + x
117
+ return x
118
+
119
+
120
+ class TransformerBlock(nn.Module):
121
+ """Vision Transformer https://arxiv.org/abs/2010.11929."""
122
+
123
+ def __init__(self, c1, c2, num_heads, num_layers):
124
+ """Initialize a Transformer module with position embedding and specified number of heads and layers."""
125
+ super().__init__()
126
+ self.conv = None
127
+ if c1 != c2:
128
+ self.conv = Conv(c1, c2)
129
+ self.linear = nn.Linear(c2, c2) # learnable position embedding
130
+ self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
131
+ self.c2 = c2
132
+
133
+ def forward(self, x):
134
+ """Forward propagates the input through the bottleneck module."""
135
+ if self.conv is not None:
136
+ x = self.conv(x)
137
+ b, _, w, h = x.shape
138
+ p = x.flatten(2).permute(2, 0, 1)
139
+ return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
140
+
141
+
142
+ class MLPBlock(nn.Module):
143
+
144
+ def __init__(self, embedding_dim, mlp_dim, act=nn.GELU):
145
+ super().__init__()
146
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
147
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
148
+ self.act = act()
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ return self.lin2(self.act(self.lin1(x)))
152
+
153
+
154
+ class MLP(nn.Module):
155
+ """ Very simple multi-layer perceptron (also called FFN)"""
156
+
157
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
158
+ super().__init__()
159
+ self.num_layers = num_layers
160
+ h = [hidden_dim] * (num_layers - 1)
161
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
162
+
163
+ def forward(self, x):
164
+ for i, layer in enumerate(self.layers):
165
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
166
+ return x
167
+
168
+
169
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
170
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
171
+ class LayerNorm2d(nn.Module):
172
+
173
+ def __init__(self, num_channels, eps=1e-6):
174
+ super().__init__()
175
+ self.weight = nn.Parameter(torch.ones(num_channels))
176
+ self.bias = nn.Parameter(torch.zeros(num_channels))
177
+ self.eps = eps
178
+
179
+ def forward(self, x):
180
+ u = x.mean(1, keepdim=True)
181
+ s = (x - u).pow(2).mean(1, keepdim=True)
182
+ x = (x - u) / torch.sqrt(s + self.eps)
183
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
184
+ return x
185
+
186
+
187
+ class MSDeformAttn(nn.Module):
188
+ """
189
+ Original Multi-Scale Deformable Attention Module.
190
+ https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py
191
+ """
192
+
193
+ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
194
+ super().__init__()
195
+ if d_model % n_heads != 0:
196
+ raise ValueError(f'd_model must be divisible by n_heads, but got {d_model} and {n_heads}')
197
+ _d_per_head = d_model // n_heads
198
+ # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
199
+ assert _d_per_head * n_heads == d_model, '`d_model` must be divisible by `n_heads`'
200
+
201
+ self.im2col_step = 64
202
+
203
+ self.d_model = d_model
204
+ self.n_levels = n_levels
205
+ self.n_heads = n_heads
206
+ self.n_points = n_points
207
+
208
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
209
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
210
+ self.value_proj = nn.Linear(d_model, d_model)
211
+ self.output_proj = nn.Linear(d_model, d_model)
212
+
213
+ self._reset_parameters()
214
+
215
+ def _reset_parameters(self):
216
+ constant_(self.sampling_offsets.weight.data, 0.)
217
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
218
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
219
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(
220
+ 1, self.n_levels, self.n_points, 1)
221
+ for i in range(self.n_points):
222
+ grid_init[:, :, i, :] *= i + 1
223
+ with torch.no_grad():
224
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
225
+ constant_(self.attention_weights.weight.data, 0.)
226
+ constant_(self.attention_weights.bias.data, 0.)
227
+ xavier_uniform_(self.value_proj.weight.data)
228
+ constant_(self.value_proj.bias.data, 0.)
229
+ xavier_uniform_(self.output_proj.weight.data)
230
+ constant_(self.output_proj.bias.data, 0.)
231
+
232
+ def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
233
+ """
234
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
235
+ Args:
236
+ query (torch.Tensor): [bs, query_length, C]
237
+ refer_bbox (torch.Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
238
+ bottom-right (1, 1), including padding area
239
+ value (torch.Tensor): [bs, value_length, C]
240
+ value_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
241
+ value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
242
+
243
+ Returns:
244
+ output (Tensor): [bs, Length_{query}, C]
245
+ """
246
+ bs, len_q = query.shape[:2]
247
+ len_v = value.shape[1]
248
+ assert sum(s[0] * s[1] for s in value_shapes) == len_v
249
+
250
+ value = self.value_proj(value)
251
+ if value_mask is not None:
252
+ value = value.masked_fill(value_mask[..., None], float(0))
253
+ value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads)
254
+ sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2)
255
+ attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
256
+ attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
257
+ # N, Len_q, n_heads, n_levels, n_points, 2
258
+ num_points = refer_bbox.shape[-1]
259
+ if num_points == 2:
260
+ offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
261
+ add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
262
+ sampling_locations = refer_bbox[:, :, None, :, None, :] + add
263
+ elif num_points == 4:
264
+ add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
265
+ sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
266
+ else:
267
+ raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
268
+ output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
269
+ output = self.output_proj(output)
270
+ return output
271
+
272
+
273
+ class DeformableTransformerDecoderLayer(nn.Module):
274
+ """
275
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
276
+ https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py
277
+ """
278
+
279
+ def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0., act=nn.ReLU(), n_levels=4, n_points=4):
280
+ super().__init__()
281
+
282
+ # self attention
283
+ self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
284
+ self.dropout1 = nn.Dropout(dropout)
285
+ self.norm1 = nn.LayerNorm(d_model)
286
+
287
+ # cross attention
288
+ self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
289
+ self.dropout2 = nn.Dropout(dropout)
290
+ self.norm2 = nn.LayerNorm(d_model)
291
+
292
+ # ffn
293
+ self.linear1 = nn.Linear(d_model, d_ffn)
294
+ self.act = act
295
+ self.dropout3 = nn.Dropout(dropout)
296
+ self.linear2 = nn.Linear(d_ffn, d_model)
297
+ self.dropout4 = nn.Dropout(dropout)
298
+ self.norm3 = nn.LayerNorm(d_model)
299
+
300
+ @staticmethod
301
+ def with_pos_embed(tensor, pos):
302
+ return tensor if pos is None else tensor + pos
303
+
304
+ def forward_ffn(self, tgt):
305
+ tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt))))
306
+ tgt = tgt + self.dropout4(tgt2)
307
+ tgt = self.norm3(tgt)
308
+ return tgt
309
+
310
+ def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
311
+ # self attention
312
+ q = k = self.with_pos_embed(embed, query_pos)
313
+ tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1),
314
+ attn_mask=attn_mask)[0].transpose(0, 1)
315
+ embed = embed + self.dropout1(tgt)
316
+ embed = self.norm1(embed)
317
+
318
+ # cross attention
319
+ tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes,
320
+ padding_mask)
321
+ embed = embed + self.dropout2(tgt)
322
+ embed = self.norm2(embed)
323
+
324
+ # ffn
325
+ embed = self.forward_ffn(embed)
326
+
327
+ return embed
328
+
329
+
330
+ class DeformableTransformerDecoder(nn.Module):
331
+ """
332
+ https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
333
+ """
334
+
335
+ def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1):
336
+ super().__init__()
337
+ self.layers = _get_clones(decoder_layer, num_layers)
338
+ self.num_layers = num_layers
339
+ self.hidden_dim = hidden_dim
340
+ self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
341
+
342
+ def forward(
343
+ self,
344
+ embed, # decoder embeddings
345
+ refer_bbox, # anchor
346
+ feats, # image features
347
+ shapes, # feature shapes
348
+ bbox_head,
349
+ score_head,
350
+ pos_mlp,
351
+ attn_mask=None,
352
+ padding_mask=None):
353
+ output = embed
354
+ dec_bboxes = []
355
+ dec_cls = []
356
+ last_refined_bbox = None
357
+ refer_bbox = refer_bbox.sigmoid()
358
+ for i, layer in enumerate(self.layers):
359
+ output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
360
+
361
+ # refine bboxes, (bs, num_queries+num_denoising, 4)
362
+ refined_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(refer_bbox))
363
+
364
+ if self.training:
365
+ dec_cls.append(score_head[i](output))
366
+ if i == 0:
367
+ dec_bboxes.append(refined_bbox)
368
+ else:
369
+ dec_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(last_refined_bbox)))
370
+ elif i == self.eval_idx:
371
+ dec_cls.append(score_head[i](output))
372
+ dec_bboxes.append(refined_bbox)
373
+ break
374
+
375
+ last_refined_bbox = refined_bbox
376
+ refer_bbox = refined_bbox.detach() if self.training else refined_bbox
377
+
378
+ return torch.stack(dec_bboxes), torch.stack(dec_cls)
ultralytics/nn/modules/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+ """
3
+ Module utils
4
+ """
5
+
6
+ import copy
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.init import uniform_
14
+
15
+ __all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid'
16
+
17
+
18
+ def _get_clones(module, n):
19
+ return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
20
+
21
+
22
+ def bias_init_with_prob(prior_prob=0.01):
23
+ """initialize conv/fc bias value according to a given probability value."""
24
+ return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init
25
+
26
+
27
+ def linear_init_(module):
28
+ bound = 1 / math.sqrt(module.weight.shape[0])
29
+ uniform_(module.weight, -bound, bound)
30
+ if hasattr(module, 'bias') and module.bias is not None:
31
+ uniform_(module.bias, -bound, bound)
32
+
33
+
34
+ def inverse_sigmoid(x, eps=1e-5):
35
+ x = x.clamp(min=0, max=1)
36
+ x1 = x.clamp(min=eps)
37
+ x2 = (1 - x).clamp(min=eps)
38
+ return torch.log(x1 / x2)
39
+
40
+
41
+ def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor,
42
+ sampling_locations: torch.Tensor,
43
+ attention_weights: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ Multi-scale deformable attention.
46
+ https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
47
+ """
48
+
49
+ bs, _, num_heads, embed_dims = value.shape
50
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
51
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
52
+ sampling_grids = 2 * sampling_locations - 1
53
+ sampling_value_list = []
54
+ for level, (H_, W_) in enumerate(value_spatial_shapes):
55
+ # bs, H_*W_, num_heads, embed_dims ->
56
+ # bs, H_*W_, num_heads*embed_dims ->
57
+ # bs, num_heads*embed_dims, H_*W_ ->
58
+ # bs*num_heads, embed_dims, H_, W_
59
+ value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
60
+ # bs, num_queries, num_heads, num_points, 2 ->
61
+ # bs, num_heads, num_queries, num_points, 2 ->
62
+ # bs*num_heads, num_queries, num_points, 2
63
+ sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
64
+ # bs*num_heads, embed_dims, num_queries, num_points
65
+ sampling_value_l_ = F.grid_sample(value_l_,
66
+ sampling_grid_l_,
67
+ mode='bilinear',
68
+ padding_mode='zeros',
69
+ align_corners=False)
70
+ sampling_value_list.append(sampling_value_l_)
71
+ # (bs, num_queries, num_heads, num_levels, num_points) ->
72
+ # (bs, num_heads, num_queries, num_levels, num_points) ->
73
+ # (bs, num_heads, 1, num_queries, num_levels*num_points)
74
+ attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries,
75
+ num_levels * num_points)
76
+ output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(
77
+ bs, num_heads * embed_dims, num_queries))
78
+ return output.transpose(1, 2).contiguous()
ultralytics/nn/tasks.py ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import contextlib
4
+ from copy import deepcopy
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
11
+ Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
12
+ Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
13
+ RTDETRDecoder, Segment)
14
+ from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
15
+ from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
16
+ from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
17
+ from ultralytics.yolo.utils.plotting import feature_visualization
18
+ from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
19
+ intersect_dicts, make_divisible, model_info, scale_img, time_sync)
20
+
21
+ try:
22
+ import thop
23
+ except ImportError:
24
+ thop = None
25
+
26
+
27
+ class BaseModel(nn.Module):
28
+ """
29
+ The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.
30
+ """
31
+
32
+ def forward(self, x, *args, **kwargs):
33
+ """
34
+ Forward pass of the model on a single scale.
35
+ Wrapper for `_forward_once` method.
36
+
37
+ Args:
38
+ x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
39
+
40
+ Returns:
41
+ (torch.Tensor): The output of the network.
42
+ """
43
+ if isinstance(x, dict): # for cases of training and validating while training.
44
+ return self.loss(x, *args, **kwargs)
45
+ return self.predict(x, *args, **kwargs)
46
+
47
+ def predict(self, x, profile=False, visualize=False, augment=False):
48
+ """
49
+ Perform a forward pass through the network.
50
+
51
+ Args:
52
+ x (torch.Tensor): The input tensor to the model.
53
+ profile (bool): Print the computation time of each layer if True, defaults to False.
54
+ visualize (bool): Save the feature maps of the model if True, defaults to False.
55
+ augment (bool): Augment image during prediction, defaults to False.
56
+
57
+ Returns:
58
+ (torch.Tensor): The last output of the model.
59
+ """
60
+ if augment:
61
+ return self._predict_augment(x)
62
+ return self._predict_once(x, profile, visualize)
63
+
64
+ def _predict_once(self, x, profile=False, visualize=False):
65
+ """
66
+ Perform a forward pass through the network.
67
+
68
+ Args:
69
+ x (torch.Tensor): The input tensor to the model.
70
+ profile (bool): Print the computation time of each layer if True, defaults to False.
71
+ visualize (bool): Save the feature maps of the model if True, defaults to False.
72
+
73
+ Returns:
74
+ (torch.Tensor): The last output of the model.
75
+ """
76
+ y, dt = [], [] # outputs
77
+ for m in self.model:
78
+ if m.f != -1: # if not from previous layer
79
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
80
+ if profile:
81
+ self._profile_one_layer(m, x, dt)
82
+ x = m(x) # run
83
+ y.append(x if m.i in self.save else None) # save output
84
+ if visualize:
85
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
86
+ return x
87
+
88
+ def _predict_augment(self, x):
89
+ """Perform augmentations on input image x and return augmented inference."""
90
+ LOGGER.warning(
91
+ f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
92
+ )
93
+ return self._predict_once(x)
94
+
95
+ def _profile_one_layer(self, m, x, dt):
96
+ """
97
+ Profile the computation time and FLOPs of a single layer of the model on a given input.
98
+ Appends the results to the provided list.
99
+
100
+ Args:
101
+ m (nn.Module): The layer to be profiled.
102
+ x (torch.Tensor): The input data to the layer.
103
+ dt (list): A list to store the computation time of the layer.
104
+
105
+ Returns:
106
+ None
107
+ """
108
+ c = m == self.model[-1] # is final layer, copy input as inplace fix
109
+ o = thop.profile(m, inputs=[x.clone() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
110
+ t = time_sync()
111
+ for _ in range(10):
112
+ m(x.clone() if c else x)
113
+ dt.append((time_sync() - t) * 100)
114
+ if m == self.model[0]:
115
+ LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
116
+ LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
117
+ if c:
118
+ LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
119
+
120
+ def fuse(self, verbose=True):
121
+ """
122
+ Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
123
+ computation efficiency.
124
+
125
+ Returns:
126
+ (nn.Module): The fused model is returned.
127
+ """
128
+ if not self.is_fused():
129
+ for m in self.model.modules():
130
+ if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
131
+ if isinstance(m, Conv2):
132
+ m.fuse_convs()
133
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
134
+ delattr(m, 'bn') # remove batchnorm
135
+ m.forward = m.forward_fuse # update forward
136
+ if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
137
+ m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
138
+ delattr(m, 'bn') # remove batchnorm
139
+ m.forward = m.forward_fuse # update forward
140
+ if isinstance(m, RepConv):
141
+ m.fuse_convs()
142
+ m.forward = m.forward_fuse # update forward
143
+ self.info(verbose=verbose)
144
+
145
+ return self
146
+
147
+ def is_fused(self, thresh=10):
148
+ """
149
+ Check if the model has less than a certain threshold of BatchNorm layers.
150
+
151
+ Args:
152
+ thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
153
+
154
+ Returns:
155
+ (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
156
+ """
157
+ bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
158
+ return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
159
+
160
+ def info(self, detailed=False, verbose=True, imgsz=640):
161
+ """
162
+ Prints model information
163
+
164
+ Args:
165
+ verbose (bool): if True, prints out the model information. Defaults to False
166
+ imgsz (int): the size of the image that the model will be trained on. Defaults to 640
167
+ """
168
+ return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz)
169
+
170
+ def _apply(self, fn):
171
+ """
172
+ `_apply()` is a function that applies a function to all the tensors in the model that are not
173
+ parameters or registered buffers
174
+
175
+ Args:
176
+ fn: the function to apply to the model
177
+
178
+ Returns:
179
+ A model that is a Detect() object.
180
+ """
181
+ self = super()._apply(fn)
182
+ m = self.model[-1] # Detect()
183
+ if isinstance(m, (Detect, Segment)):
184
+ m.stride = fn(m.stride)
185
+ m.anchors = fn(m.anchors)
186
+ m.strides = fn(m.strides)
187
+ return self
188
+
189
+ def load(self, weights, verbose=True):
190
+ """Load the weights into the model.
191
+
192
+ Args:
193
+ weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
194
+ verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
195
+ """
196
+ model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
197
+ csd = model.float().state_dict() # checkpoint state_dict as FP32
198
+ csd = intersect_dicts(csd, self.state_dict()) # intersect
199
+ self.load_state_dict(csd, strict=False) # load
200
+ if verbose:
201
+ LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
202
+
203
+ def loss(self, batch, preds=None):
204
+ """
205
+ Compute loss
206
+
207
+ Args:
208
+ batch (dict): Batch to compute loss on
209
+ preds (torch.Tensor | List[torch.Tensor]): Predictions.
210
+ """
211
+ if not hasattr(self, 'criterion'):
212
+ self.criterion = self.init_criterion()
213
+
214
+ preds = self.forward(batch['img']) if preds is None else preds
215
+ return self.criterion(preds, batch)
216
+
217
+ def init_criterion(self):
218
+ raise NotImplementedError('compute_loss() needs to be implemented by task heads')
219
+
220
+
221
+ class DetectionModel(BaseModel):
222
+ """YOLOv8 detection model."""
223
+
224
+ def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
225
+ super().__init__()
226
+ self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
227
+
228
+ # Define model
229
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
230
+ if nc and nc != self.yaml['nc']:
231
+ LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
232
+ self.yaml['nc'] = nc # override yaml value
233
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
234
+ self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
235
+ self.inplace = self.yaml.get('inplace', True)
236
+
237
+ # Build strides
238
+ m = self.model[-1] # Detect()
239
+ if isinstance(m, (Detect, Segment, Pose)):
240
+ s = 256 # 2x min stride
241
+ m.inplace = self.inplace
242
+ forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
243
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
244
+ self.stride = m.stride
245
+ m.bias_init() # only run once
246
+ else:
247
+ self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
248
+
249
+ # Init weights, biases
250
+ initialize_weights(self)
251
+ if verbose:
252
+ self.info()
253
+ LOGGER.info('')
254
+
255
+ def _predict_augment(self, x):
256
+ """Perform augmentations on input image x and return augmented inference and train outputs."""
257
+ img_size = x.shape[-2:] # height, width
258
+ s = [1, 0.83, 0.67] # scales
259
+ f = [None, 3, None] # flips (2-ud, 3-lr)
260
+ y = [] # outputs
261
+ for si, fi in zip(s, f):
262
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
263
+ yi = super().predict(xi)[0] # forward
264
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
265
+ yi = self._descale_pred(yi, fi, si, img_size)
266
+ y.append(yi)
267
+ y = self._clip_augmented(y) # clip augmented tails
268
+ return torch.cat(y, -1), None # augmented inference, train
269
+
270
+ @staticmethod
271
+ def _descale_pred(p, flips, scale, img_size, dim=1):
272
+ """De-scale predictions following augmented inference (inverse operation)."""
273
+ p[:, :4] /= scale # de-scale
274
+ x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
275
+ if flips == 2:
276
+ y = img_size[0] - y # de-flip ud
277
+ elif flips == 3:
278
+ x = img_size[1] - x # de-flip lr
279
+ return torch.cat((x, y, wh, cls), dim)
280
+
281
+ def _clip_augmented(self, y):
282
+ """Clip YOLOv5 augmented inference tails."""
283
+ nl = self.model[-1].nl # number of detection layers (P3-P5)
284
+ g = sum(4 ** x for x in range(nl)) # grid points
285
+ e = 1 # exclude layer count
286
+ i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
287
+ y[0] = y[0][..., :-i] # large
288
+ i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
289
+ y[-1] = y[-1][..., i:] # small
290
+ return y
291
+
292
+ def init_criterion(self):
293
+ return v8DetectionLoss(self)
294
+
295
+
296
+ class SegmentationModel(DetectionModel):
297
+ """YOLOv8 segmentation model."""
298
+
299
+ def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
300
+ """Initialize YOLOv8 segmentation model with given config and parameters."""
301
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
302
+
303
+ def init_criterion(self):
304
+ return v8SegmentationLoss(self)
305
+
306
+ def _predict_augment(self, x):
307
+ """Perform augmentations on input image x and return augmented inference."""
308
+ LOGGER.warning(
309
+ f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
310
+ )
311
+ return self._predict_once(x)
312
+
313
+
314
+ class PoseModel(DetectionModel):
315
+ """YOLOv8 pose model."""
316
+
317
+ def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
318
+ """Initialize YOLOv8 Pose model."""
319
+ if not isinstance(cfg, dict):
320
+ cfg = yaml_model_load(cfg) # load model YAML
321
+ if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
322
+ LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
323
+ cfg['kpt_shape'] = data_kpt_shape
324
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
325
+
326
+ def init_criterion(self):
327
+ return v8PoseLoss(self)
328
+
329
+ def _predict_augment(self, x):
330
+ """Perform augmentations on input image x and return augmented inference."""
331
+ LOGGER.warning(
332
+ f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
333
+ )
334
+ return self._predict_once(x)
335
+
336
+
337
+ class ClassificationModel(BaseModel):
338
+ """YOLOv8 classification model."""
339
+
340
+ def __init__(self,
341
+ cfg=None,
342
+ model=None,
343
+ ch=3,
344
+ nc=None,
345
+ cutoff=10,
346
+ verbose=True): # yaml, model, channels, number of classes, cutoff index, verbose flag
347
+ super().__init__()
348
+ self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
349
+
350
+ def _from_detection_model(self, model, nc=1000, cutoff=10):
351
+ """Create a YOLOv5 classification model from a YOLOv5 detection model."""
352
+ from ultralytics.nn.autobackend import AutoBackend
353
+ if isinstance(model, AutoBackend):
354
+ model = model.model # unwrap DetectMultiBackend
355
+ model.model = model.model[:cutoff] # backbone
356
+ m = model.model[-1] # last layer
357
+ ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
358
+ c = Classify(ch, nc) # Classify()
359
+ c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
360
+ model.model[-1] = c # replace
361
+ self.model = model.model
362
+ self.stride = model.stride
363
+ self.save = []
364
+ self.nc = nc
365
+
366
+ def _from_yaml(self, cfg, ch, nc, verbose):
367
+ """Set YOLOv8 model configurations and define the model architecture."""
368
+ self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
369
+
370
+ # Define model
371
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
372
+ if nc and nc != self.yaml['nc']:
373
+ LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
374
+ self.yaml['nc'] = nc # override yaml value
375
+ elif not nc and not self.yaml.get('nc', None):
376
+ raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
377
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
378
+ self.stride = torch.Tensor([1]) # no stride constraints
379
+ self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
380
+ self.info()
381
+
382
+ @staticmethod
383
+ def reshape_outputs(model, nc):
384
+ """Update a TorchVision classification model to class count 'n' if required."""
385
+ name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
386
+ if isinstance(m, Classify): # YOLO Classify() head
387
+ if m.linear.out_features != nc:
388
+ m.linear = nn.Linear(m.linear.in_features, nc)
389
+ elif isinstance(m, nn.Linear): # ResNet, EfficientNet
390
+ if m.out_features != nc:
391
+ setattr(model, name, nn.Linear(m.in_features, nc))
392
+ elif isinstance(m, nn.Sequential):
393
+ types = [type(x) for x in m]
394
+ if nn.Linear in types:
395
+ i = types.index(nn.Linear) # nn.Linear index
396
+ if m[i].out_features != nc:
397
+ m[i] = nn.Linear(m[i].in_features, nc)
398
+ elif nn.Conv2d in types:
399
+ i = types.index(nn.Conv2d) # nn.Conv2d index
400
+ if m[i].out_channels != nc:
401
+ m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
402
+
403
+ def init_criterion(self):
404
+ """Compute the classification loss between predictions and true labels."""
405
+ return v8ClassificationLoss()
406
+
407
+
408
+ class RTDETRDetectionModel(DetectionModel):
409
+
410
+ def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
411
+ super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
412
+
413
+ def init_criterion(self):
414
+ """Compute the classification loss between predictions and true labels."""
415
+ from ultralytics.vit.utils.loss import RTDETRDetectionLoss
416
+
417
+ return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
418
+
419
+ def loss(self, batch, preds=None):
420
+ if not hasattr(self, 'criterion'):
421
+ self.criterion = self.init_criterion()
422
+
423
+ img = batch['img']
424
+ # NOTE: preprocess gt_bbox and gt_labels to list.
425
+ bs = len(img)
426
+ batch_idx = batch['batch_idx']
427
+ gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
428
+ targets = {
429
+ 'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
430
+ 'bboxes': batch['bboxes'].to(device=img.device),
431
+ 'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
432
+ 'gt_groups': gt_groups}
433
+
434
+ preds = self.predict(img, batch=targets) if preds is None else preds
435
+ dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds
436
+ if dn_meta is None:
437
+ dn_bboxes, dn_scores = None, None
438
+ else:
439
+ dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
440
+ dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
441
+
442
+ dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
443
+ dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
444
+
445
+ loss = self.criterion((dec_bboxes, dec_scores),
446
+ targets,
447
+ dn_bboxes=dn_bboxes,
448
+ dn_scores=dn_scores,
449
+ dn_meta=dn_meta)
450
+ # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
451
+ return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
452
+ device=img.device)
453
+
454
+ def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
455
+ """
456
+ Perform a forward pass through the network.
457
+
458
+ Args:
459
+ x (torch.Tensor): The input tensor to the model
460
+ profile (bool): Print the computation time of each layer if True, defaults to False.
461
+ visualize (bool): Save the feature maps of the model if True, defaults to False
462
+ batch (dict): A dict including gt boxes and labels from dataloader.
463
+
464
+ Returns:
465
+ (torch.Tensor): The last output of the model.
466
+ """
467
+ y, dt = [], [] # outputs
468
+ for m in self.model[:-1]: # except the head part
469
+ if m.f != -1: # if not from previous layer
470
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
471
+ if profile:
472
+ self._profile_one_layer(m, x, dt)
473
+ x = m(x) # run
474
+ y.append(x if m.i in self.save else None) # save output
475
+ if visualize:
476
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
477
+ head = self.model[-1]
478
+ x = head([y[j] for j in head.f], batch) # head inference
479
+ return x
480
+
481
+
482
+ class Ensemble(nn.ModuleList):
483
+ """Ensemble of models."""
484
+
485
+ def __init__(self):
486
+ """Initialize an ensemble of models."""
487
+ super().__init__()
488
+
489
+ def forward(self, x, augment=False, profile=False, visualize=False):
490
+ """Function generates the YOLOv5 network's final layer."""
491
+ y = [module(x, augment, profile, visualize)[0] for module in self]
492
+ # y = torch.stack(y).max(0)[0] # max ensemble
493
+ # y = torch.stack(y).mean(0) # mean ensemble
494
+ y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
495
+ return y, None # inference, train output
496
+
497
+
498
+ # Functions ------------------------------------------------------------------------------------------------------------
499
+
500
+
501
+ def torch_safe_load(weight):
502
+ """
503
+ This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
504
+ it catches the error, logs a warning message, and attempts to install the missing module via the
505
+ check_requirements() function. After installation, the function again attempts to load the model using torch.load().
506
+
507
+ Args:
508
+ weight (str): The file path of the PyTorch model.
509
+
510
+ Returns:
511
+ (dict): The loaded PyTorch model.
512
+ """
513
+ from ultralytics.yolo.utils.downloads import attempt_download_asset
514
+
515
+ check_suffix(file=weight, suffix='.pt')
516
+ file = attempt_download_asset(weight) # search online if missing locally
517
+ try:
518
+ return torch.load(file, map_location='cpu'), file # load
519
+ except ModuleNotFoundError as e: # e.name is missing module name
520
+ if e.name == 'models':
521
+ raise TypeError(
522
+ emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
523
+ f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
524
+ f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
525
+ f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
526
+ f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
527
+ LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
528
+ f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
529
+ f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
530
+ f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
531
+ check_requirements(e.name) # install missing module
532
+
533
+ return torch.load(file, map_location='cpu'), file # load
534
+
535
+
536
+ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
537
+ """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
538
+
539
+ ensemble = Ensemble()
540
+ for w in weights if isinstance(weights, list) else [weights]:
541
+ ckpt, w = torch_safe_load(w) # load ckpt
542
+ args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args
543
+ model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
544
+
545
+ # Model compatibility updates
546
+ model.args = args # attach args to model
547
+ model.pt_path = w # attach *.pt file path to model
548
+ model.task = guess_model_task(model)
549
+ if not hasattr(model, 'stride'):
550
+ model.stride = torch.tensor([32.])
551
+
552
+ # Append
553
+ ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
554
+
555
+ # Module compatibility updates
556
+ for m in ensemble.modules():
557
+ t = type(m)
558
+ if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
559
+ m.inplace = inplace # torch 1.7.0 compatibility
560
+ elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
561
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
562
+
563
+ # Return model
564
+ if len(ensemble) == 1:
565
+ return ensemble[-1]
566
+
567
+ # Return ensemble
568
+ LOGGER.info(f'Ensemble created with {weights}\n')
569
+ for k in 'names', 'nc', 'yaml':
570
+ setattr(ensemble, k, getattr(ensemble[0], k))
571
+ ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
572
+ assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
573
+ return ensemble
574
+
575
+
576
+ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
577
+ """Loads a single model weights."""
578
+ ckpt, weight = torch_safe_load(weight) # load ckpt
579
+ args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
580
+ model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
581
+
582
+ # Model compatibility updates
583
+ model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
584
+ model.pt_path = weight # attach *.pt file path to model
585
+ model.task = guess_model_task(model)
586
+ if not hasattr(model, 'stride'):
587
+ model.stride = torch.tensor([32.])
588
+
589
+ model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
590
+
591
+ # Module compatibility updates
592
+ for m in model.modules():
593
+ t = type(m)
594
+ if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
595
+ m.inplace = inplace # torch 1.7.0 compatibility
596
+ elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
597
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
598
+
599
+ # Return model and ckpt
600
+ return model, ckpt
601
+
602
+
603
+ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
604
+ # Parse a YOLO model.yaml dictionary into a PyTorch model
605
+ import ast
606
+
607
+ # Args
608
+ max_channels = float('inf')
609
+ nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
610
+ depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
611
+ if scales:
612
+ scale = d.get('scale')
613
+ if not scale:
614
+ scale = tuple(scales.keys())[0]
615
+ LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
616
+ depth, width, max_channels = scales[scale]
617
+
618
+ if act:
619
+ Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
620
+ if verbose:
621
+ LOGGER.info(f"{colorstr('activation:')} {act}") # print
622
+
623
+ if verbose:
624
+ LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
625
+ ch = [ch]
626
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
627
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
628
+ m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
629
+ for j, a in enumerate(args):
630
+ if isinstance(a, str):
631
+ with contextlib.suppress(ValueError):
632
+ args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
633
+
634
+ n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
635
+ if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
636
+ BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
637
+ c1, c2 = ch[f], args[0]
638
+ if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
639
+ c2 = make_divisible(min(c2, max_channels) * width, 8)
640
+
641
+ args = [c1, c2, *args[1:]]
642
+ if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):
643
+ args.insert(2, n) # number of repeats
644
+ n = 1
645
+ elif m is AIFI:
646
+ args = [ch[f], *args]
647
+ elif m in (HGStem, HGBlock):
648
+ c1, cm, c2 = ch[f], args[0], args[1]
649
+ args = [c1, cm, c2, *args[2:]]
650
+ if m is HGBlock:
651
+ args.insert(4, n) # number of repeats
652
+ n = 1
653
+
654
+ elif m is nn.BatchNorm2d:
655
+ args = [ch[f]]
656
+ elif m is Concat:
657
+ c2 = sum(ch[x] for x in f)
658
+ elif m in (Detect, Segment, Pose, RTDETRDecoder):
659
+ args.append([ch[x] for x in f])
660
+ if m is Segment:
661
+ args[2] = make_divisible(min(args[2], max_channels) * width, 8)
662
+ else:
663
+ c2 = ch[f]
664
+
665
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
666
+ t = str(m)[8:-2].replace('__main__.', '') # module type
667
+ m.np = sum(x.numel() for x in m_.parameters()) # number params
668
+ m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
669
+ if verbose:
670
+ LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
671
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
672
+ layers.append(m_)
673
+ if i == 0:
674
+ ch = []
675
+ ch.append(c2)
676
+ return nn.Sequential(*layers), sorted(save)
677
+
678
+
679
+ def yaml_model_load(path):
680
+ """Load a YOLOv8 model from a YAML file."""
681
+ import re
682
+
683
+ path = Path(path)
684
+ if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)):
685
+ new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem)
686
+ LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.')
687
+ path = path.with_stem(new_stem)
688
+
689
+ unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
690
+ yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
691
+ d = yaml_load(yaml_file) # model dict
692
+ d['scale'] = guess_model_scale(path)
693
+ d['yaml_file'] = str(path)
694
+ return d
695
+
696
+
697
+ def guess_model_scale(model_path):
698
+ """
699
+ Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale.
700
+ The function uses regular expression matching to find the pattern of the model scale in the YAML file name,
701
+ which is denoted by n, s, m, l, or x. The function returns the size character of the model scale as a string.
702
+
703
+ Args:
704
+ model_path (str | Path): The path to the YOLO model's YAML file.
705
+
706
+ Returns:
707
+ (str): The size character of the model's scale, which can be n, s, m, l, or x.
708
+ """
709
+ with contextlib.suppress(AttributeError):
710
+ import re
711
+ return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x
712
+ return ''
713
+
714
+
715
+ def guess_model_task(model):
716
+ """
717
+ Guess the task of a PyTorch model from its architecture or configuration.
718
+
719
+ Args:
720
+ model (nn.Module | dict): PyTorch model or model configuration in YAML format.
721
+
722
+ Returns:
723
+ (str): Task of the model ('detect', 'segment', 'classify', 'pose').
724
+
725
+ Raises:
726
+ SyntaxError: If the task of the model could not be determined.
727
+ """
728
+
729
+ def cfg2task(cfg):
730
+ """Guess from YAML dictionary."""
731
+ m = cfg['head'][-1][-2].lower() # output module name
732
+ if m in ('classify', 'classifier', 'cls', 'fc'):
733
+ return 'classify'
734
+ if m == 'detect':
735
+ return 'detect'
736
+ if m == 'segment':
737
+ return 'segment'
738
+ if m == 'pose':
739
+ return 'pose'
740
+
741
+ # Guess from model cfg
742
+ if isinstance(model, dict):
743
+ with contextlib.suppress(Exception):
744
+ return cfg2task(model)
745
+
746
+ # Guess from PyTorch model
747
+ if isinstance(model, nn.Module): # PyTorch model
748
+ for x in 'model.args', 'model.model.args', 'model.model.model.args':
749
+ with contextlib.suppress(Exception):
750
+ return eval(x)['task']
751
+ for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
752
+ with contextlib.suppress(Exception):
753
+ return cfg2task(eval(x))
754
+
755
+ for m in model.modules():
756
+ if isinstance(m, Detect):
757
+ return 'detect'
758
+ elif isinstance(m, Segment):
759
+ return 'segment'
760
+ elif isinstance(m, Classify):
761
+ return 'classify'
762
+ elif isinstance(m, Pose):
763
+ return 'pose'
764
+
765
+ # Guess from model filename
766
+ if isinstance(model, (str, Path)):
767
+ model = Path(model)
768
+ if '-seg' in model.stem or 'segment' in model.parts:
769
+ return 'segment'
770
+ elif '-cls' in model.stem or 'classify' in model.parts:
771
+ return 'classify'
772
+ elif '-pose' in model.stem or 'pose' in model.parts:
773
+ return 'pose'
774
+ elif 'detect' in model.parts:
775
+ return 'detect'
776
+
777
+ # Unable to determine task from model
778
+ LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
779
+ "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.")
780
+ return 'detect' # assume detect