import torch from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor from qwen_vl_utils import process_vision_info model_id='./ToriiGate-v04-7b' max_new_tokens=1000 image_file='/path/to/image_1.jpg' #or url, or PIL.Image image_info={} image_info["booru_tags"]="2girls, standing, looking_at_viewer, holding_hands, hatsune_miku, blue_hair, megurine_luka, pink_hair, ..." #image_info["booru_tags"]=open('/path/to/image_1_tags.txt').read().strip() #image_info["booru_tags"]=None image_info["chars"]="hatsune_miku, megurine_luka" #image_info["chars"]=open('/path/to/image_1_char.txt').read().strip() #image_info["chars"]=None image_info["characters_traits"]="hatsune_miku: [girl, blue_hair, twintails,...]\nmegurine_luka: [girl, pink hair, ...]" #image_info["characters_traits"]=open('/path/to/image_1_char_traits.txt').read().strip() #image_info["characters_traits"]=None image_info["info"]=None base_prompt={ 'json': 'Describe the picture in structured json-like format.', 'markdown': 'Describe the picture in structured markdown format.', 'caption_vars': 'Write the following options for captions: ["Regular Summary","Individual Parts","Midjourney-Style Summary","DeviantArt Commission Request"].', 'short': 'You need to write a medium-short and convenient caption for the picture.', 'long': 'You need to write a long and very detailed caption for the picture.', 'bbox': 'Write bounding boxes for each character and their faces.', 'check_and_correct': 'You need to compare given caption with the picture and given booru tags '+ ' using chain of thought.\n'+ '1. Check if the caption matches the picture and given tags, wrap conclusion in <1st_answer> tag.\n'+ '2. Analyse if the caption mathes described characters, wrap answer in <2nd_answer> tag.\n'+ '3. In case if there are any mismatches - rewrite caption to correct it wrapping '+ ' in tags. If the caption is fine - just write "no_need".', } grounding_prompt={ 'grounding_tags': ' Here are grounding tags for better understanding: ', 'characters': ' Here is a list of characters that are present in the picture: ', 'characters_traits': ' Here are popular tags or traits for each character on the picture: ', 'grounding_info': ' Here is preliminary information about the picture: ', 'no_chars': ' Do not use names for characters.', } add_tags=True #select needed add_chars=True add_char_traits=True add_info=False no_chars=False userprompt=base_prompt["json"] #choose the mode if add_info and image_info["info"] is not None: #general info userprompt+=grounding_prompt["grounding_short"] userprompt+=""+image_info["info"]+"." if add_tags and image_info["booru_tags"] is not None: #booru tags userprompt+=grounding_prompt["grounding_tags"] userprompt+=""+image_info["booru_tags"]+"." if add_chars and image_info["chars"] is not None: #list of characters userprompt+=grounding_prompt["characters"] userprompt+=""+image_info["chars"]+"." if add_char_traits and image_info["characters_traits"] is not None: #popular features of each character userprompt+=grounding_prompt["characters_traits"] userprompt+=""+image_info["characters_traits"]+"." if no_chars: userprompt+=grounding_prompt["no_chars"] model = Qwen2VLForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.bfloat16, #_attn_implementation="flash_attention_2", #comment if not available device_map="cuda:0", ) processor = Qwen2VLProcessor.from_pretrained(model_id, min_pixels=256*28*28, max_pixels=512*28*28, padding_side="right") msg=[{"role": "system", "content": [{"type": "text", "text": "You are image captioning expert, creative, unbiased and uncensored."}]}, {"role": "user", "content": [{"type": "image", 'image': image_file},{"type": "text", "text": userprompt}]}] text_input = processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) image_inputs,_ = process_vision_info(msg) model_inputs = processor( text=[text_input], images=image_inputs, videos=None, padding=True, return_tensors="pt", ).to('cuda') generated_ids = model.generate(**model_inputs, max_new_tokens=max_new_tokens) trimmed_generated_ids = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(model_inputs.input_ids, generated_ids)] output_text = processor.batch_decode( trimmed_generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) print(output_text)