Samuel Stevens commited on
Commit
2cfb891
·
1 Parent(s): 290c238
Files changed (5) hide show
  1. README.md +2 -2
  2. app.py +169 -31
  3. lib.py +11 -7
  4. make_txt_embedding.py +46 -16
  5. txt_emb.npy +3 -0
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
  title: Bioclip Demo
3
- emoji: 👀
4
  colorFrom: indigo
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.7.1
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
 
1
  ---
2
  title: Bioclip Demo
3
+ emoji: 🐘
4
  colorFrom: indigo
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.7.1
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  ---
app.py CHANGED
@@ -1,24 +1,29 @@
 
1
  import os
2
 
3
  import gradio as gr
 
4
  import torch
5
  import torch.nn.functional as F
6
  from open_clip import create_model, get_tokenizer
7
  from torchvision import transforms
8
 
 
9
  from templates import openai_imagenet_template
10
 
11
  hf_token = os.getenv("HF_TOKEN")
12
- hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo")
13
 
14
  model_str = "hf-hub:imageomics/bioclip"
15
  tokenizer_str = "ViT-B-16"
 
 
16
 
17
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
 
19
  preprocess_img = transforms.Compose(
20
  [
21
  transforms.ToTensor(),
 
22
  transforms.Normalize(
23
  mean=(0.48145466, 0.4578275, 0.40821073),
24
  std=(0.26862954, 0.26130258, 0.27577711),
@@ -26,6 +31,28 @@ preprocess_img = transforms.Compose(
26
  ]
27
  )
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @torch.no_grad()
31
  def get_txt_features(classnames, templates):
@@ -42,8 +69,8 @@ def get_txt_features(classnames, templates):
42
 
43
 
44
  @torch.no_grad()
45
- def predict(img, classes: list[str]) -> dict[str, float]:
46
- classes = [cls.strip() for cls in classes if cls.strip()]
47
  txt_features = get_txt_features(classes, openai_imagenet_template)
48
 
49
  img = preprocess_img(img).to(device)
@@ -55,7 +82,8 @@ def predict(img, classes: list[str]) -> dict[str, float]:
55
  return {cls: prob for cls, prob in zip(classes, probs)}
56
 
57
 
58
- def hierarchical_predict(img) -> list[str]:
 
59
  """
60
  Predicts from the top of the tree of life down to the species.
61
  """
@@ -63,16 +91,44 @@ def hierarchical_predict(img) -> list[str]:
63
  img_features = model.encode_image(img.unsqueeze(0))
64
  img_features = F.normalize(img_features, dim=-1)
65
 
66
- breakpoint()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
- def run(img, cls_str: str) -> dict[str, float]:
70
- breakpoint()
71
- if cls_str:
72
- classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
73
- return predict(img, classes)
74
- else:
75
- return hierarchical_predict(img)
76
 
77
 
78
  if __name__ == "__main__":
@@ -86,22 +142,104 @@ if __name__ == "__main__":
86
 
87
  tokenizer = get_tokenizer(tokenizer_str)
88
 
89
- demo = gr.Interface(
90
- fn=run,
91
- inputs=[
92
- gr.Image(shape=(224, 224)),
93
- gr.Textbox(
94
- placeholder="dog\ncat\n...",
95
- lines=3,
96
- label="Classes",
97
- show_label=True,
98
- info="If empty, will predict from the entire tree of life.",
99
- ),
100
- ],
101
- outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
102
- allow_flagging="manual",
103
- flagging_options=["Incorrect", "Other"],
104
- flagging_callback=hf_writer,
105
- )
106
-
107
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import os
3
 
4
  import gradio as gr
5
+ import numpy as np
6
  import torch
7
  import torch.nn.functional as F
8
  from open_clip import create_model, get_tokenizer
9
  from torchvision import transforms
10
 
11
+ import lib
12
  from templates import openai_imagenet_template
13
 
14
  hf_token = os.getenv("HF_TOKEN")
 
15
 
16
  model_str = "hf-hub:imageomics/bioclip"
17
  tokenizer_str = "ViT-B-16"
18
+ name_lookup_json = "name_lookup.json"
19
+ txt_emb_npy = "txt_emb.npy"
20
 
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
  preprocess_img = transforms.Compose(
24
  [
25
  transforms.ToTensor(),
26
+ transforms.Resize((224, 224), antialias=True),
27
  transforms.Normalize(
28
  mean=(0.48145466, 0.4578275, 0.40821073),
29
  std=(0.26862954, 0.26130258, 0.27577711),
 
31
  ]
32
  )
33
 
34
+ ranks = ("Kingdom", "Phylum", "Class", "Order", "Family", "Genus", "Species")
35
+
36
+ open_domain_examples = [
37
+ ["examples/Ursus-arctos.jpeg", "Species"],
38
+ ["examples/Phoca-vitulina.png", "Species"],
39
+ ["examples/Felis-catus.jpeg", "Genus"],
40
+ ]
41
+ zero_shot_examples = [
42
+ [
43
+ "examples/Carnegiea-gigantea.png",
44
+ "Carnegiea gigantea\nSchlumbergera opuntioides\nMammillaria albicoma",
45
+ ],
46
+ [
47
+ "examples/Amanita-muscaria.jpeg",
48
+ "Amanita fulva\nAmanita vaginata (grisette)\nAmanita calyptrata (coccoli)\nAmanita crocea\nAmanita rubescens (blusher)\nAmanita caesarea (Caesar's mushroom)\nAmanita jacksonii (American Caesar's mushroom)\nAmanita muscaria (fly agaric)\nAmanita pantherina (panther cap)",
49
+ ],
50
+ [
51
+ "examples/Actinostola-abyssorum.png",
52
+ "Animalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola abyssorum\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola bulbosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola callosa\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola capensis\nAnimalia Cnidaria Hexacorallia Actiniaria Actinostolidae Actinostola carlgreni",
53
+ ],
54
+ ]
55
+
56
 
57
  @torch.no_grad()
58
  def get_txt_features(classnames, templates):
 
69
 
70
 
71
  @torch.no_grad()
72
+ def zero_shot_classification(img, cls_str: str) -> dict[str, float]:
73
+ classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
74
  txt_features = get_txt_features(classes, openai_imagenet_template)
75
 
76
  img = preprocess_img(img).to(device)
 
82
  return {cls: prob for cls, prob in zip(classes, probs)}
83
 
84
 
85
+ @torch.no_grad()
86
+ def open_domain_classification(img, rank: int) -> list[dict[str, float]]:
87
  """
88
  Predicts from the top of the tree of life down to the species.
89
  """
 
91
  img_features = model.encode_image(img.unsqueeze(0))
92
  img_features = F.normalize(img_features, dim=-1)
93
 
94
+ outputs = []
95
+
96
+ name = []
97
+ for _ in range(rank + 1):
98
+ children = tuple(zip(*name_lookup.children(name)))
99
+ if not children:
100
+ break
101
+ values, indices = children
102
+ txt_features = txt_emb[:, indices].to(device)
103
+ logits = (model.logit_scale.exp() * img_features @ txt_features).view(-1)
104
+
105
+ probs = F.softmax(logits, dim=0).to("cpu").tolist()
106
+ parent = " ".join(name)
107
+ outputs.append(
108
+ {f"{parent} {value}": prob for value, prob in zip(values, probs)}
109
+ )
110
+
111
+ top = values[logits.argmax()]
112
+ name.append(top)
113
+
114
+ while len(outputs) < 7:
115
+ outputs.append({})
116
+
117
+ return list(reversed(outputs))
118
+
119
+
120
+ def change_output(choice):
121
+ return [
122
+ gr.Label(
123
+ num_top_classes=5, label=rank, show_label=True, visible=(6 - i <= choice)
124
+ )
125
+ for i, rank in enumerate(reversed(ranks))
126
+ ]
127
 
128
 
129
+ def get_name_lookup(path):
130
+ with open(path) as fd:
131
+ return lib.TaxonomicTree.from_dict(json.load(fd))
 
 
 
 
132
 
133
 
134
  if __name__ == "__main__":
 
142
 
143
  tokenizer = get_tokenizer(tokenizer_str)
144
 
145
+ name_lookup = get_name_lookup(name_lookup_json)
146
+ txt_emb = torch.from_numpy(np.load(txt_emb_npy, mmap_mode="r"))
147
+
148
+ done = txt_emb.any(axis=0).sum().item()
149
+ total = txt_emb.shape[1]
150
+ status_msg = ""
151
+ if done != total:
152
+ status_msg = f"{done}/{total} ({done / total * 100:.1f}%) indexed"
153
+
154
+ with gr.Blocks() as app:
155
+ img_input = gr.Image()
156
+
157
+ with gr.Tab("Open-Ended"):
158
+ with gr.Row():
159
+ with gr.Column():
160
+ rank_dropdown = gr.Dropdown(
161
+ label="Taxonomic Rank",
162
+ info="Which taxonomic rank to predict. Fine-grained ranks (genus, species) are more challenging.",
163
+ choices=ranks,
164
+ value="Species",
165
+ type="index",
166
+ )
167
+ open_domain_btn = gr.Button("Submit", variant="primary")
168
+ gr.Examples(
169
+ examples=open_domain_examples,
170
+ inputs=[img_input, rank_dropdown],
171
+ )
172
+
173
+ with gr.Column():
174
+ open_domain_outputs = [
175
+ gr.Label(num_top_classes=5, label=rank, show_label=True)
176
+ for rank in reversed(ranks)
177
+ ]
178
+ open_domain_flag_btn = gr.Button("Flag Mistake", variant="primary")
179
+
180
+ open_domain_callback = gr.HuggingFaceDatasetSaver(
181
+ hf_token, "imageomics/bioclip-demo-open-domain-mistakes", private=True
182
+ )
183
+ open_domain_callback.setup(
184
+ [img_input, *open_domain_outputs], flagging_dir="logs/flagged"
185
+ )
186
+ open_domain_flag_btn.click(
187
+ lambda *args: open_domain_callback.flag(args),
188
+ [img_input, *open_domain_outputs],
189
+ None,
190
+ preprocess=False,
191
+ )
192
+
193
+ with gr.Tab("Zero-Shot"):
194
+ with gr.Row():
195
+ with gr.Column():
196
+ classes_txt = gr.Textbox(
197
+ placeholder="Canis familiaris (dog)\nFelis catus (cat)\n...",
198
+ lines=3,
199
+ label="Classes",
200
+ show_label=True,
201
+ info="Use taxonomic names where possible; include common names if possible.",
202
+ )
203
+ zero_shot_btn = gr.Button("Submit", variant="primary")
204
+ gr.Examples(
205
+ examples=zero_shot_examples,
206
+ inputs=[img_input, classes_txt],
207
+ )
208
+
209
+ with gr.Column():
210
+ zero_shot_output = gr.Label(
211
+ num_top_classes=5, label="Prediction", show_label=True
212
+ )
213
+ zero_shot_flag_btn = gr.Button("Flag Mistake", variant="primary")
214
+
215
+ zero_shot_callback = gr.HuggingFaceDatasetSaver(
216
+ hf_token, "imageomics/bioclip-demo-zero-shot-mistakes", private=True
217
+ )
218
+ zero_shot_callback.setup(
219
+ [img_input, zero_shot_output], flagging_dir="logs/flagged"
220
+ )
221
+ zero_shot_flag_btn.click(
222
+ lambda *args: zero_shot_callback.flag(args),
223
+ [img_input, zero_shot_output],
224
+ None,
225
+ preprocess=False,
226
+ )
227
+
228
+ rank_dropdown.change(
229
+ fn=change_output, inputs=rank_dropdown, outputs=open_domain_outputs
230
+ )
231
+
232
+ open_domain_btn.click(
233
+ fn=open_domain_classification,
234
+ inputs=[img_input, rank_dropdown],
235
+ outputs=open_domain_outputs,
236
+ )
237
+
238
+ zero_shot_btn.click(
239
+ fn=zero_shot_classification,
240
+ inputs=[img_input, classes_txt],
241
+ outputs=zero_shot_output,
242
+ )
243
+
244
+ app.queue(max_size=20)
245
+ app.launch()
lib.py CHANGED
@@ -1,5 +1,5 @@
1
- import json
2
  import itertools
 
3
 
4
 
5
  class TaxonomicNode:
@@ -43,11 +43,12 @@ class TaxonomicNode:
43
  @classmethod
44
  def from_dict(cls, dct, root):
45
  node = cls(dct["name"], dct["index"], root)
46
- node._children = {child["name"]: cls.from_dict(child, root) for child in dct["children"]}
 
 
47
  return node
48
 
49
 
50
-
51
  class TaxonomicTree:
52
  """
53
  Efficient structure for finding taxonomic names and their descendants.
@@ -85,11 +86,15 @@ class TaxonomicTree:
85
  for kingdom in self.kingdoms.values():
86
  yield from kingdom
87
 
 
 
 
88
  @classmethod
89
  def from_dict(cls, dct):
90
  tree = cls()
91
  tree.kingdoms = {
92
- kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree) for kingdom in dct["kingdoms"]
 
93
  }
94
  tree.size = dct["size"]
95
  return tree
@@ -112,11 +117,10 @@ class TaxonomicJsonEncoder(json.JSONEncoder):
112
  super().default(self, obj)
113
 
114
 
115
-
116
  def batched(iterable, n):
117
  # batched('ABCDEFG', 3) --> ABC DEF G
118
  if n < 1:
119
- raise ValueError('n must be at least one')
120
  it = iter(iterable)
121
  while batch := tuple(itertools.islice(it, n)):
122
- yield zip(*batch)
 
 
1
  import itertools
2
+ import json
3
 
4
 
5
  class TaxonomicNode:
 
43
  @classmethod
44
  def from_dict(cls, dct, root):
45
  node = cls(dct["name"], dct["index"], root)
46
+ node._children = {
47
+ child["name"]: cls.from_dict(child, root) for child in dct["children"]
48
+ }
49
  return node
50
 
51
 
 
52
  class TaxonomicTree:
53
  """
54
  Efficient structure for finding taxonomic names and their descendants.
 
86
  for kingdom in self.kingdoms.values():
87
  yield from kingdom
88
 
89
+ def __len__(self):
90
+ return self.size
91
+
92
  @classmethod
93
  def from_dict(cls, dct):
94
  tree = cls()
95
  tree.kingdoms = {
96
+ kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree)
97
+ for kingdom in dct["kingdoms"]
98
  }
99
  tree.size = dct["size"]
100
  return tree
 
117
  super().default(self, obj)
118
 
119
 
 
120
  def batched(iterable, n):
121
  # batched('ABCDEFG', 3) --> ABC DEF G
122
  if n < 1:
123
+ raise ValueError("n must be at least one")
124
  it = iter(iterable)
125
  while batch := tuple(itertools.islice(it, n)):
126
+ yield zip(*batch)
make_txt_embedding.py CHANGED
@@ -5,6 +5,7 @@ Uses the catalog.csv file from TreeOfLife-10M.
5
  import argparse
6
  import csv
7
  import json
 
8
 
9
  import numpy as np
10
  import torch
@@ -22,29 +23,53 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
  @torch.no_grad()
24
  def write_txt_features(name_lookup):
25
- all_features = np.memmap(
26
- args.out_path, dtype=np.float32, mode="w+", shape=(512, name_lookup.size)
27
- )
 
28
 
29
  batch_size = args.batch_size // len(openai_imagenet_template)
30
- for names, indices in tqdm(lib.batched(name_lookup, batch_size)):
31
- txts = [template(name) for name in names for template in openai_imagenet_template]
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  txts = tokenizer(txts).to(device)
33
  txt_features = model.encode_text(txts)
34
- txt_features = torch.reshape(txt_features, (batch_size, len(openai_imagenet_template), 512))
 
 
35
  txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
36
  txt_features /= txt_features.norm(dim=1, keepdim=True)
37
- all_features[:, indices] = txt_features.cpu().numpy().T
 
 
 
38
 
39
- all_features.flush()
40
 
41
 
42
- def get_name_lookup(catalog_path):
 
 
 
 
 
43
  lookup = lib.TaxonomicTree()
44
 
45
  with open(catalog_path) as fd:
46
  reader = csv.DictReader(fd)
47
- for row in tqdm(reader):
48
  name = [
49
  row["kingdom"],
50
  row["phylum"],
@@ -58,6 +83,9 @@ def get_name_lookup(catalog_path):
58
  name = name[: name.index("")]
59
  lookup.add(name)
60
 
 
 
 
61
  return lookup
62
 
63
 
@@ -69,15 +97,17 @@ if __name__ == "__main__":
69
  required=True,
70
  )
71
  parser.add_argument("--out-path", help="Path to the output file.", required=True)
72
- parser.add_argument("--name-cache-path", help="Path to the name cache file.", default=".name_lookup_cache.json")
73
- parser.add_argument("--batch-size", help="Batch size.", default=2 ** 15, type=int)
 
 
 
 
74
  args = parser.parse_args()
75
 
76
- name_lookup = get_name_lookup(args.catalog_path)
77
- with open(args.name_cache_path, "w") as fd:
78
- json.dump(name_lookup, fd, cls=lib.TaxonomicJsonEncoder)
79
 
80
- print("Starting.")
81
  model = create_model(model_str, output_dict=True, require_pretrained=True)
82
  model = model.to(device)
83
  print("Created model.")
 
5
  import argparse
6
  import csv
7
  import json
8
+ import os
9
 
10
  import numpy as np
11
  import torch
 
23
 
24
  @torch.no_grad()
25
  def write_txt_features(name_lookup):
26
+ if os.path.isfile(args.out_path):
27
+ all_features = np.load(args.out_path)
28
+ else:
29
+ all_features = np.zeros((512, len(name_lookup)), dtype=np.float32)
30
 
31
  batch_size = args.batch_size // len(openai_imagenet_template)
32
+ for batch, (names, indices) in enumerate(
33
+ tqdm(
34
+ lib.batched(name_lookup, batch_size),
35
+ desc="txt feats",
36
+ total=len(name_lookup) // batch_size,
37
+ )
38
+ ):
39
+ # Skip if any non-zero elements
40
+ if all_features[:, indices].any():
41
+ print(f"Skipping batch {batch}")
42
+ continue
43
+
44
+ txts = [
45
+ template(name) for name in names for template in openai_imagenet_template
46
+ ]
47
  txts = tokenizer(txts).to(device)
48
  txt_features = model.encode_text(txts)
49
+ txt_features = torch.reshape(
50
+ txt_features, (len(names), len(openai_imagenet_template), 512)
51
+ )
52
  txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
53
  txt_features /= txt_features.norm(dim=1, keepdim=True)
54
+ all_features[:, indices] = txt_features.T.cpu().numpy()
55
+
56
+ if batch % 100 == 0:
57
+ np.save(args.out_path, all_features)
58
 
59
+ np.save(args.out_path, all_features)
60
 
61
 
62
+ def get_name_lookup(catalog_path, cache_path):
63
+ if os.path.isfile(cache_path):
64
+ with open(cache_path) as fd:
65
+ lookup = lib.TaxonomicTree.from_dict(json.load(fd))
66
+ return lookup
67
+
68
  lookup = lib.TaxonomicTree()
69
 
70
  with open(catalog_path) as fd:
71
  reader = csv.DictReader(fd)
72
+ for row in tqdm(reader, desc="catalog"):
73
  name = [
74
  row["kingdom"],
75
  row["phylum"],
 
83
  name = name[: name.index("")]
84
  lookup.add(name)
85
 
86
+ with open(args.name_cache_path, "w") as fd:
87
+ json.dump(lookup, fd, cls=lib.TaxonomicJsonEncoder)
88
+
89
  return lookup
90
 
91
 
 
97
  required=True,
98
  )
99
  parser.add_argument("--out-path", help="Path to the output file.", required=True)
100
+ parser.add_argument(
101
+ "--name-cache-path",
102
+ help="Path to the name cache file.",
103
+ default="name_lookup.json",
104
+ )
105
+ parser.add_argument("--batch-size", help="Batch size.", default=2**15, type=int)
106
  args = parser.parse_args()
107
 
108
+ name_lookup = get_name_lookup(args.catalog_path, cache_path=args.name_cache_path)
109
+ print("Got name lookup.")
 
110
 
 
111
  model = create_model(model_str, output_dict=True, require_pretrained=True)
112
  model = model.to(device)
113
  print("Created model.")
txt_emb.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4a3c3412c3dae49cf92cc760aba5ee84227362adf1eb08f04dd50ee2a756e43
3
+ size 969818240