OriLib commited on
Commit
5ed1996
1 Parent(s): 4cb2cd7

Delete Inference.py

Browse files
Files changed (1) hide show
  1. Inference.py +0 -53
Inference.py DELETED
@@ -1,53 +0,0 @@
1
- import os
2
- import numpy as np
3
- from skimage import io
4
- from glob import glob
5
- from tqdm import tqdm
6
- import cv2
7
- import torch
8
- import torch.nn.functional as F
9
- from torchvision.transforms.functional import normalize
10
- from models import ISNetDIS
11
-
12
-
13
- if __name__ == "__main__":
14
- dataset_path="input_images" #Your dataset path
15
- model_path="model.pth"
16
- result_path="output_results" #The folder path that you want to save the results
17
-
18
- if not os.path.exists(result_path):
19
- os.makedirs(result_path)
20
-
21
- input_size=[1024,1024]
22
- net=ISNetDIS()
23
-
24
- if torch.cuda.is_available():
25
- net.load_state_dict(torch.load(model_path))
26
- net=net.cuda()
27
- else:
28
- net.load_state_dict(torch.load(model_path,map_location="cpu"))
29
- net.eval()
30
-
31
- im_list = glob(dataset_path+"/*.jpg")+glob(dataset_path+"/*.JPG")+glob(dataset_path+"/*.jpeg")+glob(dataset_path+"/*.JPEG")+glob(dataset_path+"/*.png")+glob(dataset_path+"/*.PNG")+glob(dataset_path+"/*.bmp")+glob(dataset_path+"/*.BMP")+glob(dataset_path+"/*.tiff")+glob(dataset_path+"/*.TIFF")
32
- with torch.no_grad():
33
- for i, im_path in tqdm(enumerate(im_list), total=len(im_list)):
34
- print("im_path: ", im_path)
35
- im = io.imread(im_path)
36
- if len(im.shape) < 3:
37
- im = im[:, :, np.newaxis]
38
- im_shp=im.shape[0:2]
39
- im_tensor = torch.tensor(im, dtype=torch.float32).permute(2,0,1)
40
- im_tensor = F.upsample(torch.unsqueeze(im_tensor,0), input_size, mode="bilinear").type(torch.uint8)
41
- image = torch.divide(im_tensor,255.0)
42
- image = normalize(image,[0.5,0.5,0.5],[1.0,1.0,1.0])
43
-
44
- if torch.cuda.is_available():
45
- image=image.cuda()
46
-
47
- result=net(image)
48
- result=torch.squeeze(F.upsample(result[0][0],im_shp,mode='bilinear'),0)
49
- ma = torch.max(result)
50
- mi = torch.min(result)
51
- result = (result-mi)/(ma-mi)
52
- im_name=im_path.split('/')[-1].split('.')[0]
53
- cv2.imwrite(os.path.join(result_path,im_name+".png"),(result*255).permute(1,2,0).cpu().data.numpy().astype(np.uint8))