Text model weights are different from 3.1 8B Instruct
#32
by
jlli
- opened
Hi there,
When comparing the weights of the language_model to 3.1 8B Instruct (here), they're pretty different. Truncated output of my script below:
$ python3 check_llama3_2_8b.py
Verifying weights: 0%| | 0/4 [00:00<?, ?it/s]
Param lm_head.weight has a difference of 1960.0!
Param model.layers.31.mlp.down_proj.weight has a difference of 486.0!
Verifying weights: 25%|████████ | 1/4 [00:02<00:08, 2.84s/it]
Param model.embed_tokens.weight has a difference of 2304.0!
...
Full script at the bottom, but TL;DR I take the sum of the absolute difference between parameters of the two models. What's interesting is that when I compare 90B Vision Instruct (here) to 3.1 70B Instruct (here), they exactly match! The latter behavior is expected, considering this excerpt from the Llama 3.1 paper:
Comparison script:
(apologies for incorrect indents, copy+pasting into markdown is being difficult)
import safetensors
import bisect
import math
import torch
import json
import re
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
vision_path = Path('/home/jonathanl/lama3.2/Llama-3.2-11B-Vision-Instruct')
language_path = Path('/home/jonathanl/Llama-3.1-8B-Instruct/')
def convert_vision_to_language_name(param_name, cross_attention_layers):
if 'model.layers.' in param_name:
vision_idx = int(re.search(r'model\.layers\.([0-9]+)\.', param_name).groups(0)[0])
orig_idx = vision_idx - bisect.bisect_right(cross_attention_layers, vision_idx)
param_name = param_name.replace(str(vision_idx), str(orig_idx))
param_name = param_name.replace('language_model.', '')
return param_name
if __name__ == '__main__':
with open(vision_path / 'model.safetensors.index.json') as f:
vision_json = json.load(f)
with open(vision_path / 'config.json') as f:
vision_config = json.load(f)
with open(language_path / 'model.safetensors.index.json') as f:
language_json = json.load(f)
cross_attention_layers = vision_config['text_config']['cross_attention_layers']
all_language_param_names = set()
language_weight_sets = defaultdict(set)
vision_weight_sets = defaultdict(set)
for param_name, language_filename in language_json['weight_map'].items():
language_weight_sets[language_filename].add(param_name)
all_language_param_names.add(param_name)
for param_name, vision_filename in vision_json['weight_map'].items():
if 'language_model' in param_name:
if 'model.layers.' in param_name:
vision_idx = int(re.search(r'model\.layers\.([0-9]+)\.', param_name).groups(0)[0])
if vision_idx in cross_attention_layers:
continue
vision_weight_sets[vision_filename].add(convert_vision_to_language_name(param_name, cross_attention_layers))
verified_language_params = set()
for language_file_name, param_name_set in tqdm(language_weight_sets.items(), dynamic_ncols=True, total=len(language_weight_sets), desc=f'Verifying weights'):
language_weight_dict = safetensors.safe_open(language_path / language_file_name, framework='pt')
for vision_file_name, vision_param_name_set in vision_weight_sets.items():
intersection = param_name_set.intersection(vision_param_name_set)
if len(intersection) > 0:
vision_weight_dict = safetensors.safe_open(vision_path / vision_file_name, framework='pt')
for vision_param_name in vision_weight_dict.keys():
if 'model.layers.' in vision_param_name:
vision_idx = int(re.search(r'model\.layers\.([0-9]+)\.', vision_param_name).groups(0)[0])
if vision_idx in cross_attention_layers:
continue
language_param_name = convert_vision_to_language_name(vision_param_name, cross_attention_layers)
if language_param_name in intersection:
vision_param_tensor = vision_weight_dict.get_tensor(vision_param_name)
language_param_tensor = language_weight_dict.get_tensor(language_param_name)
if language_param_tensor.shape != vision_param_tensor.shape:
assert 'embed_tokens' in language_param_name
assert 'embed_tokens' in vision_param_name
# the vision mdoel's embed tokens is expanded slightly, but these parameters aren't used
vision_param_tensor = vision_param_tensor[:language_param_tensor.shape[0], :]
sum_abs_difference = torch.sum(torch.abs(language_param_tensor - vision_param_tensor)).item()
if not math.isclose(sum_abs_difference, 0.0):
print(f'Param {language_param_name} has a difference of {sum_abs_difference}!')
# assert math.isclose(sum_abs_difference, 0.0)
verified_language_params.add(language_param_name)
assert len(verified_language_params) == len(all_language_param_names)