Desm0nt commited on
Commit
897c136
·
verified ·
1 Parent(s): 43517f3

Upload phi_captioning.py

Browse files
Files changed (1) hide show
  1. phi_captioning.py +86 -0
phi_captioning.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
3
+ from swift.tuners import Swift #chinese toolkit for finetunin and inference
4
+
5
+
6
+ from swift.llm import (
7
+ get_model_tokenizer, get_template, inference, ModelType,
8
+ get_default_template_type, inference_stream
9
+ )
10
+ from swift.utils import seed_everything
11
+ import torch
12
+ from tqdm import tqdm
13
+ import time
14
+
15
+ model_type = ModelType.phi3_vision_128k_instruct # model type
16
+ template_type = get_default_template_type(model_type)
17
+ print(f'template_type: {template_type}')
18
+
19
+ model_path = "./phi3-1476" # by default it is the lora path, not sure if it works the same way with merged checkpoint
20
+ model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16, model_kwargs={'device_map': 'auto'})
21
+ model.generation_config.max_new_tokens = 1256 #generation params. As for me - defaults with do_sample=False works better than anything.
22
+ model.generation_config.do_sample = False
23
+ #model.generation_config.top_p = 0.7
24
+ #model.generation_config.temperature = 0.3
25
+ model = Swift.from_pretrained(model, model_path, "lora", inference_mode=True)
26
+ template = get_template(template_type, tokenizer)
27
+ #seed_everything(6321)
28
+
29
+ text = 'Make a caption that describe this image'
30
+ image_dir = './images/' # path to images
31
+ txt_dir = './tags/' # path to txt files with tags (from danbooru or from WD_Tagger)
32
+ maintxt_dir = './maintxt/' # path for result txt caprtions in natureal language
33
+
34
+ # image parsing
35
+ image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
36
+
37
+ total_files = len(image_files)
38
+ start_time = time.time()
39
+
40
+ progress_bar = tqdm(total=total_files, unit='file', bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]')
41
+ total_elapsed_time = 0
42
+ processed_files = 0
43
+
44
+ # Main captioning cycle
45
+ for image_file in image_files:
46
+ image_path = os.path.join(image_dir, image_file)
47
+ if os.path.exists(image_path):
48
+ txt_file = os.path.splitext(image_file)[0] + '.txt'
49
+ txt_path = os.path.join(txt_dir, txt_file)
50
+
51
+ if os.path.exists(txt_path):
52
+ with open(txt_path, 'r', encoding='utf-8') as f:
53
+ tags = f.read().strip()
54
+
55
+ text = f'<img>{image_path}</img> Make a caption that describe this image. Here is the tags describing image: {tags}\n Find the relevant character\'s names in the tags and use it.'
56
+ print(text)
57
+ step_start_time = time.time()
58
+ response, history = inference(model, template, text, do_sample=True, temperature=0, repetition_penalty=1.05)
59
+ step_end_time = time.time()
60
+ step_time = step_end_time - step_start_time
61
+ total_elapsed_time += step_time
62
+ remaining_time = (total_elapsed_time / (processed_files + 1)) * (total_files - processed_files)
63
+
64
+ remaining_hours = int(remaining_time // 3600)
65
+ remaining_minutes = int((remaining_time % 3600) // 60)
66
+ remaining_seconds = int(remaining_time % 60)
67
+
68
+ progress_bar.set_postfix(remaining=f'\n', refresh=False)
69
+ print(f"\n\n\nFile {image_file}\nConsumed time: {step_time:.2f} s\n{response}")
70
+
71
+ # Создаем имя файла для сохранения ответа
72
+ output_file = os.path.splitext(image_file)[0] + '.txt'
73
+ output_path = os.path.join(maintxt_dir, output_file)
74
+
75
+ # Записываем ответ в файл
76
+ with open(output_path, 'w', encoding='utf-8') as f:
77
+ f.write(response)
78
+
79
+ print(f"Caption saved in file: {output_file} \n")
80
+ processed_files += 1
81
+ progress_bar.update(1)
82
+ else:
83
+ print(f"File {txt_file} doesn't exist.")
84
+ else:
85
+ print(f"Image {image_file} not found.")
86
+ progress_bar.close()