tripletmix-demo / app.py
winfred2027's picture
Update app.py
fb3c1ff verified
import sys
import threading
import streamlit as st
import numpy
import torch
import openshape
import transformers
from PIL import Image
from huggingface_hub import HfFolder, snapshot_download
from demo_support import retrieval, generation, utils, lvis
from collections import OrderedDict
@st.cache_resource
def load_openclip():
sys.clip_move_lock = threading.Lock()
clip_model, clip_prep = transformers.CLIPModel.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
low_cpu_mem_usage=True, torch_dtype=half,
offload_state_dict=True
), transformers.CLIPProcessor.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
if torch.cuda.is_available():
with sys.clip_move_lock:
clip_model.cuda()
return clip_model, clip_prep
@st.cache_resource
def load_openshape(name, to_cpu=False):
pce = openshape.load_pc_encoder(name)
if to_cpu:
pce = pce.cpu()
return pce
def load_tripletmix(name, to_cpu=False):
pce = openshape.load_pc_encoder_mix(name)
if to_cpu:
pce = pce.cpu()
return pce
def retrieval_filter_expand():
sim_th = st.sidebar.slider("Similarity Threshold", 0.05, 0.5, 0.1, key='rsimth')
tag = ""
face_min = 0
face_max = 34985808
anim_min = 0
anim_max = 563
tag_n = not bool(tag.strip())
anim_n = not (anim_min > 0 or anim_max < 563)
face_n = not (face_min > 0 or face_max < 34985808)
filter_fn = lambda x: (
(anim_n or anim_min <= x['anims'] <= anim_max)
and (face_n or face_min <= x['faces'] <= face_max)
and (tag_n or tag in x['tags'])
)
return sim_th, filter_fn
def retrieval_results(results):
st.caption("Click the link to view the 3D shape")
for i in range(len(results) // 4):
cols = st.columns(4)
for j in range(4):
idx = i * 4 + j
if idx >= len(results):
continue
entry = results[idx]
with cols[j]:
ext_link = f"https://objaverse.allenai.org/explore/?query={entry['u']}"
st.image(entry['img'])
# st.markdown(f"[![thumbnail {entry['desc'].replace('\n', ' ')}]({entry['img']})]({ext_link})")
# st.text(entry['name'])
quote_name = entry['name'].replace('[', '\\[').replace(']', '\\]').replace('\n', ' ')
st.markdown(f"[{quote_name}]({ext_link})")
def classification_lvis(load_data):
pc = load_data(prog)
col2 = utils.render_pc(pc)
prog.progress(0.5, "Running Classification")
ref_dev = next(model_classification.parameters()).device
enc = model_classification(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
argsort = torch.argsort(sim, descending=True)
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
with col2:
for i, (cat, sim) in zip(range(5), pred.items()):
st.text(cat)
st.caption("Similarity %.4f" % sim)
prog.progress(1.0, "Idle")
def classification_custom(load_data, cats):
pc = load_data(prog)
col2 = utils.render_pc(pc)
prog.progress(0.5, "Computing Category Embeddings")
device = clip_model.device
tn = clip_prep(text=cats, return_tensors='pt', truncation=True, max_length=76, padding=True).to(device)
feats = clip_model.get_text_features(**tn).float().cpu()
prog.progress(0.5, "Running Classification")
ref_dev = next(model_classification.parameters()).device
enc = model_classification(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
sim = torch.matmul(torch.nn.functional.normalize(feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
argsort = torch.argsort(sim, descending=True)
pred = OrderedDict((cats[i], sim[i]) for i in argsort if i < len(cats))
with col2:
for i, (cat, sim) in zip(range(5), pred.items()):
st.text(cat)
st.caption("Similarity %.4f" % sim)
prog.progress(1.0, "Idle")
def retrieval_pc(load_data, k, sim_th, filter_fn):
pc = load_data(prog)
prog.progress(0.5, "Computing Embeddings")
col2 = utils.render_pc(pc)
ref_dev = next(model_retrieval.parameters()).device
enc = model_retrieval(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev))
sim = torch.matmul(torch.nn.functional.normalize(lvis.feats, dim=-1), torch.nn.functional.normalize(enc.cpu(), dim=-1).squeeze())
argsort = torch.argsort(sim, descending=True)
pred = OrderedDict((lvis.categories[i], sim[i]) for i in argsort if i < len(lvis.categories))
with col2:
for i, (cat, sim) in zip(range(5), pred.items()):
st.text(cat)
st.caption("Similarity %.4f" % sim)
prog.progress(0.7, "Running Retrieval")
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
prog.progress(1.0, "Idle")
def retrieval_img(pic, k, sim_th, filter_fn):
img = Image.open(pic)
prog.progress(0.5, "Computing Embeddings")
st.image(img)
device = clip_model.device
tn = clip_prep(images=[img], return_tensors="pt").to(device)
enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
prog.progress(0.7, "Running Retrieval")
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
prog.progress(1.0, "Idle")
def retrieval_text(text, k, sim_th, filter_fn):
prog.progress(0.5, "Computing Embeddings")
device = clip_model.device
tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
enc = clip_model.get_text_features(**tn).float().cpu()
prog.progress(0.7, "Running Retrieval")
retrieval_results(retrieval.retrieve(enc, k, sim_th, filter_fn))
prog.progress(1.0, "Idle")
def generation_img(load_data, prompt, noise_scale, cfg_scale, steps):
pc = load_data(prog)
prog.progress(0.5, "Running Generation")
col2 = utils.render_pc(pc)
if torch.cuda.is_available():
with sys.clip_move_lock:
clip_model.cpu()
width = 640
height = 640
img = generation.pc_to_image(
model_g14, pc, prompt, noise_scale, width, height, cfg_scale, steps,
lambda i, t, _: prog.progress(0.49 + i / (steps + 1) / 2, "Running Diffusion Step %d" % i)
)
if torch.cuda.is_available():
with sys.clip_move_lock:
clip_model.cuda()
with col2:
st.image(img)
prog.progress(1.0, "Idle")
def generation_text(load_data, cond_scale):
pc = load_data(prog)
prog.progress(0.5, "Running Generation")
col2 = utils.render_pc(pc)
cap = generation.pc_to_text(model_g14, pc, cond_scale)
st.text(cap)
prog.progress(1.0, "Idle")
try:
f32 = numpy.float32
half = torch.float16 if torch.cuda.is_available() else torch.bfloat16
clip_model, clip_prep = load_openclip()
#model_g14 = load_openshape('openshape-pointbert-vitg14-rgb')
#model_g14 = load_tripletmix('tripletmix-spconv-all')
st.caption("This demo presents three tasks: 3D classification, cross-modal retrieval, and cross-modal generation. Examples are provided for demonstration purposes. You're encouraged to fine-tune task parameters and upload files for customized testing as required.")
st.sidebar.title("TripletMix Demo Configuration Panel")
task = st.sidebar.selectbox(
'Task Selection',
("3D Classification", "Cross-modal retrieval", "Cross-modal generation")
)
if task == "3D Classification":
cls_mode = st.sidebar.selectbox(
'Choose the source of categories',
("LVIS Categories", "Custom Categories")
)
model_name = st.sidebar.selectbox(
'Model Selection',
("pb-Mix", "pb")
)
if model_name == "pb-Mix":
model_classification = load_tripletmix('tripletmix-pointbert-all-modelnet40')
elif model_name == "pb":
model_classification = load_openshape('openshape-pointbert-vitg14-rgb')
load_data = utils.input_3d_shape('rpcinput')
if cls_mode == "LVIS Categories":
st.title("Classification with LVIS Categories")
prog = st.progress(0.0, "Idle")
if st.sidebar.button("submit"):
classification_lvis(load_data)
elif cls_mode == "Custom Categories":
st.title("Classification with Custom Categories")
prog = st.progress(0.0, "Idle")
cats = st.sidebar.text_input("Custom Categories (64 max, separated with comma)")
cats = [a.strip() for a in cats.split(',')]
if len(cats) > 64:
st.error('Maximum 64 custom categories supported in the demo')
if st.sidebar.button("submit"):
classification_custom(load_data, cats)
elif task == "Cross-modal retrieval":
#model_retrieval = load_tripletmix('tripletmix-pointbert-all-objaverse')
model_name = st.sidebar.selectbox(
'Model Selection',
("pb-Mix", "pb")
)
if model_name == "pb-Mix":
model_retrieval = load_tripletmix('tripletmix-pointbert-all-objaverse')
elif model_name == "pb":
model_retrieval = load_openshape('openshape-pointbert-vitg14-rgb')
input_mode = st.sidebar.selectbox(
'Choose an input modality',
("Point Cloud", "Image", "Text")
)
k = st.sidebar.slider("Number of items to retrieve", 1, 100, 16, key='rnum')
sim_th, filter_fn = retrieval_filter_expand()
if input_mode == "Point Cloud":
st.title("Retrieval with Point Cloud")
prog = st.progress(0.0, "Idle")
load_data = utils.input_3d_shape('rpcinput')
if st.sidebar.button("submit"):
retrieval_pc(load_data, k, sim_th, filter_fn)
elif input_mode == "Image":
st.title("Retrieval with Image")
prog = st.progress(0.0, "Idle")
pic = st.sidebar.file_uploader("Upload an Image", key='rimageinput')
if st.sidebar.button("submit"):
retrieval_img(pic, k, sim_th, filter_fn)
elif input_mode == "Text":
st.title("Retrieval with Text")
prog = st.progress(0.0, "Idle")
text = st.sidebar.text_input("Input Text", key='rtextinput')
if st.sidebar.button("submit"):
retrieval_text(text, k, sim_th, filter_fn)
elif task == "Cross-modal generation":
generation_mode = st.sidebar.selectbox(
'Choose the mode of generation',
("PointCloud-to-Image", "PointCloud-to-Text")
)
load_data = utils.input_3d_shape('rpcinput')
if generation_mode == "PointCloud-to-Image":
st.title("Image Generation")
prog = st.progress(0.0, "Idle")
prompt = st.sidebar.text_input("Prompt (Optional)", key='gprompt')
noise_scale = st.sidebar.slider('Variation Level', 0, 5, 1)
cfg_scale = st.sidebar.slider('Guidance Scale', 0.0, 30.0, 10.0)
steps = st.sidebar.slider('Diffusion Steps', 8, 50, 25)
if st.sidebar.button("submit"):
generation_img(load_data, prompt, noise_scale, cfg_scale, steps)
elif generation_mode == "PointCloud-to-Text":
st.title("Text Generation")
prog = st.progress(0.0, "Idle")
cond_scale = st.sidebar.slider('Conditioning Scale', 0.0, 4.0, 2.0, 0.1, key='gcond')
if st.sidebar.button("submit"):
generation_text(load_data, cond_scale)
except Exception:
import traceback
st.error(traceback.format_exc().replace("\n", " \n"))