Spaces:
Runtime error
Runtime error
import cv2 | |
import einops | |
import gradio as gr | |
import numpy as np | |
import torch | |
import random | |
import os | |
import json | |
import datetime | |
from huggingface_hub import hf_hub_url, hf_hub_download | |
from pytorch_lightning import seed_everything | |
from annotator.util import resize_image, HWC3 | |
from annotator.OneFormer import OneformerSegmenter | |
from cldm.model import create_model, load_state_dict | |
from cldm.ddim_hacked import DDIMSamplerSpaCFG | |
from ldm.models.autoencoder import DiagonalGaussianDistribution | |
SEGMENT_MODEL_DICT = { | |
'Oneformer': OneformerSegmenter, | |
} | |
MASK_MODEL_DICT = { | |
'Oneformer': OneformerSegmenter, | |
} | |
urls = { | |
'shi-labs/oneformer_coco_swin_large': ['150_16_swin_l_oneformer_coco_100ep.pth'], | |
'PAIR/PAIR-diffusion-sdv15-coco-finetune': ['model_e91.ckpt'] | |
} | |
WTS_DICT = { | |
} | |
if os.path.exists('checkpoints') == False: | |
os.mkdir('checkpoints') | |
for repo in urls: | |
files = urls[repo] | |
for file in files: | |
url = hf_hub_url(repo, file) | |
name_ckp = url.split('/')[-1] | |
WTS_DICT[repo] = hf_hub_download(repo_id=repo, filename=file) | |
#main model | |
model = create_model('configs/pair_diff.yaml').cpu() | |
model.load_state_dict(load_state_dict(WTS_DICT['PAIR/PAIR-diffusion-sdv15-coco-finetune'], location='cuda')) | |
save_dir = 'results/' | |
model = model.cuda() | |
ddim_sampler = DDIMSamplerSpaCFG(model) | |
save_memory = False | |
class ImageComp: | |
def __init__(self, edit_operation): | |
self.input_img = None | |
self.input_pmask = None | |
self.input_segmask = None | |
self.input_mask = None | |
self.input_points = [] | |
self.input_scale = 1 | |
self.ref_img = None | |
self.ref_pmask = None | |
self.ref_segmask = None | |
self.ref_mask = None | |
self.ref_points = [] | |
self.ref_scale = 1 | |
self.multi_modal = False | |
self.H = None | |
self.W = None | |
self.kernel = np.ones((5, 5), np.uint8) | |
self.edit_operation = edit_operation | |
self.init_segmentation_model() | |
os.makedirs(save_dir, exist_ok=True) | |
self.base_prompt = 'A picture of {}' | |
def init_segmentation_model(self, mask_model='Oneformer', segment_model='Oneformer'): | |
self.segment_model_name = segment_model | |
self.mask_model_name = mask_model | |
self.segment_model = SEGMENT_MODEL_DICT[segment_model](WTS_DICT['shi-labs/oneformer_coco_swin_large']) | |
if mask_model == 'Oneformer' and segment_model == 'Oneformer': | |
self.mask_model_inp = self.segment_model | |
self.mask_model_ref = self.segment_model | |
else: | |
self.mask_model_inp = MASK_MODEL_DICT[mask_model]() | |
self.mask_model_ref = MASK_MODEL_DICT[mask_model]() | |
print(f"Segmentation Models initialized with {mask_model} as mask and {segment_model} as segment") | |
def init_input_canvas(self, img): | |
img = HWC3(img) | |
img = resize_image(img, 512) | |
if self.segment_model_name == 'Oneformer': | |
detected_seg = self.segment_model(img, 'semantic') | |
elif self.segment_model_name == 'SAM': | |
raise NotImplementedError | |
if self.mask_model_name == 'Oneformer': | |
detected_mask = self.mask_model_inp(img, 'panoptic')[0] | |
elif self.mask_model_name == 'SAM': | |
detected_mask = self.mask_model_inp(img) | |
self.input_points = [] | |
self.input_img = img | |
self.input_pmask = detected_mask | |
self.input_segmask = detected_seg | |
self.H = img.shape[0] | |
self.W = img.shape[1] | |
return img | |
def init_ref_canvas(self, img): | |
img = HWC3(img) | |
img = resize_image(img, 512) | |
if self.segment_model_name == 'Oneformer': | |
detected_seg = self.segment_model(img, 'semantic') | |
elif self.segment_model_name == 'SAM': | |
raise NotImplementedError | |
if self.mask_model_name == 'Oneformer': | |
detected_mask = self.mask_model_ref(img, 'panoptic')[0] | |
elif self.mask_model_name == 'SAM': | |
detected_mask = self.mask_model_ref(img) | |
self.ref_points = [] | |
print("Initialized ref", img.shape) | |
self.ref_img = img | |
self.ref_pmask = detected_mask | |
self.ref_segmask = detected_seg | |
return img | |
def select_input_object(self, evt: gr.SelectData): | |
idx = list(np.array(evt.index) * self.input_scale) | |
self.input_points.append(idx) | |
if self.mask_model_name == 'Oneformer': | |
mask = self._get_mask_from_panoptic(np.array(self.input_points), self.input_pmask) | |
else: | |
mask = self.mask_model_inp(self.input_img, self.input_points) | |
c_ids = self.input_segmask[np.array(self.input_points)[:,1], np.array(self.input_points)[:,0]] | |
unique_ids, counts = torch.unique(c_ids, return_counts=True) | |
c_id = int(unique_ids[torch.argmax(counts)].cpu().detach().numpy()) | |
category = self.segment_model.metadata.stuff_classes[c_id] | |
# print(self.segment_model.metadata.stuff_classes) | |
self.input_mask = mask | |
mask = mask.cpu().numpy() | |
output = mask[:,:,None] * self.input_img + (1 - mask[:,:,None]) * self.input_img * 0.2 | |
return output.astype(np.uint8), self.base_prompt.format(category) | |
def select_ref_object(self, evt: gr.SelectData): | |
idx = list(np.array(evt.index) * self.ref_scale) | |
self.ref_points.append(idx) | |
if self.mask_model_name == 'Oneformer': | |
mask = self._get_mask_from_panoptic(np.array(self.ref_points), self.ref_pmask) | |
else: | |
mask = self.mask_model_ref(self.ref_img, self.ref_points) | |
c_ids = self.ref_segmask[np.array(self.ref_points)[:,1], np.array(self.ref_points)[:,0]] | |
unique_ids, counts = torch.unique(c_ids, return_counts=True) | |
c_id = int(unique_ids[torch.argmax(counts)].cpu().detach().numpy()) | |
category = self.segment_model.metadata.stuff_classes[c_id] | |
print("Category of reference object is:", category) | |
self.ref_mask = mask | |
mask = mask.cpu().numpy() | |
output = mask[:,:,None] * self.ref_img + (1 - mask[:,:,None]) * self.ref_img * 0.2 | |
return output.astype(np.uint8) | |
def clear_points(self): | |
self.input_points = [] | |
self.ref_points = [] | |
zeros_inp = np.zeros(self.input_img.shape) | |
zeros_ref = np.zeros(self.ref_img.shape) | |
return zeros_inp, zeros_ref | |
def return_input_img(self): | |
return self.input_img | |
def _get_mask_from_panoptic(self, points, panoptic_mask): | |
panoptic_mask_ = panoptic_mask + 1 | |
ids = panoptic_mask_[points[:,1], points[:,0]] | |
unique_ids, counts = torch.unique(ids, return_counts=True) | |
mask_id = unique_ids[torch.argmax(counts)] | |
final_mask = torch.zeros(panoptic_mask.shape).cuda() | |
final_mask[panoptic_mask_ == mask_id] = 1 | |
return final_mask | |
def _process_mask(self, mask, panoptic_mask, segmask): | |
obj_class = mask * (segmask + 1) | |
unique_ids, counts = torch.unique(obj_class, return_counts=True) | |
obj_class = unique_ids[torch.argmax(counts[1:]) + 1] - 1 | |
return mask, obj_class | |
def _edit_app(self, whole_ref): | |
""" | |
Manipulates the panoptic mask of input image to change appearance | |
""" | |
input_pmask = self.input_pmask | |
input_segmask = self.input_segmask | |
if whole_ref: | |
reference_mask = torch.ones(self.ref_pmask.shape).cuda() | |
else: | |
reference_mask, _ = self._process_mask(self.ref_mask, self.ref_pmask, self.ref_segmask) | |
edit_mask, _ = self._process_mask(self.input_mask, self.input_pmask, self.input_segmask) | |
# tmp = cv2.dilate(edit_mask.squeeze().cpu().numpy(), self.kernel, iterations = 2) | |
# region_mask = torch.tensor(tmp).cuda() | |
region_mask = edit_mask | |
ma = torch.max(input_pmask) | |
input_pmask[edit_mask == 1] = ma + 1 | |
return reference_mask, input_pmask, input_segmask, region_mask, ma | |
def _add_object(self, input_mask, dilation_fac): | |
""" | |
Manipulates the panooptic mask of input image for adding objects | |
Args: | |
input_mask (numpy array): Region where new objects needs to be added | |
dilation factor (float): Controls edge merging region for adding objects | |
""" | |
input_pmask = self.input_pmask | |
input_segmask = self.input_segmask | |
reference_mask, obj_class = self._process_mask(self.ref_mask, self.ref_pmask, self.ref_segmask) | |
tmp = cv2.dilate(input_mask['mask'][:, :, 0], self.kernel, iterations = int(dilation_fac)) | |
region = torch.tensor(tmp) | |
region_mask = torch.zeros_like(region).cuda() | |
region_mask[region > 127] = 1 | |
mask_ = torch.tensor(input_mask['mask'][:, :, 0]) | |
edit_mask = torch.zeros_like(mask_).cuda() | |
edit_mask[mask_ > 127] = 1 | |
ma = torch.max(input_pmask) | |
input_pmask[edit_mask == 1] = ma + 1 | |
print(obj_class) | |
input_segmask[edit_mask == 1] = obj_class.long() | |
return reference_mask, input_pmask, input_segmask, region_mask, ma | |
def _edit(self, input_mask, ref_mask, dilation_fac=1, whole_ref=False, inter=1): | |
""" | |
Entry point for all the appearance editing and add objects operations. The function manipulates the | |
appearance vectors and structure based on user input | |
Args: | |
input mask (numpy array): Region in input image which needs to be edited | |
dilation factor (float): Controls edge merging region for adding objects | |
whole_ref (bool): Flag for specifying if complete reference image should be used | |
inter (float): Interpolation of appearance between the reference appearance and the input appearance. | |
""" | |
input_img = (self.input_img/127.5 - 1) | |
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2) | |
reference_img = (self.ref_img/127.5 - 1) | |
reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2) | |
if self.edit_operation == 'add_obj': | |
reference_mask, input_pmask, input_segmask, region_mask, ma = self._add_object(input_mask, dilation_fac) | |
elif self.edit_operation == 'edit_app': | |
reference_mask, input_pmask, input_segmask, region_mask, ma = self._edit_app(whole_ref) | |
#concat featurees | |
input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1) | |
_, mean_feat_inpt_conc, one_hot_inpt_conc, _ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, input_img, input_pmask, return_all=True) | |
reference_mask = reference_mask.float().cuda().unsqueeze(0).unsqueeze(1) | |
_, mean_feat_ref_conc, _, _ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, reference_img, reference_mask, return_all=True) | |
# if mean_feat_ref.shape[1] > 1: | |
if isinstance(mean_feat_inpt_conc, list): | |
appearance_conc = [] | |
for i in range(len(mean_feat_inpt_conc)): | |
mean_feat_inpt_conc[i][:, ma + 1] = (1 - inter) * mean_feat_inpt_conc[i][:, ma + 1] + inter*mean_feat_ref_conc[i][:, 1] | |
splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_conc[i], one_hot_inpt_conc) | |
splatted_feat_conc = torch.nn.functional.normalize(splatted_feat_conc) | |
splatted_feat_conc = torch.nn.functional.interpolate(splatted_feat_conc, (self.H//8, self.W//8)) | |
appearance_conc.append(splatted_feat_conc) | |
appearance_conc = torch.cat(appearance_conc, dim=1) | |
else: | |
print("manipulating") | |
mean_feat_inpt_conc[:, ma + 1] = (1 - inter) * mean_feat_inpt_conc[:, ma + 1] + inter*mean_feat_ref_conc[:, 1] | |
splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_conc, one_hot_inpt_conc) | |
appearance_conc = torch.nn.functional.normalize(splatted_feat_conc) #l2 normaliz | |
appearance_conc = torch.nn.functional.interpolate(appearance_conc, (self.H//8, self.W//8)) | |
#cross attention features | |
_, mean_feat_inpt_ca, one_hot_inpt_ca, _ = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, input_img, input_pmask, return_all=True) | |
_, mean_feat_ref_ca, _, _ = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, reference_img, reference_mask, return_all=True) | |
# if mean_feat_ref.shape[1] > 1: | |
if isinstance(mean_feat_inpt_ca, list): | |
appearance_ca = [] | |
for i in range(len(mean_feat_inpt_ca)): | |
mean_feat_inpt_ca[i][:, ma + 1] = (1 - inter) * mean_feat_inpt_ca[i][:, ma + 1] + inter*mean_feat_ref_ca[i][:, 1] | |
splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_ca[i], one_hot_inpt_ca) | |
splatted_feat_ca = torch.nn.functional.normalize(splatted_feat_ca) | |
splatted_feat_ca = torch.nn.functional.interpolate(splatted_feat_ca, (self.H//8, self.W//8)) | |
appearance_ca.append(splatted_feat_ca) | |
else: | |
print("manipulating") | |
mean_feat_inpt_ca[:, ma + 1] = (1 - inter) * mean_feat_inpt_ca[:, ma + 1] + inter*mean_feat_ref_ca[:, 1] | |
splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_ca, one_hot_inpt_ca) | |
appearance_ca = torch.nn.functional.normalize(splatted_feat_ca) #l2 normaliz | |
appearance_ca = torch.nn.functional.interpolate(appearance_ca, (self.H//8, self.W//8)) | |
input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1) | |
structure = torch.nn.functional.interpolate(input_segmask, (self.H//8, self.W//8)) | |
return structure, appearance_conc, appearance_ca, region_mask, input_img | |
def _edit_obj_var(self, input_mask, ignore_structure): | |
input_img = (self.input_img/127.5 - 1) | |
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2) | |
input_pmask = self.input_pmask | |
input_segmask = self.input_segmask | |
ma = torch.max(input_pmask) | |
mask_ = torch.tensor(input_mask['mask'][:, :, 0]) | |
edit_mask = torch.zeros_like(mask_).cuda() | |
edit_mask[mask_ > 127] = 1 | |
tmp = edit_mask * (input_pmask + ma + 1) | |
if ignore_structure: | |
tmp = edit_mask | |
input_pmask = tmp * edit_mask + (1 - edit_mask) * input_pmask | |
input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1) | |
mask_ca_feat = self.input_pmask.float().cuda().unsqueeze(0).unsqueeze(1) if ignore_structure else input_pmask | |
print(torch.unique(mask_ca_feat)) | |
appearance_conc,_,_,_ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, input_img, input_pmask, return_all=True) | |
appearance_ca = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, input_img, mask_ca_feat) | |
appearance_conc = torch.nn.functional.interpolate(appearance_conc, (self.H//8, self.W//8)) | |
appearance_ca = [torch.nn.functional.interpolate(ap, (self.H//8, self.W//8)) for ap in appearance_ca] | |
input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1) | |
structure = torch.nn.functional.interpolate(input_segmask, (self.H//8, self.W//8)) | |
tmp = input_mask['mask'][:, :, 0] | |
region = torch.tensor(tmp) | |
mask = torch.zeros_like(region).cuda() | |
mask[region > 127] = 1 | |
return structure, appearance_conc, appearance_ca, mask, input_img | |
def get_caption(self, mask): | |
""" | |
Generates the captions based on a set template | |
Args: | |
mask (numpy array): Region of image based on which caption needs to be generated | |
""" | |
mask = mask['mask'][:, :, 0] | |
region = torch.tensor(mask).cuda() | |
mask = torch.zeros_like(region) | |
mask[region > 127] = 1 | |
if torch.sum(mask) == 0: | |
return "" | |
c_ids = self.input_segmask * mask | |
unique_ids, counts = torch.unique(c_ids, return_counts=True) | |
c_id = int(unique_ids[torch.argmax(counts[1:]) + 1].cpu().detach().numpy()) | |
category = self.segment_model.metadata.stuff_classes[c_id] | |
return self.base_prompt.format(category) | |
def save_result(self, input_mask, prompt, a_prompt, n_prompt, | |
ddim_steps, scale_s, scale_f, scale_t, seed, dilation_fac=1,inter=1, | |
free_form_obj_var=False, ignore_structure=False): | |
""" | |
Saves the current results with all the meta data | |
""" | |
meta_data = {} | |
meta_data['prompt'] = prompt | |
meta_data['a_prompt'] = a_prompt | |
meta_data['n_prompt'] = n_prompt | |
meta_data['seed'] = seed | |
meta_data['ddim_steps'] = ddim_steps | |
meta_data['scale_s'] = scale_s | |
meta_data['scale_f'] = scale_f | |
meta_data['scale_t'] = scale_t | |
meta_data['inter'] = inter | |
meta_data['dilation_fac'] = dilation_fac | |
meta_data['edit_operation'] = self.edit_operation | |
uuid = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
os.makedirs(f'{save_dir}/{uuid}') | |
with open(f'{save_dir}/{uuid}/meta.json', "w") as outfile: | |
json.dump(meta_data, outfile) | |
cv2.imwrite(f'{save_dir}/{uuid}/input.png', self.input_img[:,:,::-1]) | |
cv2.imwrite(f'{save_dir}/{uuid}/ref.png', self.ref_img[:,:,::-1]) | |
if self.ref_mask is not None: | |
cv2.imwrite(f'{save_dir}/{uuid}/ref_mask.png', self.ref_mask.cpu().squeeze().numpy() * 200) | |
for i in range(len(self.results)): | |
cv2.imwrite(f'{save_dir}/{uuid}/edit{i}.png', self.results[i][:,:,::-1]) | |
if self.edit_operation == 'add_obj' or free_form_obj_var: | |
cv2.imwrite(f'{save_dir}/{uuid}/input_mask.png', input_mask['mask'] * 200) | |
else: | |
cv2.imwrite(f'{save_dir}/{uuid}/input_mask.png', self.input_mask.cpu().squeeze().numpy() * 200) | |
print("Saved results at", f'{save_dir}/{uuid}') | |
def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt, | |
num_samples, ddim_steps, guess_mode, strength, | |
scale_s, scale_f, scale_t, seed, eta, dilation_fac=1,masking=True,whole_ref=False,inter=1, | |
free_form_obj_var=False, ignore_structure=False): | |
print(prompt) | |
if free_form_obj_var: | |
print("Free form") | |
structure, appearance_conc, appearance_ca, mask, img = self._edit_obj_var(input_mask, ignore_structure) | |
else: | |
structure, appearance_conc, appearance_ca, mask, img = self._edit(input_mask, ref_mask, dilation_fac=dilation_fac, | |
whole_ref=whole_ref, inter=inter) | |
input_pmask = torch.nn.functional.interpolate(self.input_pmask.cuda().unsqueeze(0).unsqueeze(1).float(), (self.H//8, self.W//8)) | |
input_pmask = input_pmask.to(memory_format=torch.contiguous_format) | |
if isinstance(appearance_ca, list): | |
null_appearance_ca = [torch.zeros(a.shape).cuda() for a in appearance_ca] | |
null_appearance_conc = torch.zeros(appearance_conc.shape).cuda() | |
null_structure = torch.zeros(structure.shape).cuda() - 1 | |
null_control = [torch.cat([null_structure, napp, input_pmask * 0], dim=1) for napp in null_appearance_ca] | |
structure_control = [torch.cat([structure, napp, input_pmask], dim=1) for napp in null_appearance_ca] | |
full_control = [torch.cat([structure, napp, input_pmask], dim=1) for napp in appearance_ca] | |
null_control.append(torch.cat([null_structure, null_appearance_conc, null_structure * 0], dim=1)) | |
structure_control.append(torch.cat([structure, null_appearance_conc, null_structure], dim=1)) | |
full_control.append(torch.cat([structure, appearance_conc, input_pmask], dim=1)) | |
null_control = [torch.cat([nc for _ in range(num_samples)], dim=0) for nc in null_control] | |
structure_control = [torch.cat([sc for _ in range(num_samples)], dim=0) for sc in structure_control] | |
full_control = [torch.cat([fc for _ in range(num_samples)], dim=0) for fc in full_control] | |
#Masking for local edit | |
if not masking: | |
mask, x0 = None, None | |
else: | |
x0 = model.encode_first_stage(img) | |
x0 = x0.sample() if isinstance(x0, DiagonalGaussianDistribution) else x0 # todo: check if we can set random number | |
x0 = x0 * model.scale_factor | |
mask = 1 - torch.tensor(mask).unsqueeze(0).unsqueeze(1).cuda() | |
mask = torch.nn.functional.interpolate(mask.float(), x0.shape[2:]).float() | |
if seed == -1: | |
seed = random.randint(0, 65535) | |
seed_everything(seed) | |
scale = [scale_s, scale_f, scale_t] | |
print(scale) | |
if save_memory: | |
model.low_vram_shift(is_diffusing=False) | |
uc_cross = model.get_learned_conditioning([n_prompt] * num_samples) | |
c_cross = model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples) | |
cond = {"c_concat": [null_control], "c_crossattn": [c_cross]} | |
un_cond = {"c_concat": None if guess_mode else [null_control], "c_crossattn": [uc_cross]} | |
un_cond_struct = {"c_concat": None if guess_mode else [structure_control], "c_crossattn": [uc_cross]} | |
un_cond_struct_app = {"c_concat": None if guess_mode else [full_control], "c_crossattn": [uc_cross]} | |
shape = (4, self.H // 8, self.W // 8) | |
if save_memory: | |
model.low_vram_shift(is_diffusing=True) | |
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01 | |
samples, _ = ddim_sampler.sample(ddim_steps, num_samples, | |
shape, cond, verbose=False, eta=eta, | |
unconditional_guidance_scale=scale, mask=mask, x0=x0, | |
unconditional_conditioning=[un_cond, un_cond_struct, un_cond_struct_app ]) | |
if save_memory: | |
model.low_vram_shift(is_diffusing=False) | |
x_samples = (model.decode_first_stage(samples) + 1) * 127.5 | |
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')).cpu().numpy().clip(0, 255).astype(np.uint8) | |
results = [x_samples[i] for i in range(num_samples)] | |
self.results = results | |
return [] + results |