Spaces:
Running
Running
import gradio as gr | |
from diffusers import DiffusionPipeline | |
from huggingface_hub import export_folder_as_dduf, create_repo, upload_file | |
import tempfile | |
import torch | |
import os | |
_DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} | |
article = """ | |
## DDUF | |
* [DDUF](https://huggingface.co./docs/diffusers/main/en/using-diffusers/other-formats#dduf) | |
Currently, we require a `repo_id` to have all the pipeline components in the Diffusers format. Examples include: | |
[black-forest-labs/FLUX.1-dev](https://huggingface.co./black-forest-labs/FLUX.1-dev), [stabilityai/stable-video-diffusion-img2vid-xt](https://huggingface.co./stabilityai/stable-video-diffusion-img2vid-xt), etc. | |
Partial components will be supported in the future. | |
""" | |
def make_dduf(repo_id: str, destination_repo_id: str, dduf_filename: str, token: str, torch_dtype: str): | |
return_message = "" | |
if destination_repo_id == "": | |
destination_repo_id = repo_id | |
try: | |
destination_repo_id = create_repo(repo_id=destination_repo_id, exist_ok=True).repo_id | |
except Exception as e: | |
return_message += f"β Got the following error while creating the repository: \n{e}" | |
with tempfile.TemporaryDirectory() as tmpdir: | |
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=_DTYPE_MAP[torch_dtype]) | |
if dduf_filename == "": | |
dduf_filename = f"{pipe.__class__.__name__.lower()}" | |
if torch_dtype != "fp32": | |
dduf_filename += f"_{torch_dtype}" | |
dduf_filename += ".dduf" | |
dduf_filename = os.path.join(tmpdir, dduf_filename) | |
pipe.save_pretrained(tmpdir, safe_serialization=True) | |
try: | |
export_folder_as_dduf(dduf_filename, folder_path=tmpdir) | |
except Exception as e: | |
return_message += f"β Got the following error while exporting: \n{e}" | |
try: | |
commit_url = upload_file( | |
repo_id=destination_repo_id, | |
path_in_repo=dduf_filename.split("/")[-1], | |
path_or_fileobj=dduf_filename, | |
token=token, | |
).commit_url | |
return_message += f"Success π₯. Find the DDUF in [this commit]({commit_url})." | |
except Exception as e: | |
return_message += f"β Got the following error while pushing: \n{e}" | |
return str(return_message) | |
demo = gr.Interface( | |
title="DDUF my repo π€", | |
article=article, | |
fn=make_dduf, | |
inputs=[ | |
gr.components.Textbox(lines=1, placeholder="Repo ID which should be DDUF'd."), | |
gr.components.Textbox( | |
lines=2, | |
value=None, | |
placeholder="Destination Repo ID that should be used to store the resultant DDUF. Leave it if you want to use the `repo_id` here.", | |
), | |
gr.components.Textbox( | |
lines=1, | |
value=None, | |
placeholder="Name of the DDUF file. If it's not provided we will infer it based on the pipeline class.", | |
), | |
gr.components.Textbox(lines=1, placeholder="HF token. You can obtain it from hf.co/settings/tokens."), | |
gr.Dropdown( | |
list(_DTYPE_MAP.keys()), | |
value="fp32", | |
multiselect=False, | |
label="dtype", | |
info="dtype to load the pipeline in.", | |
), | |
], | |
outputs="markdown", | |
examples=[ | |
["stabilityai/stable-video-diffusion-img2vid-xt", "sayakpaul/svd-dduf", "svd.dduf", "hf_XXX", "fp32"], | |
], | |
allow_flagging="never", | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True, show_error=True) | |