DucHaiten commited on
Commit
73feb13
1 Parent(s): d07a4d8

Update image_to_caption.py

Browse files
Files changed (1) hide show
  1. image_to_caption.py +947 -844
image_to_caption.py CHANGED
@@ -1,844 +1,947 @@
1
- import tkinter as tk
2
- from tkinter import filedialog, messagebox, ttk
3
- from PIL import Image as PILImage, ImageTk
4
- import os
5
- import queue
6
- import threading
7
- import torch
8
- from transformers import AutoModelForCausalLM, LlamaTokenizer
9
- import json
10
- import traceback
11
- import math
12
-
13
- torch.set_grad_enabled(False)
14
-
15
- stop_processing = False
16
- error_messages = []
17
- selected_files = []
18
- save_directory = ""
19
- caption_window = None
20
- caption_frame = None
21
- thumbnails = []
22
- caption_text_widgets = []
23
- error_window = None
24
- status_var = None
25
- num_files_var = None
26
- errors_var = None
27
- progress = None
28
- prompt_var = None
29
- max_new_tokens_var = None
30
- do_sample_var = None
31
- temperature_var = None
32
- top_k_var = None
33
- top_p_var = None
34
- thread_count_var = None
35
- precision_var = None
36
- batch_size_var = None
37
- prepend_text_var = None
38
- append_text_var = None
39
- caption_handling_var = None # Variable to handle radio buttons for caption handling
40
- start_button = None
41
- stop_button = None
42
- model = None
43
- prompt_entry = None
44
- select_files_button = None
45
- show_captions_button = None
46
- thread_count_entry = None
47
- precision_entry = None
48
- batch_size_entry = None
49
- prepend_text_entry = None
50
- append_text_entry = None
51
- root = None
52
- q = queue.Queue()
53
-
54
- current_page = 0
55
- images_per_page = 20
56
- total_pages = 1
57
- content_canvas = None
58
- search_var = None
59
- original_selected_files = []
60
- action_var = None
61
- action_entry = None
62
-
63
- def load_model():
64
- global model, tokenizer
65
- if model is None:
66
- tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
67
- dtype = torch.float16 if precision_var.get() <= 1 else torch.float32
68
- model = AutoModelForCausalLM.from_pretrained(
69
- 'THUDM/cogvlm-chat-hf',
70
- torch_dtype=dtype,
71
- low_cpu_mem_usage=True,
72
- trust_remote_code=True,
73
- ).to('cuda').eval()
74
-
75
- def update_and_save_config():
76
- top_p_value = top_p_var.get() if do_sample_var.get() else None
77
- config_entry = {
78
- 'prompt': prompt_var.get(),
79
- 'max_new_tokens': max_new_tokens_var.get(),
80
- 'temperature': temperature_var.get(),
81
- 'top_k': top_k_var.get(),
82
- 'top_p': float(top_p_value) if top_p_value is not None else None,
83
- 'precision': precision_var.get(),
84
- 'thread_count': thread_count_var.get(),
85
- 'batch_size': batch_size_var.get(),
86
- 'prepend_text': prepend_text_var.get(),
87
- 'append_text': append_text_var.get(),
88
- 'caption_handling': caption_handling_var.get() # Save the selected caption handling option
89
- }
90
-
91
- try:
92
- with open('captions.json', 'w') as f:
93
- json.dump(config_entry, f, indent=2)
94
- except Exception as e:
95
- print(f"Error saving config to captions.json: {e}")
96
-
97
- def load_config_from_json():
98
- global prompt_entry
99
- try:
100
- if os.path.exists('captions.json'):
101
- with open('captions.json', 'r') as f:
102
- config_entry = json.load(f)
103
- prompt_var.set(config_entry.get('prompt', ''))
104
- max_new_tokens_var.set(config_entry.get('max_new_tokens', 200))
105
- temperature_var.set(config_entry.get('temperature', 1.0))
106
- top_k_var.set(config_entry.get('top_k', 50))
107
- top_p_value = config_entry.get('top_p', 0.95)
108
- top_p_var.set(top_p_value if top_p_value is not None else 0.95)
109
- precision_var.set(config_entry.get('precision', 1))
110
- thread_count_var.set(config_entry.get('thread_count', 4))
111
- batch_size_var.set(config_entry.get('batch_size', 1))
112
- prepend_text_var.set(config_entry.get('prepend_text', ''))
113
- append_text_var.set(config_entry.get('append_text', ''))
114
- caption_handling_var.set(config_entry.get('caption_handling', 'skip')) # Load the saved caption handling option
115
-
116
- prompt_entry.delete("1.0", tk.END)
117
- prompt_entry.insert(tk.END, config_entry.get('prompt', ''))
118
- except Exception as e:
119
- print(f"Error loading config from captions.json: {e}")
120
-
121
- def on_config_change(*args):
122
- root.after(100, update_config)
123
-
124
- def update_config():
125
- try:
126
- precision_value = precision_var.get()
127
- if precision_value == "":
128
- return # Không làm nếu giá trị là chuỗi rỗng
129
-
130
- update_and_save_config()
131
- except Exception as e:
132
- print(f"Lỗi khi xử lý giá trị: {e}")
133
-
134
- def on_prompt_change(event=None):
135
- prompt_var.set(prompt_entry.get("1.0", tk.END).strip())
136
- update_and_save_config()
137
-
138
- def show_errors():
139
- global error_window
140
- if error_window is not None:
141
- return
142
-
143
- error_window = tk.Toplevel(root)
144
- error_window.title("Error Details")
145
- error_window.geometry("500x400")
146
-
147
- error_text = tk.Text(error_window, wrap='word')
148
- error_text.pack(expand=True, fill='both')
149
-
150
- if error_messages:
151
- for error in error_messages:
152
- error_text.insert('end', error + '\n')
153
- else:
154
- error_text.insert('end', "No errors recorded.")
155
-
156
- error_text.config(state='disabled')
157
-
158
- def on_close_error_window():
159
- global error_window
160
- error_window.destroy()
161
- error_window = None
162
-
163
- error_window.protocol("WM_DELETE_WINDOW", on_close_error_window)
164
-
165
- def validate_numeric_input(value):
166
- if value == "" or value == "-":
167
- return True
168
- try:
169
- float(value)
170
- return True
171
- except ValueError:
172
- return False
173
-
174
- def center_window(window):
175
- window.update_idletasks()
176
- width = window.winfo_width()
177
- height = window.winfo_height()
178
- x = (window.winfo_screenwidth() // 2) - (width // 2)
179
- y = (window.winfo_screenheight() // 2) - (height // 2)
180
- window.geometry(f'{width}x{height}+{x}+{y}')
181
-
182
- def toggle_sampling_options():
183
- if do_sample_var.get():
184
- temperature_label.pack(pady=5, after=do_sample_check)
185
- temperature_entry.pack(pady=5, after=temperature_label)
186
- top_k_label.pack(pady=5, after=temperature_entry)
187
- top_k_entry.pack(pady=5, after=top_k_label)
188
- top_p_label.pack(pady=5, after=top_k_entry)
189
- top_p_entry.pack(pady=5, after=top_p_label)
190
- root.geometry(f"{root.winfo_width()}x{root.winfo_height() + 150}")
191
- else:
192
- temperature_label.pack_forget()
193
- temperature_entry.pack_forget()
194
- top_k_label.pack_forget()
195
- top_k_entry.pack_forget()
196
- top_p_label.pack_forget()
197
- top_p_entry.pack_forget()
198
- root.geometry(f"{root.winfo_width()}x{root.winfo_height() - 150}")
199
- center_window(root)
200
-
201
- def open_image_to_caption():
202
- global stop_processing, error_messages, selected_files, save_directory, status_var, num_files_var, errors_var, progress
203
- global prompt_var, max_new_tokens_var, do_sample_var, temperature_var, top_k_var, top_p_var, thread_count_var, precision_var, batch_size_var
204
- global prepend_text_var, append_text_var, search_var, action_var, caption_handling_var # Updated caption handling variable
205
- global start_button, stop_button
206
- global temperature_label, temperature_entry, top_k_label, top_k_entry, top_p_label, top_p_entry
207
- global do_sample_check, prompt_entry, select_files_button, show_captions_button, thread_count_entry, precision_entry, batch_size_entry
208
- global prepend_text_entry, append_text_entry
209
- global root
210
- global q
211
-
212
- root = tk.Tk()
213
- root.title("Image to Caption")
214
- root.geometry("1050x950")
215
-
216
- # Khởi tạo các biến Tkinter sau khi root đã được tạo
217
- status_var = tk.StringVar()
218
- num_files_var = tk.StringVar()
219
- errors_var = tk.StringVar(value="Errors: 0")
220
- progress = tk.IntVar()
221
- prompt_var = tk.StringVar(value="Describe this image")
222
- max_new_tokens_var = tk.IntVar(value=200)
223
- do_sample_var = tk.BooleanVar(value=False)
224
- temperature_var = tk.DoubleVar(value=1.0)
225
- top_k_var = tk.IntVar(value=50)
226
- top_p_var = tk.DoubleVar(value=0.95)
227
- thread_count_var = tk.IntVar(value=4)
228
- precision_var = tk.IntVar(value=1)
229
- batch_size_var = tk.IntVar(value=1)
230
- prepend_text_var = tk.StringVar()
231
- append_text_var = tk.StringVar()
232
- caption_handling_var = tk.StringVar(value='skip') # Default value is 'skip'
233
- search_var = tk.StringVar() # Biến search_var khởi tạo ở đây
234
- action_var = tk.StringVar() # Biến action_var khởi tạo ở đây
235
-
236
- q = queue.Queue()
237
-
238
- validate_cmd = root.register(validate_numeric_input)
239
-
240
- back_button = tk.Button(root, text="<-", font=('Helvetica', 14), command=return_to_menu)
241
- back_button.pack(anchor='nw', padx=10, pady=10)
242
-
243
- title_label = tk.Label(root, text="Image Caption Generator", font=('Helvetica', 16))
244
- title_label.pack(pady=10)
245
-
246
- warning_label = tk.Label(root, text="NOTE: To run CogVLM with the minimum configuration, you need at least 40GB RAM to load the model in FP16 with batch size 1 and a GPU with at least 24GB of VRAM. It is recommended to install ImageDucHaiten on an NVMe SSD to optimize speed.",
247
- font=('Helvetica', 10), fg="red", wraplength=750, justify="left")
248
- warning_label.pack(pady=10)
249
-
250
- select_files_button = tk.Button(root, text="Select Files", command=select_files)
251
- select_files_button.pack(pady=10)
252
-
253
- show_captions_button = tk.Button(root, text="Show Captions", command=open_caption_window)
254
- show_captions_button.pack(pady=10)
255
-
256
- num_files_label = tk.Label(root, textvariable=num_files_var)
257
- num_files_label.pack(pady=5)
258
-
259
- prompt_label = tk.Label(root, text="Prompt (text to describe the image):")
260
- prompt_label.pack(pady=5)
261
- prompt_entry = tk.Text(root, height=3, wrap='word', width=60)
262
- prompt_entry.pack(pady=5, padx=10, fill='both', expand=True)
263
- prompt_entry.bind('<KeyRelease>', on_prompt_change)
264
-
265
- prepend_text_label = tk.Label(root, text="Prepend Text:")
266
- prepend_text_label.pack(pady=5)
267
- prepend_text_entry = tk.Entry(root, textvariable=prepend_text_var, justify='center', width=60)
268
- prepend_text_entry.pack(pady=5)
269
-
270
- append_text_label = tk.Label(root, text="Append Text:")
271
- append_text_label.pack(pady=5)
272
- append_text_entry = tk.Entry(root, textvariable=append_text_var, justify='center', width=60)
273
- append_text_entry.pack(pady=5)
274
-
275
- # Thêm các radio button để xử lý caption khi ảnh đã có caption
276
- caption_handling_label = tk.Label(root, text="If a caption already exists for an image:", font=('Helvetica', 12))
277
- caption_handling_label.pack(pady=5)
278
-
279
- # Frame chứa các radio button
280
- options_frame = tk.Frame(root)
281
- options_frame.pack(pady=5)
282
-
283
- # Radio buttons
284
- overwrite_radio = tk.Radiobutton(options_frame, text="Overwrite existing caption", variable=caption_handling_var, value='overwrite')
285
- overwrite_radio.pack(side="left", padx=10)
286
-
287
- append_radio = tk.Radiobutton(options_frame, text="Append to existing caption", variable=caption_handling_var, value='append')
288
- append_radio.pack(side="left", padx=10)
289
-
290
- skip_radio = tk.Radiobutton(options_frame, text="Skip images with existing caption", variable=caption_handling_var, value='skip')
291
- skip_radio.pack(side="left", padx=10)
292
-
293
- load_config_from_json()
294
-
295
- prompt_var.trace('w', on_config_change)
296
- max_new_tokens_var.trace('w', on_config_change)
297
- temperature_var.trace('w', on_config_change)
298
- top_k_var.trace('w', on_config_change)
299
- top_p_var.trace('w', on_config_change)
300
- precision_var.trace('w', on_config_change)
301
- thread_count_var.trace('w', on_config_change)
302
- batch_size_var.trace('w', on_config_change)
303
- prepend_text_var.trace('w', on_config_change)
304
- append_text_var.trace('w', on_config_change)
305
- caption_handling_var.trace('w', on_config_change) # Trace for the caption handling radio buttons
306
-
307
- max_new_tokens_label = tk.Label(root, text="Max New Tokens (max number of tokens to generate):")
308
- max_new_tokens_label.pack(pady=5)
309
- max_new_tokens_entry = tk.Entry(root, textvariable=max_new_tokens_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
310
- max_new_tokens_entry.pack(pady=5)
311
-
312
- do_sample_check = tk.Checkbutton(root, text="Do Sample (random sampling):", variable=do_sample_var, command=toggle_sampling_options)
313
- do_sample_check.pack(pady=5)
314
-
315
- temperature_label = tk.Label(root, text="Temperature (control randomness of sampling):")
316
- top_k_label = tk.Label(root, text="Top-k (consider top k tokens):")
317
- top_p_label = tk.Label(root, text="Top-p (consider tokens with cumulative probability p):")
318
-
319
- temperature_entry = tk.Entry(root, textvariable=temperature_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
320
- top_k_entry = tk.Entry(root, textvariable=top_k_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
321
- top_p_entry = tk.Entry(root, textvariable=top_p_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
322
-
323
- # Frame to hold all three horizontally aligned elements
324
- horizontal_frame = tk.Frame(root)
325
- horizontal_frame.pack(pady=5, padx=5)
326
-
327
- thread_count_label = tk.Label(horizontal_frame, text="Thread Count (number of threads to use):")
328
- thread_count_label.pack(side=tk.LEFT, padx=5)
329
- thread_count_entry = tk.Entry(horizontal_frame, textvariable=thread_count_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
330
- thread_count_entry.pack(side=tk.LEFT, padx=5)
331
-
332
- batch_size_label = tk.Label(horizontal_frame, text="Batch Size (number of images to process at once):")
333
- batch_size_label.pack(side=tk.LEFT, padx=5)
334
- batch_size_entry = tk.Entry(horizontal_frame, textvariable=batch_size_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
335
- batch_size_entry.pack(side=tk.LEFT, padx=5)
336
-
337
- errors_button = tk.Button(root, textvariable=errors_var, command=show_errors)
338
- errors_button.pack(pady=10)
339
-
340
- start_button = tk.Button(root, text="Generate Captions", command=lambda: [process_files(), update_and_save_config()])
341
- start_button.pack(pady=10)
342
-
343
- stop_button = tk.Button(root, text="Stop", command=stop_processing_func)
344
- stop_button.pack(pady=10)
345
-
346
- progress_bar = ttk.Progressbar(root, variable=progress, maximum=100)
347
- progress_bar.pack(pady=10, fill=tk.X)
348
-
349
- status_label = tk.Label(root, textvariable=status_var, fg="green")
350
- status_label.pack(pady=5)
351
-
352
- center_window(root)
353
- root.protocol("WM_DELETE_WINDOW", on_closing)
354
- root.mainloop()
355
-
356
- def select_files():
357
- global selected_files, save_directory, total_pages, original_selected_files
358
- filetypes = [("All Image files", "*.jpg;*.jpeg;*.png;*.gif;*.bmp;*.tiff;*.tif;*.svg;*.webp")]
359
- filepaths = filedialog.askopenfilenames(title="Select Image Files", filetypes=filetypes)
360
- if filepaths:
361
- selected_files.clear()
362
- selected_files.extend(filepaths)
363
- original_selected_files = selected_files.copy()
364
- validate_selected_files()
365
-
366
- num_files_var.set(f"{len(selected_files)} files selected.")
367
- save_directory = os.path.dirname(selected_files[0])
368
- total_pages = (len(selected_files) + images_per_page - 1) // images_per_page
369
- if caption_window is not None:
370
- update_image_preview(content_canvas)
371
-
372
- def validate_selected_files():
373
- global selected_files, num_files_var
374
- selected_files = [file for file in selected_files if os.path.exists(file)]
375
- num_files_var.set(f"{len(selected_files)} files selected.")
376
-
377
- def toggle_buttons(state):
378
- state = tk.NORMAL if state else tk.DISABLED
379
- select_files_button.config(state=state)
380
- show_captions_button.config(state=state)
381
- prompt_entry.config(state=state)
382
- prepend_text_entry.config(state=state)
383
- append_text_entry.config(state=state)
384
- do_sample_check.config(state=state)
385
- temperature_entry.config(state=state)
386
- top_k_entry.config(state=state)
387
- top_p_entry.config(state=state)
388
- thread_count_entry.config(state=state)
389
- batch_size_entry.config(state=state)
390
- start_button.config(state=state)
391
- stop_button.config(state=tk.NORMAL)
392
-
393
- def generate_caption(image_path, save_directory, q):
394
- if stop_processing:
395
- return
396
-
397
- try:
398
- load_model()
399
-
400
- filename = os.path.basename(image_path)
401
- caption_file_path = os.path.join(save_directory, f"{filename}_caption.txt")
402
-
403
- # Kiểm tra các lựa chọn của người dùng
404
- if os.path.exists(caption_file_path):
405
- if caption_handling_var.get() == 'skip':
406
- q.put(image_path)
407
- return
408
- elif caption_handling_var.get() == 'append':
409
- with open(caption_file_path, 'r', encoding='utf-8') as f:
410
- existing_caption = f.read()
411
- else:
412
- existing_caption = ""
413
- else:
414
- existing_caption = ""
415
-
416
- image = PILImage.open(image_path).convert('RGB')
417
- if not isinstance(image, PILImage.Image):
418
- raise ValueError(f"Expected image to be of type PIL.Image.Image, but got {type(image)}")
419
-
420
- inputs = model.build_conversation_input_ids(
421
- tokenizer,
422
- query=prompt_var.get(),
423
- history=[],
424
- images=[image]
425
- )
426
- inputs = {
427
- 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
428
- 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
429
- 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
430
- 'images': [[inputs['images'][0].to('cuda').to(torch.float16)]],
431
- }
432
- gen_kwargs = {
433
- "max_new_tokens": max_new_tokens_var.get(),
434
- "do_sample": do_sample_var.get(),
435
- "temperature": temperature_var.get(),
436
- "top_k": top_k_var.get(),
437
- "top_p": top_p_var.get() if do_sample_var.get() else None,
438
- "num_beams": precision_var.get()
439
- }
440
-
441
- with torch.no_grad():
442
- outputs = model.generate(**inputs, **gen_kwargs)
443
- outputs = outputs[:, inputs['input_ids'].shape[1]:]
444
- new_caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
445
-
446
- final_caption = f"{prepend_text_var.get()} {existing_caption} {new_caption} {append_text_var.get()}".strip()
447
-
448
- with open(caption_file_path, 'w', encoding='utf-8') as file:
449
- file.write(final_caption)
450
-
451
- q.put(image_path)
452
- torch.cuda.empty_cache()
453
- except torch.cuda.OutOfMemoryError as e:
454
- torch.cuda.empty_cache()
455
- error_message = f"CUDA OutOfMemoryError: {traceback.format_exc()}"
456
- print(error_message)
457
- q.put(error_message)
458
- error_messages.append(error_message)
459
- except Exception as e:
460
- error_message = f"Error processing image {image_path}: {traceback.format_exc()}"
461
- print(error_message)
462
- q.put(error_message)
463
- error_messages.append(error_message)
464
-
465
- def worker(save_directory, num_threads, batch_size):
466
- try:
467
- progress.set(0)
468
- threads = []
469
-
470
- num_batches = math.ceil(len(selected_files) / batch_size)
471
- batch_size_per_thread = max(1, batch_size // num_threads) # Số ảnh mỗi luồng xử lý trong một batch
472
-
473
- for batch_index in range(num_batches):
474
- if stop_processing:
475
- break
476
-
477
- start_index = batch_index * batch_size
478
- end_index = min(start_index + batch_size, len(selected_files))
479
- batch = selected_files[start_index:end_index]
480
-
481
- # Chia ảnh trong batch cho các luồng
482
- for i in range(0, len(batch), batch_size_per_thread):
483
- thread_batch = batch[i:i + batch_size_per_thread]
484
- thread = threading.Thread(target=generate_captions_for_batch, args=(thread_batch, save_directory, q))
485
- threads.append(thread)
486
- thread.start()
487
-
488
- # Đợi các luồng trong batch hiện tại hoàn thành
489
- for thread in threads:
490
- thread.join()
491
- threads.clear()
492
-
493
- q.put(None)
494
- except Exception as e:
495
- if not stop_processing:
496
- q.put(e)
497
-
498
- def generate_captions_for_batch(batch, save_directory, q):
499
- for image_path in batch:
500
- generate_caption(image_path, save_directory, q)
501
-
502
- def update_progress():
503
- try:
504
- completed = 0
505
- while True:
506
- item = q.get()
507
- if item is None:
508
- break
509
- if isinstance(item, str):
510
- if "Error" in item:
511
- root.after(0, errors_var.set, f"Errors: {len(error_messages)}")
512
- continue
513
- completed += 1
514
- progress.set(int((completed / len(selected_files)) * 100))
515
- if not stop_processing:
516
- root.after(0, status_var.set, f"Processed {completed} files")
517
- root.after(0, root.update_idletasks)
518
- if not stop_processing:
519
- root.after(0, progress.set(100))
520
- show_completion_message(completed)
521
- except Exception as e:
522
- if not stop_processing:
523
- root.after(0, status_var.set(f"Error: {e}"))
524
- finally:
525
- toggle_buttons(True)
526
-
527
- def show_completion_message(completed):
528
- message = f"Processing complete. {completed} files processed."
529
- if error_messages:
530
- message += f" {len(error_messages)} errors occurred."
531
- messagebox.showinfo("Process Complete", message)
532
-
533
- def process_files():
534
- global stop_processing, error_messages
535
- stop_processing = False
536
- error_messages.clear()
537
- errors_var.set("Errors: 0")
538
-
539
- validate_selected_files()
540
-
541
- if not selected_files or not save_directory:
542
- status_var.set("Please select images.")
543
- return
544
-
545
- toggle_buttons(False)
546
-
547
- threading.Thread(target=worker, args=(save_directory, thread_count_var.get(), batch_size_var.get())).start()
548
- threading.Thread(target=update_progress).start()
549
-
550
- def stop_processing_func():
551
- global stop_processing
552
- stop_processing = True
553
- torch.cuda.empty_cache()
554
- status_var.set("Processing stopped.")
555
-
556
- def open_caption_window():
557
- global caption_window, caption_frame, caption_text_widgets, current_page, total_pages, content_canvas
558
- if caption_window is not None:
559
- return
560
-
561
- validate_selected_files()
562
-
563
- caption_window = tk.Toplevel(root)
564
- caption_window.title("Image Thumbnails and Captions")
565
- caption_window.geometry("800x900")
566
-
567
- main_frame = tk.Frame(caption_window)
568
- main_frame.pack(fill=tk.BOTH, expand=True)
569
-
570
- search_frame = tk.Frame(main_frame)
571
- search_frame.pack(side=tk.TOP, fill=tk.X)
572
-
573
- search_entry = tk.Entry(search_frame, textvariable=search_var)
574
- search_entry.pack(side=tk.LEFT, padx=10, pady=5, fill=tk.X, expand=True)
575
-
576
- search_button = tk.Button(search_frame, text="Search", command=search_captions)
577
- search_button.pack(side=tk.LEFT, padx=10)
578
-
579
- reset_button = tk.Button(search_frame, text="Reset Order", command=reset_order)
580
- reset_button.pack(side=tk.LEFT, padx=10)
581
-
582
- action_frame = tk.Frame(main_frame)
583
- action_frame.pack(side=tk.TOP, fill=tk.X)
584
-
585
- action_entry = tk.Entry(action_frame, textvariable=action_var)
586
- action_entry.pack(side=tk.LEFT, padx=10, pady=5, fill=tk.X, expand=True)
587
-
588
- prepend_button = tk.Button(action_frame, text="Add to Beginning", command=lambda: add_to_captions("prepend"))
589
- prepend_button.pack(side=tk.LEFT, padx=5)
590
-
591
- append_button = tk.Button(action_frame, text="Add to End", command=lambda: add_to_captions("append"))
592
- append_button.pack(side=tk.LEFT, padx=5)
593
-
594
- insert_middle_button = tk.Button(action_frame, text="Add to Middle", command=lambda: add_to_captions("insert_middle"))
595
- insert_middle_button.pack(side=tk.LEFT, padx=5)
596
-
597
- delete_keyword_button = tk.Button(action_frame, text="Delete Keyword", command=delete_keyword_from_captions)
598
- delete_keyword_button.pack(side=tk.LEFT, padx=5)
599
-
600
- delete_images_button = tk.Button(action_frame, text="Delete Images with Keyword", command=delete_images_with_keyword)
601
- delete_images_button.pack(side=tk.LEFT, padx=5)
602
-
603
- content_canvas = tk.Canvas(main_frame)
604
- content_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
605
-
606
- caption_frame = tk.Frame(content_canvas)
607
- content_canvas.create_window((0, 0), window=caption_frame, anchor='nw')
608
-
609
- caption_scrollbar = tk.Scrollbar(main_frame, orient="vertical", command=content_canvas.yview)
610
- caption_scrollbar.pack(side=tk.LEFT, fill=tk.Y)
611
- content_canvas.configure(yscrollcommand=caption_scrollbar.set)
612
-
613
- caption_frame.bind("<Configure>", lambda e: content_canvas.configure(scrollregion=content_canvas.bbox("all")))
614
- content_canvas.bind_all("<MouseWheel>", lambda event: content_canvas.yview_scroll(int(-1*(event.delta/120)), "units"))
615
-
616
- def on_caption_window_close():
617
- global caption_window
618
- caption_window.destroy()
619
- caption_window = None
620
-
621
- caption_window.protocol("WM_DELETE_WINDOW", on_caption_window_close)
622
-
623
- update_image_preview(content_canvas)
624
-
625
- def update_image_preview(content_canvas):
626
- global thumbnails, caption_text_widgets, current_page, images_per_page, total_pages
627
- if caption_frame is None:
628
- return
629
-
630
- for widget in caption_frame.winfo_children():
631
- if isinstance(widget, tk.Label) or isinstance(widget, tk.Text) or isinstance(widget, tk.Frame):
632
- widget.destroy()
633
-
634
- thumbnails.clear()
635
- caption_text_widgets.clear()
636
-
637
- if not selected_files:
638
- return
639
-
640
- start_index = current_page * images_per_page
641
- end_index = start_index + images_per_page
642
- files_to_display = selected_files[start_index:end_index]
643
-
644
- for i, file_path in enumerate(files_to_display):
645
- thumbnail_size = (200, 200)
646
- try:
647
- image = PILImage.open(file_path)
648
- image.thumbnail(thumbnail_size)
649
- thumbnail = ImageTk.PhotoImage(image)
650
- thumbnails.append(thumbnail)
651
-
652
- img_label = tk.Label(caption_frame, image=thumbnail)
653
- img_label.grid(row=i*2, column=0, padx=5, pady=5, sticky="nsew")
654
-
655
- file_label = tk.Label(caption_frame, text=os.path.basename(file_path), font=('Helvetica', 12))
656
- file_label.grid(row=i*2, column=1, padx=5, pady=5, sticky="nsew")
657
-
658
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
659
- if os.path.exists(caption_file):
660
- with open(caption_file, 'r', encoding='utf-8') as file:
661
- caption_text = file.read()
662
- else:
663
- caption_text = ""
664
-
665
- caption_var = tk.StringVar(value=caption_text)
666
-
667
- caption_text_widget = tk.Text(caption_frame, width=50, height=3, wrap=tk.WORD, font=('Helvetica', 12))
668
- caption_text_widget.insert(tk.END, caption_text)
669
- caption_text_widget.grid(row=i*2, column=2, padx=5, pady=5, sticky="nsew")
670
-
671
- caption_var.trace_add("write", lambda *args, fp=file_path, cv=caption_var: save_caption(fp, cv.get()))
672
-
673
- caption_text_widget.bind("<KeyRelease>", lambda e, cv=caption_var, w=caption_text_widget: cv.set(w.get("1.0", "end-1c")))
674
- caption_text_widgets.append(caption_text_widget)
675
-
676
- except Exception as e:
677
- tk.Label(caption_frame, text="Error loading image").grid(row=i*2, column=0, columnspan=4, padx=5, pady=5)
678
-
679
- nav_frame = tk.Frame(caption_frame)
680
- nav_frame.grid(row=images_per_page*2, column=0, columnspan=3, pady=10)
681
-
682
- if current_page > 0:
683
- prev_button = tk.Button(nav_frame, text="Previous", command=lambda: navigate(-1, content_canvas))
684
- prev_button.pack(side=tk.LEFT)
685
-
686
- page_label = tk.Label(nav_frame, text=f"Page {current_page + 1} of {total_pages}")
687
- page_label.pack(side=tk.LEFT, padx=5)
688
-
689
- page_entry = tk.Entry(nav_frame, width=5)
690
- page_entry.pack(side=tk.LEFT)
691
-
692
- go_button = tk.Button(nav_frame, text="Go", command=lambda: go_to_page(page_entry.get(), content_canvas))
693
- go_button.pack(side=tk.LEFT, padx=5)
694
-
695
- if current_page < total_pages - 1:
696
- next_button = tk.Button(nav_frame, text="Next", command=lambda: navigate(1, content_canvas))
697
- next_button.pack(side=tk.RIGHT)
698
-
699
- def navigate(direction, content_canvas):
700
- global current_page
701
- current_page += direction
702
- update_image_preview(content_canvas)
703
-
704
- def go_to_page(page_number, content_canvas):
705
- global current_page, total_pages
706
- try:
707
- page_number = int(page_number)
708
- if 1 <= page_number <= total_pages:
709
- current_page = page_number - 1
710
- update_image_preview(content_canvas)
711
- else:
712
- messagebox.showerror("Invalid Page", f"Please enter a valid page number between 1 and {total_pages}.")
713
- except ValueError:
714
- messagebox.showerror("Invalid Input", "Please enter a valid integer for the page number.")
715
-
716
- def save_caption(file_path, caption_text):
717
- output_path = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
718
- try:
719
- with open(output_path, 'w', encoding='utf-8') as file:
720
- file.write(caption_text.strip())
721
- except Exception as e:
722
- print(f"Error saving captions: {e}")
723
-
724
- def search_captions():
725
- global selected_files
726
- search_term = search_var.get().lower().strip()
727
- if not search_term:
728
- return
729
-
730
- try:
731
- selected_files.sort(key=lambda x: search_score(x, search_term), reverse=True)
732
- except Exception as e:
733
- error_message = f"Error during sorting: {e}"
734
- print(error_message)
735
- error_messages.append(error_message)
736
-
737
- update_image_preview(content_canvas)
738
-
739
- def search_score(file_path, search_term):
740
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
741
- try:
742
- if os.path.exists(caption_file):
743
- with open(caption_file, 'r', encoding='utf-8') as file:
744
- caption_text = file.read().lower()
745
- if search_term in caption_text:
746
- return caption_text.count(search_term)
747
-
748
- except Exception as e:
749
- error_message = f"Error reading file {caption_file}: {e}"
750
- print(error_message)
751
- error_messages.append(error_message)
752
- return 0
753
-
754
- def reset_order():
755
- global selected_files
756
- selected_files = original_selected_files.copy()
757
- update_image_preview(content_canvas)
758
-
759
- def add_to_captions(position):
760
- global selected_files
761
- keyword = action_var.get()
762
- if not keyword:
763
- return
764
-
765
- for file_path in selected_files:
766
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
767
- if os.path.exists(caption_file):
768
- with open(caption_file, 'r+', encoding='utf-8') as file:
769
- caption_text = file.read()
770
- if position == "prepend":
771
- caption_text = f"{keyword} {caption_text}"
772
- elif position == "append":
773
- caption_text = f"{caption_text} {keyword}"
774
- elif position == "insert_middle":
775
- middle_index = len(caption_text) // 2
776
- caption_text = f"{caption_text[:middle_index]} {keyword} {caption_text[middle_index:]}"
777
- file.seek(0)
778
- file.write(caption_text)
779
- file.truncate()
780
-
781
- update_image_preview(content_canvas)
782
-
783
- def delete_keyword_from_captions():
784
- keyword = action_var.get().lower().strip()
785
- if not keyword:
786
- return
787
-
788
- for file_path in selected_files:
789
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
790
- if os.path.exists(caption_file):
791
- with open(caption_file, 'r+', encoding='utf-8') as file:
792
- caption_text = file.read().lower().replace(keyword, "")
793
-
794
- updated_caption = caption_text.replace(keyword, "").strip()
795
-
796
- file.seek(0)
797
- file.write(updated_caption)
798
- file.truncate()
799
-
800
- update_image_preview(content_canvas)
801
-
802
- def delete_images_with_keyword():
803
- global selected_files
804
- keyword = action_var.get().lower()
805
- if not keyword:
806
- return
807
-
808
- files_to_delete = []
809
- for file_path in selected_files:
810
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
811
- if os.path.exists(caption_file):
812
- with open(caption_file, 'r', encoding='utf-8') as file:
813
- caption_text = file.read().lower()
814
- if keyword in caption_text:
815
- files_to_delete.append(file_path)
816
-
817
- for file_path in files_to_delete:
818
- try:
819
- os.remove(file_path)
820
- caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
821
- if os.path.exists(caption_file):
822
- os.remove(caption_file)
823
- except Exception as e:
824
- error_message = f"Error deleting file {file_path} or its caption: {e}"
825
- print(error_message)
826
- error_messages.append(error_message)
827
-
828
- selected_files = [file_path for file_path in selected_files if file_path not in files_to_delete]
829
-
830
- validate_selected_files()
831
-
832
- update_image_preview(content_canvas)
833
-
834
- def return_to_menu():
835
- stop_processing_func()
836
- root.destroy()
837
- import main
838
- main.open_main_menu()
839
-
840
- def on_closing():
841
- return_to_menu()
842
-
843
- if __name__ == "__main__":
844
- open_image_to_caption()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import tkinter as tk
3
+ from tkinter import filedialog, messagebox, ttk
4
+ from PIL import Image as PILImage, ImageTk
5
+ import os
6
+ import queue
7
+ import threading
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
10
+ import json
11
+ import traceback
12
+ import math
13
+
14
+ torch.set_grad_enabled(False)
15
+
16
+ stop_processing = False
17
+ error_messages = []
18
+ selected_files = []
19
+ save_directory = ""
20
+ caption_window = None
21
+ caption_frame = None
22
+ thumbnails = []
23
+ caption_text_widgets = []
24
+ error_window = None
25
+ status_var = None
26
+ num_files_var = None
27
+ errors_var = None
28
+ progress = None
29
+ prompt_var = None
30
+ max_new_tokens_var = None
31
+ do_sample_var = None
32
+ temperature_var = None
33
+ top_k_var = None
34
+ top_p_var = None
35
+ thread_count_var = None
36
+ precision_var = None
37
+ batch_size_var = None
38
+ prepend_text_var = None
39
+ append_text_var = None
40
+ caption_handling_var = None # Variable to handle radio buttons for caption handling
41
+ start_button = None
42
+ stop_button = None
43
+ model = None
44
+ prompt_entry = None
45
+ select_files_button = None
46
+ show_captions_button = None
47
+ thread_count_entry = None
48
+ precision_entry = None
49
+ batch_size_entry = None
50
+ prepend_text_entry = None
51
+ append_text_entry = None
52
+ root = None
53
+ q = queue.Queue()
54
+
55
+ current_page = 0
56
+ images_per_page = 20
57
+ total_pages = 1
58
+ content_canvas = None
59
+ search_var = None
60
+ original_selected_files = []
61
+ action_var = None
62
+ action_entry = None
63
+
64
+ def load_model():
65
+ global model, tokenizer
66
+ if model is None:
67
+ tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
68
+
69
+ bit_precision = bit_precision_var.get()
70
+
71
+ load_in_4bit = load_in_8bit = False
72
+
73
+ # Thiết lập torch_type dựa trên giá trị bit_precision
74
+ if bit_precision == 4:
75
+ load_in_4bit = True
76
+ torch_type = torch.float16 # Dùng float16 khi sử dụng bitsandbytes
77
+ elif bit_precision == 8:
78
+ load_in_8bit = True
79
+ torch_type = torch.float16 # Dùng float16 khi sử dụng bitsandbytes
80
+ elif bit_precision == 16:
81
+ torch_type = torch.float16
82
+ elif bit_precision == 32:
83
+ torch_type = torch.float32
84
+
85
+ try:
86
+ import bitsandbytes as bnb
87
+ model = AutoModelForCausalLM.from_pretrained(
88
+ 'THUDM/cogvlm-chat-hf',
89
+ torch_dtype=torch_type,
90
+ low_cpu_mem_usage=True,
91
+ load_in_4bit=load_in_4bit,
92
+ load_in_8bit=load_in_8bit,
93
+ trust_remote_code=True,
94
+ )
95
+ except ImportError:
96
+ # Nếu không có bitsandbytes hoặc dùng 16-bit hoặc 32-bit
97
+ model = AutoModelForCausalLM.from_pretrained(
98
+ 'THUDM/cogvlm-chat-hf',
99
+ torch_dtype=torch_type,
100
+ low_cpu_mem_usage=True,
101
+ trust_remote_code=True,
102
+ )
103
+
104
+ # Chỉ chuyển mô hình sang GPU nếu không sử dụng chế độ 4-bit hoặc 8-bit
105
+ if not load_in_4bit and not load_in_8bit:
106
+ model = model.to(torch.device('cuda'))
107
+
108
+ # Đảm bảo chuyển đổi hình sang float32 nếu đang ở chế độ 32-bit
109
+ if bit_precision == 32:
110
+ model = model.to(torch.float32)
111
+ elif bit_precision == 16:
112
+ model = model.to(torch.float16)
113
+
114
+ model.eval()
115
+
116
+ # Kiểm tra thông tin model nạp vào
117
+ print(f"Model loaded with dtype: {torch_type}, 4bit: {load_in_4bit}, 8bit: {load_in_8bit}")
118
+
119
+
120
+ def update_and_save_config():
121
+ top_p_value = top_p_var.get() if do_sample_var.get() else None
122
+ config_entry = {
123
+ 'prompt': prompt_var.get(),
124
+ 'max_new_tokens': max_new_tokens_var.get(),
125
+ 'temperature': temperature_var.get(),
126
+ 'top_k': top_k_var.get(),
127
+ 'top_p': float(top_p_value) if top_p_value is not None else None,
128
+ 'bit_precision': bit_precision_var.get(), # Hợp nhất cả precision bit
129
+ 'thread_count': thread_count_var.get(),
130
+ 'batch_size': batch_size_var.get(),
131
+ 'prepend_text': prepend_text_var.get(),
132
+ 'append_text': append_text_var.get(),
133
+ 'caption_handling': caption_handling_var.get()
134
+ }
135
+
136
+ try:
137
+ with open('captions.json', 'w') as f:
138
+ json.dump(config_entry, f, indent=2)
139
+ except Exception as e:
140
+ print(f"Error saving config to captions.json: {e}")
141
+
142
+ def load_config_from_json():
143
+ try:
144
+ if os.path.exists('captions.json'):
145
+ with open('captions.json', 'r') as f:
146
+ config_entry = json.load(f)
147
+ prompt_var.set(config_entry.get('prompt', ''))
148
+ max_new_tokens_var.set(config_entry.get('max_new_tokens', 200))
149
+ temperature_var.set(config_entry.get('temperature', 1.0))
150
+ top_k_var.set(config_entry.get('top_k', 50))
151
+ top_p_var.set(config_entry.get('top_p', 0.95))
152
+ bit_precision_var.set(config_entry.get('bit_precision', 8)) # Tải bit_precision
153
+ thread_count_var.set(config_entry.get('thread_count', 4))
154
+ batch_size_var.set(config_entry.get('batch_size', 1))
155
+ prepend_text_var.set(config_entry.get('prepend_text', ''))
156
+ append_text_var.set(config_entry.get('append_text', ''))
157
+ caption_handling_var.set(config_entry.get('caption_handling', 'skip'))
158
+
159
+ prompt_entry.delete("1.0", tk.END)
160
+ prompt_entry.insert(tk.END, config_entry.get('prompt', ''))
161
+ except Exception as e:
162
+ print(f"Error loading config from captions.json: {e}")
163
+
164
+ def on_config_change(*args):
165
+ root.after(100, update_config)
166
+
167
+ def update_config():
168
+ try:
169
+ precision_value = precision_var.get()
170
+ if precision_value == "":
171
+ return # Không làm gì nếu giá trị là chuỗi rỗng
172
+
173
+ update_and_save_config()
174
+ except Exception as e:
175
+ print(f"Lỗi khi xử lý giá trị: {e}")
176
+
177
+ def on_prompt_change(event=None):
178
+ prompt_var.set(prompt_entry.get("1.0", tk.END).strip())
179
+ update_and_save_config()
180
+
181
+ def show_errors():
182
+ global error_window
183
+ if error_window is not None:
184
+ return
185
+
186
+ error_window = tk.Toplevel(root)
187
+ error_window.title("Error Details")
188
+ error_window.geometry("500x400")
189
+
190
+ error_text = tk.Text(error_window, wrap='word')
191
+ error_text.pack(expand=True, fill='both')
192
+
193
+ if error_messages:
194
+ for error in error_messages:
195
+ error_text.insert('end', error + '\n')
196
+ else:
197
+ error_text.insert('end', "No errors recorded.")
198
+
199
+ error_text.config(state='disabled')
200
+
201
+ def on_close_error_window():
202
+ global error_window
203
+ error_window.destroy()
204
+ error_window = None
205
+
206
+ error_window.protocol("WM_DELETE_WINDOW", on_close_error_window)
207
+
208
+ def validate_numeric_input(value):
209
+ if value == "" or value == "-":
210
+ return True
211
+ try:
212
+ float(value)
213
+ return True
214
+ except ValueError:
215
+ return False
216
+
217
+ def center_window(window):
218
+ window.update_idletasks()
219
+ width = window.winfo_width()
220
+ height = window.winfo_height()
221
+ x = (window.winfo_screenwidth() // 2) - (width // 2)
222
+ y = (window.winfo_screenheight() // 2) - (height // 2)
223
+ window.geometry(f'{width}x{height}+{x}+{y}')
224
+
225
+ def toggle_sampling_options():
226
+ if do_sample_var.get():
227
+ temperature_label.pack(pady=5, after=do_sample_check)
228
+ temperature_entry.pack(pady=5, after=temperature_label)
229
+ top_k_label.pack(pady=5, after=temperature_entry)
230
+ top_k_entry.pack(pady=5, after=top_k_label)
231
+ top_p_label.pack(pady=5, after=top_k_entry)
232
+ top_p_entry.pack(pady=5, after=top_p_label)
233
+ root.geometry(f"{root.winfo_width()}x{root.winfo_height() + 150}")
234
+ else:
235
+ temperature_label.pack_forget()
236
+ temperature_entry.pack_forget()
237
+ top_k_label.pack_forget()
238
+ top_k_entry.pack_forget()
239
+ top_p_label.pack_forget()
240
+ top_p_entry.pack_forget()
241
+ root.geometry(f"{root.winfo_width()}x{root.winfo_height() - 150}")
242
+ center_window(root)
243
+
244
+ def open_image_to_caption():
245
+ global bit_precision_var, root
246
+ global initial_bit_precision
247
+ global app_initialized
248
+ global stop_processing, error_messages, selected_files, save_directory, status_var, num_files_var, errors_var, progress
249
+ global prompt_var, max_new_tokens_var, do_sample_var, temperature_var, top_k_var, top_p_var, thread_count_var, precision_var, batch_size_var
250
+ global prepend_text_var, append_text_var, search_var, action_var, caption_handling_var
251
+ global start_button, stop_button
252
+ global temperature_label, temperature_entry, top_k_label, top_k_entry, top_p_label, top_p_entry
253
+ global do_sample_check, prompt_entry, select_files_button, show_captions_button, thread_count_entry, precision_entry, batch_size_entry
254
+ global prepend_text_entry, append_text_entry
255
+ global q
256
+
257
+ app_initialized = False
258
+
259
+ # Định nghĩa hàm xử khi bit_precision thay đổi
260
+ def on_bit_precision_change(*args):
261
+ if not app_initialized:
262
+ return
263
+
264
+ update_and_save_config()
265
+
266
+ result = messagebox.showinfo(
267
+ "Bit Precision Changed",
268
+ "You have changed the bit precision. Please restart the app for the changes to take effect."
269
+ )
270
+
271
+ if result == "ok":
272
+ root.destroy() # Tắt ứng dụng hiện tại
273
+ python = sys.executable
274
+ os.execl(python, python, "main.py")
275
+
276
+
277
+ # Initialize the main Tkinter root window
278
+ root = tk.Tk()
279
+ root.title("Image to Caption")
280
+ root.geometry("1050x950")
281
+
282
+ # Khởi tạo các biến Tkinter sau khi root đã được tạo
283
+ status_var = tk.StringVar()
284
+ num_files_var = tk.StringVar()
285
+ errors_var = tk.StringVar(value="Errors: 0")
286
+ progress = tk.IntVar()
287
+ prompt_var = tk.StringVar(value="Describe this image")
288
+ max_new_tokens_var = tk.IntVar(value=200)
289
+ do_sample_var = tk.BooleanVar(value=False)
290
+ temperature_var = tk.DoubleVar(value=1.0)
291
+ top_k_var = tk.IntVar(value=50)
292
+ top_p_var = tk.DoubleVar(value=0.95)
293
+ thread_count_var = tk.IntVar(value=4)
294
+ precision_var = tk.IntVar(value=1)
295
+ batch_size_var = tk.IntVar(value=1)
296
+ prepend_text_var = tk.StringVar()
297
+ append_text_var = tk.StringVar()
298
+ caption_handling_var = tk.StringVar(value='skip') # Default value is 'skip'
299
+ search_var = tk.StringVar() # Biến search_var khởi tạo ở đây
300
+ action_var = tk.StringVar() # Biến action_var khởi tạo ở đây
301
+
302
+ bit_precision_var = tk.IntVar(value=8)
303
+ initial_bit_precision = bit_precision_var.get()
304
+
305
+ q = queue.Queue()
306
+
307
+ validate_cmd = root.register(validate_numeric_input)
308
+
309
+ back_button = tk.Button(root, text="<-", font=('Helvetica', 14), command=return_to_menu)
310
+ back_button.pack(anchor='nw', padx=10, pady=10)
311
+
312
+ title_label = tk.Label(root, text="Image Caption Generator", font=('Helvetica', 16))
313
+ title_label.pack(pady=10)
314
+
315
+ warning_label = tk.Label(root, text="NOTE: 4-bit requires 20GB RAM and 12GB VRAM, 8-bit requires 20GB RAM and 16GB VRAM, 16-bit requires 50GB RAM and 24GB VRAM, 32-bit requires 85GB RAM and 40GB VRAM.",
316
+ font=('Helvetica', 10), fg="red", wraplength=750, justify="left")
317
+ warning_label.pack(pady=10)
318
+
319
+ select_files_button = tk.Button(root, text="Select Files", command=select_files)
320
+ select_files_button.pack(pady=10)
321
+
322
+ show_captions_button = tk.Button(root, text="Show Captions", command=open_caption_window)
323
+ show_captions_button.pack(pady=10)
324
+
325
+ num_files_label = tk.Label(root, textvariable=num_files_var)
326
+ num_files_label.pack(pady=5)
327
+
328
+ bit_frame = tk.Frame(root)
329
+ bit_frame.pack(pady=5)
330
+
331
+ bit_label = tk.Label(bit_frame, text="Select Bit Precision:")
332
+ bit_label.pack(side="left", padx=10)
333
+
334
+ tk.Radiobutton(bit_frame, text="4-bit", variable=bit_precision_var, value=4).pack(side="left", padx=5)
335
+ tk.Radiobutton(bit_frame, text="8-bit", variable=bit_precision_var, value=8).pack(side="left", padx=5)
336
+ tk.Radiobutton(bit_frame, text="16-bit", variable=bit_precision_var, value=16).pack(side="left", padx=5)
337
+ tk.Radiobutton(bit_frame, text="32-bit", variable=bit_precision_var, value=32).pack(side="left", padx=5)
338
+
339
+ prompt_label = tk.Label(root, text="Prompt (text to describe the image):")
340
+ prompt_label.pack(pady=5)
341
+ prompt_entry = tk.Text(root, height=3, wrap='word', width=60)
342
+ prompt_entry.pack(pady=5, padx=10, fill='both', expand=True)
343
+ prompt_entry.bind('<KeyRelease>', on_prompt_change)
344
+
345
+ prepend_text_label = tk.Label(root, text="Prepend Text:")
346
+ prepend_text_label.pack(pady=5)
347
+ prepend_text_entry = tk.Entry(root, textvariable=prepend_text_var, justify='center', width=60)
348
+ prepend_text_entry.pack(pady=5)
349
+
350
+ append_text_label = tk.Label(root, text="Append Text:")
351
+ append_text_label.pack(pady=5)
352
+ append_text_entry = tk.Entry(root, textvariable=append_text_var, justify='center', width=60)
353
+ append_text_entry.pack(pady=5)
354
+
355
+ # Thêm các radio button để xử lý caption khi ảnh đã có caption
356
+ caption_handling_label = tk.Label(root, text="If a caption already exists for an image:", font=('Helvetica', 12))
357
+ caption_handling_label.pack(pady=5)
358
+
359
+ # Frame chứa các radio button
360
+ options_frame = tk.Frame(root)
361
+ options_frame.pack(pady=5)
362
+
363
+ # Radio buttons
364
+ overwrite_radio = tk.Radiobutton(options_frame, text="Overwrite existing caption", variable=caption_handling_var, value='overwrite')
365
+ overwrite_radio.pack(side="left", padx=10)
366
+
367
+ append_radio = tk.Radiobutton(options_frame, text="Append to existing caption", variable=caption_handling_var, value='append')
368
+ append_radio.pack(side="left", padx=10)
369
+
370
+ skip_radio = tk.Radiobutton(options_frame, text="Skip images with existing caption", variable=caption_handling_var, value='skip')
371
+ skip_radio.pack(side="left", padx=10)
372
+
373
+ bit_precision_var.trace('w', on_bit_precision_change)
374
+
375
+ load_config_from_json()
376
+
377
+ app_initialized = True
378
+
379
+ prompt_var.trace('w', on_config_change)
380
+ max_new_tokens_var.trace('w', on_config_change)
381
+ temperature_var.trace('w', on_config_change)
382
+ top_k_var.trace('w', on_config_change)
383
+ top_p_var.trace('w', on_config_change)
384
+ precision_var.trace('w', on_config_change)
385
+ thread_count_var.trace('w', on_config_change)
386
+ batch_size_var.trace('w', on_config_change)
387
+ prepend_text_var.trace('w', on_config_change)
388
+ append_text_var.trace('w', on_config_change)
389
+ caption_handling_var.trace('w', on_config_change) # Trace for the caption handling radio buttons
390
+
391
+ max_new_tokens_label = tk.Label(root, text="Max New Tokens (max number of tokens to generate):")
392
+ max_new_tokens_label.pack(pady=5)
393
+ max_new_tokens_entry = tk.Entry(root, textvariable=max_new_tokens_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
394
+ max_new_tokens_entry.pack(pady=5)
395
+
396
+ do_sample_check = tk.Checkbutton(root, text="Do Sample (random sampling):", variable=do_sample_var, command=toggle_sampling_options)
397
+ do_sample_check.pack(pady=5)
398
+
399
+ temperature_label = tk.Label(root, text="Temperature (control randomness of sampling):")
400
+ top_k_label = tk.Label(root, text="Top-k (consider top k tokens):")
401
+ top_p_label = tk.Label(root, text="Top-p (consider tokens with cumulative probability p):")
402
+
403
+ temperature_entry = tk.Entry(root, textvariable=temperature_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
404
+ top_k_entry = tk.Entry(root, textvariable=top_k_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
405
+ top_p_entry = tk.Entry(root, textvariable=top_p_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
406
+
407
+ # Frame to hold all three horizontally aligned elements
408
+ horizontal_frame = tk.Frame(root)
409
+ horizontal_frame.pack(pady=5, padx=5)
410
+
411
+ thread_count_label = tk.Label(horizontal_frame, text="Thread Count (number of threads to use):")
412
+ thread_count_label.pack(side=tk.LEFT, padx=5)
413
+ thread_count_entry = tk.Entry(horizontal_frame, textvariable=thread_count_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
414
+ thread_count_entry.pack(side=tk.LEFT, padx=5)
415
+
416
+ batch_size_label = tk.Label(horizontal_frame, text="Batch Size (number of images to process at once):")
417
+ batch_size_label.pack(side=tk.LEFT, padx=5)
418
+ batch_size_entry = tk.Entry(horizontal_frame, textvariable=batch_size_var, justify='center', width=5, validate='key', validatecommand=(validate_cmd, '%P'))
419
+ batch_size_entry.pack(side=tk.LEFT, padx=5)
420
+
421
+ errors_button = tk.Button(root, textvariable=errors_var, command=show_errors)
422
+ errors_button.pack(pady=10)
423
+
424
+ start_button = tk.Button(root, text="Generate Captions", command=lambda: [process_files(), update_and_save_config()])
425
+ start_button.pack(pady=10)
426
+
427
+ stop_button = tk.Button(root, text="Stop", command=stop_processing_func)
428
+ stop_button.pack(pady=10)
429
+
430
+ progress_bar = ttk.Progressbar(root, variable=progress, maximum=100)
431
+ progress_bar.pack(pady=10, fill=tk.X)
432
+
433
+ status_label = tk.Label(root, textvariable=status_var, fg="green")
434
+ status_label.pack(pady=5)
435
+
436
+ center_window(root)
437
+ root.protocol("WM_DELETE_WINDOW", on_closing)
438
+ root.mainloop()
439
+
440
+ def select_files():
441
+ global selected_files, save_directory, total_pages, original_selected_files
442
+ filetypes = [("All Image files", "*.jpg;*.jpeg;*.png;*.gif;*.bmp;*.tiff;*.tif;*.svg;*.webp")]
443
+ filepaths = filedialog.askopenfilenames(title="Select Image Files", filetypes=filetypes)
444
+ if filepaths:
445
+ selected_files.clear()
446
+ selected_files.extend(filepaths)
447
+ original_selected_files = selected_files.copy()
448
+ validate_selected_files()
449
+
450
+ num_files_var.set(f"{len(selected_files)} files selected.")
451
+ save_directory = os.path.dirname(selected_files[0])
452
+ total_pages = (len(selected_files) + images_per_page - 1) // images_per_page
453
+ if caption_window is not None:
454
+ update_image_preview(content_canvas)
455
+
456
+ def validate_selected_files():
457
+ global selected_files, num_files_var
458
+ selected_files = [file for file in selected_files if os.path.exists(file)]
459
+ num_files_var.set(f"{len(selected_files)} files selected.")
460
+
461
+ def toggle_buttons(state):
462
+ state = tk.NORMAL if state else tk.DISABLED
463
+ select_files_button.config(state=state)
464
+ show_captions_button.config(state=state)
465
+ prompt_entry.config(state=state)
466
+ prepend_text_entry.config(state=state)
467
+ append_text_entry.config(state=state)
468
+ do_sample_check.config(state=state)
469
+ temperature_entry.config(state=state)
470
+ top_k_entry.config(state=state)
471
+ top_p_entry.config(state=state)
472
+ thread_count_entry.config(state=state)
473
+ batch_size_entry.config(state=state)
474
+ start_button.config(state=state)
475
+ stop_button.config(state=tk.NORMAL)
476
+
477
+ def generate_caption(image_path, save_directory, q):
478
+ if stop_processing:
479
+ return
480
+
481
+ try:
482
+ load_model()
483
+
484
+ filename = os.path.basename(image_path)
485
+ caption_file_path = os.path.join(save_directory, f"{filename}_caption.txt")
486
+
487
+ # Kiểm tra các lựa chọn của người dùng
488
+ if os.path.exists(caption_file_path):
489
+ if caption_handling_var.get() == 'skip':
490
+ q.put(image_path)
491
+ return
492
+ elif caption_handling_var.get() == 'append':
493
+ with open(caption_file_path, 'r', encoding='utf-8') as f:
494
+ existing_caption = f.read()
495
+ else:
496
+ existing_caption = ""
497
+ else:
498
+ existing_caption = ""
499
+
500
+ image = PILImage.open(image_path).convert('RGB')
501
+ if not isinstance(image, PILImage.Image):
502
+ raise ValueError(f"Expected image to be of type PIL.Image.Image, but got {type(image)}")
503
+
504
+ inputs = model.build_conversation_input_ids(
505
+ tokenizer,
506
+ query=prompt_var.get(),
507
+ history=[],
508
+ images=[image]
509
+ )
510
+
511
+ # Điều chỉnh dtype dựa trên bit_precision
512
+ if bit_precision_var.get() == 32:
513
+ image_tensor = inputs['images'][0].to('cuda').to(torch.float32)
514
+ else:
515
+ image_tensor = inputs['images'][0].to('cuda').to(torch.float16)
516
+
517
+ inputs = {
518
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
519
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
520
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
521
+ 'images': [[image_tensor]],
522
+ }
523
+
524
+ gen_kwargs = {
525
+ "max_new_tokens": max_new_tokens_var.get(),
526
+ "do_sample": do_sample_var.get(),
527
+ "temperature": temperature_var.get(),
528
+ "top_k": top_k_var.get(),
529
+ "top_p": top_p_var.get() if do_sample_var.get() else None,
530
+ "num_beams": precision_var.get()
531
+ }
532
+
533
+ with torch.no_grad():
534
+ outputs = model.generate(**inputs, **gen_kwargs)
535
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
536
+ new_caption = tokenizer.decode(outputs[0], skip_special_tokens=True)
537
+
538
+ final_caption = f"{prepend_text_var.get()} {existing_caption} {new_caption} {append_text_var.get()}".strip()
539
+
540
+ with open(caption_file_path, 'w', encoding='utf-8') as file:
541
+ file.write(final_caption)
542
+
543
+ q.put(image_path)
544
+ torch.cuda.empty_cache()
545
+ except torch.cuda.OutOfMemoryError as e:
546
+ torch.cuda.empty_cache()
547
+ error_message = f"CUDA OutOfMemoryError: {traceback.format_exc()}"
548
+ print(error_message)
549
+ q.put(error_message)
550
+ error_messages.append(error_message)
551
+ except Exception as e:
552
+ error_message = f"Error processing image {image_path}: {traceback.format_exc()}"
553
+ print(error_message)
554
+ q.put(error_message)
555
+ error_messages.append(error_message)
556
+
557
+
558
+ def worker(save_directory, num_threads, batch_size):
559
+ try:
560
+ progress.set(0)
561
+ threads = []
562
+
563
+ num_batches = math.ceil(len(selected_files) / batch_size)
564
+ batch_size_per_thread = max(1, batch_size // num_threads) # Số ảnh mỗi luồng xử lý trong một batch
565
+
566
+ for batch_index in range(num_batches):
567
+ if stop_processing:
568
+ break
569
+
570
+ start_index = batch_index * batch_size
571
+ end_index = min(start_index + batch_size, len(selected_files))
572
+ batch = selected_files[start_index:end_index]
573
+
574
+ # Chia ảnh trong batch cho các luồng
575
+ for i in range(0, len(batch), batch_size_per_thread):
576
+ thread_batch = batch[i:i + batch_size_per_thread]
577
+ thread = threading.Thread(target=generate_captions_for_batch, args=(thread_batch, save_directory, q))
578
+ threads.append(thread)
579
+ thread.start()
580
+
581
+ # Đợi các luồng trong batch hiện tại hoàn thành
582
+ for thread in threads:
583
+ thread.join()
584
+ threads.clear()
585
+
586
+ q.put(None)
587
+ except Exception as e:
588
+ if not stop_processing:
589
+ q.put(e)
590
+
591
+ def generate_captions_for_batch(batch, save_directory, q):
592
+ for image_path in batch:
593
+ generate_caption(image_path, save_directory, q)
594
+
595
+ def update_progress():
596
+ try:
597
+ completed = 0
598
+ while True:
599
+ item = q.get()
600
+ if item is None:
601
+ break
602
+ if isinstance(item, str):
603
+ if "Error" in item:
604
+ root.after(0, errors_var.set, f"Errors: {len(error_messages)}")
605
+ continue
606
+ completed += 1
607
+ progress.set(int((completed / len(selected_files)) * 100))
608
+ if not stop_processing:
609
+ root.after(0, status_var.set, f"Processed {completed} files")
610
+ root.after(0, root.update_idletasks)
611
+ if not stop_processing:
612
+ root.after(0, progress.set(100))
613
+ show_completion_message(completed)
614
+ except Exception as e:
615
+ if not stop_processing:
616
+ root.after(0, status_var.set(f"Error: {e}"))
617
+ finally:
618
+ toggle_buttons(True)
619
+
620
+ def show_completion_message(completed):
621
+ message = f"Processing complete. {completed} files processed."
622
+ if error_messages:
623
+ message += f" {len(error_messages)} errors occurred."
624
+ messagebox.showinfo("Process Complete", message)
625
+
626
+ def process_files():
627
+ global stop_processing, error_messages
628
+ stop_processing = False
629
+ error_messages.clear()
630
+ errors_var.set("Errors: 0")
631
+
632
+ validate_selected_files()
633
+
634
+ if not selected_files or not save_directory:
635
+ status_var.set("Please select images.")
636
+ return
637
+
638
+ toggle_buttons(False)
639
+
640
+ threading.Thread(target=worker, args=(save_directory, thread_count_var.get(), batch_size_var.get())).start()
641
+ threading.Thread(target=update_progress).start()
642
+
643
+ def stop_processing_func():
644
+ global stop_processing
645
+ stop_processing = True
646
+ torch.cuda.empty_cache()
647
+ status_var.set("Processing stopped.")
648
+
649
+ def open_caption_window():
650
+ global caption_window, caption_frame, caption_text_widgets, current_page, total_pages, content_canvas
651
+ if caption_window is not None:
652
+ return
653
+
654
+ validate_selected_files()
655
+
656
+ caption_window = tk.Toplevel(root)
657
+ caption_window.title("Image Thumbnails and Captions")
658
+ caption_window.geometry("940x900")
659
+
660
+ main_frame = tk.Frame(caption_window)
661
+ main_frame.pack(fill=tk.BOTH, expand=True)
662
+
663
+ search_frame = tk.Frame(main_frame)
664
+ search_frame.pack(side=tk.TOP, fill=tk.X)
665
+
666
+ search_entry = tk.Entry(search_frame, textvariable=search_var)
667
+ search_entry.pack(side=tk.LEFT, padx=10, pady=5, fill=tk.X, expand=True)
668
+
669
+ search_button = tk.Button(search_frame, text="Search", command=search_captions)
670
+ search_button.pack(side=tk.LEFT, padx=10)
671
+
672
+ reset_button = tk.Button(search_frame, text="Reset Order", command=reset_order)
673
+ reset_button.pack(side=tk.LEFT, padx=10)
674
+
675
+ action_frame = tk.Frame(main_frame)
676
+ action_frame.pack(side=tk.TOP, fill=tk.X)
677
+
678
+ action_entry = tk.Entry(action_frame, textvariable=action_var)
679
+ action_entry.pack(side=tk.LEFT, padx=10, pady=5, fill=tk.X, expand=True)
680
+
681
+ prepend_button = tk.Button(action_frame, text="Add to Beginning", command=lambda: add_to_captions("prepend"))
682
+ prepend_button.pack(side=tk.LEFT, padx=5)
683
+
684
+ append_button = tk.Button(action_frame, text="Add to End", command=lambda: add_to_captions("append"))
685
+ append_button.pack(side=tk.LEFT, padx=5)
686
+
687
+ insert_middle_button = tk.Button(action_frame, text="Add to Middle", command=lambda: add_to_captions("insert_middle"))
688
+ insert_middle_button.pack(side=tk.LEFT, padx=5)
689
+
690
+ delete_keyword_button = tk.Button(action_frame, text="Delete Keyword", command=delete_keyword_from_captions)
691
+ delete_keyword_button.pack(side=tk.LEFT, padx=5)
692
+
693
+ delete_images_button = tk.Button(action_frame, text="Delete Images with Keyword", command=delete_images_with_keyword)
694
+ delete_images_button.pack(side=tk.LEFT, padx=5)
695
+
696
+ content_canvas = tk.Canvas(main_frame)
697
+ content_canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
698
+
699
+ caption_frame = tk.Frame(content_canvas)
700
+ content_canvas.create_window((0, 0), window=caption_frame, anchor='nw')
701
+
702
+ caption_scrollbar = tk.Scrollbar(main_frame, orient="vertical", command=content_canvas.yview)
703
+ caption_scrollbar.pack(side=tk.LEFT, fill=tk.Y)
704
+ content_canvas.configure(yscrollcommand=caption_scrollbar.set)
705
+
706
+ caption_frame.bind("<Configure>", lambda e: content_canvas.configure(scrollregion=content_canvas.bbox("all")))
707
+ content_canvas.bind_all("<MouseWheel>", lambda event: content_canvas.yview_scroll(int(-1*(event.delta/120)), "units"))
708
+
709
+ # Định nghĩa hàm on_mouse_wheel
710
+ def on_mouse_wheel(event):
711
+ try:
712
+ if content_canvas.winfo_exists():
713
+ content_canvas.yview_scroll(int(-1*(event.delta/120)), "units")
714
+ except tk.TclError:
715
+ pass
716
+
717
+ content_canvas.bind_all("<MouseWheel>", on_mouse_wheel)
718
+
719
+ def on_caption_window_close():
720
+ global caption_window
721
+ caption_window.destroy()
722
+ caption_window = None
723
+
724
+ caption_window.protocol("WM_DELETE_WINDOW", on_caption_window_close)
725
+
726
+ update_image_preview(content_canvas)
727
+
728
+ def update_image_preview(content_canvas):
729
+ global thumbnails, caption_text_widgets, current_page, images_per_page, total_pages
730
+ if caption_frame is None:
731
+ return
732
+
733
+ for widget in caption_frame.winfo_children():
734
+ if isinstance(widget, tk.Label) or isinstance(widget, tk.Text) or isinstance(widget, tk.Frame):
735
+ widget.destroy()
736
+
737
+ thumbnails.clear()
738
+ caption_text_widgets.clear()
739
+
740
+ if not selected_files:
741
+ return
742
+
743
+ start_index = current_page * images_per_page
744
+ end_index = start_index + images_per_page
745
+ files_to_display = selected_files[start_index:end_index]
746
+
747
+ for i, file_path in enumerate(files_to_display):
748
+ thumbnail_size = (200, 200)
749
+ try:
750
+ image = PILImage.open(file_path)
751
+ image.thumbnail(thumbnail_size)
752
+ thumbnail = ImageTk.PhotoImage(image)
753
+ thumbnails.append(thumbnail)
754
+
755
+ img_label = tk.Label(caption_frame, image=thumbnail)
756
+ img_label.grid(row=i*2, column=0, padx=5, pady=5, sticky="nsew")
757
+
758
+ file_label = tk.Label(caption_frame, text=os.path.basename(file_path), font=('Helvetica', 12), wraplength=300, justify="left")
759
+ file_label.grid(row=i*2, column=1, padx=5, pady=5, sticky="nsew")
760
+
761
+ caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
762
+ if os.path.exists(caption_file):
763
+ with open(caption_file, 'r', encoding='utf-8') as file:
764
+ caption_text = file.read()
765
+ else:
766
+ caption_text = ""
767
+
768
+ caption_var = tk.StringVar(value=caption_text)
769
+
770
+ caption_text_widget = tk.Text(caption_frame, width=50, height=3, wrap=tk.WORD, font=('Helvetica', 12))
771
+ caption_text_widget.insert(tk.END, caption_text)
772
+ caption_text_widget.grid(row=i*2, column=2, padx=5, pady=5, sticky="nsew")
773
+
774
+ caption_var.trace_add("write", lambda *args, fp=file_path, cv=caption_var: save_caption(fp, cv.get()))
775
+
776
+ caption_text_widget.bind("<KeyRelease>", lambda e, cv=caption_var, w=caption_text_widget: cv.set(w.get("1.0", "end-1c")))
777
+ caption_text_widgets.append(caption_text_widget)
778
+
779
+ except Exception as e:
780
+ tk.Label(caption_frame, text="Error loading image").grid(row=i*2, column=0, columnspan=4, padx=5, pady=5)
781
+
782
+ nav_frame = tk.Frame(caption_frame)
783
+ nav_frame.grid(row=images_per_page*2, column=0, columnspan=3, pady=10)
784
+
785
+ if current_page > 0:
786
+ prev_button = tk.Button(nav_frame, text="Previous", command=lambda: navigate(-1, content_canvas))
787
+ prev_button.pack(side=tk.LEFT)
788
+
789
+ page_label = tk.Label(nav_frame, text=f"Page {current_page + 1} of {total_pages}")
790
+ page_label.pack(side=tk.LEFT, padx=5)
791
+
792
+ page_entry = tk.Entry(nav_frame, width=5)
793
+ page_entry.pack(side=tk.LEFT)
794
+
795
+ go_button = tk.Button(nav_frame, text="Go", command=lambda: go_to_page(page_entry.get(), content_canvas))
796
+ go_button.pack(side=tk.LEFT, padx=5)
797
+
798
+ if current_page < total_pages - 1:
799
+ next_button = tk.Button(nav_frame, text="Next", command=lambda: navigate(1, content_canvas))
800
+ next_button.pack(side=tk.RIGHT)
801
+
802
+ def navigate(direction, content_canvas):
803
+ global current_page
804
+ current_page += direction
805
+ update_image_preview(content_canvas)
806
+
807
+ def go_to_page(page_number, content_canvas):
808
+ global current_page, total_pages
809
+ try:
810
+ page_number = int(page_number)
811
+ if 1 <= page_number <= total_pages:
812
+ current_page = page_number - 1
813
+ update_image_preview(content_canvas)
814
+ else:
815
+ messagebox.showerror("Invalid Page", f"Please enter a valid page number between 1 and {total_pages}.")
816
+ except ValueError:
817
+ messagebox.showerror("Invalid Input", "Please enter a valid integer for the page number.")
818
+
819
+ def save_caption(file_path, caption_text):
820
+ output_path = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
821
+ try:
822
+ with open(output_path, 'w', encoding='utf-8') as file:
823
+ file.write(caption_text.strip())
824
+ except Exception as e:
825
+ print(f"Error saving captions: {e}")
826
+
827
+ def search_captions():
828
+ global selected_files
829
+ search_term = search_var.get().lower().strip()
830
+ if not search_term:
831
+ return
832
+
833
+ try:
834
+ selected_files.sort(key=lambda x: search_score(x, search_term), reverse=True)
835
+ except Exception as e:
836
+ error_message = f"Error during sorting: {e}"
837
+ print(error_message)
838
+ error_messages.append(error_message)
839
+
840
+ update_image_preview(content_canvas)
841
+
842
+ def search_score(file_path, search_term):
843
+ caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
844
+ try:
845
+ if os.path.exists(caption_file):
846
+ with open(caption_file, 'r', encoding='utf-8') as file:
847
+ caption_text = file.read().lower()
848
+ if search_term in caption_text:
849
+ return caption_text.count(search_term)
850
+
851
+ except Exception as e:
852
+ error_message = f"Error reading file {caption_file}: {e}"
853
+ print(error_message)
854
+ error_messages.append(error_message)
855
+ return 0
856
+
857
+ def reset_order():
858
+ global selected_files
859
+ selected_files = original_selected_files.copy()
860
+ update_image_preview(content_canvas)
861
+
862
+ def add_to_captions(position):
863
+ global selected_files
864
+ keyword = action_var.get()
865
+ if not keyword:
866
+ return
867
+
868
+ for file_path in selected_files:
869
+ caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
870
+ if os.path.exists(caption_file):
871
+ with open(caption_file, 'r+', encoding='utf-8') as file:
872
+ caption_text = file.read()
873
+ if position == "prepend":
874
+ caption_text = f"{keyword} {caption_text}"
875
+ elif position == "append":
876
+ caption_text = f"{caption_text} {keyword}"
877
+ elif position == "insert_middle":
878
+ middle_index = len(caption_text) // 2
879
+ caption_text = f"{caption_text[:middle_index]} {keyword} {caption_text[middle_index:]}"
880
+ file.seek(0)
881
+ file.write(caption_text)
882
+ file.truncate()
883
+
884
+ update_image_preview(content_canvas)
885
+
886
+ def delete_keyword_from_captions():
887
+ keyword = action_var.get().lower().strip()
888
+ if not keyword:
889
+ return
890
+
891
+ for file_path in selected_files:
892
+ caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
893
+ if os.path.exists(caption_file):
894
+ with open(caption_file, 'r+', encoding='utf-8') as file:
895
+ caption_text = file.read().lower().replace(keyword, "")
896
+
897
+ updated_caption = caption_text.replace(keyword, "").strip()
898
+
899
+ file.seek(0)
900
+ file.write(updated_caption)
901
+ file.truncate()
902
+
903
+ update_image_preview(content_canvas)
904
+
905
+ def delete_images_with_keyword():
906
+ global selected_files
907
+ keyword = action_var.get().lower()
908
+ if not keyword:
909
+ return
910
+
911
+ files_to_delete = []
912
+ for file_path in selected_files:
913
+ caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
914
+ if os.path.exists(caption_file):
915
+ with open(caption_file, 'r', encoding='utf-8') as file:
916
+ caption_text = file.read().lower()
917
+ if keyword in caption_text:
918
+ files_to_delete.append(file_path)
919
+
920
+ for file_path in files_to_delete:
921
+ try:
922
+ os.remove(file_path)
923
+ caption_file = os.path.join(save_directory, f"{os.path.basename(file_path)}_caption.txt")
924
+ if os.path.exists(caption_file):
925
+ os.remove(caption_file)
926
+ except Exception as e:
927
+ error_message = f"Error deleting file {file_path} or its caption: {e}"
928
+ print(error_message)
929
+ error_messages.append(error_message)
930
+
931
+ selected_files = [file_path for file_path in selected_files if file_path not in files_to_delete]
932
+
933
+ validate_selected_files()
934
+
935
+ update_image_preview(content_canvas)
936
+
937
+ def return_to_menu():
938
+ stop_processing_func()
939
+ root.destroy()
940
+ import main
941
+ main.open_main_menu()
942
+
943
+ def on_closing():
944
+ return_to_menu()
945
+
946
+ if __name__ == "__main__":
947
+ open_image_to_caption()