Samuel Stevens commited on
Commit
290c238
·
1 Parent(s): 5cfebb1

wip: hierarchical prediction

Browse files
Files changed (7) hide show
  1. README.md +0 -2
  2. app.py +42 -8
  3. embed_texts.sh +12 -0
  4. lib.py +122 -0
  5. make_txt_embedding.py +89 -0
  6. templates.py +80 -81
  7. test_lib.py +424 -0
README.md CHANGED
@@ -9,5 +9,3 @@ app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  license: mit
11
  ---
 
 
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  import torch
3
  import torch.nn.functional as F
@@ -6,9 +8,13 @@ from torchvision import transforms
6
 
7
  from templates import openai_imagenet_template
8
 
 
 
 
9
  model_str = "hf-hub:imageomics/bioclip"
10
  tokenizer_str = "ViT-B-16"
11
 
 
12
 
13
  preprocess_img = transforms.Compose(
14
  [
@@ -26,7 +32,7 @@ def get_txt_features(classnames, templates):
26
  all_features = []
27
  for classname in classnames:
28
  txts = [template(classname) for template in templates]
29
- txts = tokenizer(txts)
30
  txt_features = model.encode_text(txts)
31
  txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
32
  txt_features /= txt_features.norm()
@@ -36,22 +42,43 @@ def get_txt_features(classnames, templates):
36
 
37
 
38
  @torch.no_grad()
39
- def predict(img, cls_str: str) -> dict[str, float]:
40
- classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
41
  txt_features = get_txt_features(classes, openai_imagenet_template)
42
 
43
- img = preprocess_img(img)
44
-
45
  img_features = model.encode_image(img.unsqueeze(0))
46
  img_features = F.normalize(img_features, dim=-1)
 
47
  logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
48
- probs = F.softmax(logits, dim=0).tolist()
49
  return {cls: prob for cls, prob in zip(classes, probs)}
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  if __name__ == "__main__":
53
  print("Starting.")
54
  model = create_model(model_str, output_dict=True, require_pretrained=True)
 
55
  print("Created model.")
56
 
57
  model = torch.compile(model)
@@ -60,14 +87,21 @@ if __name__ == "__main__":
60
  tokenizer = get_tokenizer(tokenizer_str)
61
 
62
  demo = gr.Interface(
63
- fn=predict,
64
  inputs=[
65
  gr.Image(shape=(224, 224)),
66
  gr.Textbox(
67
- placeholder="dog\ncat\n...", lines=3, label="Classes", show_label=True
 
 
 
 
68
  ),
69
  ],
70
  outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
 
 
 
71
  )
72
 
73
  demo.launch()
 
1
+ import os
2
+
3
  import gradio as gr
4
  import torch
5
  import torch.nn.functional as F
 
8
 
9
  from templates import openai_imagenet_template
10
 
11
+ hf_token = os.getenv("HF_TOKEN")
12
+ hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo")
13
+
14
  model_str = "hf-hub:imageomics/bioclip"
15
  tokenizer_str = "ViT-B-16"
16
 
17
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
 
19
  preprocess_img = transforms.Compose(
20
  [
 
32
  all_features = []
33
  for classname in classnames:
34
  txts = [template(classname) for template in templates]
35
+ txts = tokenizer(txts).to(device)
36
  txt_features = model.encode_text(txts)
37
  txt_features = F.normalize(txt_features, dim=-1).mean(dim=0)
38
  txt_features /= txt_features.norm()
 
42
 
43
 
44
  @torch.no_grad()
45
+ def predict(img, classes: list[str]) -> dict[str, float]:
46
+ classes = [cls.strip() for cls in classes if cls.strip()]
47
  txt_features = get_txt_features(classes, openai_imagenet_template)
48
 
49
+ img = preprocess_img(img).to(device)
 
50
  img_features = model.encode_image(img.unsqueeze(0))
51
  img_features = F.normalize(img_features, dim=-1)
52
+
53
  logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze()
54
+ probs = F.softmax(logits, dim=0).to("cpu").tolist()
55
  return {cls: prob for cls, prob in zip(classes, probs)}
56
 
57
 
58
+ def hierarchical_predict(img) -> list[str]:
59
+ """
60
+ Predicts from the top of the tree of life down to the species.
61
+ """
62
+ img = preprocess_img(img).to(device)
63
+ img_features = model.encode_image(img.unsqueeze(0))
64
+ img_features = F.normalize(img_features, dim=-1)
65
+
66
+ breakpoint()
67
+
68
+
69
+ def run(img, cls_str: str) -> dict[str, float]:
70
+ breakpoint()
71
+ if cls_str:
72
+ classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()]
73
+ return predict(img, classes)
74
+ else:
75
+ return hierarchical_predict(img)
76
+
77
+
78
  if __name__ == "__main__":
79
  print("Starting.")
80
  model = create_model(model_str, output_dict=True, require_pretrained=True)
81
+ model = model.to(device)
82
  print("Created model.")
83
 
84
  model = torch.compile(model)
 
87
  tokenizer = get_tokenizer(tokenizer_str)
88
 
89
  demo = gr.Interface(
90
+ fn=run,
91
  inputs=[
92
  gr.Image(shape=(224, 224)),
93
  gr.Textbox(
94
+ placeholder="dog\ncat\n...",
95
+ lines=3,
96
+ label="Classes",
97
+ show_label=True,
98
+ info="If empty, will predict from the entire tree of life.",
99
  ),
100
  ],
101
  outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True),
102
+ allow_flagging="manual",
103
+ flagging_options=["Incorrect", "Other"],
104
+ flagging_callback=hf_writer,
105
  )
106
 
107
  demo.launch()
embed_texts.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #SBATCH --nodes=1
3
+ #SBATCH --account=PAS2136
4
+ #SBATCH --gpus-per-node=1
5
+ #SBATCH --ntasks-per-node=10
6
+ #SBATCH --job-name=embed-treeoflife
7
+ #SBATCH --time=12:00:00
8
+ #SBATCH --partition=gpu
9
+
10
+ python make_txt_embedding.py \
11
+ --catalog-path /fs/ess/PAS2136/open_clip/data/evobio10m-v3.3/predicted-statistics.csv \
12
+ --out-path text_emb.bin
lib.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import itertools
3
+
4
+
5
+ class TaxonomicNode:
6
+ __slots__ = ("name", "index", "root", "_children")
7
+
8
+ def __init__(self, name, index, root):
9
+ self.name = name
10
+ self.index = index
11
+ self.root = root
12
+ self._children = {}
13
+
14
+ def add(self, name):
15
+ added = 0
16
+ if not name:
17
+ return added
18
+
19
+ first, rest = name[0], name[1:]
20
+ if first not in self._children:
21
+ self._children[first] = TaxonomicNode(first, self.root.size, self.root)
22
+ self.root.size += 1
23
+
24
+ self._children[first].add(rest)
25
+
26
+ def children(self, name):
27
+ if not name:
28
+ return set((child.name, child.index) for child in self._children.values())
29
+
30
+ first, rest = name[0], name[1:]
31
+ if first not in self._children:
32
+ return set()
33
+
34
+ return self._children[first].children(rest)
35
+
36
+ def __iter__(self):
37
+ yield self.name, self.index
38
+
39
+ for child in self._children.values():
40
+ for name, index in child:
41
+ yield f"{self.name} {name}", index
42
+
43
+ @classmethod
44
+ def from_dict(cls, dct, root):
45
+ node = cls(dct["name"], dct["index"], root)
46
+ node._children = {child["name"]: cls.from_dict(child, root) for child in dct["children"]}
47
+ return node
48
+
49
+
50
+
51
+ class TaxonomicTree:
52
+ """
53
+ Efficient structure for finding taxonomic names and their descendants.
54
+ Also returns an integer index i for each possible name.
55
+ """
56
+
57
+ def __init__(self):
58
+ self.kingdoms = {}
59
+ self.size = 0
60
+
61
+ def add(self, name: list[str]):
62
+ if not name:
63
+ return
64
+
65
+ first, rest = name[0], name[1:]
66
+ if first not in self.kingdoms:
67
+ self.kingdoms[first] = TaxonomicNode(first, self.size, self)
68
+ self.size += 1
69
+
70
+ self.kingdoms[first].add(rest)
71
+
72
+ def children(self, name=None):
73
+ if not name:
74
+ return set(
75
+ (kingdom.name, kingdom.index) for kingdom in self.kingdoms.values()
76
+ )
77
+
78
+ first, rest = name[0], name[1:]
79
+ if first not in self.kingdoms:
80
+ return set()
81
+
82
+ return self.kingdoms[first].children(rest)
83
+
84
+ def __iter__(self):
85
+ for kingdom in self.kingdoms.values():
86
+ yield from kingdom
87
+
88
+ @classmethod
89
+ def from_dict(cls, dct):
90
+ tree = cls()
91
+ tree.kingdoms = {
92
+ kingdom["name"]: TaxonomicNode.from_dict(kingdom, tree) for kingdom in dct["kingdoms"]
93
+ }
94
+ tree.size = dct["size"]
95
+ return tree
96
+
97
+
98
+ class TaxonomicJsonEncoder(json.JSONEncoder):
99
+ def default(self, obj):
100
+ if isinstance(obj, TaxonomicNode):
101
+ return {
102
+ "name": obj.name,
103
+ "index": obj.index,
104
+ "children": list(obj._children.values()),
105
+ }
106
+ elif isinstance(obj, TaxonomicTree):
107
+ return {
108
+ "kingdoms": list(obj.kingdoms.values()),
109
+ "size": obj.size,
110
+ }
111
+ else:
112
+ super().default(self, obj)
113
+
114
+
115
+
116
+ def batched(iterable, n):
117
+ # batched('ABCDEFG', 3) --> ABC DEF G
118
+ if n < 1:
119
+ raise ValueError('n must be at least one')
120
+ it = iter(iterable)
121
+ while batch := tuple(itertools.islice(it, n)):
122
+ yield zip(*batch)
make_txt_embedding.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Makes the entire set of text emebeddings for all possible names in the tree of life.
3
+ Uses the catalog.csv file from TreeOfLife-10M.
4
+ """
5
+ import argparse
6
+ import csv
7
+ import json
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from open_clip import create_model, get_tokenizer
13
+ from tqdm import tqdm
14
+
15
+ import lib
16
+ from templates import openai_imagenet_template
17
+
18
+ model_str = "hf-hub:imageomics/bioclip"
19
+ tokenizer_str = "ViT-B-16"
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+
22
+
23
+ @torch.no_grad()
24
+ def write_txt_features(name_lookup):
25
+ all_features = np.memmap(
26
+ args.out_path, dtype=np.float32, mode="w+", shape=(512, name_lookup.size)
27
+ )
28
+
29
+ batch_size = args.batch_size // len(openai_imagenet_template)
30
+ for names, indices in tqdm(lib.batched(name_lookup, batch_size)):
31
+ txts = [template(name) for name in names for template in openai_imagenet_template]
32
+ txts = tokenizer(txts).to(device)
33
+ txt_features = model.encode_text(txts)
34
+ txt_features = torch.reshape(txt_features, (batch_size, len(openai_imagenet_template), 512))
35
+ txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
36
+ txt_features /= txt_features.norm(dim=1, keepdim=True)
37
+ all_features[:, indices] = txt_features.cpu().numpy().T
38
+
39
+ all_features.flush()
40
+
41
+
42
+ def get_name_lookup(catalog_path):
43
+ lookup = lib.TaxonomicTree()
44
+
45
+ with open(catalog_path) as fd:
46
+ reader = csv.DictReader(fd)
47
+ for row in tqdm(reader):
48
+ name = [
49
+ row["kingdom"],
50
+ row["phylum"],
51
+ row["class"],
52
+ row["order"],
53
+ row["family"],
54
+ row["genus"],
55
+ row["species"],
56
+ ]
57
+ if any(not value for value in name):
58
+ name = name[: name.index("")]
59
+ lookup.add(name)
60
+
61
+ return lookup
62
+
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser()
66
+ parser.add_argument(
67
+ "--catalog-path",
68
+ help="Path to the catalog.csv file from TreeOfLife-10M.",
69
+ required=True,
70
+ )
71
+ parser.add_argument("--out-path", help="Path to the output file.", required=True)
72
+ parser.add_argument("--name-cache-path", help="Path to the name cache file.", default=".name_lookup_cache.json")
73
+ parser.add_argument("--batch-size", help="Batch size.", default=2 ** 15, type=int)
74
+ args = parser.parse_args()
75
+
76
+ name_lookup = get_name_lookup(args.catalog_path)
77
+ with open(args.name_cache_path, "w") as fd:
78
+ json.dump(name_lookup, fd, cls=lib.TaxonomicJsonEncoder)
79
+
80
+ print("Starting.")
81
+ model = create_model(model_str, output_dict=True, require_pretrained=True)
82
+ model = model.to(device)
83
+ print("Created model.")
84
+
85
+ model = torch.compile(model)
86
+ print("Compiled model.")
87
+
88
+ tokenizer = get_tokenizer(tokenizer_str)
89
+ write_txt_features(name_lookup)
templates.py CHANGED
@@ -1,83 +1,82 @@
1
  openai_imagenet_template = [
2
- lambda c: f'a bad photo of a {c}.',
3
- lambda c: f'a photo of many {c}.',
4
- lambda c: f'a sculpture of a {c}.',
5
- lambda c: f'a photo of the hard to see {c}.',
6
- lambda c: f'a low resolution photo of the {c}.',
7
- lambda c: f'a rendering of a {c}.',
8
- lambda c: f'graffiti of a {c}.',
9
- lambda c: f'a bad photo of the {c}.',
10
- lambda c: f'a cropped photo of the {c}.',
11
- lambda c: f'a tattoo of a {c}.',
12
- lambda c: f'the embroidered {c}.',
13
- lambda c: f'a photo of a hard to see {c}.',
14
- lambda c: f'a bright photo of a {c}.',
15
- lambda c: f'a photo of a clean {c}.',
16
- lambda c: f'a photo of a dirty {c}.',
17
- lambda c: f'a dark photo of the {c}.',
18
- lambda c: f'a drawing of a {c}.',
19
- lambda c: f'a photo of my {c}.',
20
- lambda c: f'the plastic {c}.',
21
- lambda c: f'a photo of the cool {c}.',
22
- lambda c: f'a close-up photo of a {c}.',
23
- lambda c: f'a black and white photo of the {c}.',
24
- lambda c: f'a painting of the {c}.',
25
- lambda c: f'a painting of a {c}.',
26
- lambda c: f'a pixelated photo of the {c}.',
27
- lambda c: f'a sculpture of the {c}.',
28
- lambda c: f'a bright photo of the {c}.',
29
- lambda c: f'a cropped photo of a {c}.',
30
- lambda c: f'a plastic {c}.',
31
- lambda c: f'a photo of the dirty {c}.',
32
- lambda c: f'a jpeg corrupted photo of a {c}.',
33
- lambda c: f'a blurry photo of the {c}.',
34
- lambda c: f'a photo of the {c}.',
35
- lambda c: f'a good photo of the {c}.',
36
- lambda c: f'a rendering of the {c}.',
37
- lambda c: f'a {c} in a video game.',
38
- lambda c: f'a photo of one {c}.',
39
- lambda c: f'a doodle of a {c}.',
40
- lambda c: f'a close-up photo of the {c}.',
41
- lambda c: f'a photo of a {c}.',
42
- lambda c: f'the origami {c}.',
43
- lambda c: f'the {c} in a video game.',
44
- lambda c: f'a sketch of a {c}.',
45
- lambda c: f'a doodle of the {c}.',
46
- lambda c: f'a origami {c}.',
47
- lambda c: f'a low resolution photo of a {c}.',
48
- lambda c: f'the toy {c}.',
49
- lambda c: f'a rendition of the {c}.',
50
- lambda c: f'a photo of the clean {c}.',
51
- lambda c: f'a photo of a large {c}.',
52
- lambda c: f'a rendition of a {c}.',
53
- lambda c: f'a photo of a nice {c}.',
54
- lambda c: f'a photo of a weird {c}.',
55
- lambda c: f'a blurry photo of a {c}.',
56
- lambda c: f'a cartoon {c}.',
57
- lambda c: f'art of a {c}.',
58
- lambda c: f'a sketch of the {c}.',
59
- lambda c: f'a embroidered {c}.',
60
- lambda c: f'a pixelated photo of a {c}.',
61
- lambda c: f'itap of the {c}.',
62
- lambda c: f'a jpeg corrupted photo of the {c}.',
63
- lambda c: f'a good photo of a {c}.',
64
- lambda c: f'a plushie {c}.',
65
- lambda c: f'a photo of the nice {c}.',
66
- lambda c: f'a photo of the small {c}.',
67
- lambda c: f'a photo of the weird {c}.',
68
- lambda c: f'the cartoon {c}.',
69
- lambda c: f'art of the {c}.',
70
- lambda c: f'a drawing of the {c}.',
71
- lambda c: f'a photo of the large {c}.',
72
- lambda c: f'a black and white photo of a {c}.',
73
- lambda c: f'the plushie {c}.',
74
- lambda c: f'a dark photo of a {c}.',
75
- lambda c: f'itap of a {c}.',
76
- lambda c: f'graffiti of the {c}.',
77
- lambda c: f'a toy {c}.',
78
- lambda c: f'itap of my {c}.',
79
- lambda c: f'a photo of a cool {c}.',
80
- lambda c: f'a photo of a small {c}.',
81
- lambda c: f'a tattoo of the {c}.',
82
  ]
83
-
 
1
  openai_imagenet_template = [
2
+ lambda c: f"a bad photo of a {c}.",
3
+ lambda c: f"a photo of many {c}.",
4
+ lambda c: f"a sculpture of a {c}.",
5
+ lambda c: f"a photo of the hard to see {c}.",
6
+ lambda c: f"a low resolution photo of the {c}.",
7
+ lambda c: f"a rendering of a {c}.",
8
+ lambda c: f"graffiti of a {c}.",
9
+ lambda c: f"a bad photo of the {c}.",
10
+ lambda c: f"a cropped photo of the {c}.",
11
+ lambda c: f"a tattoo of a {c}.",
12
+ lambda c: f"the embroidered {c}.",
13
+ lambda c: f"a photo of a hard to see {c}.",
14
+ lambda c: f"a bright photo of a {c}.",
15
+ lambda c: f"a photo of a clean {c}.",
16
+ lambda c: f"a photo of a dirty {c}.",
17
+ lambda c: f"a dark photo of the {c}.",
18
+ lambda c: f"a drawing of a {c}.",
19
+ lambda c: f"a photo of my {c}.",
20
+ lambda c: f"the plastic {c}.",
21
+ lambda c: f"a photo of the cool {c}.",
22
+ lambda c: f"a close-up photo of a {c}.",
23
+ lambda c: f"a black and white photo of the {c}.",
24
+ lambda c: f"a painting of the {c}.",
25
+ lambda c: f"a painting of a {c}.",
26
+ lambda c: f"a pixelated photo of the {c}.",
27
+ lambda c: f"a sculpture of the {c}.",
28
+ lambda c: f"a bright photo of the {c}.",
29
+ lambda c: f"a cropped photo of a {c}.",
30
+ lambda c: f"a plastic {c}.",
31
+ lambda c: f"a photo of the dirty {c}.",
32
+ lambda c: f"a jpeg corrupted photo of a {c}.",
33
+ lambda c: f"a blurry photo of the {c}.",
34
+ lambda c: f"a photo of the {c}.",
35
+ lambda c: f"a good photo of the {c}.",
36
+ lambda c: f"a rendering of the {c}.",
37
+ lambda c: f"a {c} in a video game.",
38
+ lambda c: f"a photo of one {c}.",
39
+ lambda c: f"a doodle of a {c}.",
40
+ lambda c: f"a close-up photo of the {c}.",
41
+ lambda c: f"a photo of a {c}.",
42
+ lambda c: f"the origami {c}.",
43
+ lambda c: f"the {c} in a video game.",
44
+ lambda c: f"a sketch of a {c}.",
45
+ lambda c: f"a doodle of the {c}.",
46
+ lambda c: f"a origami {c}.",
47
+ lambda c: f"a low resolution photo of a {c}.",
48
+ lambda c: f"the toy {c}.",
49
+ lambda c: f"a rendition of the {c}.",
50
+ lambda c: f"a photo of the clean {c}.",
51
+ lambda c: f"a photo of a large {c}.",
52
+ lambda c: f"a rendition of a {c}.",
53
+ lambda c: f"a photo of a nice {c}.",
54
+ lambda c: f"a photo of a weird {c}.",
55
+ lambda c: f"a blurry photo of a {c}.",
56
+ lambda c: f"a cartoon {c}.",
57
+ lambda c: f"art of a {c}.",
58
+ lambda c: f"a sketch of the {c}.",
59
+ lambda c: f"a embroidered {c}.",
60
+ lambda c: f"a pixelated photo of a {c}.",
61
+ lambda c: f"itap of the {c}.",
62
+ lambda c: f"a jpeg corrupted photo of the {c}.",
63
+ lambda c: f"a good photo of a {c}.",
64
+ lambda c: f"a plushie {c}.",
65
+ lambda c: f"a photo of the nice {c}.",
66
+ lambda c: f"a photo of the small {c}.",
67
+ lambda c: f"a photo of the weird {c}.",
68
+ lambda c: f"the cartoon {c}.",
69
+ lambda c: f"art of the {c}.",
70
+ lambda c: f"a drawing of the {c}.",
71
+ lambda c: f"a photo of the large {c}.",
72
+ lambda c: f"a black and white photo of a {c}.",
73
+ lambda c: f"the plushie {c}.",
74
+ lambda c: f"a dark photo of a {c}.",
75
+ lambda c: f"itap of a {c}.",
76
+ lambda c: f"graffiti of the {c}.",
77
+ lambda c: f"a toy {c}.",
78
+ lambda c: f"itap of my {c}.",
79
+ lambda c: f"a photo of a cool {c}.",
80
+ lambda c: f"a photo of a small {c}.",
81
+ lambda c: f"a tattoo of the {c}.",
82
  ]
 
test_lib.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lib
2
+
3
+
4
+ def test_taxonomiclookup_empty():
5
+ lookup = lib.TaxonomicTree()
6
+ assert lookup.size == 0
7
+
8
+
9
+ def test_taxonomiclookup_kingdom_size():
10
+ lookup = lib.TaxonomicTree()
11
+
12
+ lookup.add(("Animalia",))
13
+
14
+ assert lookup.size == 1
15
+
16
+
17
+ def test_taxonomiclookup_genus_size():
18
+ lookup = lib.TaxonomicTree()
19
+
20
+ lookup.add(
21
+ (
22
+ "Animalia",
23
+ "Chordata",
24
+ "Aves",
25
+ "Accipitriformes",
26
+ "Accipitridae",
27
+ "Halieaeetus",
28
+ )
29
+ )
30
+
31
+ assert lookup.size == 6
32
+
33
+
34
+ def test_taxonomictree_kingdom_children():
35
+ lookup = lib.TaxonomicTree()
36
+
37
+ lookup.add(
38
+ (
39
+ "Animalia",
40
+ "Chordata",
41
+ "Aves",
42
+ "Accipitriformes",
43
+ "Accipitridae",
44
+ "Halieaeetus",
45
+ )
46
+ )
47
+
48
+ expected = set([("Animalia", 0)])
49
+ actual = lookup.children()
50
+ assert actual == expected
51
+
52
+
53
+ def test_taxonomiclookup_children_of_animal_only_birds():
54
+ lookup = lib.TaxonomicTree()
55
+
56
+ lookup.add(
57
+ (
58
+ "Animalia",
59
+ "Chordata",
60
+ "Aves",
61
+ "Accipitriformes",
62
+ "Accipitridae",
63
+ "Halieaeetus",
64
+ "leucocephalus",
65
+ )
66
+ )
67
+ lookup.add(
68
+ (
69
+ "Animalia",
70
+ "Chordata",
71
+ "Aves",
72
+ "Strigiformes",
73
+ "Strigidae",
74
+ "Ninox",
75
+ "scutulata",
76
+ )
77
+ )
78
+ lookup.add(
79
+ (
80
+ "Animalia",
81
+ "Chordata",
82
+ "Aves",
83
+ "Strigiformes",
84
+ "Strigidae",
85
+ "Ninox",
86
+ "plesseni",
87
+ )
88
+ )
89
+
90
+ actual = lookup.children(("Animalia",))
91
+ expected = set([("Chordata", 1)])
92
+ assert actual == expected
93
+
94
+
95
+ def test_taxonomiclookup_children_of_animal():
96
+ lookup = lib.TaxonomicTree()
97
+
98
+ lookup.add(
99
+ (
100
+ "Animalia",
101
+ "Chordata",
102
+ "Aves",
103
+ "Accipitriformes",
104
+ "Accipitridae",
105
+ "Halieaeetus",
106
+ "leucocephalus",
107
+ )
108
+ )
109
+ lookup.add(
110
+ (
111
+ "Animalia",
112
+ "Chordata",
113
+ "Aves",
114
+ "Strigiformes",
115
+ "Strigidae",
116
+ "Ninox",
117
+ "scutulata",
118
+ )
119
+ )
120
+ lookup.add(
121
+ (
122
+ "Animalia",
123
+ "Chordata",
124
+ "Aves",
125
+ "Strigiformes",
126
+ "Strigidae",
127
+ "Ninox",
128
+ "plesseni",
129
+ )
130
+ )
131
+ lookup.add(
132
+ (
133
+ "Animalia",
134
+ "Chordata",
135
+ "Mammalia",
136
+ "Primates",
137
+ "Hominidae",
138
+ "Gorilla",
139
+ "gorilla",
140
+ )
141
+ )
142
+ lookup.add(
143
+ (
144
+ "Animalia",
145
+ "Arthropoda",
146
+ "Insecta",
147
+ "Hymenoptera",
148
+ "Apidae",
149
+ "Bombus",
150
+ "balteatus",
151
+ )
152
+ )
153
+
154
+ actual = lookup.children(("Animalia",))
155
+ expected = set([("Chordata", 1), ("Arthropoda", 17)])
156
+ assert actual == expected
157
+
158
+
159
+ def test_taxonomiclookup_children_of_chordata():
160
+ lookup = lib.TaxonomicTree()
161
+
162
+ lookup.add(
163
+ (
164
+ "Animalia",
165
+ "Chordata",
166
+ "Aves",
167
+ "Accipitriformes",
168
+ "Accipitridae",
169
+ "Halieaeetus",
170
+ "leucocephalus",
171
+ )
172
+ )
173
+ lookup.add(
174
+ (
175
+ "Animalia",
176
+ "Chordata",
177
+ "Aves",
178
+ "Strigiformes",
179
+ "Strigidae",
180
+ "Ninox",
181
+ "scutulata",
182
+ )
183
+ )
184
+ lookup.add(
185
+ (
186
+ "Animalia",
187
+ "Chordata",
188
+ "Aves",
189
+ "Strigiformes",
190
+ "Strigidae",
191
+ "Ninox",
192
+ "plesseni",
193
+ )
194
+ )
195
+ lookup.add(
196
+ (
197
+ "Animalia",
198
+ "Chordata",
199
+ "Mammalia",
200
+ "Primates",
201
+ "Hominidae",
202
+ "Gorilla",
203
+ "gorilla",
204
+ )
205
+ )
206
+ lookup.add(
207
+ (
208
+ "Animalia",
209
+ "Arthropoda",
210
+ "Insecta",
211
+ "Hymenoptera",
212
+ "Apidae",
213
+ "Bombus",
214
+ "balteatus",
215
+ )
216
+ )
217
+
218
+ actual = lookup.children(("Animalia", "Chordata"))
219
+ expected = set([("Aves", 2), ("Mammalia", 12)])
220
+ assert actual == expected
221
+
222
+
223
+ def test_taxonomiclookup_children_of_strigiformes():
224
+ lookup = lib.TaxonomicTree()
225
+
226
+ lookup.add(
227
+ (
228
+ "Animalia",
229
+ "Chordata",
230
+ "Aves",
231
+ "Accipitriformes",
232
+ "Accipitridae",
233
+ "Halieaeetus",
234
+ "leucocephalus",
235
+ )
236
+ )
237
+ lookup.add(
238
+ (
239
+ "Animalia",
240
+ "Chordata",
241
+ "Aves",
242
+ "Strigiformes",
243
+ "Strigidae",
244
+ "Ninox",
245
+ "scutulata",
246
+ )
247
+ )
248
+ lookup.add(
249
+ (
250
+ "Animalia",
251
+ "Chordata",
252
+ "Aves",
253
+ "Strigiformes",
254
+ "Strigidae",
255
+ "Ninox",
256
+ "plesseni",
257
+ )
258
+ )
259
+ lookup.add(
260
+ (
261
+ "Animalia",
262
+ "Chordata",
263
+ "Mammalia",
264
+ "Primates",
265
+ "Hominidae",
266
+ "Gorilla",
267
+ "gorilla",
268
+ )
269
+ )
270
+ lookup.add(
271
+ (
272
+ "Animalia",
273
+ "Arthropoda",
274
+ "Insecta",
275
+ "Hymenoptera",
276
+ "Apidae",
277
+ "Bombus",
278
+ "balteatus",
279
+ )
280
+ )
281
+
282
+ actual = lookup.children(("Animalia", "Chordata", "Aves", "Strigiformes"))
283
+ expected = set([("Strigidae", 8)])
284
+ assert actual == expected
285
+
286
+
287
+ def test_taxonomiclookup_children_of_ninox():
288
+ lookup = lib.TaxonomicTree()
289
+
290
+ lookup.add(
291
+ (
292
+ "Animalia",
293
+ "Chordata",
294
+ "Aves",
295
+ "Accipitriformes",
296
+ "Accipitridae",
297
+ "Halieaeetus",
298
+ "leucocephalus",
299
+ )
300
+ )
301
+ lookup.add(
302
+ (
303
+ "Animalia",
304
+ "Chordata",
305
+ "Aves",
306
+ "Strigiformes",
307
+ "Strigidae",
308
+ "Ninox",
309
+ "scutulata",
310
+ )
311
+ )
312
+ lookup.add(
313
+ (
314
+ "Animalia",
315
+ "Chordata",
316
+ "Aves",
317
+ "Strigiformes",
318
+ "Strigidae",
319
+ "Ninox",
320
+ "plesseni",
321
+ )
322
+ )
323
+ lookup.add(
324
+ (
325
+ "Animalia",
326
+ "Chordata",
327
+ "Mammalia",
328
+ "Primates",
329
+ "Hominidae",
330
+ "Gorilla",
331
+ "gorilla",
332
+ )
333
+ )
334
+ lookup.add(
335
+ (
336
+ "Animalia",
337
+ "Arthropoda",
338
+ "Insecta",
339
+ "Hymenoptera",
340
+ "Apidae",
341
+ "Bombus",
342
+ "balteatus",
343
+ )
344
+ )
345
+
346
+ actual = lookup.children(
347
+ ("Animalia", "Chordata", "Aves", "Strigiformes", "Strigidae", "Ninox")
348
+ )
349
+ expected = set([("scutulata", 10), ("plesseni", 11)])
350
+ assert actual == expected
351
+
352
+
353
+ def test_taxonomiclookup_children_of_gorilla():
354
+ lookup = lib.TaxonomicTree()
355
+
356
+ lookup.add(
357
+ (
358
+ "Animalia",
359
+ "Chordata",
360
+ "Aves",
361
+ "Accipitriformes",
362
+ "Accipitridae",
363
+ "Halieaeetus",
364
+ "leucocephalus",
365
+ )
366
+ )
367
+ lookup.add(
368
+ (
369
+ "Animalia",
370
+ "Chordata",
371
+ "Aves",
372
+ "Strigiformes",
373
+ "Strigidae",
374
+ "Ninox",
375
+ "scutulata",
376
+ )
377
+ )
378
+ lookup.add(
379
+ (
380
+ "Animalia",
381
+ "Chordata",
382
+ "Aves",
383
+ "Strigiformes",
384
+ "Strigidae",
385
+ "Ninox",
386
+ "plesseni",
387
+ )
388
+ )
389
+ lookup.add(
390
+ (
391
+ "Animalia",
392
+ "Chordata",
393
+ "Mammalia",
394
+ "Primates",
395
+ "Hominidae",
396
+ "Gorilla",
397
+ "gorilla",
398
+ )
399
+ )
400
+ lookup.add(
401
+ (
402
+ "Animalia",
403
+ "Arthropoda",
404
+ "Insecta",
405
+ "Hymenoptera",
406
+ "Apidae",
407
+ "Bombus",
408
+ "balteatus",
409
+ )
410
+ )
411
+
412
+ actual = lookup.children(
413
+ (
414
+ "Animalia",
415
+ "Chordata",
416
+ "Mammalia",
417
+ "Primates",
418
+ "Hominidae",
419
+ "Gorilla",
420
+ "gorilla",
421
+ )
422
+ )
423
+ expected = set()
424
+ assert actual == expected