File size: 1,883 Bytes
f1c8651
 
 
 
 
 
e8ce44e
f1c8651
 
 
e8ce44e
f1c8651
 
 
 
 
 
 
 
 
 
 
 
f335e44
 
f1c8651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from fastapi import FastAPI, UploadFile, File
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import torch
import io
import os
from typing import Union

# Patch to remove flash-attn dependency
from transformers.dynamic_module_utils import get_imports
def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
    """Work around for flash-attn imports."""
    if not str(filename).endswith("/modeling_florence2.py"):
        return get_imports(filename)
    imports = get_imports(filename)
    if "flash_attn" in imports:
        imports.remove("flash_attn")
    return imports

device = "cuda" if torch.cuda.is_available() else "cpu"
# Apply the patch
from unittest.mock import patch
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
    model = AutoModelForCausalLM.from_pretrained("numberPlate_model_2", trust_remote_code=True).to(device)
    processor = AutoProcessor.from_pretrained("numberPlate_model_2", trust_remote_code=True)

# Initialize FastAPI
app = FastAPI()

def process_image(image, task_token):
    inputs = processor(text=task_token, images=image, return_tensors="pt", padding=True).to(device)
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        num_beams=3,
        do_sample=False
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_result = processor.post_process_generation(generated_text, task=task_token, image_size=(image.width, image.height))
    return parsed_result

@app.post("/process-image/")
async def process_image_endpoint(file: UploadFile = File(...), task_token: str = "<OD>"):
    image = Image.open(io.BytesIO(await file.read())).convert("RGB")
    result = process_image(image, task_token)
    return result