amaye15 commited on
Commit
6fda347
1 Parent(s): 36a0697

Create convert_weights.py

Browse files
Files changed (1) hide show
  1. convert_weights.py +70 -0
convert_weights.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoConfig
2
+ from DaViT.modeling_davit import DaViTModel
3
+ from DaViT.configuration_davit import DaViTConfig
4
+ from unittest.mock import patch
5
+ import os
6
+ import logging
7
+ import requests
8
+ from PIL import Image
9
+ import torch
10
+ from transformers import AutoProcessor, AutoModelForCausalLM
11
+ from unittest.mock import patch
12
+ from transformers.dynamic_module_utils import get_imports
13
+ from typing import Tuple, Dict, Any, Union, List
14
+
15
+
16
+ def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
17
+ """
18
+ Custom workaround for the import error related to flash_attn.
19
+ Args:
20
+ filename (str | os.PathLike): The filename to check for imports.
21
+ Returns:
22
+ list[str]: List of required imports.
23
+ """
24
+ if not str(filename).endswith("modeling_florence2.py"):
25
+ return get_imports(filename)
26
+ imports = get_imports(filename)
27
+ if "flash_attn" in imports:
28
+ imports.remove("flash_attn")
29
+ return imports
30
+
31
+
32
+ current_directory = os.getcwd()
33
+
34
+ # Register the configuration and model
35
+ AutoConfig.register("davit", DaViTConfig)
36
+ AutoModel.register(DaViTConfig, DaViTModel)
37
+
38
+
39
+ # Register Huggingface Model
40
+ DaViTConfig.register_for_auto_class()
41
+ DaViTModel.register_for_auto_class("AutoModel")
42
+
43
+ AutoConfig.register("davit", DaViTConfig)
44
+ AutoModel.register(DaViTConfig, DaViTModel)
45
+
46
+ # Step 1: Create a configuration object
47
+ config = DaViTConfig()
48
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ "microsoft/Florence-2-large-ft",
51
+ trust_remote_code=True,
52
+ cache_dir=current_directory,
53
+ device_map="cpu",
54
+ torch_dtype=torch.float16,
55
+ )
56
+ processor = AutoProcessor.from_pretrained(
57
+ "microsoft/Florence-2-large-ft",
58
+ trust_remote_code=True,
59
+ cache_dir=current_directory,
60
+ device_map="cpu",
61
+ )
62
+ # Step 2: Create a model object
63
+ model2 = AutoModel.from_config(config)
64
+ model2.to(torch.float16)
65
+
66
+ model2.load_state_dict(model.vision_tower.state_dict())
67
+
68
+
69
+ model2.push_to_hub("DaViT-Florence-2-large-ft")
70
+ processor.push_to_hub("DaViT-Florence-2-large-ft")