import os import torch from transformers import Blip2Processor, Blip2ForConditionalGeneration # , BitsAndBytesConfig from .env_utils import get_device, low_vram_mode device = get_device() blip2_model_id = "Salesforce/blip2-opt-2.7b" # or replace with your local model path blip2_precision = torch.bfloat16 # Load BLIP2 model and processor from HuggingFace blip2_processor = Blip2Processor.from_pretrained(blip2_model_id) if low_vram_mode: blip2_model = Blip2ForConditionalGeneration.from_pretrained( blip2_model_id, torch_dtype=blip2_precision, device_map=device, # quantization_config = BitsAndBytesConfig(load_in_8bit=True) if low_vram_mode else None, # ZeroGPU does not support quantization. ).eval() else: blip2_model = Blip2ForConditionalGeneration.from_pretrained(blip2_model_id, torch_dtype=blip2_precision, device_map=device).eval() def blip2_caption(raw_image): # unconditional image captioning inputs = blip2_processor(raw_image, return_tensors="pt") inputs = inputs.to(device=device, dtype=blip2_precision) out = blip2_model.generate(**inputs) caption = blip2_processor.decode(out[0], skip_special_tokens=True) return caption # if __name__ == "__main__": # from PIL import Image # # Test the RAM++ model # image_path = os.path.join(os.path.dirname(__file__), "../sources/test_imgs/1.jpg") # image = Image.open(image_path) # result = blip2_caption(image) # print(result)