JUNJIE99 commited on
Commit
2eb6ce7
·
verified ·
1 Parent(s): e136115

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_MMRet_CLIP.py +9 -10
modeling_MMRet_CLIP.py CHANGED
@@ -22,12 +22,12 @@ import torch.utils.checkpoint
22
  from torch import nn
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
  from PIL import Image
25
- from ...activations import ACT2FN
26
- from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
27
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
28
- from ...modeling_utils import PreTrainedModel
29
- from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
30
- from ...utils import (
31
  ModelOutput,
32
  add_code_sample_docstrings,
33
  add_start_docstrings,
@@ -37,7 +37,7 @@ from ...utils import (
37
  logging,
38
  replace_return_docstrings,
39
  )
40
- from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
41
 
42
 
43
  if is_flash_attn_2_available():
@@ -47,11 +47,10 @@ if is_flash_attn_2_available():
47
  logger = logging.get_logger(__name__)
48
 
49
  # General docstring
50
- _CONFIG_FOR_DOC = "CLIPConfig"
51
- _CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
52
 
53
  # Image classification docstring
54
- _IMAGE_CLASS_CHECKPOINT = "openai/clip-vit-base-patch32"
55
  _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
56
 
57
 
 
22
  from torch import nn
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
  from PIL import Image
25
+ from transformers.activations import ACT2FN
26
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
27
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
28
+ from transformers.modeling_utils import PreTrainedModel
29
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
30
+ from transformers.utils import (
31
  ModelOutput,
32
  add_code_sample_docstrings,
33
  add_start_docstrings,
 
37
  logging,
38
  replace_return_docstrings,
39
  )
40
+ from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
41
 
42
 
43
  if is_flash_attn_2_available():
 
47
  logger = logging.get_logger(__name__)
48
 
49
  # General docstring
50
+ _CONFIG_FOR_DOC = "MMRet_CLIP"
 
51
 
52
  # Image classification docstring
53
+ _IMAGE_CLASS_CHECKPOINT = "JUNJIE/MMRet-large"
54
  _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
55
 
56