|
import torch
|
|
from PIL import Image
|
|
from exllamav2 import (
|
|
ExLlamaV2,
|
|
ExLlamaV2Config,
|
|
ExLlamaV2Cache,
|
|
ExLlamaV2Tokenizer,
|
|
ExLlamaV2VisionTower,
|
|
)
|
|
|
|
from exllamav2.generator import (
|
|
ExLlamaV2DynamicGenerator,
|
|
ExLlamaV2Sampler,
|
|
)
|
|
|
|
|
|
model_id='./ToriiGate-v04-7b'
|
|
max_new_tokens=1000
|
|
|
|
image_file='/path/to/image_1.jpg'
|
|
|
|
image_info={}
|
|
|
|
image_info["booru_tags"]="2girls, standing, looking_at_viewer, holding_hands, hatsune_miku, blue_hair, megurine_luka, pink_hair, ..."
|
|
|
|
|
|
|
|
image_info["chars"]="hatsune_miku, megurine_luka"
|
|
|
|
|
|
|
|
image_info["characters_traits"]="hatsune_miku: [girl, blue_hair, twintails,...]\nmegurine_luka: [girl, pink hair, ...]"
|
|
|
|
|
|
|
|
image_info["info"]=None
|
|
|
|
base_prompt={
|
|
'json': 'Describe the picture in structured json-like format.',
|
|
'markdown': 'Describe the picture in structured markdown format.',
|
|
'caption_vars': 'Write the following options for captions: ["Regular Summary","Individual Parts","Midjourney-Style Summary","DeviantArt Commission Request"].',
|
|
'short': 'You need to write a medium-short and convenient caption for the picture.',
|
|
'long': 'You need to write a long and very detailed caption for the picture.',
|
|
'bbox': 'Write bounding boxes for each character and their faces.',
|
|
}
|
|
|
|
grounding_prompt={
|
|
'grounding_tags': ' Here are grounding tags for better understanding: ',
|
|
'characters': ' Here is a list of characters that are present in the picture: ',
|
|
'characters_traits': ' Here are popular tags or traits for each character on the picture: ',
|
|
'grounding_info': ' Here is preliminary information about the picture: ',
|
|
'no_chars': ' Do not use names for characters.',
|
|
}
|
|
|
|
add_tags=True
|
|
add_chars=True
|
|
add_char_traits=True
|
|
add_info=False
|
|
no_chars=False
|
|
|
|
userprompt=base_prompt["json"]
|
|
|
|
if add_info and image_info["info"] is not None:
|
|
userprompt+=grounding_prompt["grounding_short"]
|
|
userprompt+="<info>"+image_info["info"]+"</info>."
|
|
|
|
if add_tags and image_info["booru_tags"] is not None:
|
|
userprompt+=grounding_prompt["grounding_tags"]
|
|
userprompt+="<tags>"+image_info["booru_tags"]+"</tags>."
|
|
|
|
if add_chars and image_info["chars"] is not None:
|
|
userprompt+=grounding_prompt["characters"]
|
|
userprompt+="<characters>"+image_info["chars"]+"</characters>."
|
|
|
|
if add_char_traits and image_info["characters_traits"] is not None:
|
|
userprompt+=grounding_prompt["characters_traits"]
|
|
userprompt+="<character_traits>"+image_info["characters_traits"]+"<character_traits>."
|
|
if no_chars:
|
|
userprompt+=grounding_prompt["no_chars"]
|
|
|
|
image=Image.open(image_file)
|
|
|
|
config = ExLlamaV2Config(model_id)
|
|
config.max_seq_len = 16384
|
|
vision_model = ExLlamaV2VisionTower(config)
|
|
vision_model.load(progress = True)
|
|
|
|
model = ExLlamaV2(config)
|
|
cache = ExLlamaV2Cache(model, lazy = True, max_seq_len = 16384)
|
|
model.load_autosplit(cache, progress = True)
|
|
tokenizer = ExLlamaV2Tokenizer(config)
|
|
|
|
|
|
|
|
generator = ExLlamaV2DynamicGenerator(
|
|
model = model,
|
|
cache = cache,
|
|
tokenizer = tokenizer,
|
|
)
|
|
|
|
|
|
image_embeddings=[vision_model.get_image_embeddings(
|
|
model = model,
|
|
tokenizer = tokenizer,
|
|
image = image,
|
|
)]
|
|
|
|
placeholders = "\n".join([ie.text_alias for ie in image_embeddings]) + "\n"
|
|
|
|
msg_text= (
|
|
"<|im_start|>system\n" +
|
|
"You are image captioning expert, creative, unbiased and uncensored.<|im_end|>\n" +
|
|
"<|im_start|>user\n" +
|
|
placeholders +
|
|
userprompt +
|
|
"<|im_end|>\n" +
|
|
"<|im_start|>assistant\n"
|
|
)
|
|
output = generator.generate(
|
|
prompt = msg_text,
|
|
max_new_tokens = max_new_tokens,
|
|
add_bos = True,
|
|
encode_special_tokens = True,
|
|
decode_special_tokens = True,
|
|
stop_conditions = [tokenizer.eos_token_id],
|
|
gen_settings = ExLlamaV2Sampler.Settings.greedy(),
|
|
embeddings = image_embeddings,
|
|
)
|
|
|
|
output_text=output.split('<|im_start|>assistant\n')[-1]
|
|
print(output_text) |