Upload folder using huggingface_hub
Browse files- modeling_MMRet_CLIP.py +9 -10
modeling_MMRet_CLIP.py
CHANGED
@@ -22,12 +22,12 @@ import torch.utils.checkpoint
|
|
22 |
from torch import nn
|
23 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
24 |
from PIL import Image
|
25 |
-
from
|
26 |
-
from
|
27 |
-
from
|
28 |
-
from
|
29 |
-
from
|
30 |
-
from
|
31 |
ModelOutput,
|
32 |
add_code_sample_docstrings,
|
33 |
add_start_docstrings,
|
@@ -37,7 +37,7 @@ from ...utils import (
|
|
37 |
logging,
|
38 |
replace_return_docstrings,
|
39 |
)
|
40 |
-
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
41 |
|
42 |
|
43 |
if is_flash_attn_2_available():
|
@@ -47,11 +47,10 @@ if is_flash_attn_2_available():
|
|
47 |
logger = logging.get_logger(__name__)
|
48 |
|
49 |
# General docstring
|
50 |
-
_CONFIG_FOR_DOC = "
|
51 |
-
_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
|
52 |
|
53 |
# Image classification docstring
|
54 |
-
_IMAGE_CLASS_CHECKPOINT = "
|
55 |
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
|
56 |
|
57 |
|
|
|
22 |
from torch import nn
|
23 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
24 |
from PIL import Image
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
|
27 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
|
28 |
+
from transformers.modeling_utils import PreTrainedModel
|
29 |
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
|
30 |
+
from transformers.utils import (
|
31 |
ModelOutput,
|
32 |
add_code_sample_docstrings,
|
33 |
add_start_docstrings,
|
|
|
37 |
logging,
|
38 |
replace_return_docstrings,
|
39 |
)
|
40 |
+
from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
41 |
|
42 |
|
43 |
if is_flash_attn_2_available():
|
|
|
47 |
logger = logging.get_logger(__name__)
|
48 |
|
49 |
# General docstring
|
50 |
+
_CONFIG_FOR_DOC = "MMRet_CLIP"
|
|
|
51 |
|
52 |
# Image classification docstring
|
53 |
+
_IMAGE_CLASS_CHECKPOINT = "JUNJIE/MMRet-large"
|
54 |
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
|
55 |
|
56 |
|