blip2-opt-6.7b / handler.py
advaitadasein's picture
Upload handler.py
e93799d verified
raw
history blame contribute delete
No virus
1.21 kB
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from typing import Dict, List, Any
import torch
import base64
from io import BytesIO
class EndpointHandler():
def __init__(self, path=""):
self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-6.7b")
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-6.7b")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
image_encoded = data.pop("inputs", data)
text = data["text"]
# image = Image.open(image_path)
image = self.decode_base64_image(image_encoded)
processed = self.processor(images=image, text=text, return_tensors="pt").to(self.device)
out = self.model.generate(**processed)
return self.processor.decode(out[0], skip_special_tokens=True)
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image