Spaces:
Sleeping
Sleeping
import dataclasses | |
import os | |
import hydra | |
import numpy as np | |
import torch | |
from flask import Flask, jsonify, request, render_template | |
from flask_cors import CORS | |
from omegaconf import OmegaConf | |
from safetensors.torch import load_model | |
from scipy.spatial.transform import Rotation | |
from point_sam import build_point_sam | |
import argparse | |
app = Flask(__name__, static_folder="static") | |
CORS(app) | |
MAX_POINT_ID = 100 | |
point_info_id = 0 | |
point_info_list = [None for _ in range(MAX_POINT_ID)] | |
class AuxInputs: | |
coords: torch.Tensor | |
features: torch.Tensor | |
centers: torch.Tensor | |
interp_index: torch.Tensor = None | |
interp_weight: torch.Tensor = None | |
def repeat_interleave(x: torch.Tensor, repeats: int, dim: int): | |
if repeats == 1: | |
return x | |
shape = list(x.shape) | |
shape.insert(dim + 1, 1) | |
shape[dim + 1] = repeats | |
x = x.unsqueeze(dim + 1).expand(shape).flatten(dim, dim + 1) | |
return x | |
class PointCloudProcessor: | |
def __init__(self, device="cuda", batch=True, return_tensors="pt"): | |
self.device = device | |
self.batch = batch | |
self.return_tensors = return_tensors | |
self.center = None | |
self.scale = None | |
def __call__(self, xyz: np.ndarray, rgb: np.ndarray): | |
# # The original data is z-up. Make it y-up. | |
# rot = Rotation.from_euler("x", -90, degrees=True) | |
# xyz = rot.apply(xyz) | |
if self.center is None or self.scale is None: | |
self.center = xyz.mean(0) | |
self.scale = np.max(np.linalg.norm(xyz - self.center, axis=-1)) | |
xyz = (xyz - self.center) / self.scale | |
rgb = ((rgb / 255.0) - 0.5) * 2 | |
if self.return_tensors == "np": | |
coords = np.float32(xyz) | |
feats = np.float32(rgb) | |
if self.batch: | |
coords = np.expand_dims(coords, 0) | |
feats = np.expand_dims(feats, 0) | |
elif self.return_tensors == "pt": | |
coords = torch.tensor(xyz, dtype=torch.float32, device=self.device) | |
feats = torch.tensor(rgb, dtype=torch.float32, device=self.device) | |
if self.batch: | |
coords = coords.unsqueeze(0) | |
feats = feats.unsqueeze(0) | |
else: | |
raise ValueError(self.return_tensors) | |
return coords, feats | |
def normalize(self, xyz): | |
return (xyz - self.center) / self.scale | |
class PointCloudSAMPredictor: | |
input_xyz: np.ndarray | |
input_rgb: np.ndarray | |
prompt_coords: list[tuple[float, float, float]] | |
prompt_labels: list[int] | |
coords: torch.Tensor | |
feats: torch.Tensor | |
pc_embedding: torch.Tensor | |
patches: dict[str, torch.Tensor] | |
prompt_mask: torch.Tensor | |
def __init__(self): | |
print("Created model") | |
model = build_point_sam("./model-2.safetensors") | |
model.pc_encoder.patch_embed.grouper.num_groups = 1024 | |
model.pc_encoder.patch_embed.grouper.group_size = 128 | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
model.eval() | |
self.model = model | |
self.input_rgb = None | |
self.input_xyz = None | |
self.input_processor = None | |
self.coords = None | |
self.feats = None | |
self.pc_embedding = None | |
self.patches = None | |
self.prompt_coords = None | |
self.prompt_labels = None | |
self.prompt_mask = None | |
self.candidate_index = 0 | |
def set_pointcloud(self, xyz, rgb): | |
self.input_xyz = xyz | |
self.input_rgb = rgb | |
self.input_processor = PointCloudProcessor() | |
coords, feats = self.input_processor(xyz, rgb) | |
self.coords = coords | |
self.feats = feats | |
pc_embedding, patches = self.model.pc_encoder(self.coords, self.feats) | |
self.pc_embedding = pc_embedding | |
self.patches = patches | |
self.prompt_mask = None | |
def set_prompts(self, prompt_coords, prompt_labels): | |
self.prompt_coords = prompt_coords | |
self.prompt_labels = prompt_labels | |
def predict_mask(self): | |
normalized_prompt_coords = self.input_processor.normalize( | |
np.array(self.prompt_coords) | |
) | |
prompt_coords = torch.tensor( | |
normalized_prompt_coords, dtype=torch.float32, device="cuda" | |
) | |
prompt_labels = torch.tensor( | |
self.prompt_labels, dtype=torch.bool, device="cuda" | |
) | |
prompt_coords = prompt_coords.reshape(1, -1, 3) | |
prompt_labels = prompt_labels.reshape(1, -1) | |
multimask_output = prompt_coords.shape[1] == 1 | |
# [B * M, num_outputs, num_points], [B * M, num_outputs] | |
def decode_masks(coords, feats, pc_embedding, patches, prompt_coords, prompt_labels, prompt_masks, multimask_output): | |
pc_embeddings, patches = pc_embedding, patches | |
centers = patches["centers"] | |
knn_idx = patches["knn_idx"] | |
coords = patches["coords"] | |
feats = patches["feats"] | |
aux_inputs = AuxInputs(coords=coords, features=feats, centers=centers) | |
pc_pe = self.model.point_encoder.pe_layer(centers) | |
sparse_embeddings = self.model.point_encoder(prompt_coords, prompt_labels) | |
dense_embeddings = self.model.mask_encoder(prompt_masks, coords, centers, knn_idx) | |
dense_embeddings = repeat_interleave( | |
dense_embeddings, sparse_embeddings.shape[0] // dense_embeddings.shape[0], 0 | |
) | |
logits, iou_preds = self.model.mask_decoder( | |
pc_embeddings, | |
pc_pe, | |
sparse_embeddings, | |
dense_embeddings, | |
aux_inputs=aux_inputs, | |
multimask_output=multimask_output, | |
) | |
return logits, iou_preds | |
logits, scores = decode_masks( | |
self.coords, | |
self.feats, | |
self.pc_embedding, | |
self.patches, | |
prompt_coords, | |
prompt_labels, | |
self.prompt_mask[self.candidate_index].unsqueeze(0) if self.prompt_mask is not None else None, | |
multimask_output, | |
) | |
logits = logits.squeeze(0) | |
scores = scores.squeeze(0) | |
# if multimask_output: | |
# index = scores.argmax(0).item() | |
# logit = logits[index] | |
# else: | |
# logit = logits.squeeze(0) | |
# self.prompt_mask = logit.unsqueeze(0) | |
# pred_mask = logit > 0 | |
# return pred_mask.cpu().numpy() | |
# Sort according to scores | |
_, indices = scores.sort(descending=True) | |
logits = logits[indices] | |
self.prompt_mask = logits # [num_outputs, num_points] | |
self.candidate_index = 0 | |
return (logits > 0).cpu().numpy() | |
def set_candidate(self, index): | |
self.candidate_index = index | |
predictor = PointCloudSAMPredictor() | |
def index(): | |
return app.send_static_file("index.html") | |
def assets_route(path): | |
print(path) | |
return app.send_static_file(f"assets/{path}") | |
def hello_world(): | |
return "Hello, World!" | |
def set_pointcloud(): | |
request_data = request.get_json() | |
# print(request_data) | |
# print(type(request_data["points"])) | |
# print(type(request_data["colors"])) | |
xyz = request_data["points"] | |
xyz = np.array(xyz).reshape(-1, 3) | |
rgb = request_data["colors"] | |
rgb = np.array(list(rgb)).reshape(-1, 3) | |
predictor.set_pointcloud(xyz, rgb) | |
pc_embedding = predictor.pc_embedding.cpu() | |
patches = {"centers": predictor.patches["centers"].cpu(), "knn_idx": predictor.patches["knn_idx"].cpu(), "coords": predictor.coords.cpu(), "feats": predictor.feats.cpu()} | |
center = predictor.input_processor.center | |
scale = predictor.input_processor.scale | |
global point_info_id | |
global point_info_list | |
point_info_list[point_info_id] = {"pc_embedding": pc_embedding, "patches": patches, "center": center, "scale": scale, "prompt_mask": None} | |
return_msg = {"user_id": point_info_id} | |
point_info_id += 1 | |
return jsonify(return_msg) | |
def set_candidate(): | |
request_data = request.get_json() | |
candidate_index = request_data["index"] | |
predictor.set_candidate(candidate_index) | |
return "success" | |
def visualize_pcd_with_prompts(xyz, rgb, prompt_coords, prompt_labels): | |
import trimesh | |
pcd = trimesh.PointCloud(xyz, rgb) | |
prompt_spheres = [] | |
for i, coord in enumerate(prompt_coords): | |
sphere = trimesh.creation.icosphere() | |
sphere.apply_scale(0.02) | |
sphere.apply_translation(coord) | |
sphere.visual.vertex_colors = [255, 0, 0] if prompt_labels[i] else [0, 255, 0] | |
prompt_spheres.append(sphere) | |
return trimesh.Scene([pcd] + prompt_spheres) | |
def set_prompts(): | |
global point_info_list | |
request_data = request.get_json() | |
print(request_data.keys()) | |
# [n_prompts, 3] | |
prompt_coords = request_data["prompt_coords"] | |
# [n_prompts]. 0 for negative, 1 for positive | |
prompt_labels = request_data["prompt_labels"] | |
user_id = request_data["user_id"] | |
print(user_id) | |
point_info = point_info_list[user_id] | |
predictor.pc_embedding = point_info["pc_embedding"].cuda() | |
patches = point_info["patches"] | |
predictor.patches = {"centers": patches["centers"].cuda(), "knn_idx": patches["knn_idx"].cuda(), "coords": patches["coords"].cuda(), "feats": patches["feats"].cuda()} | |
predictor.input_processor.center = point_info["center"] | |
predictor.input_processor.scale = point_info["scale"] | |
if point_info["prompt_mask"] is not None: | |
predictor.prompt_mask = point_info["prompt_mask"].cuda() | |
else: | |
predictor.prompt_mask = None | |
# instance_id = request_data["instance_id"] # int | |
if len(prompt_coords) == 0: | |
predictor.prompt_mask = None | |
pred_mask = np.zeros([len(prompt_coords)], dtype=np.bool_) | |
return jsonify({"mask": pred_mask.tolist()}) | |
predictor.set_prompts(prompt_coords, prompt_labels) | |
pred_mask = predictor.predict_mask() | |
point_info_list[user_id]["prompt_mask"] = predictor.prompt_mask.cpu() | |
# # Visualize | |
# xyz = predictor.coords.cpu().numpy()[0] | |
# rgb = predictor.feats.cpu().numpy()[0] * 0.5 + 0.5 | |
# prompt_coords = predictor.input_processor.normalize(np.array(predictor.prompt_coords)) | |
# scene = visualize_pcd_with_prompts(xyz, rgb, prompt_coords, predictor.prompt_labels) | |
# scene.show() | |
return jsonify({"mask": pred_mask.tolist()}) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default=7860) | |
args = parser.parse_args() | |
app.run(host=args.host, port=args.port, debug=True) | |