Spaces:
Build error
Build error
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
import os | |
import torch | |
from modules.mask_utils import * | |
from modules.model_downloader import * | |
class SamInference: | |
def __init__(self): | |
self.model = None | |
self.model_path = f"models/sam_vit_h_4b8939.pth" | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.mask_generator = None | |
# Tuable Parameters , All default values | |
self.tunable_params = { | |
'points_per_side': 32, | |
'pred_iou_thresh': 0.88, | |
'stability_score_thresh': 0.95, | |
'crop_n_layers': 0, | |
'crop_n_points_downscale_factor': 1, | |
'min_mask_region_area': 0 | |
} | |
def set_mask_generator(self): | |
print("applying configs to model..") | |
if not os.path.exists(self.model_path): | |
print("No needed SAM model detected. downloading VIT H SAM model....") | |
download_sam_model_url() | |
self.model = sam_model_registry["default"](checkpoint=self.model_path) | |
self.model.to(device=self.device) | |
self.mask_generator = SamAutomaticMaskGenerator( | |
self.model, | |
points_per_side=self.tunable_params['points_per_side'], | |
pred_iou_thresh=self.tunable_params['pred_iou_thresh'], | |
stability_score_thresh=self.tunable_params['stability_score_thresh'], | |
crop_n_layers=self.tunable_params['crop_n_layers'], | |
crop_n_points_downscale_factor=self.tunable_params['crop_n_points_downscale_factor'], | |
min_mask_region_area=self.tunable_params['min_mask_region_area'], | |
output_mode="coco_rle", | |
) | |
def generate_mask(self, image): | |
return [self.mask_generator.generate(image)] | |
def generate_mask_app(self, image, *params): | |
tunable_params = { | |
'points_per_side': int(params[0]), | |
'pred_iou_thresh': float(params[1]), | |
'stability_score_thresh': float(params[2]), | |
'crop_n_layers': int(params[3]), | |
'crop_n_points_downscale_factor': int(params[4]), | |
'min_mask_region_area': int(params[5]), | |
} | |
try: | |
if self.model is None or self.mask_generator is None or self.tunable_params != tunable_params: | |
self.tunable_params = tunable_params | |
self.set_mask_generator() | |
masks = self.mask_generator.generate(image) | |
combined_image = create_mask_combined_images(image, masks) | |
gallery = create_mask_gallery(image, masks) | |
return [combined_image] + gallery | |
except Exception as e: | |
print(e) |