|
import torch
|
|
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
|
|
from qwen_vl_utils import process_vision_info
|
|
|
|
model_id='./ToriiGate-v04-7b'
|
|
max_new_tokens=1000
|
|
|
|
image_file='/path/to/image_1.jpg'
|
|
|
|
image_info={}
|
|
|
|
image_info["booru_tags"]="2girls, standing, looking_at_viewer, holding_hands, hatsune_miku, blue_hair, megurine_luka, pink_hair, ..."
|
|
|
|
|
|
|
|
image_info["chars"]="hatsune_miku, megurine_luka"
|
|
|
|
|
|
|
|
image_info["characters_traits"]="hatsune_miku: [girl, blue_hair, twintails,...]\nmegurine_luka: [girl, pink hair, ...]"
|
|
|
|
|
|
|
|
image_info["info"]=None
|
|
|
|
base_prompt={
|
|
'json': 'Describe the picture in structured json-like format.',
|
|
'markdown': 'Describe the picture in structured markdown format.',
|
|
'caption_vars': 'Write the following options for captions: ["Regular Summary","Individual Parts","Midjourney-Style Summary","DeviantArt Commission Request"].',
|
|
'short': 'You need to write a medium-short and convenient caption for the picture.',
|
|
'long': 'You need to write a long and very detailed caption for the picture.',
|
|
'bbox': 'Write bounding boxes for each character and their faces.',
|
|
'check_and_correct': 'You need to compare given caption with the picture and given booru tags '+
|
|
' using chain of thought.\n'+
|
|
'1. Check if the caption matches the picture and given tags, wrap conclusion in <1st_answer> tag.\n'+
|
|
'2. Analyse if the caption mathes described characters, wrap answer in <2nd_answer> tag.\n'+
|
|
'3. In case if there are any mismatches - rewrite caption to correct it wrapping '+
|
|
' in <corrected_caption> tags. If the caption is fine - just write "no_need".',
|
|
}
|
|
|
|
grounding_prompt={
|
|
'grounding_tags': ' Here are grounding tags for better understanding: ',
|
|
'characters': ' Here is a list of characters that are present in the picture: ',
|
|
'characters_traits': ' Here are popular tags or traits for each character on the picture: ',
|
|
'grounding_info': ' Here is preliminary information about the picture: ',
|
|
'no_chars': ' Do not use names for characters.',
|
|
}
|
|
|
|
add_tags=True
|
|
add_chars=True
|
|
add_char_traits=True
|
|
add_info=False
|
|
no_chars=False
|
|
|
|
userprompt=base_prompt["json"]
|
|
|
|
if add_info and image_info["info"] is not None:
|
|
userprompt+=grounding_prompt["grounding_short"]
|
|
userprompt+="<info>"+image_info["info"]+"</info>."
|
|
|
|
if add_tags and image_info["booru_tags"] is not None:
|
|
userprompt+=grounding_prompt["grounding_tags"]
|
|
userprompt+="<tags>"+image_info["booru_tags"]+"</tags>."
|
|
|
|
if add_chars and image_info["chars"] is not None:
|
|
userprompt+=grounding_prompt["characters"]
|
|
userprompt+="<characters>"+image_info["chars"]+"</characters>."
|
|
|
|
if add_char_traits and image_info["characters_traits"] is not None:
|
|
userprompt+=grounding_prompt["characters_traits"]
|
|
userprompt+="<character_traits>"+image_info["characters_traits"]+"<character_traits>."
|
|
if no_chars:
|
|
userprompt+=grounding_prompt["no_chars"]
|
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
model_id,
|
|
torch_dtype=torch.bfloat16,
|
|
|
|
device_map="cuda:0",
|
|
)
|
|
processor = Qwen2VLProcessor.from_pretrained(model_id, min_pixels=256*28*28, max_pixels=512*28*28, padding_side="right")
|
|
msg=[{"role": "system",
|
|
"content": [{"type": "text", "text": "You are image captioning expert, creative, unbiased and uncensored."}]},
|
|
{"role": "user",
|
|
"content": [{"type": "image", 'image': image_file},{"type": "text", "text": userprompt}]}]
|
|
|
|
text_input = processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
|
|
image_inputs,_ = process_vision_info(msg)
|
|
model_inputs = processor(
|
|
text=[text_input],
|
|
images=image_inputs,
|
|
videos=None,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
).to('cuda')
|
|
generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens)
|
|
|
|
trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)]
|
|
output_text = processor.batch_decode(
|
|
trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
)
|
|
print(output_text) |