import spaces import cv2 import numpy as np import gradio as gr import cwm.utils as utils import os os.system("pip uninstall -y gradio") os.system("pip install gradio==4.31.0") # Points color and arrow properties arrow_color = (0, 255, 0) # Green color for all arrows dot_color = (0, 255, 0) # Green color for the dots at start and end dot_color_fixed = (255, 0, 0) # Red color for zero-length vectors thickness = 3 # Thickness of the arrow tip_length = 0.3 # The length of the arrow tip relative to the arrow length dot_radius = 7 # Radius for the dots dot_thickness = -1 # Thickness for solid circle (-1 fills the circle) from PIL import Image import torch import json #load model from cwm.model.model_factory import model_factory from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load CWM 3-frame model (automatically download pre-trained checkpoint) model = model_factory.load_model('vitb_8x8patch_2frames_encoder_mask_token')#.to(device) model.requires_grad_(False) model.eval() model = model#.to(torch.float16) import matplotlib.pyplot as plt from matplotlib.patches import FancyArrowPatch from PIL import Image import numpy as np from torchvision import transforms def draw_arrows_matplotlib(img, selected_points, zero_length): """ Draw arrows on the image using matplotlib for better quality arrows and dots. """ fig, ax = plt.subplots() ax.imshow(img) for i in range(0, len(selected_points), 2): start_point = selected_points[i] end_point = selected_points[i + 1] if start_point == end_point or zero_length: # Draw a dot for zero-length vectors or if only one point is clicked ax.scatter(start_point[0], start_point[1], color='red', s=100) # Red dot for zero-length vector else: # Draw arrows arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]), color='green', linewidth=2, arrowstyle='->', mutation_scale=15) ax.add_patch(arrow) # Optionally, draw a small circle (dot) at the start and end points ax.scatter(start_point[0], start_point[1], color='green', s=100) # Green dot at start ax.scatter(end_point[0], end_point[1], color='green', s=100) # Green dot at end # Save the image to a numpy array fig.canvas.draw() img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close(fig) return img_array import os # def load_preuploaded_images(): # image_folder = "assets" # images = [] # for img_file in os.listdir(image_folder): # img_path = os.path.join(image_folder, img_file) # if img_file.endswith(('png', 'jpg', 'jpeg')): # images.append(Image.open(img_path)) # return images # # # Function to transfer image from gallery to the input image section # # preloaded_images = load_preuploaded_images() # # print("Preloaded images:", preloaded_images) # @spaces.GPU(duration=110) def get_c(x, points): x = utils.imagenet_normalize(x)#.to(device) with torch.no_grad(): counterfactual = model.get_counterfactual(x, points) return counterfactual with gr.Blocks() as demo: with gr.Row(): gr.Markdown('''# Scene editing interventions with Counterfactual World Models! ''') # Annotating arrows on an image with gr.Tab(label='Image'): with gr.Row(): with gr.Column(): # Input image original_image = gr.State(value=None) # store original image without arrows original_image_high_res = gr.State(value=None) # store original image without arrows input_image = gr.Image(type="numpy", label="Upload Image") # Annotate arrows selected_points = gr.State([]) # store points zero_length_toggle = gr.Checkbox(label="Select patches to be kept fixed", value=False) # Toggle for zero-length vectors with gr.Row(): gr.Markdown('1. **Click on the image** to specify patch motion by selecting a start and end point. \n 2. After selecting the points to move, enable the **"Select patches to be kept fixed"** checkbox to choose a few points to keep fixed. \n 3. **Click "Run Model"** to visualize the result of the edit.') undo_button = gr.Button('Undo last action') clear_button = gr.Button('Clear All') # Run model button run_model_button = gr.Button('Run Model') # Show the image with the annotated arrows with gr.Tab(label='Intervention'): output_image = gr.Image(type='numpy') # Store the original image and resize to square size once uploaded def resize_to_square(img, size=512): print("Resizing image to square") img = Image.fromarray(img) transform = transforms.Compose([ transforms.Resize((size, size)), # transforms.CenterCrop(size) ]) img = transform(img) # .transpose(1, 2, 0) return np.array(img) def load_img(evt: gr.SelectData): img_path = evt.value['image']['path'] img = np.array(Image.open(img_path)) # print(f"Image uploaded with shape: {input.shape}") with open('./assets/intervention_test_images/annot.json', 'r') as f: points_json = json.load(f) resized_img = resize_to_square(img) if os.path.basename(img_path) not in points_json: return resized_img, resized_img, img, [] points_json = points_json[os.path.basename(img_path)] # print(f"Image uploaded with shape: {input.shape}") temp = resized_img.copy() # Redraw all remaining arrows and dots for i in range(0, len(points_json), 2): start_point = points_json[i] end_point = points_json[i + 1] if start_point == end_point: # Zero-length vector: Draw a dot color = dot_color_fixed else: cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length, line_type=cv2.LINE_AA) color = arrow_color # Draw arrow # Draw dots at start and end points cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA) cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA) # If there is an odd number of points (e.g., only a start point), draw a dot for it if len(points_json) == 1: start_point = points_json[0] cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) return temp, resized_img, img, points_json def store_img(img): resized_img = resize_to_square(img) # Resize the uploaded image to a square print(f"Image uploaded with shape: {resized_img.shape}") return resized_img, resized_img, img, [] with gr.Row(): with gr.Column(): gallery = gr.Gallery( ["./assets/ducks.jpg", "./assets/robot_arm.jpg", "./assets/bread.jpg", "./assets/bird.jpg", "./assets/desk_1.jpg", "./assets/glasses.jpg", "./assets/watering_pot.jpg"], columns=5, allow_preview=False, label="Select an example image to test") # examples = gr.Examples( # examples=[ # ["./assets/desk_1.jpg", "./assets/desk_1.jpg"], # ], # inputs=[input_image, original_image], # # fn=load_img, # # outputs=[input_image, original_image], # # cache_examples=True, # # run_on_click=True, # # label="Select an example image to test" # ) gallery.select(load_img, outputs=[input_image, original_image, original_image_high_res, selected_points]) input_image.upload(store_img, [input_image], [input_image, original_image, original_image_high_res, selected_points]) # Get points and draw arrows or zero-length vectors based on the toggle def get_point(img, sel_pix, zero_length, evt: gr.SelectData): sel_pix.append(evt.index) # Append the point's location (coordinates) # Zero-length vector case: Draw a single dot at the clicked point if zero_length: point = sel_pix[-1] # Last point clicked cv2.circle(img, point, dot_radius, dot_color_fixed, dot_thickness, lineType=cv2.LINE_AA) # Draw a dot at the point sel_pix.append(evt.index) else: # Regular case: two clicks for an arrow # Check if this is the first point (start point for the arrow) if len(sel_pix) % 2 == 1: # Draw a dot at the start point to give feedback start_point = sel_pix[-1] # Last point is the start cv2.circle(img, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) # Check if two points have been selected (start and end points for an arrow) if len(sel_pix) % 2 == 0: # Draw an arrow between the last two points start_point = sel_pix[-2] # Second last point is the start end_point = sel_pix[-1] # Last point is the end # Draw arrow cv2.arrowedLine(img, start_point, end_point, arrow_color, thickness, tipLength=tip_length, line_type=cv2.LINE_AA) # Draw a dot at the end point cv2.circle(img, end_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) return img if isinstance(img, np.ndarray) else np.array(img) input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image]) # Undo the last selected action def undo_arrows(orig_img, sel_pix, zero_length): temp = orig_img.copy() # if zero_length: # # Undo the last zero-length vector (just the last dot) # if len(sel_pix) >= 1: # sel_pix.pop() # Remove the last point # else: if len(sel_pix) >= 2: sel_pix.pop() # Remove the last end point sel_pix.pop() # Remove the last start point # Redraw all remaining arrows and dots for i in range(0, len(sel_pix), 2): start_point = sel_pix[i] end_point = sel_pix[i + 1] if start_point == end_point: # Zero-length vector: Draw a dot color = dot_color_fixed else: cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length) color = arrow_color # Draw arrow # Draw dots at start and end points cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA) cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA) # If there is an odd number of points (e.g., only a start point), draw a dot for it if len(sel_pix) == 1: start_point = sel_pix[0] cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA) return temp if isinstance(temp, np.ndarray) else np.array(temp) undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image]) # Clear all points and reset the image def clear_all_points(orig_img, sel_pix): sel_pix.clear() # Clear all points return orig_img # Reset image to original clear_button.click(clear_all_points, [original_image, selected_points], [input_image]) # Dummy model function to simulate running a model def run_model_on_points(points, input_image, original_image): H = input_image.shape[0] W = input_image.shape[1] factor = 256/H # Example: pretend the model processes points and returns a simple transformation on the image points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor points = points[:, [1, 0, 3, 2]] img = Image.fromarray(original_image) transform = transforms.Compose([ transforms.Resize((256, 256)), # transforms.CenterCrop(256) ]) img = np.array(transform(img)) # np.save("img.npy", original_image) img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0 img = img[None] # reshape image to [B, C, T, H, W], C = 3, T = 3 (3-frame model), H = W = 224 x = img[:, :, None].expand(-1, -1, 2, -1, -1)#.to(torch.float16) # Imagenet-normalize the inputs (standardization) counterfactual = get_c(x, points) counterfactual = counterfactual.squeeze() counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy() # for i in range(0, len(points), 2): # # Draw rectangles on the points as model output example # cv2.rectangle(processed_image, points[i], points[i + 1], (255, 0, 0), 3) return counterfactual # Run model when the button is clicked run_model_button.click(run_model_on_points, [selected_points, input_image, original_image_high_res], [output_image]) # Launch the app demo.queue().launch(inbrowser=True, share=True)