Upload app.py
Browse files
app.py
CHANGED
@@ -206,6 +206,7 @@ def load_models():
|
|
206 |
print("Loading CLIP π")
|
207 |
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
|
208 |
clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
|
|
|
209 |
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
210 |
print("Loading VLM's custom vision model π")
|
211 |
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu', weights_only=False)
|
@@ -312,88 +313,88 @@ def stream_chat(input_images: List[Image.Image], caption_type: str, caption_leng
|
|
312 |
|
313 |
for i in range(0, len(input_images), batch_size):
|
314 |
batch = input_images[i:i+batch_size]
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
322 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
323 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
caption =
|
396 |
-
all_captions.append(caption)
|
397 |
|
398 |
if pbar:
|
399 |
pbar.update(len(batch))
|
|
|
206 |
print("Loading CLIP π")
|
207 |
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
|
208 |
clip_model = AutoModel.from_pretrained(CLIP_PATH).vision_model
|
209 |
+
assert (CHECKPOINT_PATH / "clip_model.pt").exists()
|
210 |
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
211 |
print("Loading VLM's custom vision model π")
|
212 |
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu', weights_only=False)
|
|
|
313 |
|
314 |
for i in range(0, len(input_images), batch_size):
|
315 |
batch = input_images[i:i+batch_size]
|
316 |
+
|
317 |
+
for input_image in input_images:
|
318 |
+
try:
|
319 |
+
# Preprocess image
|
320 |
+
# NOTE: I found the default processor for so400M to have worse results than just using PIL directly
|
321 |
+
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
322 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
323 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
324 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
325 |
+
pixel_values = pixel_values.to(device)
|
326 |
+
except ValueError as e:
|
327 |
+
print(f"Error processing image: {e}")
|
328 |
+
print("Skipping this image and continuing...")
|
329 |
+
continue
|
330 |
+
|
331 |
+
# Embed image
|
332 |
+
# This results in Batch x Image Tokens x Features
|
333 |
+
with torch.amp.autocast_mode.autocast(device, enabled=True):
|
334 |
+
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
335 |
+
image_features = vision_outputs.hidden_states
|
336 |
+
embedded_images = image_adapter(image_features).to(device)
|
337 |
+
|
338 |
+
# Build the conversation
|
339 |
+
convo = [
|
340 |
+
{
|
341 |
+
"role": "system",
|
342 |
+
"content": "You are a helpful image captioner.",
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"role": "user",
|
346 |
+
"content": prompt_str,
|
347 |
+
},
|
348 |
+
]
|
349 |
+
|
350 |
+
# Format the conversation
|
351 |
+
convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
|
352 |
+
assert isinstance(convo_string, str)
|
353 |
+
|
354 |
+
# Tokenize the conversation
|
355 |
+
# prompt_str is tokenized separately so we can do the calculations below
|
356 |
+
convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
|
357 |
+
prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
|
358 |
+
assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
|
359 |
+
convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
|
360 |
+
prompt_tokens = prompt_tokens.squeeze(0)
|
361 |
+
|
362 |
+
# Calculate where to inject the image
|
363 |
+
eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
|
364 |
+
assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
|
365 |
+
|
366 |
+
preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
|
367 |
+
|
368 |
+
# Embed the tokens
|
369 |
+
convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
|
370 |
+
|
371 |
+
# Construct the input
|
372 |
+
input_embeds = torch.cat([
|
373 |
+
convo_embeds[:, :preamble_len], # Part before the prompt
|
374 |
+
embedded_images.to(dtype=convo_embeds.dtype), # Image
|
375 |
+
convo_embeds[:, preamble_len:], # The prompt and anything after it
|
376 |
+
], dim=1).to(device)
|
377 |
+
|
378 |
+
input_ids = torch.cat([
|
379 |
+
convo_tokens[:preamble_len].unsqueeze(0),
|
380 |
+
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), # Dummy tokens for the image (TODO: Should probably use a special token here so as not to confuse any generation algorithms that might be inspecting the input)
|
381 |
+
convo_tokens[preamble_len:].unsqueeze(0),
|
382 |
+
], dim=1).to(device)
|
383 |
+
attention_mask = torch.ones_like(input_ids)
|
384 |
+
|
385 |
+
# Debugging
|
386 |
+
#print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
|
387 |
+
|
388 |
+
generate_ids = text_model.generate(input_ids=input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, do_sample=True,
|
389 |
+
suppress_tokens=None, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature)
|
390 |
+
|
391 |
+
# Trim off the prompt
|
392 |
+
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
393 |
+
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
394 |
+
generate_ids = generate_ids[:, :-1]
|
395 |
+
|
396 |
+
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
397 |
+
all_captions.append(caption.strip())
|
398 |
|
399 |
if pbar:
|
400 |
pbar.update(len(batch))
|