hysts HF staff commited on
Commit
aa3f835
1 Parent(s): c570d14
Files changed (1) hide show
  1. app.py +39 -39
app.py CHANGED
@@ -8,6 +8,7 @@ import json
8
  import os
9
  import pathlib
10
  import subprocess
 
11
  from typing import Callable
12
 
13
  # workaround for https://github.com/gradio-app/gradio/issues/483
@@ -22,6 +23,10 @@ import torchvision.transforms as T
22
 
23
  TOKEN = os.environ['TOKEN']
24
 
 
 
 
 
25
 
26
  def parse_args() -> argparse.Namespace:
27
  parser = argparse.ArgumentParser()
@@ -40,21 +45,40 @@ def parse_args() -> argparse.Namespace:
40
  return parser.parse_args()
41
 
42
 
43
- def download_sample_images() -> list[pathlib.Path]:
44
- image_dir = pathlib.Path('samples')
45
- image_dir.mkdir(exist_ok=True)
46
-
47
- dataset_repo = 'hysts/sample-images-TADNE'
48
- n_images = 36
49
- paths = []
50
- for index in range(n_images):
51
  path = huggingface_hub.hf_hub_download(dataset_repo,
52
- f'{index:02d}.jpg',
53
  repo_type='dataset',
54
- cache_dir=image_dir.as_posix(),
55
  use_auth_token=TOKEN)
56
- paths.append(pathlib.Path(path))
57
- return paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  @torch.inference_mode()
@@ -75,40 +99,18 @@ def predict(image: PIL.Image.Image, score_threshold: float,
75
  return res
76
 
77
 
78
- def load_labels() -> list[str]:
79
- label_path = pathlib.Path('class_names_6000.json')
80
- label_url = 'https://raw.githubusercontent.com/RF5/danbooru-pretrained/master/config/class_names_6000.json'
81
- if not label_path.exists():
82
- torch.hub.download_url_to_file(label_url, label_path.as_posix())
83
- with open(label_path) as f:
84
- labels = json.load(f)
85
- return labels
86
-
87
-
88
  def main():
89
  gr.close_all()
90
 
91
  args = parse_args()
92
  device = torch.device(args.device)
93
 
94
- image_paths = download_sample_images()
95
  examples = [[path.as_posix(), args.score_threshold]
96
  for path in image_paths]
97
 
98
- if device.type == 'cpu':
99
- model_path = pathlib.Path('resnet50-13306192.pth')
100
- model_url = 'https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth'
101
- if not model_path.exists():
102
- torch.hub.download_url_to_file(model_url, model_path.as_posix())
103
- model = torch.hub.load('RF5/danbooru-pretrained',
104
- 'resnet50',
105
- pretrained=False)
106
- state_dict = torch.load(model_path, map_location=device)
107
- model.load_state_dict(state_dict)
108
- else:
109
- model = torch.hub.load('RF5/danbooru-pretrained', 'resnet50')
110
- model.to(device)
111
- model.eval()
112
 
113
  transform = T.Compose([
114
  T.Resize(360),
@@ -117,8 +119,6 @@ def main():
117
  std=[0.2970, 0.3017, 0.2979]),
118
  ])
119
 
120
- labels = load_labels()
121
-
122
  func = functools.partial(predict,
123
  transform=transform,
124
  device=device,
 
8
  import os
9
  import pathlib
10
  import subprocess
11
+ import tarfile
12
  from typing import Callable
13
 
14
  # workaround for https://github.com/gradio-app/gradio/issues/483
 
23
 
24
  TOKEN = os.environ['TOKEN']
25
 
26
+ MODEL_REPO = 'hysts/danbooru-pretrained'
27
+ MODEL_FILENAME = 'resnet50-13306192.pth'
28
+ LABEL_FILENAME = 'class_names_6000.json'
29
+
30
 
31
  def parse_args() -> argparse.Namespace:
32
  parser = argparse.ArgumentParser()
 
45
  return parser.parse_args()
46
 
47
 
48
+ def load_sample_image_paths() -> list[pathlib.Path]:
49
+ image_dir = pathlib.Path('images')
50
+ if not image_dir.exists():
51
+ dataset_repo = 'hysts/sample-images-TADNE'
 
 
 
 
52
  path = huggingface_hub.hf_hub_download(dataset_repo,
53
+ 'images.tar.gz',
54
  repo_type='dataset',
 
55
  use_auth_token=TOKEN)
56
+ with tarfile.open(path) as f:
57
+ f.extractall()
58
+ return sorted(image_dir.glob('*'))
59
+
60
+
61
+ def load_model(device: torch.device) -> torch.nn.Module:
62
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
63
+ MODEL_FILENAME,
64
+ use_auth_token=TOKEN)
65
+ state_dict = torch.load(path)
66
+ model = torch.hub.load('RF5/danbooru-pretrained',
67
+ 'resnet50',
68
+ pretrained=False)
69
+ model.load_state_dict(state_dict)
70
+ model.to(device)
71
+ model.eval()
72
+ return model
73
+
74
+
75
+ def load_labels() -> list[str]:
76
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
77
+ LABEL_FILENAME,
78
+ use_auth_token=TOKEN)
79
+ with open(path) as f:
80
+ labels = json.load(f)
81
+ return labels
82
 
83
 
84
  @torch.inference_mode()
 
99
  return res
100
 
101
 
 
 
 
 
 
 
 
 
 
 
102
  def main():
103
  gr.close_all()
104
 
105
  args = parse_args()
106
  device = torch.device(args.device)
107
 
108
+ image_paths = load_sample_image_paths()
109
  examples = [[path.as_posix(), args.score_threshold]
110
  for path in image_paths]
111
 
112
+ model = load_model(device)
113
+ labels = load_labels()
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  transform = T.Compose([
116
  T.Resize(360),
 
119
  std=[0.2970, 0.3017, 0.2979]),
120
  ])
121
 
 
 
122
  func = functools.partial(predict,
123
  transform=transform,
124
  device=device,