pranavajay commited on
Commit
8e7098c
·
verified ·
1 Parent(s): 936aca5

Update A.py

Browse files
Files changed (1) hide show
  1. A.py +16 -10
A.py CHANGED
@@ -34,7 +34,7 @@ def resize_tensor_shapes(tensor1, tensor2):
34
 
35
  return tensor1_resized, tensor2_resized
36
 
37
- def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.75):
38
  print(f"Merging checkpoints with blend ratio: {blend_ratio}")
39
  merged = {}
40
  all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))
@@ -49,7 +49,7 @@ def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.75):
49
  else:
50
  merged[key] = t2
51
 
52
- # Control the final size to be approximately 26 GB
53
  control_output_size(merged, target_size_gb=26)
54
 
55
  return merged
@@ -64,14 +64,20 @@ def control_output_size(merged, target_size_gb):
64
  excess_size = current_size_bytes - target_size_bytes
65
  print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...")
66
 
67
- # Adjusting the tensors to meet the target size
 
 
 
 
68
  for key in merged.keys():
69
  tensor = merged[key]
70
- # Calculate how much we can reduce
71
- reduce_size = excess_size // tensor.element_size() # Number of elements to reduce
72
- if tensor.numel() > reduce_size:
73
- # Truncate the tensor
74
- merged[key] = tensor.flatten()[:tensor.numel() - reduce_size].view(tensor.shape)
 
 
75
 
76
  def cleanup_files(*file_paths):
77
  for file_path in file_paths:
@@ -83,8 +89,8 @@ if __name__ == "__main__":
83
  try:
84
  model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors"
85
  model2_path = "output_checkpoint.safetensors"
86
- blend_ratio = 0.75 # Adjust ratio based on requirement
87
- output_file = "output_checkpoints.safetensors"
88
 
89
  # Loading models
90
  model1 = load_model(model1_path)
 
34
 
35
  return tensor1_resized, tensor2_resized
36
 
37
+ def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.6):
38
  print(f"Merging checkpoints with blend ratio: {blend_ratio}")
39
  merged = {}
40
  all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))
 
49
  else:
50
  merged[key] = t2
51
 
52
+ # Control the final size to be strictly 26 GB
53
  control_output_size(merged, target_size_gb=26)
54
 
55
  return merged
 
64
  excess_size = current_size_bytes - target_size_bytes
65
  print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...")
66
 
67
+ # Calculate the total number of elements to reduce
68
+ elements_to_reduce = excess_size // 4 # Assuming 4 bytes per float32 tensor
69
+ total_elements = sum(tensor.numel() for tensor in merged.values())
70
+
71
+ # Distribute the reduction uniformly across all tensors
72
  for key in merged.keys():
73
  tensor = merged[key]
74
+ num_elements = tensor.numel()
75
+ # Calculate how much to reduce from this tensor
76
+ reduction = min(elements_to_reduce, num_elements)
77
+ merged[key] = tensor.flatten()[:num_elements - reduction].view(tensor.shape)
78
+ elements_to_reduce -= reduction
79
+ if elements_to_reduce <= 0:
80
+ break
81
 
82
  def cleanup_files(*file_paths):
83
  for file_path in file_paths:
 
89
  try:
90
  model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors"
91
  model2_path = "output_checkpoint.safetensors"
92
+ blend_ratio = 0.6 # Set to 60%
93
+ output_file = "output_checkpoint.safetensors"
94
 
95
  # Loading models
96
  model1 = load_model(model1_path)