chrisc36 commited on
Commit
e3ff8ff
1 Parent(s): 6b128f8

Delete convert_to_hf.py

Browse files
Files changed (1) hide show
  1. convert_to_hf.py +0 -89
convert_to_hf.py DELETED
@@ -1,89 +0,0 @@
1
- import argparse
2
- import logging
3
- import os
4
-
5
- import torch
6
-
7
- from hf_molmo.config_molmo import MolmoConfig
8
- from hf_molmo.image_preprocessing_molmo import MolmoImageProcessor
9
- from hf_molmo.modelling_molmo import MOLMoForCausalLM
10
- from hf_molmo.preprocessing_molmo import MolmoProcessor
11
- from olmo import ModelConfig
12
- from olmo.mm_data.data_utils import build_tokenizer
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- def write_config(checkpoint_dir: str, output_dir: str):
18
- # save config as HF config
19
-
20
- logger.info(f"Loading checkpoint from {checkpoint_dir}")
21
-
22
- config_path = os.path.join(checkpoint_dir, "config.yaml")
23
- model_config = ModelConfig.load(config_path, key="model")
24
- config_kwargs = model_config.asdict()
25
- config_kwargs["use_cache"] = True
26
- config_kwargs["vit_load_path"] = None
27
- config_kwargs["llm_load_path"] = None
28
- config = MolmoConfig(
29
- vocab_size=model_config.vocab_size,
30
- embedding_size=model_config.embedding_size,
31
- hidden_size=model_config.d_model,
32
- intermediate_size=model_config.mlp_hidden_size,
33
- num_hidden_layers=model_config.n_layers,
34
- num_attention_heads=model_config.n_heads,
35
- num_key_value_heads=model_config.n_kv_heads,
36
- max_position_embeddings=model_config.max_position_embeddings or model_config.max_sequence_length,
37
- initializer_range=model_config.initializer_range,
38
- use_cache=True,
39
- layer_norm_eps=model_config.layer_norm_eps,
40
- rope_theta=model_config.rope_theta,
41
- clip_qkv=model_config.clip_qkv,
42
- qkv_bias=model_config.qkv_bias,
43
- weight_tying=model_config.weight_tying,
44
- use_position_ids=True,
45
- tie_word_embeddings=False
46
- )
47
-
48
- logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}")
49
- config.save_pretrained(output_dir)
50
-
51
- preprocessor = MolmoProcessor(
52
- MolmoImageProcessor(
53
- max_crops=model_config.max_crops
54
- ), # FIXME now just assumes everything if fixed
55
- build_tokenizer(model_config.tokenizer.identifier.split("m:")[1]).tokenizer
56
- )
57
- preprocessor.save_pretrained(output_dir)
58
-
59
-
60
- def write_model(checkpoint_dir: str, output_dir: str, ignore_olmo_compatibility: bool = False):
61
- # For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly.
62
- # So, we explicitly store the model with the expected prefix.
63
- old_model_path = os.path.join(checkpoint_dir, "model.pt")
64
- new_model_path = os.path.join(output_dir, "pytorch_model.bin")
65
-
66
- state_dict = torch.load(old_model_path)
67
- new_state_dict = {f"{MOLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()}
68
- torch.save(new_state_dict, new_model_path)
69
-
70
-
71
- def convert_checkpoint(checkpoint_dir: str, output_dir: str):
72
- os.makedirs(output_dir, exist_ok=True)
73
- write_config(checkpoint_dir, output_dir)
74
- write_model(checkpoint_dir, output_dir)
75
-
76
-
77
- def main():
78
- parser = argparse.ArgumentParser(
79
- description="Adds a config.json to the checkpoint directory, and creates pytorch_model.bin, "
80
- "making it easier to load weights as HF models."
81
- )
82
- parser.add_argument("checkpoint_dir")
83
- parser.add_argument("output_dir")
84
- args = parser.parse_args()
85
- convert_checkpoint(args.checkpoint_dir, args.output_dir)
86
-
87
-
88
- if __name__ == "__main__":
89
- main()