File size: 3,698 Bytes
fe3fdf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import cv2
import numpy as np
import torch
import threading
from chain_img_processor import ChainImgProcessor, ChainImgPlugin
from torchvision import transforms
from clip.clipseg import CLIPDensePredT
from numpy import asarray


THREAD_LOCK_CLIP = threading.Lock()

modname = os.path.basename(__file__)[:-3] # calculating modname

model_clip = None

   


# start function
def start(core:ChainImgProcessor):
    manifest = { # plugin settings
        "name": "Text2Clip", # name
        "version": "1.0", # version

        "default_options": {
        },
        "img_processor": {
            "txt2clip": Text2Clip
        }
    }
    return manifest

def start_with_options(core:ChainImgProcessor, manifest:dict):
    pass



class Text2Clip(ChainImgPlugin):

    def load_clip_model(self):
        global model_clip

        if model_clip is None:
            device = torch.device(super().device)
            model_clip = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
            model_clip.eval();
            model_clip.load_state_dict(torch.load('models/CLIP/rd64-uni-refined.pth', map_location=torch.device('cpu')), strict=False) 
            model_clip.to(device)    


    def init_plugin(self):
        self.load_clip_model()

    def process(self, frame, params:dict):
        if "face_detected" in params:
            if not params["face_detected"]:
                return frame
       
        return self.mask_original(params["original_frame"], frame, params["clip_prompt"])
        

    def mask_original(self, img1, img2, keywords):
        global model_clip

        source_image_small = cv2.resize(img1, (256,256))
        
        img_mask = np.full((source_image_small.shape[0],source_image_small.shape[1]), 0, dtype=np.float32)
        mask_border = 1
        l = 0
        t = 0
        r = 1
        b = 1
        
        mask_blur = 5
        clip_blur = 5
        
        img_mask = cv2.rectangle(img_mask, (mask_border+int(l), mask_border+int(t)), 
                                (256 - mask_border-int(r), 256-mask_border-int(b)), (255, 255, 255), -1)    
        img_mask = cv2.GaussianBlur(img_mask, (mask_blur*2+1,mask_blur*2+1), 0)    
        img_mask /= 255

        
        input_image = source_image_small

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.Resize((256, 256)),
        ])
        img = transform(input_image).unsqueeze(0)

        thresh = 0.5
        prompts = keywords.split(',')
        with THREAD_LOCK_CLIP:
            with torch.no_grad():
                preds = model_clip(img.repeat(len(prompts),1,1,1), prompts)[0]
        clip_mask = torch.sigmoid(preds[0][0])
        for i in range(len(prompts)-1):
            clip_mask += torch.sigmoid(preds[i+1][0])
           
        clip_mask = clip_mask.data.cpu().numpy()
        np.clip(clip_mask, 0, 1)
        
        clip_mask[clip_mask>thresh] = 1.0
        clip_mask[clip_mask<=thresh] = 0.0
        kernel = np.ones((5, 5), np.float32)
        clip_mask = cv2.dilate(clip_mask, kernel, iterations=1)
        clip_mask = cv2.GaussianBlur(clip_mask, (clip_blur*2+1,clip_blur*2+1), 0)
       
        img_mask *= clip_mask
        img_mask[img_mask<0.0] = 0.0
       
        img_mask = cv2.resize(img_mask, (img2.shape[1], img2.shape[0]))
        img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
       
        target = img2.astype(np.float32)
        result = (1-img_mask) * target
        result += img_mask * img1.astype(np.float32)
        return np.uint8(result)