visheratin commited on
Commit
e8f5de1
1 Parent(s): e9369b2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +21 -84
README.md CHANGED
@@ -24,16 +24,12 @@ widget:
24
 
25
  ## Model details
26
 
27
- The core idea behind multi-crop LLaVA (MC-LLaVA) is that instead of N visual token embeddings per image, I generate one token embedding per N parts of the image.
28
- Having high-quality embeddings for smaller parts of the image helps to extract more details and understand the scene better.
 
29
 
30
- For every crop of the image, I generate an embedding from the full SigLIP encoder (size [1, 1152]) and then push all N embeddings through the LLaVA adapter, which
31
- gives the token embedding of size [N, 2560]. Right now, the tokens do not contain explicit information about their position in the original image. I plan to add it later.
32
-
33
- MC-LLaVA-3b was fine-tuned from [Dolphin 2.6 Phi](https://huggingface.co/cognitivecomputations/dolphin-2_6-phi-2) using vision tower from
34
- [SigLIP 400M](https://huggingface.co/timm/ViT-SO400M-14-SigLIP-384).
35
-
36
- The context length during training was 1200 tokens, as the L4 GPUs I used didn't allow me to get more.
37
 
38
  As Dolphin 2.6 Phi, LLaVA-3b uses ChatML prompt format:
39
 
@@ -47,91 +43,30 @@ You are Dolphin, a helpful AI assistant.<|im_end|>
47
 
48
  ## How to use
49
 
50
- **Install dependencies**
51
-
52
- ```bash
53
- !pip install -q open_clip_torch timm einops
54
- ```
55
-
56
- **Download modeling files**
57
-
58
  ```python
59
- from huggingface_hub import hf_hub_download
60
-
61
- hf_hub_download(repo_id="visheratin/LLaVA-3b", filename="configuration_llava.py", local_dir="./", force_download=True)
62
- hf_hub_download(repo_id="visheratin/LLaVA-3b", filename="configuration_phi.py", local_dir="./", force_download=True)
63
- hf_hub_download(repo_id="visheratin/LLaVA-3b", filename="modeling_llava.py", local_dir="./", force_download=True)
64
- hf_hub_download(repo_id="visheratin/LLaVA-3b", filename="modeling_phi.py", local_dir="./", force_download=True)
65
- hf_hub_download(repo_id="visheratin/LLaVA-3b", filename="processing_llava.py", local_dir="./", force_download=True)
66
- ```
67
-
68
- **Create a model**
69
-
70
- ```python
71
- from modeling_llava import LlavaForConditionalGeneration
72
  import torch
73
 
74
- model = LlavaForConditionalGeneration.from_pretrained("visheratin/LLaVA-3b", torch_dtype=torch.float16)
75
- model = model.to("cuda")
76
- ```
77
-
78
- **Create processors**
79
-
80
- ```python
81
- from transformers import AutoTokenizer
82
- from processing_llava import LlavaProcessor, OpenCLIPImageProcessor
83
-
84
- tokenizer = AutoTokenizer.from_pretrained("visheratin/LLaVA-3b")
85
- image_processor = OpenCLIPImageProcessor(model.config.preprocess_config)
86
- processor = LlavaProcessor(image_processor, tokenizer)
87
- ```
88
-
89
- **Set image and text**
90
-
91
- ```python
92
- from PIL import Image
93
- import requests
94
 
95
- image_file = "https://images.unsplash.com/photo-1439246854758-f686a415d9da"
96
- raw_image = Image.open(requests.get(image_file, stream=True).raw)
97
 
98
- prompt = """<|im_start|>system
99
- A chat between a curious human and an artificial intelligence assistant.
100
- The assistant gives helpful, detailed, and polite answers to the human's questions.
101
- The assistant does not hallucinate and pays very close attention to the details.<|im_end|>
102
- <|im_start|>user
103
- <image>
104
- Describe the image.<|im_end|>
105
- <|im_start|>assistant
106
- """
107
- ```
108
-
109
- **Process inputs**
110
-
111
- ```python
112
- inputs = processor(prompt, raw_image, model, return_tensors='pt')
113
-
114
- inputs['input_ids'] = inputs['input_ids'].to(model.device)
115
- inputs['attention_mask'] = inputs['attention_mask'].to(model.device)
116
- ```
117
-
118
- **Generate the data**
119
 
120
- ```python
121
- import torch
122
 
123
- with torch.inference_mode():
124
- output = model.generate(**inputs, max_new_tokens=200, do_sample=True, temperature=0.4, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id)
125
  ```
126
 
127
  ## Benchmarks
128
 
129
- - TextVQA - 38.59%
130
- - GQA - 49.6%
131
- - VQAv2 - 64.24%
132
- - VizWiz - 24.88%
133
- - POPE - 80.59%
134
- - V*-bench - 52.25% (OCR - 46.66%, GPT4V-hard - 41.17%, direct attributes - 43.48%, relative position - 65.79%)
135
 
136
  ## Examples
137
 
@@ -146,4 +81,6 @@ Which means don't create competitor models for them.
146
 
147
  ## Acknowledgments
148
 
149
- Thanks to [ML Collective](https://mlcollective.org/) for providing credits for computing resources.
 
 
 
24
 
25
  ## Model details
26
 
27
+ Usually, in LLaVA models, we generate N embeddings for the image, which we then combine with text embeddings and send to the LLM. But what if instead of creating N tokens
28
+ for one image, we create K<<N tokens for M<N parts of the image (crops)? It would allow us to get visual information from small parts of the image and not inflate the
29
+ number of image "tokens" too much. I called this method multi-crop LLaVA (MC-LLaVA).
30
 
31
+ MC-LLaVA-3b was fine-tuned from [Phi-2 merge](vince62s/phi-2-psy) using vision tower from
32
+ [SigLIP 400M](https://huggingface.co/google/siglip-so400m-patch14-384).
 
 
 
 
 
33
 
34
  As Dolphin 2.6 Phi, LLaVA-3b uses ChatML prompt format:
35
 
 
43
 
44
  ## How to use
45
 
 
 
 
 
 
 
 
 
46
  ```python
47
+ from transformers import AutoModel, AutoProcessor
 
 
 
 
 
 
 
 
 
 
 
 
48
  import torch
49
 
50
+ model = AutoModel.from_pretrained("visheratin/MC-LLaVA-3b", torch_dtype=torch.float16, trust_remote_code=True).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ processor = AutoProcessor.from_pretrained("visheratin/MC-LLaVA-3b", trust_remote_code=True)
 
53
 
54
+ with torch.inference_mode():
55
+ inputs = processor(prompt, [raw_image], model, max_crops=100, num_tokens=728)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ output = model.generate(**inputs, max_new_tokens=200, use_cache=True, do_sample=False, eos_token_id=processor.tokenizer.eos_token_id, pad_token_id=processor.tokenizer.eos_token_id)
58
+ result = processor.tokenizer.decode(output[0]).replace(prompt, "").replace("<|im_end|>", "")
59
 
60
+ print(result)
 
61
  ```
62
 
63
  ## Benchmarks
64
 
65
+ - TextVQA - 50.9%
66
+ - GQA - 59.5%
67
+ - VQAv2 - 76.72%
68
+ - VizWiz - 32.68%
69
+ - V*-bench - OCR - 56.66%, GPT4V-hard - 52.94%, direct attributes - 40.86%, relative position - 56.57%
 
70
 
71
  ## Examples
72
 
 
81
 
82
  ## Acknowledgments
83
 
84
+ Thanks to [Lambda](https://lambdalabs.com/) for providing a machine to train the model.
85
+
86
+ Thanks to [ML Collective](https://mlcollective.org/) for continuous support and providing compute resources for testing the model.