File size: 3,972 Bytes
011036f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from PIL import Image
import numpy as np
import torch
import torchvision.transforms.functional as TF

def tensor2pil(image):
    return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))

def pil2tensor(image):
    return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)

def tensor2mask(t: torch.Tensor) -> torch.Tensor:
    size = t.size()
    if (len(size) < 4):
        return t
    if size[3] == 1:
        return t[:,:,:,0]
    elif size[3] == 4:
        # Use alpha if available
        if torch.min(t[:, :, :, 3]).item() != 1.:
            return t[:,:,:,3]
    # Convert RGB to grayscale
    return TF.rgb_to_grayscale(t.permute(0,3,1,2), num_output_channels=1)[:,0,:,:]

class image_concat_mask:
    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image1": ("IMAGE",),
            },
            "optional": {
                "image2": ("IMAGE",),
                "mask": ("MASK",),
            }
        }

    RETURN_TYPES = ("IMAGE", "MASK",)
    RETURN_NAMES = ("image", "mask")
    FUNCTION = "image_concat_mask"
    CATEGORY = "hhy"

    def image_concat_mask(self, image1, image2=None, mask=None):
        processed_images = []
        masks = []
        
        for idx, img1 in enumerate(image1):
            # Convert tensor to PIL
            pil_image1 = tensor2pil(img1)
            
            # Get first image dimensions
            width1, height1 = pil_image1.size
            
            if image2 is not None and idx < len(image2):
                # Use provided second image
                pil_image2 = tensor2pil(image2[idx])
                width2, height2 = pil_image2.size
                
                # Resize image2 to match height of image1
                new_width2 = int(width2 * (height1 / height2))
                pil_image2 = pil_image2.resize((new_width2, height1), Image.Resampling.LANCZOS)
            else:
                # Create white image with same dimensions as image1
                pil_image2 = Image.new('RGB', (width1, height1), 'white')
                new_width2 = width1
            
            # Create new image to hold both images side by side
            combined_image = Image.new('RGB', (width1 + new_width2, height1))
            
            # Paste both images
            combined_image.paste(pil_image1, (0, 0))
            combined_image.paste(pil_image2, (width1, 0))
            
            # Convert combined image to tensor
            combined_tensor = pil2tensor(combined_image)
            processed_images.append(combined_tensor)
            
            # Create mask (0 for left image area, 1 for right image area)
            final_mask = torch.zeros((1, height1, width1 + new_width2))
            final_mask[:, :, width1:] = 1.0  # Set right half to 1
            
            # If mask is provided, subtract it from the right side
            if mask is not None and idx < len(mask):
                input_mask = mask[idx]
                # Resize input mask to match height1
                pil_input_mask = tensor2pil(input_mask)
                pil_input_mask = pil_input_mask.resize((new_width2, height1), Image.Resampling.LANCZOS)
                resized_input_mask = pil2tensor(pil_input_mask)
                
                # Subtract input mask from the right side
                final_mask[:, :, width1:] *= (1.0 - resized_input_mask)
            
            masks.append(final_mask)
            
        processed_images = torch.cat(processed_images, dim=0)
        masks = torch.cat(masks, dim=0)
        
        return (processed_images, masks)

NODE_CLASS_MAPPINGS = {
    "image concat mask": image_concat_mask
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "image concat mask": "Image Concat with Mask"
}