Spaces:
Running
on
Zero
Running
on
Zero
from PIL import Image, ImageDraw, ImageFont | |
import os | |
import torch | |
import glob | |
import matplotlib.pyplot as plt | |
def read_images_in_path(path, size = (512,512)): | |
image_paths = [] | |
for filename in os.listdir(path): | |
if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg"): | |
image_path = os.path.join(path, filename) | |
image_paths.append(image_path) | |
image_paths = sorted(image_paths) | |
return [Image.open(image_path).convert("RGB").resize(size) for image_path in image_paths] | |
def concatenate_images(image_lists, return_list = False): | |
num_rows = len(image_lists[0]) | |
num_columns = len(image_lists) | |
image_width = image_lists[0][0].width | |
image_height = image_lists[0][0].height | |
grid_width = num_columns * image_width | |
grid_height = num_rows * image_height if not return_list else image_height | |
if not return_list: | |
grid_image = [Image.new('RGB', (grid_width, grid_height))] | |
else: | |
grid_image = [Image.new('RGB', (grid_width, grid_height)) for i in range(num_rows)] | |
for i in range(num_rows): | |
row_index = i if return_list else 0 | |
for j in range(num_columns): | |
image = image_lists[j][i] | |
x_offset = j * image_width | |
y_offset = i * image_height if not return_list else 0 | |
grid_image[row_index].paste(image, (x_offset, y_offset)) | |
return grid_image if return_list else grid_image[0] | |
def concatenate_images_single(image_lists): | |
num_columns = len(image_lists) | |
image_width = image_lists[0].width | |
image_height = image_lists[0].height | |
grid_width = num_columns * image_width | |
grid_height = image_height | |
grid_image = Image.new('RGB', (grid_width, grid_height)) | |
for j in range(num_columns): | |
image = image_lists[j] | |
x_offset = j * image_width | |
y_offset = 0 | |
grid_image.paste(image, (x_offset, y_offset)) | |
return grid_image | |
def get_captions_for_images(images, device): | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
model = Blip2ForConditionalGeneration.from_pretrained( | |
"Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16 | |
) # doctest: +IGNORE_RESULT | |
res = [] | |
for image in images: | |
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) | |
generated_ids = model.generate(**inputs) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
res.append(generated_text) | |
del processor | |
del model | |
return res | |
def find_and_plot_images(directory, output_file, recursive=True, figsize=(15, 15), image_formats=("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tiff")): | |
""" | |
Finds all images in the specified directory (optionally recursively) | |
and saves them in a single figure with their filenames. | |
Parameters: | |
directory (str): Path to the directory. | |
output_file (str): Path to save the resulting figure (e.g., 'output.png'). | |
recursive (bool): Whether to search directories recursively. | |
figsize (tuple): Size of the resulting figure. | |
image_formats (tuple): Image file formats to look for. | |
Returns: | |
None | |
""" | |
# Gather all image file paths | |
pattern = "**/" if recursive else "" | |
images = [] | |
for fmt in image_formats: | |
images.extend(glob.glob(os.path.join(directory, pattern + fmt), recursive=recursive)) | |
images = [image for image in images if "noise.jpg" not in image and "results.jpg" not in image] # Filter out noise and result images | |
# move "original" to the front, followed by "reconstruction" and then the rest | |
images = sorted( | |
images, | |
key=lambda x: (not x.endswith("original.jpg"), not x.endswith("reconstruction.jpg"), x) | |
) | |
if not images: | |
print("No images found!") | |
return | |
# Create a figure | |
num_images = len(images) | |
cols = num_images # Max 5 images per row | |
rows = (num_images + cols - 1) // cols # Calculate number of rows | |
fig, axs = plt.subplots(rows, cols, figsize=figsize) | |
axs = axs.flatten() if num_images > 1 else [axs] # Flatten axes for single image case | |
for i, image_path in enumerate(images): | |
# Open and plot image | |
img = Image.open(image_path) | |
axs[i].imshow(img) | |
axs[i].axis('off') # Remove axes | |
axs[i].set_title(os.path.basename(image_path), fontsize=8) # Add filename | |
# Hide any remaining empty axes | |
for j in range(i + 1, len(axs)): | |
axs[j].axis('off') | |
plt.tight_layout() | |
plt.savefig(output_file, bbox_inches='tight', dpi=300) # Save the figure to the file | |
plt.close(fig) # Close the figure to free up memory | |
print(f"Figure saved to {output_file}") | |
def add_label_to_image(image, label): | |
""" | |
Adds a label to the lower-right corner of an image. | |
Args: | |
image (PIL.Image): Image to add the label to. | |
label (str): Text to add as a label. | |
Returns: | |
PIL.Image: Image with the added label. | |
""" | |
# Create a drawing context | |
draw = ImageDraw.Draw(image) | |
# Create a drawing context | |
draw = ImageDraw.Draw(image) | |
# Define font and size | |
font_size = int(min(image.size) * 0.05) # Adjust font size based on image dimensions | |
try: | |
font = ImageFont.truetype("fonts/arial.ttf", font_size) # Replace with a font path if needed | |
except IOError: | |
font = ImageFont.load_default() # Fallback to default font if arial.ttf is not found | |
# Measure text size using textbbox | |
text_bbox = draw.textbbox((0, 0), label, font=font) # (left, top, right, bottom) | |
text_width = text_bbox[2] - text_bbox[0] | |
text_height = text_bbox[3] - text_bbox[1] | |
# Position the text in the lower-right corner with some padding | |
padding = 10 | |
position = (image.width - text_width - padding, image.height - text_height - padding) | |
# Add a semi-transparent background for the label | |
draw.rectangle( | |
[ | |
(position[0] - padding, position[1] - padding), | |
(position[0] + text_width + padding, position[1] + text_height + padding) | |
], | |
fill=(0, 0, 0, 150) # Black with transparency | |
) | |
# Draw the label | |
draw.text(position, label, fill="white", font=font) | |
return image | |
def crop_center_square_and_resize(img, size, output_path=None): | |
""" | |
Crops the center of an image to make it square. | |
Args: | |
img (PIL.Image): Image to crop. | |
output_path (str, optional): Path to save the cropped image. If None, the cropped image is not saved. | |
Returns: | |
Image: The cropped square image. | |
""" | |
width, height = img.size | |
# Determine the shorter side | |
side_length = min(width, height) | |
# Calculate the cropping box | |
left = (width - side_length) // 2 | |
top = (height - side_length) // 2 | |
right = left + side_length | |
bottom = top + side_length | |
# Crop the image | |
cropped_img = img.crop((left, top, right, bottom)) | |
# Resize the image | |
cropped_img = cropped_img.resize(size) | |
# Save the cropped image if output path is specified | |
if output_path: | |
cropped_img.save(output_path) | |
return cropped_img | |