tight-inversion
Initial commit
4ebc565
raw
history blame
7.39 kB
from PIL import Image, ImageDraw, ImageFont
import os
import torch
import glob
import matplotlib.pyplot as plt
def read_images_in_path(path, size = (512,512)):
image_paths = []
for filename in os.listdir(path):
if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg"):
image_path = os.path.join(path, filename)
image_paths.append(image_path)
image_paths = sorted(image_paths)
return [Image.open(image_path).convert("RGB").resize(size) for image_path in image_paths]
def concatenate_images(image_lists, return_list = False):
num_rows = len(image_lists[0])
num_columns = len(image_lists)
image_width = image_lists[0][0].width
image_height = image_lists[0][0].height
grid_width = num_columns * image_width
grid_height = num_rows * image_height if not return_list else image_height
if not return_list:
grid_image = [Image.new('RGB', (grid_width, grid_height))]
else:
grid_image = [Image.new('RGB', (grid_width, grid_height)) for i in range(num_rows)]
for i in range(num_rows):
row_index = i if return_list else 0
for j in range(num_columns):
image = image_lists[j][i]
x_offset = j * image_width
y_offset = i * image_height if not return_list else 0
grid_image[row_index].paste(image, (x_offset, y_offset))
return grid_image if return_list else grid_image[0]
def concatenate_images_single(image_lists):
num_columns = len(image_lists)
image_width = image_lists[0].width
image_height = image_lists[0].height
grid_width = num_columns * image_width
grid_height = image_height
grid_image = Image.new('RGB', (grid_width, grid_height))
for j in range(num_columns):
image = image_lists[j]
x_offset = j * image_width
y_offset = 0
grid_image.paste(image, (x_offset, y_offset))
return grid_image
def get_captions_for_images(images, device):
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
) # doctest: +IGNORE_RESULT
res = []
for image in images:
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
res.append(generated_text)
del processor
del model
return res
def find_and_plot_images(directory, output_file, recursive=True, figsize=(15, 15), image_formats=("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tiff")):
"""
Finds all images in the specified directory (optionally recursively)
and saves them in a single figure with their filenames.
Parameters:
directory (str): Path to the directory.
output_file (str): Path to save the resulting figure (e.g., 'output.png').
recursive (bool): Whether to search directories recursively.
figsize (tuple): Size of the resulting figure.
image_formats (tuple): Image file formats to look for.
Returns:
None
"""
# Gather all image file paths
pattern = "**/" if recursive else ""
images = []
for fmt in image_formats:
images.extend(glob.glob(os.path.join(directory, pattern + fmt), recursive=recursive))
images = [image for image in images if "noise.jpg" not in image and "results.jpg" not in image] # Filter out noise and result images
# move "original" to the front, followed by "reconstruction" and then the rest
images = sorted(
images,
key=lambda x: (not x.endswith("original.jpg"), not x.endswith("reconstruction.jpg"), x)
)
if not images:
print("No images found!")
return
# Create a figure
num_images = len(images)
cols = num_images # Max 5 images per row
rows = (num_images + cols - 1) // cols # Calculate number of rows
fig, axs = plt.subplots(rows, cols, figsize=figsize)
axs = axs.flatten() if num_images > 1 else [axs] # Flatten axes for single image case
for i, image_path in enumerate(images):
# Open and plot image
img = Image.open(image_path)
axs[i].imshow(img)
axs[i].axis('off') # Remove axes
axs[i].set_title(os.path.basename(image_path), fontsize=8) # Add filename
# Hide any remaining empty axes
for j in range(i + 1, len(axs)):
axs[j].axis('off')
plt.tight_layout()
plt.savefig(output_file, bbox_inches='tight', dpi=300) # Save the figure to the file
plt.close(fig) # Close the figure to free up memory
print(f"Figure saved to {output_file}")
def add_label_to_image(image, label):
"""
Adds a label to the lower-right corner of an image.
Args:
image (PIL.Image): Image to add the label to.
label (str): Text to add as a label.
Returns:
PIL.Image: Image with the added label.
"""
# Create a drawing context
draw = ImageDraw.Draw(image)
# Create a drawing context
draw = ImageDraw.Draw(image)
# Define font and size
font_size = int(min(image.size) * 0.05) # Adjust font size based on image dimensions
try:
font = ImageFont.truetype("fonts/arial.ttf", font_size) # Replace with a font path if needed
except IOError:
font = ImageFont.load_default() # Fallback to default font if arial.ttf is not found
# Measure text size using textbbox
text_bbox = draw.textbbox((0, 0), label, font=font) # (left, top, right, bottom)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
# Position the text in the lower-right corner with some padding
padding = 10
position = (image.width - text_width - padding, image.height - text_height - padding)
# Add a semi-transparent background for the label
draw.rectangle(
[
(position[0] - padding, position[1] - padding),
(position[0] + text_width + padding, position[1] + text_height + padding)
],
fill=(0, 0, 0, 150) # Black with transparency
)
# Draw the label
draw.text(position, label, fill="white", font=font)
return image
def crop_center_square_and_resize(img, size, output_path=None):
"""
Crops the center of an image to make it square.
Args:
img (PIL.Image): Image to crop.
output_path (str, optional): Path to save the cropped image. If None, the cropped image is not saved.
Returns:
Image: The cropped square image.
"""
width, height = img.size
# Determine the shorter side
side_length = min(width, height)
# Calculate the cropping box
left = (width - side_length) // 2
top = (height - side_length) // 2
right = left + side_length
bottom = top + side_length
# Crop the image
cropped_img = img.crop((left, top, right, bottom))
# Resize the image
cropped_img = cropped_img.resize(size)
# Save the cropped image if output path is specified
if output_path:
cropped_img.save(output_path)
return cropped_img