Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import argparse | |
import numpy as np | |
import torch.nn.functional as F | |
import glob | |
import cv2 | |
from tqdm import tqdm | |
import time | |
import os | |
from model.deep_lab_model.deeplab import * | |
from MBD import mask_base_dewarper | |
import time | |
from utils import cvimg2torch,torch2cvimg | |
def net1_net2_infer(model,img_paths,args): | |
### validate on the real datasets | |
seg_model=model | |
seg_model.eval() | |
for img_path in tqdm(img_paths): | |
if os.path.exists(img_path.replace('_origin','_capture')): | |
continue | |
t1 = time.time() | |
### segmentation mask predict | |
img_org = cv2.imread(img_path) | |
h_org,w_org = img_org.shape[:2] | |
img = cv2.resize(img_org,(448, 448)) | |
img = cv2.GaussianBlur(img,(15,15),0,0) | |
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) | |
img = cvimg2torch(img) | |
with torch.no_grad(): | |
pred = seg_model(img.cuda()) | |
mask_pred = pred[:,0,:,:].unsqueeze(1) | |
mask_pred = F.interpolate(mask_pred,(h_org,w_org)) | |
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy() | |
mask_pred = (mask_pred*255).astype(np.uint8) | |
kernel = np.ones((3,3)) | |
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3) | |
mask_pred = cv2.erode(mask_pred,kernel,iterations=3) | |
mask_pred[mask_pred>100] = 255 | |
mask_pred[mask_pred<100] = 0 | |
### tps transform base on the mask | |
# dewarp, grid = mask_base_dewarper(img_org,mask_pred) | |
try: | |
dewarp, grid = mask_base_dewarper(img_org,mask_pred) | |
except: | |
print('fail') | |
grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1) | |
grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1) | |
dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0] | |
grid = grid[0].numpy() | |
# cv2.imshow('in',cv2.resize(img_org,(512,512))) | |
# cv2.imshow('out',cv2.resize(dewarp,(512,512))) | |
# cv2.waitKey(0) | |
cv2.imwrite(img_path.replace('_origin','_capture'),dewarp) | |
cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred) | |
grid0 = cv2.resize(grid[:,:,0],(128,128)) | |
grid1 = cv2.resize(grid[:,:,1],(128,128)) | |
grid = np.stack((grid0,grid1),axis=-1) | |
np.save(img_path.replace('_origin','_grid1'),grid) | |
def net1_net2_infer_single_im(img,model_path): | |
seg_model = DeepLab(num_classes=1, | |
backbone='resnet', | |
output_stride=16, | |
sync_bn=None, | |
freeze_bn=False) | |
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count())) | |
seg_model.cuda() | |
checkpoint = torch.load(model_path) | |
seg_model.load_state_dict(checkpoint['model_state']) | |
### validate on the real datasets | |
seg_model.eval() | |
### segmentation mask predict | |
img_org = img | |
h_org,w_org = img_org.shape[:2] | |
img = cv2.resize(img_org,(448, 448)) | |
img = cv2.GaussianBlur(img,(15,15),0,0) | |
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) | |
img = cvimg2torch(img) | |
with torch.no_grad(): | |
# from torchtoolbox.tools import summary | |
# print(summary(seg_model,torch.rand((1, 3, 448, 448)).cuda())) 59.4M 135.6G | |
pred = seg_model(img.cuda()) | |
mask_pred = pred[:,0,:,:].unsqueeze(1) | |
mask_pred = F.interpolate(mask_pred,(h_org,w_org)) | |
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy() | |
mask_pred = (mask_pred*255).astype(np.uint8) | |
kernel = np.ones((3,3)) | |
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3) | |
mask_pred = cv2.erode(mask_pred,kernel,iterations=3) | |
mask_pred[mask_pred>100] = 255 | |
mask_pred[mask_pred<100] = 0 | |
### tps transform base on the mask | |
# dewarp, grid = mask_base_dewarper(img_org,mask_pred) | |
# try: | |
# dewarp, grid = mask_base_dewarper(img_org,mask_pred) | |
# except: | |
# print('fail') | |
# grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1) | |
# grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1) | |
# dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0] | |
# grid = grid[0].numpy() | |
# cv2.imshow('in',cv2.resize(img_org,(512,512))) | |
# cv2.imshow('out',cv2.resize(dewarp,(512,512))) | |
# cv2.waitKey(0) | |
# cv2.imwrite(img_path.replace('_origin','_capture'),dewarp) | |
# cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred) | |
# grid0 = cv2.resize(grid[:,:,0],(128,128)) | |
# grid1 = cv2.resize(grid[:,:,1],(128,128)) | |
# grid = np.stack((grid0,grid1),axis=-1) | |
# np.save(img_path.replace('_origin','_grid1'),grid) | |
return mask_pred | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Hyperparams') | |
parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data') | |
parser.add_argument('--img_rows', nargs='?', type=int, default=448, | |
help='Height of the input image') | |
parser.add_argument('--img_cols', nargs='?', type=int, default=448, | |
help='Width of the input image') | |
parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl', | |
help='Path to previous saved model to restart from') | |
args = parser.parse_args() | |
seg_model = DeepLab(num_classes=1, | |
backbone='resnet', | |
output_stride=16, | |
sync_bn=None, | |
freeze_bn=False) | |
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count())) | |
seg_model.cuda() | |
checkpoint = torch.load(args.seg_model_path) | |
seg_model.load_state_dict(checkpoint['model_state']) | |
im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*')) | |
net1_net2_infer(seg_model,im_paths,args) | |