John6666 commited on
Commit
625b8ab
Β·
verified Β·
1 Parent(s): 394e8ba

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -79
app.py CHANGED
@@ -206,6 +206,7 @@ def load_models():
206
  print("Loading CLIP πŸ“Ž")
207
  clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
208
  clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
 
209
  if (CHECKPOINT_PATH / "clip_model.pt").exists():
210
  print("Loading VLM's custom vision model πŸ“Ž")
211
  checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu', weights_only=False)
@@ -312,88 +313,88 @@ def stream_chat(input_images: List[Image.Image], caption_type: str, caption_leng
312
 
313
  for i in range(0, len(input_images), batch_size):
314
  batch = input_images[i:i+batch_size]
315
- try:
316
- # Preprocess image
317
- # NOTE: I found the default processor for so400M to have worse results than just using PIL directly
318
- #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
319
- all_images = []
320
- for input_image in batch:
321
  image = input_image.resize((384, 384), Image.LANCZOS)
322
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
323
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
324
- all_images.append(TVF.to_pil_image(pixel_values.squeeze()))
325
- batch_pixel_values = clip_processor(images=all_images, return_tensors='pt', padding=True).pixel_values.to(device)
326
- except ValueError as e:
327
- print(f"Error processing image batch: {e}")
328
- print("Skipping this batch and continuing...")
329
- continue
330
-
331
- # Embed image
332
- # This results in Batch x Image Tokens x Features
333
- with torch.amp.autocast_mode.autocast(device, enabled=True):
334
- vision_outputs = clip_model(pixel_values=batch_pixel_values, output_hidden_states=True)
335
- image_features = vision_outputs.hidden_states
336
- embedded_images = image_adapter(image_features).to(device)
337
-
338
- # Build the conversation
339
- convo = [
340
- {
341
- "role": "system",
342
- "content": "You are a helpful image captioner.",
343
- },
344
- {
345
- "role": "user",
346
- "content": prompt_str,
347
- },
348
- ]
349
-
350
- # Format the conversation
351
- convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
352
- assert isinstance(convo_string, str)
353
-
354
- # Tokenize the conversation
355
- # prompt_str is tokenized separately so we can do the calculations below
356
- convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
357
- prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
358
- assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
359
- convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
360
- prompt_tokens = prompt_tokens.squeeze(0)
361
-
362
- # Calculate where to inject the image
363
- eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
364
- assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
365
-
366
- preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
367
-
368
- # Embed the tokens
369
- convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
370
-
371
- # Construct the input
372
- input_embeds = torch.cat([
373
- convo_embeds[:, :preamble_len], # Part before the prompt
374
- embedded_images.to(dtype=convo_embeds.dtype), # Image
375
- convo_embeds[:, preamble_len:], # The prompt and anything after it
376
- ], dim=1).to(device)
377
-
378
- input_ids = torch.cat([
379
- convo_tokens[:preamble_len].unsqueeze(0),
380
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), # Dummy tokens for the image (TODO: Should probably use a special token here so as not to confuse any generation algorithms that might be inspecting the input)
381
- convo_tokens[preamble_len:].unsqueeze(0),
382
- ], dim=1).to(device)
383
- attention_mask = torch.ones_like(input_ids)
384
-
385
- # Debugging
386
- #print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
387
-
388
- generate_ids = text_model.generate(input_ids=input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, do_sample=True,
389
- suppress_tokens=None, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature)
390
-
391
- generate_ids = generate_ids[:, input_ids.shape[1]:]
392
-
393
- for ids in generate_ids:
394
- caption = tokenizer.decode(ids[:-1] if ids[-1] == tokenizer.eos_token_id else ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
395
- caption = caption.replace('<|end_of_text|>', '').replace('<|finetune_right_pad_id|>', '').strip()
396
- all_captions.append(caption)
397
 
398
  if pbar:
399
  pbar.update(len(batch))
 
206
  print("Loading CLIP πŸ“Ž")
207
  clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
208
  clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
209
+ assert (CHECKPOINT_PATH / "clip_model.pt").exists()
210
  if (CHECKPOINT_PATH / "clip_model.pt").exists():
211
  print("Loading VLM's custom vision model πŸ“Ž")
212
  checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu', weights_only=False)
 
313
 
314
  for i in range(0, len(input_images), batch_size):
315
  batch = input_images[i:i+batch_size]
316
+
317
+ for input_image in input_images:
318
+ try:
319
+ # Preprocess image
320
+ # NOTE: I found the default processor for so400M to have worse results than just using PIL directly
321
+ #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
322
  image = input_image.resize((384, 384), Image.LANCZOS)
323
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
324
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
325
+ pixel_values = pixel_values.to(device)
326
+ except ValueError as e:
327
+ print(f"Error processing image: {e}")
328
+ print("Skipping this image and continuing...")
329
+ continue
330
+
331
+ # Embed image
332
+ # This results in Batch x Image Tokens x Features
333
+ with torch.amp.autocast_mode.autocast(device, enabled=True):
334
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
335
+ image_features = vision_outputs.hidden_states
336
+ embedded_images = image_adapter(image_features).to(device)
337
+
338
+ # Build the conversation
339
+ convo = [
340
+ {
341
+ "role": "system",
342
+ "content": "You are a helpful image captioner.",
343
+ },
344
+ {
345
+ "role": "user",
346
+ "content": prompt_str,
347
+ },
348
+ ]
349
+
350
+ # Format the conversation
351
+ convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
352
+ assert isinstance(convo_string, str)
353
+
354
+ # Tokenize the conversation
355
+ # prompt_str is tokenized separately so we can do the calculations below
356
+ convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
357
+ prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
358
+ assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
359
+ convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
360
+ prompt_tokens = prompt_tokens.squeeze(0)
361
+
362
+ # Calculate where to inject the image
363
+ eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
364
+ assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
365
+
366
+ preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
367
+
368
+ # Embed the tokens
369
+ convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
370
+
371
+ # Construct the input
372
+ input_embeds = torch.cat([
373
+ convo_embeds[:, :preamble_len], # Part before the prompt
374
+ embedded_images.to(dtype=convo_embeds.dtype), # Image
375
+ convo_embeds[:, preamble_len:], # The prompt and anything after it
376
+ ], dim=1).to(device)
377
+
378
+ input_ids = torch.cat([
379
+ convo_tokens[:preamble_len].unsqueeze(0),
380
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), # Dummy tokens for the image (TODO: Should probably use a special token here so as not to confuse any generation algorithms that might be inspecting the input)
381
+ convo_tokens[preamble_len:].unsqueeze(0),
382
+ ], dim=1).to(device)
383
+ attention_mask = torch.ones_like(input_ids)
384
+
385
+ # Debugging
386
+ #print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
387
+
388
+ generate_ids = text_model.generate(input_ids=input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, do_sample=True,
389
+ suppress_tokens=None, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature)
390
+
391
+ # Trim off the prompt
392
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
393
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
394
+ generate_ids = generate_ids[:, :-1]
395
+
396
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
397
+ all_captions.append(caption.strip())
398
 
399
  if pbar:
400
  pbar.update(len(batch))