fancyfeast commited on
Commit
de9952a
·
verified ·
1 Parent(s): d17e136

Update README.md example, processor works now, vLLM instructions

Browse files
Files changed (1) hide show
  1. README.md +38 -39
README.md CHANGED
@@ -33,10 +33,8 @@ Example usage:
33
 
34
  ```
35
  import torch
36
- import torch.amp
37
- import torchvision.transforms.functional as TVF
38
  from PIL import Image
39
- from transformers import AutoTokenizer, LlavaForConditionalGeneration
40
 
41
 
42
  IMAGE_PATH = "image.jpg"
@@ -47,27 +45,14 @@ MODEL_NAME = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
47
  # Load JoyCaption
48
  # bfloat16 is the native dtype of the LLM used in JoyCaption (Llama 3.1)
49
  # device_map=0 loads the model into the first GPU
50
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
51
  llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype="bfloat16", device_map=0)
52
  llava_model.eval()
53
 
54
  with torch.no_grad():
55
- # Load and preprocess image
56
- # Normally you would use the Processor here, but the image module's processor
57
- # has some buggy behavior and a simple resize in Pillow yields higher quality results
58
  image = Image.open(IMAGE_PATH)
59
 
60
- if image.size != (384, 384):
61
- image = image.resize((384, 384), Image.LANCZOS)
62
-
63
- image = image.convert("RGB")
64
- pixel_values = TVF.pil_to_tensor(image)
65
-
66
- # Normalize the image
67
- pixel_values = pixel_values / 255.0
68
- pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
69
- pixel_values = pixel_values.to(torch.bfloat16).unsqueeze(0)
70
-
71
  # Build the conversation
72
  convo = [
73
  {
@@ -81,30 +66,44 @@ with torch.no_grad():
81
  ]
82
 
83
  # Format the conversation
84
- convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
85
-
86
- # Tokenize the conversation
87
- convo_tokens = tokenizer.encode(convo_string, add_special_tokens=False, truncation=False)
88
-
89
- # Repeat the image tokens
90
- input_tokens = []
91
- for token in convo_tokens:
92
- if token == llava_model.config.image_token_index:
93
- input_tokens.extend([llava_model.config.image_token_index] * llava_model.config.image_seq_length)
94
- else:
95
- input_tokens.append(token)
96
-
97
- input_ids = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0)
98
- attention_mask = torch.ones_like(input_ids)
99
-
100
- # Generate the caption
101
- generate_ids = llava_model.generate(input_ids=input_ids.to('cuda'), pixel_values=pixel_values.to('cuda'), attention_mask=attention_mask.to('cuda'), max_new_tokens=300, do_sample=True, suppress_tokens=None, use_cache=True)[0]
 
 
 
102
 
103
  # Trim off the prompt
104
- generate_ids = generate_ids[input_ids.shape[1]:]
105
 
106
  # Decode the caption
107
- caption = tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
108
  caption = caption.strip()
109
  print(caption)
110
- ```
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  ```
35
  import torch
 
 
36
  from PIL import Image
37
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
38
 
39
 
40
  IMAGE_PATH = "image.jpg"
 
45
  # Load JoyCaption
46
  # bfloat16 is the native dtype of the LLM used in JoyCaption (Llama 3.1)
47
  # device_map=0 loads the model into the first GPU
48
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
49
  llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype="bfloat16", device_map=0)
50
  llava_model.eval()
51
 
52
  with torch.no_grad():
53
+ # Load image
 
 
54
  image = Image.open(IMAGE_PATH)
55
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Build the conversation
57
  convo = [
58
  {
 
66
  ]
67
 
68
  # Format the conversation
69
+ # WARNING: HF's handling of chat's on Llava models is very fragile. This specific combination of processor.apply_chat_template(), and processor() works
70
+ # but if using other combinations always inspect the final input_ids to ensure they are correct. Often times you will end up with multiple <bos> tokens
71
+ # if not careful, which can make the model perform poorly.
72
+ convo_string = processor.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
73
+ assert isinstance(convo_string, str)
74
+
75
+ # Process the inputs
76
+ inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to('cuda')
77
+ inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
78
+
79
+ # Generate the captions
80
+ generate_ids = llava_model.generate(
81
+ **inputs,
82
+ max_new_tokens=300,
83
+ do_sample=True,
84
+ suppress_tokens=None,
85
+ use_cache=True,
86
+ temperature=0.6,
87
+ top_k=None,
88
+ top_p=0.9,
89
+ )[0]
90
 
91
  # Trim off the prompt
92
+ generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
93
 
94
  # Decode the caption
95
+ caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
96
  caption = caption.strip()
97
  print(caption)
98
+ ```
99
+
100
+
101
+ ## vLLM
102
+
103
+ vLLM provides the highest performance inference for JoyCaption, and an OpenAI compatible API so JoyCaption can be used like any other VLMs. Example usage:
104
+
105
+ ```
106
+ vllm serve fancyfeast/llama-joycaption-alpha-two-hf-llava --max-model-len 4096 --enable-prefix-caching
107
+ ```
108
+
109
+ VLMs are a bit finicky on vLLM, and vLLM is memory hungry, so you may have to adjust settings for your particular environment, such as forcing eager mode, adjusting max-model-len, adjusting gpu_memory_utilization, etc.