Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import shutil | |
import subprocess | |
from pathlib import Path | |
from textwrap import dedent | |
import torch | |
import streamlit as st | |
import numpy as np | |
from PIL import Image | |
from transformers import CLIPTokenizer | |
def hex_to_rgb(s: str) -> tuple[int, int, int]: | |
value = s.lstrip("#") | |
return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16)) | |
col1, col2 = st.columns([15, 85]) | |
color = col1.color_picker("Pick a color", "#00f900") | |
col2.text_input("", color, disabled=True) | |
emb_name = st.text_input("Embedding name", color.lstrip("#").upper()) | |
rgb = hex_to_rgb(color) | |
img_array = np.zeros((128, 128, 3), dtype=np.uint8) | |
for i in range(3): | |
img_array[..., i] = rgb[i] | |
dataset_path = Path("dataset") | |
output_path = Path("output") | |
if dataset_path.exists(): | |
shutil.rmtree(dataset_path) | |
if output_path.exists(): | |
shutil.rmtree(output_path) | |
dataset_path.mkdir() | |
img_path = dataset_path / f"{emb_name}.png" | |
Image.fromarray(img_array).save(img_path) | |
tokenizer = CLIPTokenizer.from_pretrained( | |
"Linaqruf/anything-v3.0", subfolder="tokenizer" | |
) | |
with st.sidebar: | |
init_text = st.text_input("Initializer", "init token name") | |
steps = st.slider("Steps", 1, 100, 30, step=1) | |
learning_rate = st.text_input("Learning rate", "0.005") | |
learning_rate = float(learning_rate) | |
# case 1: init_text is not a single token | |
token = tokenizer.tokenize(init_text) | |
if len(token) > 1: | |
st.warning("init_text must be a single token") | |
st.stop() | |
# case 2: init_text already exists in the tokenizer | |
num_added_tokens = tokenizer.add_tokens(emb_name) | |
if num_added_tokens == 0: | |
st.warning(f"The tokenizer already contains the token {emb_name}") | |
st.stop() | |
cmd = """ | |
accelerate launch textual_inversion.py \ | |
--pretrained_model_name_or_path="Linaqruf/anything-v3.0" \ | |
--train_data_dir="dataset" \ | |
--learnable_property="style" \ | |
--placeholder_token="{emb_name}" \ | |
--initializer_token="{init}" \ | |
--resolution=128 \ | |
--train_batch_size=1 \ | |
--repeats=1 \ | |
--gradient_accumulation_steps=1 \ | |
--max_train_steps={steps} \ | |
--learning_rate={lr} \ | |
--output_dir="output" \ | |
--only_save_embeds | |
""".strip() | |
cmd = dedent(cmd).format( | |
emb_name=emb_name, init=init_text, lr=learning_rate, steps=steps | |
) | |
if st.button("Start"): | |
with st.spinner("Training..."): | |
subprocess.run(cmd, shell=True) | |
result_path = Path("output") / "learned_embeds.bin" | |
if not result_path.exists(): | |
st.stop() | |
# fix unknown error | |
trained_emb = torch.load(result_path, map_location="cpu") | |
for k, v in trained_emb.items(): | |
trained_emb[k] = torch.from_numpy(v.numpy()) | |
torch.save(trained_emb, result_path) | |
file = result_path.read_bytes() | |
st.download_button("Download", file, f"{emb_name}.pt") | |