Upload modeling_molmo.py with huggingface_hub
Browse files- modeling_molmo.py +9 -2
modeling_molmo.py
CHANGED
@@ -762,7 +762,6 @@ class ViTMLP(nn.Module):
|
|
762 |
return x
|
763 |
|
764 |
|
765 |
-
|
766 |
class ResidualAttentionBlock(nn.Module):
|
767 |
|
768 |
def __init__(self, config: FullMolmoConfig):
|
@@ -819,6 +818,14 @@ class BlockCollection(nn.Module):
|
|
819 |
return hidden_states
|
820 |
|
821 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
822 |
class VisionTransformer(nn.Module):
|
823 |
|
824 |
def __init__(self, config: FullMolmoConfig):
|
@@ -844,7 +851,7 @@ class VisionTransformer(nn.Module):
|
|
844 |
device=config.init_device,
|
845 |
)
|
846 |
|
847 |
-
self.pre_ln =
|
848 |
v_cfg.image_emb_dim,
|
849 |
eps=v_cfg.image_norm_eps,
|
850 |
)
|
|
|
762 |
return x
|
763 |
|
764 |
|
|
|
765 |
class ResidualAttentionBlock(nn.Module):
|
766 |
|
767 |
def __init__(self, config: FullMolmoConfig):
|
|
|
818 |
return hidden_states
|
819 |
|
820 |
|
821 |
+
class LayerNormFp32(nn.LayerNorm):
|
822 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
823 |
+
orig_type = x.dtype
|
824 |
+
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight.to(torch.float32),
|
825 |
+
self.bias.to(torch.float32), self.eps)
|
826 |
+
return x.to(orig_type)
|
827 |
+
|
828 |
+
|
829 |
class VisionTransformer(nn.Module):
|
830 |
|
831 |
def __init__(self, config: FullMolmoConfig):
|
|
|
851 |
device=config.init_device,
|
852 |
)
|
853 |
|
854 |
+
self.pre_ln = LayerNormFp32(
|
855 |
v_cfg.image_emb_dim,
|
856 |
eps=v_cfg.image_norm_eps,
|
857 |
)
|