import torch import numpy as np from torchvision import transforms from fbrs.inference import clicker from fbrs.inference.predictors import get_predictor from fbrs.utils.vis import draw_with_blend_and_clicks class InteractiveController: def __init__(self, net, device, predictor_params, prob_thresh=0.5): self.net = net.to(device) self.prob_thresh = prob_thresh self.clicker = clicker.Clicker() self.states = [] self.probs_history = [] self.object_count = 0 self._result_mask = None self.image = None self.predictor = None self.device = device self.predictor_params = predictor_params self.reset_predictor() def set_image(self, image): self.image = image self._result_mask = torch.zeros(image.shape[-2:], dtype=torch.uint8) self.object_count = 0 self.reset_last_object() def add_click(self, x, y, is_positive): self.states.append({ 'clicker': self.clicker.get_state(), 'predictor': self.predictor.get_states() }) click = clicker.Click(is_positive=is_positive, coords=(y, x)) self.clicker.add_click(click) pred = self.predictor.get_prediction(self.clicker) torch.cuda.empty_cache() if self.probs_history: self.probs_history.append((self.probs_history[-1][0], pred)) else: self.probs_history.append((torch.zeros_like(pred), pred)) def undo_click(self): if not self.states: return prev_state = self.states.pop() self.clicker.set_state(prev_state['clicker']) self.predictor.set_states(prev_state['predictor']) self.probs_history.pop() def partially_finish_object(self): object_prob = self.current_object_prob if object_prob is None: return self.probs_history.append((object_prob, torch.zeros_like(object_prob))) self.states.append(self.states[-1]) self.clicker.reset_clicks() self.reset_predictor() def finish_object(self): object_prob = self.current_object_prob if object_prob is None: return self.object_count += 1 object_mask = object_prob > self.prob_thresh self._result_mask[object_mask] = self.object_count self.reset_last_object() def reset_last_object(self): self.states = [] self.probs_history = [] self.clicker.reset_clicks() self.reset_predictor() def reset_predictor(self, predictor_params=None): if predictor_params is not None: self.predictor_params = predictor_params self.predictor = get_predictor(self.net, device=self.device, **self.predictor_params) if self.image is not None: self.predictor.set_input_image(self.image) @property def current_object_prob(self): if self.probs_history: current_prob_total, current_prob_additive = self.probs_history[-1] return torch.maximum(current_prob_total, current_prob_additive) else: return None @property def is_incomplete_mask(self): return len(self.probs_history) > 0 @property def result_mask(self): return self._result_mask.clone()