import os | |
import torch | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration # , BitsAndBytesConfig | |
from .env_utils import get_device, low_vram_mode | |
device = get_device() | |
blip2_model_id = "Salesforce/blip2-opt-2.7b" # or replace with your local model path | |
blip2_precision = torch.bfloat16 | |
# Load BLIP2 model and processor from HuggingFace | |
blip2_processor = Blip2Processor.from_pretrained(blip2_model_id) | |
if low_vram_mode: | |
blip2_model = Blip2ForConditionalGeneration.from_pretrained( | |
blip2_model_id, | |
torch_dtype=blip2_precision, | |
device_map=device, | |
# quantization_config = BitsAndBytesConfig(load_in_8bit=True) if low_vram_mode else None, # ZeroGPU does not support quantization. | |
).eval() | |
else: | |
blip2_model = Blip2ForConditionalGeneration.from_pretrained(blip2_model_id, torch_dtype=blip2_precision, device_map=device).eval() | |
def blip2_caption(raw_image): | |
# unconditional image captioning | |
inputs = blip2_processor(raw_image, return_tensors="pt") | |
inputs = inputs.to(device=device, dtype=blip2_precision) | |
out = blip2_model.generate(**inputs) | |
caption = blip2_processor.decode(out[0], skip_special_tokens=True) | |
return caption | |
# if __name__ == "__main__": | |
# from PIL import Image | |
# # Test the RAM++ model | |
# image_path = os.path.join(os.path.dirname(__file__), "../sources/test_imgs/1.jpg") | |
# image = Image.open(image_path) | |
# result = blip2_caption(image) | |
# print(result) | |