Wi-zz commited on
Commit
6b99536
·
verified ·
1 Parent(s): 0bc9e8b

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +54 -0
  2. app.py +259 -0
  3. requirements.txt +8 -0
  4. wpkklhc6/config.yaml +32 -0
  5. wpkklhc6/image_adapter.pt +3 -0
README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Here's a concise, structured, and aesthetically formatted markdown description for your GitHub repo README:
2
+
3
+ # Image Captioning App
4
+
5
+ ## Overview
6
+
7
+ This application generates descriptive captions for images using advanced ML models. It processes single images or entire directories, leveraging CLIP and LLM models for accurate and contextual captions. It has NSFW captioning support with natural language.
8
+
9
+ ## Features
10
+
11
+ - Single image and batch processing
12
+ - Multiple directory support
13
+ - Custom output directory
14
+ - Adjustable batch size
15
+ - Progress tracking
16
+
17
+ ## Usage
18
+
19
+ | Command | Description |
20
+ |---------|-------------|
21
+ | `python app.py image.jpg` | Process a single image |
22
+ | `python app.py /path/to/directory` | Process all images in a directory |
23
+ | `python app.py /path/to/dir1 /path/to/dir2` | Process multiple directories |
24
+ | `python app.py /path/to/dir --output /path/to/output` | Specify output directory |
25
+ | `python app.py /path/to/dir --bs 8` | Set batch size (default: 4) |
26
+
27
+ ## Technical Details
28
+
29
+ - **Models**: CLIP (vision), LLM (language), custom ImageAdapter
30
+ - **Optimization**: CUDA-enabled GPU support
31
+ - **Error Handling**: Skips problematic images in batch processing
32
+
33
+ ## Requirements
34
+
35
+ - Python 3.x
36
+ - PyTorch
37
+ - Transformers library
38
+ - CUDA-capable GPU (recommended)
39
+
40
+ ## Installation
41
+
42
+ ```bash
43
+ git clone https://huggingface.co/Wi-zz/joy-caption-pre-alpha
44
+ cd joy-caption-pre-alpha
45
+ pip install -r requirements.txt
46
+ ```
47
+
48
+ ## Contributing
49
+
50
+ Contributions are welcome! Please feel free to submit a Pull Request.
51
+
52
+ ## License
53
+
54
+ This project is licensed under the [MIT License](LICENSE).
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # For a single image
2
+ # python app.py image.jpg
3
+
4
+ # # For a single directory
5
+ # python app.py /path/to/directory
6
+
7
+ # # For multiple directories
8
+ # python app.py /path/to/directory1 /path/to/directory2 /path/to/directory3
9
+
10
+ # # With output directory specified
11
+ # python app.py /path/to/directory1 /path/to/directory2 --output /path/to/output
12
+
13
+ # # With batch size specified
14
+ # python app.py /path/to/directory1 /path/to/directory2 --bs 8
15
+
16
+ import torch
17
+ import torch.amp.autocast_mode
18
+ import os
19
+ import sys
20
+ import logging
21
+ import warnings
22
+ import argparse
23
+ from PIL import Image
24
+ from pathlib import Path
25
+ from tqdm import tqdm
26
+ from torch import nn
27
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
28
+ from typing import List
29
+
30
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
31
+ VLM_PROMPT = "A descriptive caption for this image:\n"
32
+ MODEL_PATH = "unsloth/Meta-Llama-3.1-8B-bnb-4bit"
33
+ CHECKPOINT_PATH = Path("wpkklhc6")
34
+ warnings.filterwarnings("ignore", category=UserWarning)
35
+
36
+ class ImageAdapter(nn.Module):
37
+ def __init__(self, input_features: int, output_features: int):
38
+ super().__init__()
39
+ self.linear1 = nn.Linear(input_features, output_features)
40
+ self.activation = nn.GELU()
41
+ self.linear2 = nn.Linear(output_features, output_features)
42
+
43
+ def forward(self, vision_outputs: torch.Tensor):
44
+ x = self.linear1(vision_outputs)
45
+ x = self.activation(x)
46
+ x = self.linear2(x)
47
+ return x
48
+
49
+ # Load CLIP
50
+ print("Loading CLIP 📎")
51
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
52
+ clip_model = AutoModel.from_pretrained(CLIP_PATH)
53
+ clip_model = clip_model.vision_model
54
+ clip_model.eval()
55
+ clip_model.requires_grad_(False)
56
+ clip_model.to("cuda")
57
+
58
+ # Tokenizer
59
+ print("Loading tokenizer 🪙")
60
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
61
+ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
62
+
63
+ # LLM
64
+ print("Loading LLM 🤖")
65
+ logging.getLogger("transformers").setLevel(logging.ERROR)
66
+ text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
67
+ text_model.eval()
68
+
69
+ # Image Adapter
70
+ print("Loading image adapter 🖼️")
71
+ image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size)
72
+ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
73
+ image_adapter.eval()
74
+ image_adapter.to("cuda")
75
+
76
+ @torch.no_grad()
77
+ def stream_chat(input_images: List[Image.Image], batch_size=4, pbar=None):
78
+ torch.cuda.empty_cache()
79
+ all_captions = []
80
+
81
+ if not isinstance(input_images, list):
82
+ input_images = [input_images]
83
+
84
+ for i in range(0, len(input_images), batch_size):
85
+ batch = input_images[i:i+batch_size]
86
+
87
+ # Preprocess image batch
88
+ try:
89
+ images = clip_processor(images=batch, return_tensors='pt', padding=True).pixel_values
90
+ except ValueError as e:
91
+ print(f"Error processing image batch: {e}")
92
+ print("Skipping this batch and continuing...")
93
+ continue
94
+
95
+ images = images.to('cuda')
96
+
97
+ # Embed image batch
98
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
99
+ vision_outputs = clip_model(pixel_values=images, output_hidden_states=True)
100
+ image_features = vision_outputs.hidden_states[-2]
101
+ embedded_images = image_adapter(image_features)
102
+ embedded_images = embedded_images.to(dtype=torch.bfloat16)
103
+
104
+ # Embed prompt
105
+ prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt')
106
+ prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda')).to(dtype=torch.bfloat16)
107
+ embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64)).to(dtype=torch.bfloat16)
108
+
109
+ # Construct prompts
110
+ inputs_embeds = torch.cat([
111
+ embedded_bos.expand(embedded_images.shape[0], -1, -1),
112
+ embedded_images,
113
+ prompt_embeds.expand(embedded_images.shape[0], -1, -1),
114
+ ], dim=1).to(dtype=torch.bfloat16)
115
+
116
+ input_ids = torch.cat([
117
+ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).expand(embedded_images.shape[0], -1),
118
+ torch.zeros((embedded_images.shape[0], embedded_images.shape[1]), dtype=torch.long),
119
+ prompt.expand(embedded_images.shape[0], -1),
120
+ ], dim=1).to('cuda')
121
+
122
+ attention_mask = torch.ones_like(input_ids)
123
+
124
+ generate_ids = text_model.generate(
125
+ input_ids=input_ids,
126
+ inputs_embeds=inputs_embeds,
127
+ attention_mask=attention_mask,
128
+ max_new_tokens=300,
129
+ do_sample=True,
130
+ top_k=10,
131
+ temperature=0.5,
132
+ )
133
+
134
+ if pbar:
135
+ pbar.update(len(batch))
136
+
137
+ # Trim off the prompt
138
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
139
+
140
+ for ids in generate_ids:
141
+ if ids[-1] == tokenizer.eos_token_id:
142
+ ids = ids[:-1]
143
+ caption = tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
144
+ # Remove any remaining special tokens
145
+ caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip()
146
+ all_captions.append(caption)
147
+
148
+ return all_captions
149
+
150
+ def preprocess_image(img):
151
+ return img.convert('RGBA')
152
+
153
+ def process_image(image_path, output_path, pbar=None):
154
+ try:
155
+ with Image.open(image_path) as img:
156
+ # Convert image to RGB
157
+ img = img.convert('RGB')
158
+ caption = stream_chat([img], pbar=pbar)[0]
159
+ with open(output_path, 'w', encoding='utf-8') as f:
160
+ f.write(caption)
161
+ except Exception as e:
162
+ print(f"Error processing {image_path}: {e}")
163
+ if pbar:
164
+ pbar.update(1)
165
+ return
166
+
167
+ with Image.open(image_path) as img:
168
+ # Pass the image as a list to stream_chat
169
+ caption = stream_chat([img], pbar=pbar)[0] # Get the first (and only) caption
170
+
171
+ with open(output_path, 'w', encoding='utf-8') as f:
172
+ f.write(caption)
173
+
174
+ def process_directory(input_dir, output_dir, batch_size):
175
+ input_path = Path(input_dir)
176
+ output_path = Path(output_dir)
177
+ output_path.mkdir(parents=True, exist_ok=True)
178
+
179
+ image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
180
+ image_files = [f for f in input_path.iterdir() if f.suffix.lower() in image_extensions]
181
+
182
+ # Create a list to store images that need processing
183
+ images_to_process = []
184
+
185
+ # Check which images need processing
186
+ for file in image_files:
187
+ output_file = output_path / (file.stem + '.txt')
188
+ if not output_file.exists():
189
+ images_to_process.append(file)
190
+ else:
191
+ print(f"Skipping {file.name} - Caption already exists")
192
+
193
+ # Process images in batches
194
+ with tqdm(total=len(images_to_process), desc="Processing images", unit="image") as pbar:
195
+ for i in range(0, len(images_to_process), batch_size):
196
+ batch_files = images_to_process[i:i+batch_size]
197
+ batch_images = []
198
+ for f in batch_files:
199
+ try:
200
+ img = Image.open(f).convert('RGB')
201
+ batch_images.append(img)
202
+ except Exception as e:
203
+ print(f"Error opening {f}: {e}")
204
+ continue
205
+
206
+ if batch_images:
207
+ captions = stream_chat(batch_images, batch_size, pbar)
208
+ for file, caption in zip(batch_files, captions):
209
+ output_file = output_path / (file.stem + '.txt')
210
+ with open(output_file, 'w', encoding='utf-8') as f:
211
+ f.write(caption)
212
+
213
+ # Close the image files
214
+ for img in batch_images:
215
+ img.close()
216
+
217
+ def parse_arguments():
218
+ parser = argparse.ArgumentParser(description="Process images and generate captions.")
219
+ parser.add_argument("input", nargs='+', help="Input image file or directory (or multiple directories)")
220
+ parser.add_argument("--output", help="Output directory (optional)")
221
+ parser.add_argument("--bs", type=int, default=4, help="Batch size (default: 4)")
222
+ return parser.parse_args()
223
+
224
+ def is_image_file(file_path):
225
+ image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.webp')
226
+ return Path(file_path).suffix.lower() in image_extensions
227
+
228
+ # Main execution
229
+ if __name__ == "__main__":
230
+ args = parse_arguments()
231
+ input_paths = [Path(input_path) for input_path in args.input]
232
+ batch_size = args.bs
233
+
234
+ for input_path in input_paths:
235
+ if input_path.is_file() and is_image_file(input_path):
236
+ # Single file processing
237
+ output_path = input_path.with_suffix('.txt')
238
+ print(f"Processing single image 🎞️: {input_path.name}")
239
+ with tqdm(total=1, desc="Processing image", unit="image") as pbar:
240
+ process_image(input_path, output_path, pbar)
241
+ print(f"Output saved to {output_path}")
242
+ elif input_path.is_dir():
243
+ # Directory processing
244
+ output_path = Path(args.output) if args.output else input_path
245
+ print(f"Processing directory 📁: {input_path}")
246
+ print(f"Output directory 📦: {output_path}")
247
+ print(f"Batch size 🗄️: {batch_size}")
248
+ process_directory(input_path, output_path, batch_size)
249
+ else:
250
+ print(f"Invalid input: {input_path}")
251
+ print("Skipping...")
252
+
253
+ if not input_paths:
254
+ print("Usage:")
255
+ print("For single image: python app.py [image_file] [--bs batch_size]")
256
+ print("For directory (same input/output): python app.py [directory] [--bs batch_size]")
257
+ print("For directory (separate input/output): python app.py [directory] --output [output_directory] [--bs batch_size]")
258
+ print("For multiple directories: python app.py [directory1] [directory2] ... [--output output_directory] [--bs batch_size]")
259
+ sys.exit(1)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.24.3
2
+ accelerate
3
+ torch
4
+ transformers==4.43.3
5
+ sentencepiece
6
+ bitsandbytes
7
+ Pillow
8
+ protobuf
wpkklhc6/config.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_project: joy-caption-1
2
+ device_batch_size: 2
3
+ batch_size: 256
4
+ learning_rate: 0.001
5
+ warmup_samples: 18000
6
+ max_samples: 600000
7
+ save_every: 50000
8
+ test_every: 50000
9
+ use_amp: true
10
+ grad_scaler: true
11
+ lr_scheduler_type: cosine
12
+ min_lr_ratio: 0.0
13
+ allow_tf32: true
14
+ seed: 42
15
+ num_workers: 8
16
+ optimizer_type: adamw
17
+ adam_beta1: 0.9
18
+ adam_beta2: 0.999
19
+ adam_eps: 1.0e-08
20
+ adam_weight_decay: 0.0
21
+ clip_grad_norm: 1.0
22
+ dataset: fancyfeast/joy-captioning-20240729a
23
+ clip_model: google/siglip-so400m-patch14-384
24
+ text_model: meta-llama/Meta-Llama-3.1-8B
25
+ resume: null
26
+ gradient_checkpointing: false
27
+ test_size: 2048
28
+ grad_scaler_init: 65536.0
29
+ max_caption_length: 257
30
+ num_image_tokens: 32
31
+ adapter_type: mlp
32
+ text_model_dtype: float16
wpkklhc6/image_adapter.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ebb1d1437bbb3264a6f25a896b25a7c7dd06c570c5de909dc2f19d3a5c5c110
3
+ size 86018240