SEED-LLaMA / scripts /seed_llama_inference_8B.py
sjzhao's picture
update demo
bd63939
import hydra
import pyrootutils
import os
import torch
from omegaconf import OmegaConf
import json
from typing import Optional
import transformers
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
BOI_TOKEN = '<img>'
EOI_TOKEN = '</img>'
IMG_TOKEN = '<img_{:05d}>'
IMG_FLAG = '<image>'
NUM_IMG_TOKNES = 32
NUM_IMG_CODES = 8192
image_id_shift = 32000
def generate(tokenizer, input_tokens, generation_config, model):
input_ids = tokenizer(input_tokens, add_special_tokens=False, return_tensors='pt').input_ids
input_ids = input_ids.to("cuda")
generate_ids = model.generate(
input_ids=input_ids,
**generation_config
)
generate_ids = generate_ids[0][input_ids.shape[1]:]
return generate_ids
def decode_image_text(generate_ids, tokenizer, save_path=None):
boi_list = torch.where(generate_ids == tokenizer(BOI_TOKEN, add_special_tokens=False).input_ids[0])[0]
eoi_list = torch.where(generate_ids == tokenizer(EOI_TOKEN, add_special_tokens=False).input_ids[0])[0]
if len(boi_list) == 0 and len(eoi_list) == 0:
text_ids = generate_ids
texts = tokenizer.decode(text_ids, skip_special_tokens=True)
print(texts)
else:
boi_index = boi_list[0]
eoi_index = eoi_list[0]
text_ids = generate_ids[:boi_index]
if len(text_ids) != 0:
texts = tokenizer.decode(text_ids, skip_special_tokens=True)
print(texts)
image_ids = (generate_ids[boi_index+1:eoi_index] - image_id_shift).reshape(1,-1)
images = tokenizer.decode_image(image_ids)
images[0].save(save_path)
device = "cuda"
tokenizer_cfg_path = 'configs/tokenizer/seed_llama_tokenizer.yaml'
tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path)
tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=device, load_diffusion=True)
transform_cfg_path = 'configs/transform/clip_transform.yaml'
transform_cfg = OmegaConf.load(transform_cfg_path)
transform = hydra.utils.instantiate(transform_cfg)
model_cfg = OmegaConf.load('configs/llm/seed_llama_8b.yaml')
model = hydra.utils.instantiate(model_cfg, torch_dtype=torch.float16)
model = model.eval().to(device)
generation_config = {
'temperature': 1.0,
'num_beams': 1,
'max_new_tokens': 512,
'top_p': 0.5,
'do_sample': True
}
s_token = "USER:"
e_token = "ASSISTANT:"
sep = "\n"
### visual question answering
image_path = "images/cat.jpg"
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).to(device)
img_ids = tokenizer.encode_image(image_torch=image_tensor)
img_ids = img_ids.view(-1).cpu().numpy()
img_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(item) for item in img_ids]) + EOI_TOKEN
question = "What is this animal?"
input_tokens = tokenizer.bos_token + s_token + " " + img_tokens + question + sep + e_token
generate_ids = generate(tokenizer, input_tokens, generation_config, model)
decode_image_text(generate_ids, tokenizer)
### text-to-image generation
prompt = "Can you generate an image of a dog on the green grass?"
input_tokens = tokenizer.bos_token + s_token + " " + prompt + sep + e_token
generate_ids = generate(tokenizer, input_tokens, generation_config, model)
save_path = 'dog.jpg'
decode_image_text(generate_ids, tokenizer, save_path)
### multimodal prompt image generation
instruction = "Can you make the cat wear sunglasses?"
input_tokens = tokenizer.bos_token + s_token + " " + img_tokens + instruction + sep + e_token
generate_ids = generate(tokenizer, input_tokens, generation_config, model)
save_path = 'cat_sunglasses.jpg'
decode_image_text(generate_ids, tokenizer, save_path)