mk / A.py
pranavajay's picture
Update A.py
8e7098c verified
import os
import requests
from safetensors.torch import load_file, save_file
import torch
torch.cuda.empty_cache()
import torch.nn.functional as F
from tqdm import tqdm
def download_file(url, dest_path):
print(f"Downloading {url} to {dest_path}")
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(1024):
f.write(chunk)
else:
raise Exception(f"Failed to download file from {url}")
def load_model(file_path):
return load_file(file_path)
def save_model(merged_model, output_file):
print(f"Saving merged model to {output_file}")
save_file(merged_model, output_file)
def resize_tensor_shapes(tensor1, tensor2):
if tensor1.size() == tensor2.size():
return tensor1, tensor2
# Resize tensor2 to match tensor1's size (Base size)
max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)]
tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1)))
tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1)))
return tensor1_resized, tensor2_resized
def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.6):
print(f"Merging checkpoints with blend ratio: {blend_ratio}")
merged = {}
all_keys = set(ckpt1.keys()).union(set(ckpt2.keys()))
for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"):
t1, t2 = ckpt1.get(key), ckpt2.get(key)
if t1 is not None and t2 is not None:
t1, t2 = resize_tensor_shapes(t1, t2)
merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2
elif t1 is not None:
merged[key] = t1
else:
merged[key] = t2
# Control the final size to be strictly 26 GB
control_output_size(merged, target_size_gb=26)
return merged
def control_output_size(merged, target_size_gb):
# Estimate the size in bytes
target_size_bytes = target_size_gb * 1024**3 # Convert GB to bytes
current_size_bytes = sum(tensor.numel() * tensor.element_size() for tensor in merged.values())
# If the current size exceeds the target, truncate the tensors
if current_size_bytes > target_size_bytes:
excess_size = current_size_bytes - target_size_bytes
print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...")
# Calculate the total number of elements to reduce
elements_to_reduce = excess_size // 4 # Assuming 4 bytes per float32 tensor
total_elements = sum(tensor.numel() for tensor in merged.values())
# Distribute the reduction uniformly across all tensors
for key in merged.keys():
tensor = merged[key]
num_elements = tensor.numel()
# Calculate how much to reduce from this tensor
reduction = min(elements_to_reduce, num_elements)
merged[key] = tensor.flatten()[:num_elements - reduction].view(tensor.shape)
elements_to_reduce -= reduction
if elements_to_reduce <= 0:
break
def cleanup_files(*file_paths):
for file_path in file_paths:
if os.path.exists(file_path):
os.remove(file_path)
print(f"Deleted {file_path}")
if __name__ == "__main__":
try:
model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors"
model2_path = "output_checkpoint.safetensors"
blend_ratio = 0.6 # Set to 60%
output_file = "output_checkpoint.safetensors"
# Loading models
model1 = load_model(model1_path)
model2 = load_model(model2_path)
# Merging models
merged_model = merge_checkpoints(model1, model2, blend_ratio)
# Saving merged model
save_model(merged_model, output_file)
# Cleaning up downloaded files
cleanup_files(model1_path)
except Exception as e:
print(f"An error occurred: {e}")