jhj0517
rollback pytoshop because it doesn't work in huggingface
607e627
raw
history blame
2.7 kB
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import os
import torch
from modules.mask_utils import *
from modules.model_downloader import *
class SamInference:
def __init__(self):
self.model = None
self.model_path = f"models/sam_vit_h_4b8939.pth"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.mask_generator = None
# Tuable Parameters , All default values
self.tunable_params = {
'points_per_side': 32,
'pred_iou_thresh': 0.88,
'stability_score_thresh': 0.95,
'crop_n_layers': 0,
'crop_n_points_downscale_factor': 1,
'min_mask_region_area': 0
}
def set_mask_generator(self):
print("applying configs to model..")
if not os.path.exists(self.model_path):
print("No needed SAM model detected. downloading VIT H SAM model....")
download_sam_model_url()
self.model = sam_model_registry["default"](checkpoint=self.model_path)
self.model.to(device=self.device)
self.mask_generator = SamAutomaticMaskGenerator(
self.model,
points_per_side=self.tunable_params['points_per_side'],
pred_iou_thresh=self.tunable_params['pred_iou_thresh'],
stability_score_thresh=self.tunable_params['stability_score_thresh'],
crop_n_layers=self.tunable_params['crop_n_layers'],
crop_n_points_downscale_factor=self.tunable_params['crop_n_points_downscale_factor'],
min_mask_region_area=self.tunable_params['min_mask_region_area'],
output_mode="coco_rle",
)
def generate_mask(self, image):
return [self.mask_generator.generate(image)]
def generate_mask_app(self, image, *params):
tunable_params = {
'points_per_side': int(params[0]),
'pred_iou_thresh': float(params[1]),
'stability_score_thresh': float(params[2]),
'crop_n_layers': int(params[3]),
'crop_n_points_downscale_factor': int(params[4]),
'min_mask_region_area': int(params[5]),
}
try:
if self.model is None or self.mask_generator is None or self.tunable_params != tunable_params:
self.tunable_params = tunable_params
self.set_mask_generator()
masks = self.mask_generator.generate(image)
combined_image = create_mask_combined_images(image, masks)
gallery = create_mask_gallery(image, masks)
return [combined_image] + gallery
except Exception as e:
print(e)