File size: 946 Bytes
1de83e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64a3ce9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from transformers import AutoConfig, AutoModelForCausalLM

# Load the configuration and initialize the model
config_path = "config.json"  # Adjust path as necessary
config = AutoConfig.from_pretrained(config_path)
model = AutoModelForCausalLM.from_config(config)


# Reinitialize weights with a standard deviation of 0.02 for a more controlled initialization
def reinitialize_weights(module):
    if hasattr(module, "weight") and not isinstance(module, torch.nn.LayerNorm):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    if hasattr(module, "bias") and module.bias is not None:
        torch.nn.init.constant_(module.bias, 0.0)


model.apply(reinitialize_weights)

# Cast the model's parameters to bf16
model = model.to(
    dtype=torch.bfloat16
)  # Converts all floating point parameters to bfloat16

# Save the model with SafeTensors
model.save_pretrained("./micro_mistral", save_in_safe_tensors_format=True)