File size: 14,160 Bytes
89022d9
6dfcb0f
 
 
 
8aafeea
 
 
6dfcb0f
 
 
 
 
 
 
 
 
 
 
4d601e2
6dfcb0f
 
 
 
 
 
 
 
 
8e8833a
6dfcb0f
 
 
 
7ea6ce9
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e8833a
89022d9
110d56f
89022d9
 
 
 
6dfcb0f
 
5637cf2
6dfcb0f
 
 
 
 
 
 
 
8e8833a
6dfcb0f
 
 
 
 
 
5637cf2
6dfcb0f
 
 
 
 
 
 
 
 
 
 
8e8833a
6dfcb0f
 
 
19af009
8e8833a
6dfcb0f
 
 
 
 
 
 
 
 
4d601e2
 
 
 
 
 
6dfcb0f
 
4d601e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dfcb0f
 
 
 
 
8e8833a
6dfcb0f
 
 
 
4d601e2
6dfcb0f
 
 
 
 
 
 
 
 
 
 
8e8833a
6dfcb0f
8e8833a
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d601e2
 
6dfcb0f
 
 
 
4d601e2
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e8833a
6dfcb0f
 
 
 
 
 
 
 
8e8833a
 
6dfcb0f
 
 
 
 
 
 
 
 
 
8e8833a
6dfcb0f
 
a45652e
89022d9
 
 
6dfcb0f
 
 
 
 
 
 
 
 
 
 
8e8833a
6dfcb0f
 
 
 
8e8833a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
import spaces
import cv2
import numpy as np
import gradio as gr
import cwm.utils as utils
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==4.31.0")

# Points color and arrow properties
arrow_color = (0, 255, 0)  # Green color for all arrows
dot_color = (0, 255, 0)  # Green color for the dots at start and end
dot_color_fixed = (255, 0, 0)  # Red color for zero-length vectors
thickness = 3  # Thickness of the arrow
tip_length = 0.3  # The length of the arrow tip relative to the arrow length
dot_radius = 7  # Radius for the dots
dot_thickness = -1  # Thickness for solid circle (-1 fills the circle)
from PIL import Image
import torch
import json
#load model
from cwm.model.model_factory import model_factory

from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Load CWM 3-frame model (automatically download pre-trained checkpoint)
model = model_factory.load_model('vitb_8x8patch_2frames_encoder_mask_token')#.to(device)

model.requires_grad_(False)
model.eval()

model = model#.to(torch.float16)


import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from PIL import Image
import numpy as np

from torchvision import transforms

def draw_arrows_matplotlib(img, selected_points, zero_length):
    """
    Draw arrows on the image using matplotlib for better quality arrows and dots.
    """
    fig, ax = plt.subplots()
    ax.imshow(img)

    for i in range(0, len(selected_points), 2):
        start_point = selected_points[i]
        end_point = selected_points[i + 1]

        if start_point == end_point or zero_length:
            # Draw a dot for zero-length vectors or if only one point is clicked
            ax.scatter(start_point[0], start_point[1], color='red', s=100)  # Red dot for zero-length vector
        else:
            # Draw arrows
            arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]),
                                    color='green', linewidth=2, arrowstyle='->', mutation_scale=15)
            ax.add_patch(arrow)

            # Optionally, draw a small circle (dot) at the start and end points
            ax.scatter(start_point[0], start_point[1], color='green', s=100)  # Green dot at start
            ax.scatter(end_point[0], end_point[1], color='green', s=100)  # Green dot at end

    # Save the image to a numpy array
    fig.canvas.draw()
    img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)
    return img_array

import os
# def load_preuploaded_images():
#     image_folder = "assets"
#     images = []
#     for img_file in os.listdir(image_folder):
#         img_path = os.path.join(image_folder, img_file)
#         if img_file.endswith(('png', 'jpg', 'jpeg')):
#             images.append(Image.open(img_path))
#     return images
#
# # Function to transfer image from gallery to the input image section


#
# preloaded_images = load_preuploaded_images()
#
# print("Preloaded images:", preloaded_images)
# @spaces.GPU(duration=110)
def get_c(x, points):
    x = utils.imagenet_normalize(x)#.to(device)
    with torch.no_grad():
        counterfactual = model.get_counterfactual(x, points)
    return counterfactual

