from typing import Dict, Any from PIL import Image import requests import flash-attn import torch import numpy as np from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig class EndpointHandler(): def __init__(self, path=""): model_id = path self.model = LlavaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, use_flash_attention_2=True ) self.processor = AutoProcessor.from_pretrained(model_id) def __call__(self, data: Dict[list, Any]): parameters = data.pop("inputs", data) givenprompt = data.pop("prompt", data) outputs = [] print(parameters) prompt = f"USER: \n{givenprompt}?\nASSISTANT:" for link in parameters: try: # Fetch image from URL response = requests.get(link, stream=True) response.raise_for_status() # Raise an exception for 4xx or 5xx status codes raw_image = Image.open(response.raw) # Process image and generate output inputs = self.processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16) output = self.model.generate(**inputs, max_new_tokens=200, do_sample=False) readable = self.processor.decode(output[0][2:], skip_special_tokens=True) outputs.append(readable) except Exception as e: # Handle any exceptions and log the error outputs.append(f"Error processing image from {link}: {str(e)}") return outputs