import os import requests from safetensors.torch import load_file, save_file import torch torch.cuda.empty_cache() import torch.nn.functional as F from tqdm import tqdm def download_file(url, dest_path): print(f"Downloading {url} to {dest_path}") response = requests.get(url, stream=True) if response.status_code == 200: with open(dest_path, 'wb') as f: for chunk in response.iter_content(1024): f.write(chunk) else: raise Exception(f"Failed to download file from {url}") def load_model(file_path): return load_file(file_path) def save_model(merged_model, output_file): print(f"Saving merged model to {output_file}") save_file(merged_model, output_file) def resize_tensor_shapes(tensor1, tensor2): if tensor1.size() == tensor2.size(): return tensor1, tensor2 # Resize tensor2 to match tensor1's size (Base size) max_shape = [max(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)] tensor1_resized = F.pad(tensor1, (0, max_shape[-1] - tensor1.size(-1))) tensor2_resized = F.pad(tensor2, (0, max_shape[-1] - tensor2.size(-1))) return tensor1_resized, tensor2_resized def merge_checkpoints(ckpt1, ckpt2, blend_ratio=0.6): print(f"Merging checkpoints with blend ratio: {blend_ratio}") merged = {} all_keys = set(ckpt1.keys()).union(set(ckpt2.keys())) for key in tqdm(all_keys, desc="Merging Checkpoints", unit="layer"): t1, t2 = ckpt1.get(key), ckpt2.get(key) if t1 is not None and t2 is not None: t1, t2 = resize_tensor_shapes(t1, t2) merged[key] = blend_ratio * t1 + (1 - blend_ratio) * t2 elif t1 is not None: merged[key] = t1 else: merged[key] = t2 # Control the final size to be strictly 26 GB control_output_size(merged, target_size_gb=26) return merged def control_output_size(merged, target_size_gb): # Estimate the size in bytes target_size_bytes = target_size_gb * 1024**3 # Convert GB to bytes current_size_bytes = sum(tensor.numel() * tensor.element_size() for tensor in merged.values()) # If the current size exceeds the target, truncate the tensors if current_size_bytes > target_size_bytes: excess_size = current_size_bytes - target_size_bytes print(f"Current size exceeds target by {excess_size / (1024**2):.2f} MB. Adjusting...") # Calculate the total number of elements to reduce elements_to_reduce = excess_size // 4 # Assuming 4 bytes per float32 tensor total_elements = sum(tensor.numel() for tensor in merged.values()) # Distribute the reduction uniformly across all tensors for key in merged.keys(): tensor = merged[key] num_elements = tensor.numel() # Calculate how much to reduce from this tensor reduction = min(elements_to_reduce, num_elements) merged[key] = tensor.flatten()[:num_elements - reduction].view(tensor.shape) elements_to_reduce -= reduction if elements_to_reduce <= 0: break def cleanup_files(*file_paths): for file_path in file_paths: if os.path.exists(file_path): os.remove(file_path) print(f"Deleted {file_path}") if __name__ == "__main__": try: model1_path = "mangledMergeFlux_v0Bfloat16Dev.safetensors" model2_path = "output_checkpoint.safetensors" blend_ratio = 0.6 # Set to 60% output_file = "output_checkpoint.safetensors" # Loading models model1 = load_model(model1_path) model2 = load_model(model2_path) # Merging models merged_model = merge_checkpoints(model1, model2, blend_ratio) # Saving merged model save_model(merged_model, output_file) # Cleaning up downloaded files cleanup_files(model1_path) except Exception as e: print(f"An error occurred: {e}")