sky24h's picture
gradio demo for ZeroGPU, HF
a9d25c7
import os
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration # , BitsAndBytesConfig
from .env_utils import get_device, low_vram_mode
device = get_device()
blip2_model_id = "Salesforce/blip2-opt-2.7b" # or replace with your local model path
blip2_precision = torch.bfloat16
# Load BLIP2 model and processor from HuggingFace
blip2_processor = Blip2Processor.from_pretrained(blip2_model_id)
if low_vram_mode:
blip2_model = Blip2ForConditionalGeneration.from_pretrained(
blip2_model_id,
torch_dtype=blip2_precision,
device_map=device,
# quantization_config = BitsAndBytesConfig(load_in_8bit=True) if low_vram_mode else None, # ZeroGPU does not support quantization.
).eval()
else:
blip2_model = Blip2ForConditionalGeneration.from_pretrained(blip2_model_id, torch_dtype=blip2_precision, device_map=device).eval()
def blip2_caption(raw_image):
# unconditional image captioning
inputs = blip2_processor(raw_image, return_tensors="pt")
inputs = inputs.to(device=device, dtype=blip2_precision)
out = blip2_model.generate(**inputs)
caption = blip2_processor.decode(out[0], skip_special_tokens=True)
return caption
# if __name__ == "__main__":
# from PIL import Image
# # Test the RAM++ model
# image_path = os.path.join(os.path.dirname(__file__), "../sources/test_imgs/1.jpg")
# image = Image.open(image_path)
# result = blip2_caption(image)
# print(result)