|
import torch |
|
from safetensors.torch import save_file, load_file |
|
import gradio as gr |
|
import os |
|
|
|
def convert_embedding(uploaded_file): |
|
output_path = "embedding.safetensors" |
|
file_extension = os.path.splitext(uploaded_file.name)[1] |
|
|
|
|
|
|
|
if file_extension == '.pt': |
|
sd15_embedding = torch.load(uploaded_file.name, map_location=torch.device('cpu'), weights_only=True) |
|
sd15_tensor = sd15_embedding['string_to_param']['*'] |
|
elif file_extension == '.safetensors': |
|
loaded_tensors = load_file(uploaded_file.name) |
|
sd15_tensor = loaded_tensors['emb_params'] |
|
else: |
|
raise ValueError("Unsupported file format") |
|
|
|
num_vectors = sd15_tensor.shape[0] |
|
clip_g_shape = (num_vectors, 1280) |
|
clip_l_shape = (num_vectors, 768) |
|
clip_g = torch.zeros(clip_g_shape, dtype=torch.float16) |
|
clip_l = torch.zeros(clip_l_shape, dtype=torch.float16) |
|
clip_l[:sd15_tensor.shape[0], :sd15_tensor.shape[1]] = sd15_tensor.to(dtype=torch.float16) |
|
save_file({"clip_g": clip_g, "clip_l": clip_l}, output_path) |
|
|
|
|
|
return output_path |
|
|
|
iface = gr.Interface( |
|
fn=convert_embedding, |
|
inputs=gr.File(label="Upload SD1.5 embedding"), |
|
outputs=gr.File(label="Download converted SDXL safetensors embedding"), |
|
title="SD1.5 to SDXL Embedding Converter", |
|
description="Upload an SD1.5 embedding file to convert it to SDXL." |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|
|
|