rahulvenkk
annot
2d342bf
import spaces
import cv2
import numpy as np
import gradio as gr
import cwm.utils as utils
import os
os.system("pip uninstall -y gradio")
os.system("pip install gradio==4.31.0")
# Points color and arrow properties
arrow_color = (0, 255, 0) # Green color for all arrows
dot_color = (0, 255, 0) # Green color for the dots at start and end
dot_color_fixed = (255, 0, 0) # Red color for zero-length vectors
thickness = 3 # Thickness of the arrow
tip_length = 0.3 # The length of the arrow tip relative to the arrow length
dot_radius = 7 # Radius for the dots
dot_thickness = -1 # Thickness for solid circle (-1 fills the circle)
from PIL import Image
import torch
import json
#load model
from cwm.model.model_factory import model_factory
from timm.data.constants import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load CWM 3-frame model (automatically download pre-trained checkpoint)
model = model_factory.load_model('vitb_8x8patch_2frames_encoder_mask_token')#.to(device)
model.requires_grad_(False)
model.eval()
model = model#.to(torch.float16)
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from PIL import Image
import numpy as np
from torchvision import transforms
def draw_arrows_matplotlib(img, selected_points, zero_length):
"""
Draw arrows on the image using matplotlib for better quality arrows and dots.
"""
fig, ax = plt.subplots()
ax.imshow(img)
for i in range(0, len(selected_points), 2):
start_point = selected_points[i]
end_point = selected_points[i + 1]
if start_point == end_point or zero_length:
# Draw a dot for zero-length vectors or if only one point is clicked
ax.scatter(start_point[0], start_point[1], color='red', s=100) # Red dot for zero-length vector
else:
# Draw arrows
arrow = FancyArrowPatch((start_point[0], start_point[1]), (end_point[0], end_point[1]),
color='green', linewidth=2, arrowstyle='->', mutation_scale=15)
ax.add_patch(arrow)
# Optionally, draw a small circle (dot) at the start and end points
ax.scatter(start_point[0], start_point[1], color='green', s=100) # Green dot at start
ax.scatter(end_point[0], end_point[1], color='green', s=100) # Green dot at end
# Save the image to a numpy array
fig.canvas.draw()
img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close(fig)
return img_array
import os
# def load_preuploaded_images():
# image_folder = "assets"
# images = []
# for img_file in os.listdir(image_folder):
# img_path = os.path.join(image_folder, img_file)
# if img_file.endswith(('png', 'jpg', 'jpeg')):
# images.append(Image.open(img_path))
# return images
#
# # Function to transfer image from gallery to the input image section
#
# preloaded_images = load_preuploaded_images()
#
# print("Preloaded images:", preloaded_images)
# @spaces.GPU(duration=110)
def get_c(x, points):
x = utils.imagenet_normalize(x)#.to(device)
with torch.no_grad():
counterfactual = model.get_counterfactual(x, points)
return counterfactual
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown('''# Scene editing interventions with Counterfactual World Models!
''')
# Annotating arrows on an image
with gr.Tab(label='Image'):
with gr.Row():
with gr.Column():
# Input image
original_image = gr.State(value=None) # store original image without arrows
original_image_high_res = gr.State(value=None) # store original image without arrows
input_image = gr.Image(type="numpy", label="Upload Image")
# Annotate arrows
selected_points = gr.State([]) # store points
zero_length_toggle = gr.Checkbox(label="Select patches to be kept fixed", value=False) # Toggle for zero-length vectors
with gr.Row():
gr.Markdown('1. **Click on the image** to specify patch motion by selecting a start and end point. \n 2. After selecting the points to move, enable the **"Select patches to be kept fixed"** checkbox to choose a few points to keep fixed. \n 3. **Click "Run Model"** to visualize the result of the edit.')
undo_button = gr.Button('Undo last action')
clear_button = gr.Button('Clear All')
# Run model button
run_model_button = gr.Button('Run Model')
# Show the image with the annotated arrows
with gr.Tab(label='Intervention'):
output_image = gr.Image(type='numpy')
# Store the original image and resize to square size once uploaded
def resize_to_square(img, size=512):
print("Resizing image to square")
img = Image.fromarray(img)
transform = transforms.Compose([
transforms.Resize((size, size)),
# transforms.CenterCrop(size)
])
img = transform(img) # .transpose(1, 2, 0)
return np.array(img)
def load_img(evt: gr.SelectData):
img_path = evt.value['image']['path']
img = np.array(Image.open(img_path))
# print(f"Image uploaded with shape: {input.shape}")
with open('./assets/intervention_test_images/annot.json', 'r') as f:
points_json = json.load(f)
resized_img = resize_to_square(img)
if os.path.basename(img_path) not in points_json:
return resized_img, resized_img, img, []
points_json = points_json[os.path.basename(img_path)]
# print(f"Image uploaded with shape: {input.shape}")
temp = resized_img.copy()
# Redraw all remaining arrows and dots
for i in range(0, len(points_json), 2):
start_point = points_json[i]
end_point = points_json[i + 1]
if start_point == end_point:
# Zero-length vector: Draw a dot
color = dot_color_fixed
else:
cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length,
line_type=cv2.LINE_AA)
color = arrow_color
# Draw arrow
# Draw dots at start and end points
cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
# If there is an odd number of points (e.g., only a start point), draw a dot for it
if len(points_json) == 1:
start_point = points_json[0]
cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)
return temp, resized_img, img, points_json
def store_img(img):
resized_img = resize_to_square(img) # Resize the uploaded image to a square
print(f"Image uploaded with shape: {resized_img.shape}")
return resized_img, resized_img, img, []
with gr.Row():
with gr.Column():
gallery = gr.Gallery( ["./assets/ducks.jpg", "./assets/robot_arm.jpg", "./assets/bread.jpg", "./assets/bird.jpg", "./assets/desk_1.jpg", "./assets/glasses.jpg", "./assets/watering_pot.jpg"], columns=5, allow_preview=False, label="Select an example image to test")
# examples = gr.Examples(
# examples=[
# ["./assets/desk_1.jpg", "./assets/desk_1.jpg"],
# ],
# inputs=[input_image, original_image],
# # fn=load_img,
# # outputs=[input_image, original_image],
# # cache_examples=True,
# # run_on_click=True,
# # label="Select an example image to test"
# )
gallery.select(load_img, outputs=[input_image, original_image, original_image_high_res, selected_points])
input_image.upload(store_img, [input_image], [input_image, original_image, original_image_high_res, selected_points])
# Get points and draw arrows or zero-length vectors based on the toggle
def get_point(img, sel_pix, zero_length, evt: gr.SelectData):
sel_pix.append(evt.index) # Append the point's location (coordinates)
# Zero-length vector case: Draw a single dot at the clicked point
if zero_length:
point = sel_pix[-1] # Last point clicked
cv2.circle(img, point, dot_radius, dot_color_fixed, dot_thickness, lineType=cv2.LINE_AA) # Draw a dot at the point
sel_pix.append(evt.index)
else:
# Regular case: two clicks for an arrow
# Check if this is the first point (start point for the arrow)
if len(sel_pix) % 2 == 1:
# Draw a dot at the start point to give feedback
start_point = sel_pix[-1] # Last point is the start
cv2.circle(img, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)
# Check if two points have been selected (start and end points for an arrow)
if len(sel_pix) % 2 == 0:
# Draw an arrow between the last two points
start_point = sel_pix[-2] # Second last point is the start
end_point = sel_pix[-1] # Last point is the end
# Draw arrow
cv2.arrowedLine(img, start_point, end_point, arrow_color, thickness, tipLength=tip_length, line_type=cv2.LINE_AA)
# Draw a dot at the end point
cv2.circle(img, end_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)
return img if isinstance(img, np.ndarray) else np.array(img)
input_image.select(get_point, [input_image, selected_points, zero_length_toggle], [input_image])
# Undo the last selected action
def undo_arrows(orig_img, sel_pix, zero_length):
temp = orig_img.copy()
# if zero_length:
# # Undo the last zero-length vector (just the last dot)
# if len(sel_pix) >= 1:
# sel_pix.pop() # Remove the last point
# else:
if len(sel_pix) >= 2:
sel_pix.pop() # Remove the last end point
sel_pix.pop() # Remove the last start point
# Redraw all remaining arrows and dots
for i in range(0, len(sel_pix), 2):
start_point = sel_pix[i]
end_point = sel_pix[i + 1]
if start_point == end_point:
# Zero-length vector: Draw a dot
color = dot_color_fixed
else:
cv2.arrowedLine(temp, start_point, end_point, arrow_color, thickness, tipLength=tip_length)
color = arrow_color
# Draw arrow
# Draw dots at start and end points
cv2.circle(temp, start_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
cv2.circle(temp, end_point, dot_radius, color, dot_thickness, lineType=cv2.LINE_AA)
# If there is an odd number of points (e.g., only a start point), draw a dot for it
if len(sel_pix) == 1:
start_point = sel_pix[0]
cv2.circle(temp, start_point, dot_radius, dot_color, dot_thickness, lineType=cv2.LINE_AA)
return temp if isinstance(temp, np.ndarray) else np.array(temp)
undo_button.click(undo_arrows, [original_image, selected_points, zero_length_toggle], [input_image])
# Clear all points and reset the image
def clear_all_points(orig_img, sel_pix):
sel_pix.clear() # Clear all points
return orig_img # Reset image to original
clear_button.click(clear_all_points, [original_image, selected_points], [input_image])
# Dummy model function to simulate running a model
def run_model_on_points(points, input_image, original_image):
H = input_image.shape[0]
W = input_image.shape[1]
factor = 256/H
# Example: pretend the model processes points and returns a simple transformation on the image
points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor
points = points[:, [1, 0, 3, 2]]
img = Image.fromarray(original_image)
transform = transforms.Compose([
transforms.Resize((256, 256)),
# transforms.CenterCrop(256)
])
img = np.array(transform(img))
# np.save("img.npy", original_image)
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
img = img[None]
# reshape image to [B, C, T, H, W], C = 3, T = 3 (3-frame model), H = W = 224
x = img[:, :, None].expand(-1, -1, 2, -1, -1)#.to(torch.float16)
# Imagenet-normalize the inputs (standardization)
counterfactual = get_c(x, points)
counterfactual = counterfactual.squeeze()
counterfactual = counterfactual.clamp(0, 1).permute(1,2,0).detach().cpu().numpy()
# for i in range(0, len(points), 2):
# # Draw rectangles on the points as model output example
# cv2.rectangle(processed_image, points[i], points[i + 1], (255, 0, 0), 3)
return counterfactual
# Run model when the button is clicked
run_model_button.click(run_model_on_points, [selected_points, input_image, original_image_high_res], [output_image])
# Launch the app
demo.queue().launch(inbrowser=True, share=True)