SkalskiP's picture
more tasks
fc7652c
from typing import Tuple, Dict, Any, List
from unittest.mock import patch
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from utils.imports import fixed_get_imports
CHECKPOINTS = [
"microsoft/Florence-2-large-ft",
"microsoft/Florence-2-large",
"microsoft/Florence-2-base-ft",
"microsoft/Florence-2-base",
]
def load_models(device: torch.device) -> Tuple[Dict[str, Any], Dict[str, Any]]:
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
models = {}
processors = {}
for checkpoint in CHECKPOINTS:
models[checkpoint] = AutoModelForCausalLM.from_pretrained(
checkpoint, trust_remote_code=True).to(device).eval()
processors[checkpoint] = AutoProcessor.from_pretrained(
checkpoint, trust_remote_code=True)
return models, processors
def run_inference(
model: Any,
processor: Any,
device: torch.device,
image: Image,
task: str,
text: str = ""
) -> Tuple[str, Dict]:
prompt = task + text
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3
)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=False)[0]
response = processor.post_process_generation(
generated_text, task=task, image_size=image.size)
return generated_text, response
def pre_process_region_task_input(
prompt: List[float],
resolution_wh: Tuple[int, int]
) -> str:
x1, y1, _, x2, y2, _ = prompt
w, h = resolution_wh
box = np.array([x1, y1, x2, y2])
box /= np.array([w, h, w, h])
box *= 1000
return "".join([f"<loc_{int(coordinate)}>" for coordinate in box])
def post_process_region_output(
detections: sv.Detections,
resolution_wh: Tuple[int, int]
) -> sv.Detections:
w, h = resolution_wh
detections.xyxy = (detections.xyxy / 1000 * np.array([w, h, w, h])).astype(np.int32)
return detections