Ceyda Cinarel commited on
Commit
c8f27cf
Β·
1 Parent(s): 9a639cc

cached demo

Browse files
.gitattributes CHANGED
@@ -29,3 +29,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zstandard filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
29
  *.zip filter=lfs diff=lfs merge=lfs -text
30
  *.zstandard filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.faiss filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: Fashion Classification
3
- emoji: πŸ“Š
4
- colorFrom: pink
5
- colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
- app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Fashion Classification
3
+ emoji: πŸ‘
4
+ colorFrom: gray
5
+ colorTo: pink
6
  sdk: streamlit
7
  sdk_version: 1.10.0
8
+ app_file: data_analysis_app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
data_analysis_app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import streamlit as st
3
+
4
+ from data_utils import get_embedding
5
+
6
+ from bokeh.plotting import figure,show
7
+ from bokeh.io import push_notebook, output_notebook
8
+ # output_notebook()
9
+ from bokeh.palettes import d3
10
+
11
+ from bokeh.models import ColumnDataSource, Grid, LinearAxis, Plot, Scatter
12
+ from bokeh.transform import factor_cmap, factor_mark
13
+ import base64
14
+ from io import BytesIO
15
+
16
+ label_columns=["gender","subCategory","masterCategory"]
17
+
18
+ model_interest=['facebook/deit-tiny-patch16-224', # very small model 5M param model
19
+ 'microsoft/beit-base-patch16-224', # big model
20
+ "facebook/dino-vits8",
21
+ "facebook/levit-128S"]
22
+
23
+ def convert_base64(img):
24
+ buffered = BytesIO()
25
+ img.save(buffered, format="JPEG")
26
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
27
+ return "data:image/jpeg;base64,"+img_str
28
+
29
+ @st.experimental_singleton
30
+ def cache_embedding(model_name):
31
+ dataset=load_dataset("ceyda/fashion-products-small", split="train")
32
+ dataset=dataset.shuffle(seed=100) #pick a random seed
33
+ viz_dat=dataset.train_test_split(0.1,shuffle=False) #일뢀λ₯Ό visualizationμœ„ν•΄μ„œ λ½‘μ‹œλ‹¨
34
+ viz_dat=viz_dat["test"]
35
+ embedding = get_embedding(model_name,viz_dat)
36
+ embedding["image"]=embedding["image"].apply(convert_base64)
37
+ labels = {label:viz_dat.unique(label) for label in label_columns}
38
+ return embedding,labels
39
+
40
+ @st.experimental_singleton
41
+ def cache_graph(model_name,color_column):
42
+ embedding,labels=cache_embedding(model_name)
43
+
44
+ color_palette = (d3['Category20'][20]+d3['Category20b'][20]+d3['Category20c'][20])[:len(labels[color_column])]
45
+ source = ColumnDataSource(data=embedding)
46
+ # colors = factor_cmap('gender', palette=["purple","navy","green","blue","pink"], factors=embedding["gender"].unique())
47
+
48
+
49
+ TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,reset,tap,save,box_select,lasso_select,"
50
+ TOOLTIPS = """
51
+ <div>
52
+ <div>
53
+ <img
54
+ src="@image" height="42" alt="@image" width="42"
55
+ style="float: left; margin: 0px 15px 15px 0px;"
56
+ border="2"
57
+ ></img>
58
+ </div>
59
+ """
60
+ p = figure(tools=TOOLS,tooltips=TOOLTIPS)
61
+
62
+ p.scatter(x="x", y="y", source=source,
63
+ # marker=factor_mark('gender', ['circle', 'circle_cross', 'circle_dot','circle_x','circle_y'], labels["gender"]),
64
+ color=factor_cmap(color_column, color_palette, labels[color_column])
65
+ )
66
+
67
+ return p
68
+
69
+
70
+ model_name=st.sidebar.selectbox("Model",model_interest)
71
+ color_column=st.selectbox("Color by",label_columns)
72
+ p=cache_graph(model_name,color_column)
73
+ st.bokeh_chart(p, use_container_width=False)
data_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from PIL import Image
3
+ import os
4
+ import pandas as pd
5
+ from transformers import AutoFeatureExtractor,AutoModel
6
+ from faiss.contrib.inspect_tools import get_flat_data
7
+ import pymde
8
+ import numpy as np
9
+
10
+ def get_embedding(model_name,viz_dat):
11
+
12
+ index_file=f"./indexes/{model_name.split('/')[1]}.faiss"
13
+
14
+ if os.path.exists(index_file):
15
+ viz_dat.load_faiss_index('embeddings', index_file)
16
+ else:
17
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
18
+ model = AutoModel.from_pretrained(model_name)
19
+ # model.to("cuda:0")
20
+ def embed(x):
21
+ images=x["image"]
22
+ inputs = feature_extractor(images=images, return_tensors="pt")
23
+ # inputs.to("cuda:0")
24
+ outputs = model(**inputs,output_hidden_states= True)
25
+ final_emb=outputs.pooler_output.detach().cpu().numpy() # this line depends on the model you are using
26
+ x["embeddings"]=final_emb
27
+ return x
28
+ # Add embeddings to dataset
29
+ viz_dat = viz_dat.map(embed,batched=True,batch_size=20)
30
+ viz_dat.add_faiss_index(column='embeddings')
31
+ viz_dat.save_faiss_index('embeddings',index_file)
32
+
33
+ embedding_file=f"./indexes/{model_name.split('/')[1]}.npy"
34
+ if os.path.exists(embedding_file):
35
+ embedding = np.load(embedding_file) # load
36
+ else:
37
+ index=viz_dat.get_index("embeddings").faiss_index
38
+ embeddings=get_flat_data(index)
39
+ embedding=pymde.preserve_neighbors(embeddings, verbose=True).embed().numpy()
40
+ np.save(embedding_file, embedding) # save
41
+
42
+ embedding=pd.DataFrame(embedding,columns=["x","y"])
43
+ embedding["image"]=viz_dat["image"]
44
+ embedding["gender"]=viz_dat["gender"]
45
+ embedding["masterCategory"]=viz_dat["masterCategory"]
46
+ embedding["subCategory"]=viz_dat["subCategory"]
47
+
48
+ return embedding
indexes/beit-base-patch16-224.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eef8b38b055d02f2dbfd6ac5c2cbba0cb670ef225285ba375e4baf63034ef589
3
+ size 13117485
indexes/beit-base-patch16-224.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b42f5b1e9a149b412a120a63999ea7a5d89438aef5651e200acbdcb226e75418
3
+ size 34288
indexes/deit-tiny-patch16-224.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5df0a6edbaf0b746bffc2ff740d783a4fa6be961dedece7ec069a37e1688341a
3
+ size 3279405
indexes/deit-tiny-patch16-224.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5bb0e5186ea3125d548838806af53cffc3fbb53ed752e06a7af9b2487c997df
3
+ size 34288
indexes/dino-vits8.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81965f420ca2c66ad55ad0e3300bbce2d5fd366b2c9e5e3ae48ecc88e4a00d97
3
+ size 6558765
indexes/dino-vits8.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1499f7501e455d088c5621512fd94a819fd26940271782cbf047db9e49359df0
3
+ size 34288
indexes/levit-128S.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1d9fec9e282812e613e312c036f8998b7177f6d19f787b61a5cd9810916ba8e
3
+ size 6558765
indexes/levit-128S.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8757326676da484d4aa542eacc54f74fb8c8c743697fde4d3dd0dd44e2eeb1f9
3
+ size 34288
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # didn't pin versions
2
+ datasets
3
+ transformers
4
+ pymde
5
+ bokeh==2.4.1
6
+ faiss-cpu