handertrails / handler.py
hughtayloe's picture
Update handler.py
c8ebd57 verified
raw
history blame
1.13 kB
from typing import Dict, Any
from PIL import Image
import requests
import torch
import numpy as np
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
class EndpointHandler():
def __init__(self, path=""):
model_id = path
self.model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
load_in_4bit=True
)
self.processor = AutoProcessor.from_pretrained(model_id)
def __call__(self, data: Dict[str, Any]):
parameters = data.pop("inputs", data)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
raw_image = Image.open(requests.get(url, stream=True).raw)
inputs = self.processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
output = self.model.generate(**inputs, max_new_tokens=200, do_sample=False)
readable = (self.processor.decode(output[0][2:], skip_special_tokens=True))
return readable