moooji's picture
Update handler.py
e190089
raw
history blame contribute delete
No virus
900 Bytes
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
from io import BytesIO
from transformers import AutoProcessor, BlipForQuestionAnswering
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
self.model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device)
def __call__(self, data: Any) -> List[float]:
inputs = data.pop("inputs", data)
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
inputs = self.processor(image, inputs['question'], return_tensors="pt").to(device)
outputs = self.model.generate(**inputs)
return self.processor.decode(outputs[0], skip_special_tokens=True)