jadechoghari
commited on
Update ferret_arch.py
Browse files- ferret_arch.py +37 -2
ferret_arch.py
CHANGED
@@ -21,9 +21,44 @@ import torch
|
|
21 |
import torch.nn as nn
|
22 |
import torch.nn.functional as F
|
23 |
import torch.distributed as dist
|
|
|
24 |
|
25 |
-
from .
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
from .constants import (IGNORE_INDEX, IMAGE_TOKEN_INDEX,
|
29 |
DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
|
|
|
21 |
import torch.nn as nn
|
22 |
import torch.nn.functional as F
|
23 |
import torch.distributed as dist
|
24 |
+
import re
|
25 |
|
26 |
+
from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
|
27 |
+
import os
|
28 |
+
## modified add build_vision_tower
|
29 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
30 |
+
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
|
31 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
32 |
+
use_s2 = getattr(vision_tower_cfg, 's2', False)
|
33 |
+
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
|
34 |
+
if use_s2:
|
35 |
+
return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
|
36 |
+
else:
|
37 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
38 |
+
|
39 |
+
raise ValueError(f'Unknown vision tower: {vision_tower}')
|
40 |
+
|
41 |
+
|
42 |
+
# from .multimodal_projector.builder import build_vision_projector
|
43 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
44 |
+
projector_type = getattr(config, 'mm_projector_type', 'linear')
|
45 |
+
|
46 |
+
if projector_type == 'linear':
|
47 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
48 |
+
|
49 |
+
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
|
50 |
+
if mlp_gelu_match:
|
51 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
52 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
53 |
+
for _ in range(1, mlp_depth):
|
54 |
+
modules.append(nn.GELU())
|
55 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
56 |
+
return nn.Sequential(*modules)
|
57 |
+
|
58 |
+
if projector_type == 'identity':
|
59 |
+
return IdentityMap()
|
60 |
+
|
61 |
+
raise ValueError(f'Unknown projector type: {projector_type}')
|
62 |
|
63 |
from .constants import (IGNORE_INDEX, IMAGE_TOKEN_INDEX,
|
64 |
DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
|