sky24h's picture
gradio demo for ZeroGPU, HF
a9d25c7
raw
history blame
2.73 kB
import os
import torch
from PIL import Image
from .env_utils import get_device, low_vram_mode
device = get_device()
pretrained_models_folder = os.path.join(os.path.dirname(__file__), "../pretrained-models")
# RAM++
from ram.models import ram_plus
from ram import get_transform, inference_ram
ram_ckpt = os.path.join(pretrained_models_folder, "checkpoints/ram_plus_swin_large_14m.pth")
ram_precision = torch.bfloat16
def ram_init():
image_size = 384
transform = get_transform(image_size=image_size)
#######load model#######
model = ram_plus(pretrained=ram_ckpt, image_size=image_size, vit="swin_l")
model = model.to(device=device, dtype=ram_precision)
model.eval()
print("RAM++ model loaded")
return model, transform
# Initialize the model when importing the module
ram_model, ram_transform = ram_init()
def _inference(image_pil):
image = ram_transform(image_pil).unsqueeze(0)
image = image.to(device=device, dtype=ram_precision)
res = inference_ram(image, ram_model)
result = res[0].replace(" | ", ", ")
return result
def _split_large_image(image_pil):
size = image_pil.size
print("Image size is too large, split into smaller patches")
# Split the image into 4 patches
patches = []
patch_size = (size[0] // 2, size[1] // 2)
for i in range(2):
for j in range(2):
left = i * patch_size[0]
top = j * patch_size[1]
right = left + patch_size[0]
bottom = top + patch_size[1]
patch = image_pil.crop((left, top, right, bottom))
patches.append(patch)
return patches
def ram_inference(image_pil: Image.Image):
size = image_pil.size
if size[0] > 640 or size[1] > 640:
# split only once in the online demo version.
patches = _split_large_image(image_pil)
# while any(patch.size[0] > 640 or patch.size[1] > 640 for patch in patches):
# patches = [_split_large_image(patch) for patch in patches]
# patches = [patch for sublist in patches for patch in sublist]
# Inference on each patch
results = []
for patch in patches:
result = _inference(patch)
results.extend(result.split(", "))
results = list(set(results))
# Combine the results
final_result = ", ".join(results)
return final_result
else:
print("Image size is small enough for inference")
return _inference(image_pil)
if __name__ == "__main__":
# Test the RAM++ model
image_path = os.path.join(os.path.dirname(__file__), "../sources/test_imgs/1.jpg")
image = Image.open(image_path)
result = ram_inference(image)
print(result)