File size: 2,221 Bytes
fc7652c
9c79daa
fc7652c
 
 
 
9c79daa
fc7652c
9c79daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e2035e
9c79daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc7652c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Tuple, Dict, Any, List
from unittest.mock import patch

import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor

from utils.imports import fixed_get_imports

CHECKPOINTS = [
    "microsoft/Florence-2-large-ft",
    "microsoft/Florence-2-large",
    "microsoft/Florence-2-base-ft",
    "microsoft/Florence-2-base",
]


def load_models(device: torch.device) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
        models = {}
        processors = {}
        for checkpoint in CHECKPOINTS:
            models[checkpoint] = AutoModelForCausalLM.from_pretrained(
                checkpoint, trust_remote_code=True).to(device).eval()
            processors[checkpoint] = AutoProcessor.from_pretrained(
                checkpoint, trust_remote_code=True)
    return models, processors


def run_inference(
    model: Any,
    processor: Any,
    device: torch.device,
    image: Image,
    task: str,
    text: str = ""
) -> Tuple[str, Dict]:
    prompt = task + text
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        num_beams=3
    )
    generated_text = processor.batch_decode(
        generated_ids, skip_special_tokens=False)[0]
    response = processor.post_process_generation(
        generated_text, task=task, image_size=image.size)
    return generated_text, response


def pre_process_region_task_input(
    prompt: List[float],
    resolution_wh: Tuple[int, int]
) -> str:
    x1, y1, _, x2, y2, _ = prompt
    w, h = resolution_wh
    box = np.array([x1, y1, x2, y2])
    box /= np.array([w, h, w, h])
    box *= 1000
    return "".join([f"<loc_{int(coordinate)}>" for coordinate in box])


def post_process_region_output(
    detections: sv.Detections,
    resolution_wh: Tuple[int, int]
) -> sv.Detections:
    w, h = resolution_wh
    detections.xyxy = (detections.xyxy / 1000 * np.array([w, h, w, h])).astype(np.int32)
    return detections