File size: 2,583 Bytes
3ae84a3
 
 
78cdedf
3e5aff4
3ae84a3
 
 
dc86583
9b28e54
3ae84a3
 
0d69242
 
3ae84a3
 
0d69242
78cdedf
3ae84a3
 
0d69242
9b28e54
 
78cdedf
0d69242
c45624f
19cb4eb
c45624f
8fd2365
 
19cb4eb
 
3e5aff4
 
 
 
 
 
 
 
 
 
 
 
3ae84a3
 
19cb4eb
 
 
 
2e6b143
 
19cb4eb
3e5aff4
 
2e6b143
a6893ba
19cb4eb
3e5aff4
c3eae7d
 
 
 
 
 
 
 
3e5aff4
19cb4eb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pickle
import gradio as gr
from datasets import load_dataset
from transformers import AutoModel, AutoFeatureExtractor
import wikipedia


# Only runs once when the script is first run.
with open("butts_1024_new.pickle", "rb") as handle:
    index = pickle.load(handle)

# Load model for computing embeddings.
feature_extractor = AutoFeatureExtractor.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")
model = AutoModel.from_pretrained("sasha/autotrain-butterfly-similarity-2490576840")

# Candidate images.
dataset = load_dataset("sasha/butterflies_10k_names_multiple")
ds = dataset["train"]


def query(image, top_k=4):
    inputs = feature_extractor(image, return_tensors="pt")
    model_output = model(**inputs)
    embedding = model_output.pooler_output.detach()
    results = index.query(embedding, k=top_k)
    inx = results[0][0].tolist()
    logits = results[1][0].tolist()
    images = ds.select(inx)["image"]
    captions = ds.select(inx)["name"]
    images_with_captions = [(i, c) for i, c in zip(images,captions)]
    labels_with_probs = dict(zip(captions,logits))
    labels_with_probs = {k: 1- v for k, v in labels_with_probs.items()}
    try:
    	description = wikipedia.summary(captions[0], sentences = 1)
    	description = "### " + description
    	url = wikipedia.page(captions[0]).url
    	url = " You can learn more about your butterfly [here](" + str(url) + ")!"
    	description = description + url
    except:
    	description = "### Butterflies are insects in the order Lepidoptera, which also includes moths. Adult butterflies have large, often brightly coloured wings."
    	url = "https://en.wikipedia.org/wiki/Butterfly"
    	url = " You can learn more about butterflies [here](" + str(url) + ")!"
    	description = description + url
    return images_with_captions, labels_with_probs, description


with gr.Blocks() as demo:
	gr.Markdown("# Find my Butterfly 🦋")
	gr.Markdown("## Use this Space to find your butterfly, based on the [iNaturalist butterfly dataset](https://huggingface.co./datasets/huggan/inat_butterflies_top10k)!")
	with gr.Row():
		with gr.Column(scale=1):
			inputs = gr.Image()
			btn = gr.Button("Find my butterfly!")
			description = gr.Markdown()
			
		with gr.Column(scale=2):
			outputs=gr.Gallery(rows=2)
			labels = gr.Label()
			
	gr.Markdown("### Image Examples")
	gr.Examples(
	examples=["elton.jpg", "ken.jpg", "gaga.jpg", "taylor.jpg"],
	inputs=inputs,
	outputs=[outputs,labels],
	fn=query,
	cache_examples=True,
	)
	btn.click(query, inputs, [outputs, labels, description])
	
demo.launch()