jadechoghari commited on
Commit
fb2ff23
·
verified ·
1 Parent(s): 56f618d

Update ferret_arch.py

Browse files
Files changed (1) hide show
  1. ferret_arch.py +37 -2
ferret_arch.py CHANGED
@@ -21,9 +21,44 @@ import torch
21
  import torch.nn as nn
22
  import torch.nn.functional as F
23
  import torch.distributed as dist
 
24
 
25
- from .multimodal_encoder.builder import build_vision_tower
26
- from .multimodal_projector.builder import build_vision_projector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  from .constants import (IGNORE_INDEX, IMAGE_TOKEN_INDEX,
29
  DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
 
21
  import torch.nn as nn
22
  import torch.nn.functional as F
23
  import torch.distributed as dist
24
+ import re
25
 
26
+ from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
27
+ import os
28
+ ## modified add build_vision_tower
29
+ def build_vision_tower(vision_tower_cfg, **kwargs):
30
+ vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
31
+ is_absolute_path_exists = os.path.exists(vision_tower)
32
+ use_s2 = getattr(vision_tower_cfg, 's2', False)
33
+ if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
34
+ if use_s2:
35
+ return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
36
+ else:
37
+ return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
38
+
39
+ raise ValueError(f'Unknown vision tower: {vision_tower}')
40
+
41
+
42
+ # from .multimodal_projector.builder import build_vision_projector
43
+ def build_vision_projector(config, delay_load=False, **kwargs):
44
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
45
+
46
+ if projector_type == 'linear':
47
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
48
+
49
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
50
+ if mlp_gelu_match:
51
+ mlp_depth = int(mlp_gelu_match.group(1))
52
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
53
+ for _ in range(1, mlp_depth):
54
+ modules.append(nn.GELU())
55
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
56
+ return nn.Sequential(*modules)
57
+
58
+ if projector_type == 'identity':
59
+ return IdentityMap()
60
+
61
+ raise ValueError(f'Unknown projector type: {projector_type}')
62
 
63
  from .constants import (IGNORE_INDEX, IMAGE_TOKEN_INDEX,
64
  DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,