Spaces:
Runtime error
Runtime error
Upload app.py
Browse filesInitial commit of app
app.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#install huggingface_hub["fastai"] gradio timm
|
2 |
+
from huggingface_hub import from_pretrained_fastai
|
3 |
+
from gradio import Interface, inputs, outputs
|
4 |
+
from fastai.learner import Learner
|
5 |
+
import fastai
|
6 |
+
|
7 |
+
repo_id = "Kieranm/britishmus_plate_material_classifier"
|
8 |
+
|
9 |
+
learner = from_pretrained_fastai(repo_id)
|
10 |
+
|
11 |
+
mappings = {
|
12 |
+
fastai.torch_core.TensorImage: {
|
13 |
+
"type": inputs.IMage(type='file', label='input'),
|
14 |
+
"process": lambda inp : inp.name
|
15 |
+
},
|
16 |
+
fastai.torch_core.TensorCategory: {
|
17 |
+
"type": outputs.Label(num_top_classes=3, label = 'output'),
|
18 |
+
"process": lambda dls, out: {dls.vocab[i]: float(out[2][i]) for i in range(len(dls.vocab))}
|
19 |
+
|
20 |
+
}
|
21 |
+
}
|
22 |
+
|
23 |
+
#Taken from fastgradio library
|
24 |
+
|
25 |
+
class Demo:
|
26 |
+
def __init__(self, learner):
|
27 |
+
|
28 |
+
self.learner = learner
|
29 |
+
self.types = getattr(self.learner.dls, '_types')[tuple]
|
30 |
+
|
31 |
+
def learner_predict(self, inp):
|
32 |
+
inp = mappings[self.types[0]]["process"](inp)
|
33 |
+
prediction = self.learner.predict(inp)
|
34 |
+
output = mappings[self.types[1]]["process"](self.learner.dls, prediction)
|
35 |
+
return output
|
36 |
+
|
37 |
+
def launch(self, share=True, debug=False, auth=None, **kwargs):
|
38 |
+
inputs = mappings[self.types[0]]["type"]
|
39 |
+
|
40 |
+
outputs = mappings[self.types[1]]["type"]
|
41 |
+
|
42 |
+
Interface(fn=self.learner_predict, inputs=inputs, outputs=outputs,
|
43 |
+
examples = ["examples/earthen1.jpg", "examples/earthen2.png", "porcelain1.png", "porcelain2.png"],
|
44 |
+
**kwargs).launch(share=share, debug=debug, auth=auth)
|
45 |
+
|
46 |
+
|
47 |
+
Demo(learner).launch()
|