with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown('''# Scene editing interventions with Counterfactual World Models!
        ''')

    # Annotating arrows on an image
    with gr.Tab(label='Image'):
        with gr.Row():
            with gr.Column():
                # Input image
                original_image = gr.State(value=None)  # store original image without arrows
                original_image_high_res = gr.State(value=None)  # store original image without arrows
                input_image = gr.Image(type="numpy", label="Upload Image")

                # Annotate arrows
                selected_points = gr.State([])  # store points
                zero_length_toggle = gr.Checkbox(label="Select patches to be kept fixed", value=False)  # Toggle for zero-length vectors
                with gr.Row():
                    gr.Markdown('1. **Click on the image** to specify patch motion by selecting a start and end point. \n 2. After selecting the points to move, enable the **"Select patches to be kept fixed"** checkbox to choose a few points to keep fixed. \n 3. **Click "Run Model"** to visualize the result of the edit.')
                    undo_button = gr.Button('Undo last action')
                    clear_button = gr.Button('Clear All')

                # Run model button
                run_model_button = gr.Button('Run Model')

            # Show the image with the annotated arrows
            with gr.Tab(label='Intervention'):
                output_image = gr.Image(type='numpy')

        # Store the original image and resize to square size once uploaded
        def resize_to_square(img, size=512):
            print("Resizing image to square")
            img = Image.fromarray(img)
            transform = transforms.Compose([
                transforms.Resize((size, size)),
                # transforms.CenterCrop(size)
            ])
            img = transform(img)  # .transpose(1, 2, 0)

            return np.array(img)


        def load_img(evt: gr.SelectData):
            img_path = evt.value['image']['path']
            img = np.array(Image.open(img_path))
            # print(f"Image uploaded with shape: {input.shape}")
            with open('./assets/intervention_test_images/annot.json', 'r') as f:
                points_json = json.load(f)

            points_json = points_json[os.path.basename(img_path)]

            # print(f"Image uploaded with shape: {input.shape}")
            resized_img = resize_to_square(img)

            temp = resized_img.copy()

            # Redraw all remaining arrows and dots
            for i in range(0, len(points_json), 2):
                start_point = points_json[i]
                end_point = points_json[i + 1]
                if start_point == end_point:
                    # Zero-length vector: Draw a dot
                    color = dot_color_fixed
                else:
                    cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length,
                                    line_type=cv2.LINE_AA)
                    color = arrow_color
                # Draw arrow

                # Draw dots at start and end points
                cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
                cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)

            # If there is an odd number of points (e.g., only a start point), draw a dot for it
            if len(points_json) == 1:
                start_point = points_json[0]
                cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)



            return temp, resized_img, img,  points_json


        def store_img(img):
            resized_img = resize_to_square(img)  # Resize the uploaded image to a square
            print(f"Image uploaded with shape: {resized_img.shape}")
            return resized_img, resized_img, img, []


        with gr.Row():
            with gr.Column():
                gallery = gr.Gallery( ["./assets/ducks.jpg", "./assets/robot_arm.jpg", "./assets/bread.jpg",  "./assets/bird.jpg", "./assets/desk_1.jpg", "./assets/glasses.jpg", "./assets/watering_pot.jpg"], columns=5, allow_preview=False, label="Select an example image to test")
                # examples = gr.Examples(
                #     examples=[
                #         ["./assets/desk_1.jpg", "./assets/desk_1.jpg"],
                #     ],
                #     inputs=[input_image, original_image],
                #     # fn=load_img,
                #     # outputs=[input_image, original_image],
                #     # cache_examples=True,
                #     # run_on_click=True,
                #     # label="Select an example image to test"
                # )
        gallery.select(load_img, outputs=[input_image, original_image, original_image_high_res, selected_points])

        input_image.upload(store_img, [input_image], [input_image, original_image, original_image_high_res, selected_points])

        # Get points and draw arrows or zero-length vectors based on the toggle
        def get_point(img, sel_pix, zero_length, evt: gr.SelectData):
            sel_pix.append(evt.index)  # Append the point's location (coordinates)

            # Zero-length vector case: Draw a single dot at the clicked point
            if zero_length:
                point = sel_pix[-1]  # Last point clicked
                cv2.circle(img, point, dot_radius, dot_color_fixed, dot_thickness, lineType=cv2.LINE_AA)  # Draw a dot at the point
                sel_pix.append(evt.index)
            else:
                # Regular case: two clicks for an arrow
                # Check if this is the first point (start point for the arrow)
                if len(sel_pix) % 2 == 1:
                    # Draw a dot at the start point to give feedback
                    start_point = sel_pix[-1]  # Last point is the start
                    cv2.circle(img, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)

                # Check if two points have been selected (start and end points for an arrow)
                if len(sel_pix) % 2 == 0:
                    # Draw an arrow between the last two points
                    start_point = sel_pix[-2]  # Second last point is the start
                    end_point = sel_pix[-1]  # Last point is the end

                    # Draw arrow
                    cv2.arrowedLine(img, start_point, end_point, arrow_color, thickness, tipLength=tip_length, line_type=cv2.LINE_AA)

                    # Draw a dot at the end point
                    cv2.circle(img, end_point, dot_radius, dot_color, dot_thickness,  lineType=cv2.LINE_AA)

            return img if isinstance(img, np.ndarray) else np.array(img)

        input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image])

        # Undo the last selected action
        def undo_arrows(orig_img, sel_pix, zero_length):
            temp = orig_img.copy()
            # if zero_length:
            #     # Undo the last zero-length vector (just the last dot)
            #     if len(sel_pix) >= 1:
            #         sel_pix.pop()  # Remove the last point
            # else:
            if len(sel_pix) >= 2:
                sel_pix.pop()  # Remove the last end point
                sel_pix.pop()  # Remove the last start point

            # Redraw all remaining arrows and dots
            for i in range(0, len(sel_pix), 2):
                start_point = sel_pix[i]
                end_point = sel_pix[i + 1]
                if start_point == end_point:
                    # Zero-length vector: Draw a dot
                    color = dot_color_fixed
                else:
                    cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length)
                    color = arrow_color
                # Draw arrow

                # Draw dots at start and end points
                cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
                cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)

            # If there is an odd number of points (e.g., only a start point), draw a dot for it
            if len(sel_pix) == 1:
                start_point = sel_pix[0]
                cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)

            return temp if isinstance(temp, np.ndarray) else np.array(temp)

        undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image])


        # Clear all points and reset the image
        def clear_all_points(orig_img, sel_pix):
            sel_pix.clear()  # Clear all points
            return orig_img  # Reset image to original

        clear_button.click(clear_all_points, [original_image, selected_points], [input_image])

        # Dummy model function to simulate running a model
        def run_model_on_points(points, input_image, original_image):
            H = input_image.shape[0]
            W = input_image.shape[1]
            factor = 256/H
            # Example: pretend the model processes points and returns a simple transformation on the image
            points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor

            points = points[:, [1, 0, 3, 2]]

            img = Image.fromarray(original_image)

            transform = transforms.Compose([
                transforms.Resize((256, 256)),
                # transforms.CenterCrop(256)
            ])
            img = np.array(transform(img))

            # np.save("img.npy", original_image)

            img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0

            img = img[None]

            # reshape image to [B, C, T, H, W], C = 3, T = 3 (3-frame model), H = W = 224
            x = img[:, :, None].expand(-1, -1, 2, -1, -1)#.to(torch.float16)

            # Imagenet-normalize the inputs (standardization)


            counterfactual = get_c(x, points)


            counterfactual = counterfactual.squeeze()

            counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy()

            # for i in range(0, len(points), 2):
            #     # Draw rectangles on the points as model output example
            #     cv2.rectangle(processed_image, points[i], points[i + 1], (255, 0, 0), 3)
            return counterfactual

        # Run model when the button is clicked
        run_model_button.click(run_model_on_points, [selected_points, input_image, original_image_high_res], [output_image])



    # Launch the app
demo.queue().launch(inbrowser=True, share=True)