Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -30,25 +30,22 @@ def load_openshape(name, to_cpu=False):
|
|
30 |
pce = pce.cpu()
|
31 |
return pce
|
32 |
|
33 |
-
def retrieval_filter_expand(
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
anim_n
|
45 |
-
face_n
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
and (tag_n or tag in x['tags'])
|
50 |
-
)
|
51 |
-
return sim_th, filter_fn
|
52 |
|
53 |
def retrieval_results(results):
|
54 |
st.caption("Click the link to view the 3D shape")
|
@@ -148,32 +145,125 @@ def demo_retrieval():
|
|
148 |
|
149 |
prog.progress(1.0, "Idle")
|
150 |
|
151 |
-
st.title("TripletMix Demo")
|
152 |
-
st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.")
|
153 |
-
prog = st.progress(0.0, "Idle")
|
154 |
-
tab_cls, tab_pc, tab_img, tab_text, tab_sd, tab_cap = st.tabs([
|
155 |
-
"Classification",
|
156 |
-
"Retrieval w/ 3D",
|
157 |
-
"Retrieval w/ Image",
|
158 |
-
"Retrieval w/ Text",
|
159 |
-
"Image Generation",
|
160 |
-
"Captioning",
|
161 |
-
])
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
|
169 |
try:
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
except Exception:
|
178 |
import traceback
|
179 |
st.error(traceback.format_exc().replace("\n", " \n"))
|
|
|
30 |
pce = pce.cpu()
|
31 |
return pce
|
32 |
|
33 |
+
def retrieval_filter_expand():
|
34 |
+
sim_th = st.sidebar.slider("Similarity Threshold", 0.05, 0.5, 0.1, key='rsimth')
|
35 |
+
tag = ""
|
36 |
+
face_min = 0
|
37 |
+
face_max = 34985808
|
38 |
+
anim_min = 0
|
39 |
+
anim_max = 563
|
40 |
+
tag_n = not bool(tag.strip())
|
41 |
+
anim_n = not (anim_min > 0 or anim_max < 563)
|
42 |
+
face_n = not (face_min > 0 or face_max < 34985808)
|
43 |
+
filter_fn = lambda x: (
|
44 |
+
(anim_n or anim_min <= x['anims'] <= anim_max)
|
45 |
+
and (face_n or face_min <= x['faces'] <= face_max)
|
46 |
+
and (tag_n or tag in x['tags'])
|
47 |
+
)
|
48 |
+
return sim_th, filter_fn
|
|
|
|
|
|
|
49 |
|
50 |
def retrieval_results(results):
|
51 |
st.caption("Click the link to view the 3D shape")
|
|
|
145 |
|
146 |
prog.progress(1.0, "Idle")
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
def retrieval_pc(load_data, k, sim_th, filter_fn):
|
150 |
+
pc = load_data(prog)
|
151 |
+
prog.progress(0.49, "Computing Embeddings")
|
152 |
+
col2 = utils.render_pc(pc)
|
153 |
+
ref_dev = next(model_g14.parameters()).device
|
154 |
+
enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
|
155 |
+
|
156 |
+
sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc, dim=-1).squeeze())
|
157 |
+
argsort = torch.argsort(sim, descending=True)
|
158 |
+
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
|
159 |
+
with col2:
|
160 |
+
for i, (cat, sim) in zip(range(5), pred.items()):
|
161 |
+
st.text(cat)
|
162 |
+
st.caption("Similarity %.4f" % sim)
|
163 |
+
|
164 |
+
prog.progress(0.7, "Running Retrieval")
|
165 |
+
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
|
166 |
+
|
167 |
+
prog.progress(1.0, "Idle")
|
168 |
+
|
169 |
+
def retrieval_img(pic, k, sim_th, filter_fn):
|
170 |
+
img = Image.open(pic)
|
171 |
+
prog.progress(0.49, "Computing Embeddings")
|
172 |
+
st.image(img)
|
173 |
+
device = clip_model.device
|
174 |
+
tn = clip_prep(images=[img], return_tensors="pt").to(device)
|
175 |
+
enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
|
176 |
+
|
177 |
+
prog.progress(0.7, "Running Retrieval")
|
178 |
+
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
|
179 |
+
|
180 |
+
prog.progress(1.0, "Idle")
|
181 |
+
|
182 |
+
def retrieval_text(text, k, sim_th, filter_fn):
|
183 |
+
prog.progress(0.49, "Computing Embeddings")
|
184 |
+
device = clip_model.device
|
185 |
+
tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
|
186 |
+
enc = clip_model.get_text_features(**tn).float().cpu()
|
187 |
|
188 |
+
prog.progress(0.7, "Running Retrieval")
|
189 |
+
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
|
190 |
+
|
191 |
+
prog.progress(1.0, "Idle")
|
192 |
|
193 |
try:
|
194 |
+
f32 = numpy.float32
|
195 |
+
half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
|
196 |
+
clip_model, clip_prep = load_openclip()
|
197 |
+
model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
|
198 |
+
|
199 |
+
st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
|
200 |
+
st.sidebar.title("TripletMix Demo Configuration Panel")
|
201 |
+
task = st.sidebar.selectbox(
|
202 |
+
'Task Selection',
|
203 |
+
("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
|
204 |
+
)
|
205 |
+
|
206 |
+
if task == "3D Classification":
|
207 |
+
cls_mode = st.sidebar.selectbox(
|
208 |
+
'Choose the source of categories',
|
209 |
+
("LVIS Categories", "Custom Categories")
|
210 |
+
)
|
211 |
+
pc = st.sidebar.text_input("Input pc", key='rtextinput')
|
212 |
+
if cls_mode == "LVIS Categories":
|
213 |
+
if st.sidebar.button("submit"):
|
214 |
+
st.title("Classification with LVIS Categories")
|
215 |
+
prog = st.progress(0.0, "Idle")
|
216 |
+
|
217 |
+
elif cls_mode == "Custom Categories":
|
218 |
+
cats = st.sidebar.text_input("Custom Categories (64 max, separated with comma)")
|
219 |
+
cats = [a.strip() for a in cats.split(',')]
|
220 |
+
if len(cats) > 64:
|
221 |
+
st.error('Maximum 64 custom categories supported in the demo')
|
222 |
+
if st.sidebar.button("submit"):
|
223 |
+
st.title("Classification with Custom Categories")
|
224 |
+
prog = st.progress(0.0, "Idle")
|
225 |
+
|
226 |
+
elif task == "Cross-modal retrieval":
|
227 |
+
input_mode = st.sidebar.selectbox(
|
228 |
+
'Choose an input modality',
|
229 |
+
("Point Cloud", "Image", "Text")
|
230 |
+
)
|
231 |
+
k = st.sidebar.slider("Number of items to retrieve", 1, 100, 16, key='rnum')
|
232 |
+
sim_th, filter_fn = retrieval_filter_expand()
|
233 |
+
if input_mode == "Point Cloud":
|
234 |
+
load_data = utils.input_3d_shape('rpcinput')
|
235 |
+
if st.sidebar.button("submit"):
|
236 |
+
st.title("Retrieval with Point Cloud")
|
237 |
+
prog = st.progress(0.0, "Idle")
|
238 |
+
retrieval_pc(load_data, k, sim_th, filter_fn)
|
239 |
+
elif input_mode == "Image":
|
240 |
+
pic = st.sidebar.file_uploader("Upload an Image", key='rimageinput')
|
241 |
+
if st.sidebar.button("submit"):
|
242 |
+
st.title("Retrieval with Image")
|
243 |
+
prog = st.progress(0.0, "Idle")
|
244 |
+
retrieval_img(pic, k, sim_th, filter_fn)
|
245 |
+
elif input_mode == "Text":
|
246 |
+
text = st.sidebar.text_input("Input Text", key='rtextinput')
|
247 |
+
if st.sidebar.button("submit"):
|
248 |
+
st.title("Retrieval with Text")
|
249 |
+
prog = st.progress(0.0, "Idle")
|
250 |
+
retrieval_text(text, k, sim_th, filter_fn)
|
251 |
+
elif task == "Cross-modal generation":
|
252 |
+
generation_mode = st.sidebar.selectbox(
|
253 |
+
'Choose the mode of generation',
|
254 |
+
("PointCloud-to-Image", "PointCloud-to-Text")
|
255 |
+
)
|
256 |
+
pc = st.sidebar.text_input("Input pc", key='rtextinput')
|
257 |
+
if generation_mode == "PointCloud-to-Image":
|
258 |
+
if st.sidebar.button("submit"):
|
259 |
+
st.title("Image Generation")
|
260 |
+
prog = st.progress(0.0, "Idle")
|
261 |
+
|
262 |
+
elif generation_mode == "PointCloud-to-Text":
|
263 |
+
if st.sidebar.button("submit"):
|
264 |
+
st.title("Text Generation")
|
265 |
+
prog = st.progress(0.0, "Idle")
|
266 |
+
|
267 |
except Exception:
|
268 |
import traceback
|
269 |
st.error(traceback.format_exc().replace("\n", " \n"))
|