Spaces:
Runtime error
Runtime error
feat: main feature
Browse files
app.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import shutil
|
4 |
+
import subprocess
|
5 |
+
from pathlib import Path
|
6 |
+
from textwrap import dedent
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import streamlit as st
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
from transformers import CLIPTokenizer
|
13 |
+
|
14 |
+
|
15 |
+
def hex_to_rgb(s: str) -> tuple[int, int, int]:
|
16 |
+
value = s.lstrip("#")
|
17 |
+
return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16))
|
18 |
+
|
19 |
+
|
20 |
+
col1, col2 = st.columns([15, 85])
|
21 |
+
color = col1.color_picker("Pick a color", "#00f900")
|
22 |
+
col2.text_input("", color, disabled=True)
|
23 |
+
|
24 |
+
emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
|
25 |
+
rgb = hex_to_rgb(color)
|
26 |
+
|
27 |
+
img_array = np.zeros((128, 128, 3), dtype=np.uint8)
|
28 |
+
for i in range(3):
|
29 |
+
img_array[..., i] = rgb[i]
|
30 |
+
|
31 |
+
dataset_path = Path("dataset")
|
32 |
+
output_path = Path("output")
|
33 |
+
if dataset_path.exists():
|
34 |
+
shutil.rmtree(dataset_path)
|
35 |
+
if output_path.exists():
|
36 |
+
shutil.rmtree(output_path)
|
37 |
+
|
38 |
+
dataset_path.mkdir()
|
39 |
+
img_path = dataset_path / f"{emb_name}.png"
|
40 |
+
Image.fromarray(img_array).save(img_path)
|
41 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
42 |
+
"Linaqruf/anything-v3.0", subfolder="tokenizer"
|
43 |
+
)
|
44 |
+
|
45 |
+
with st.sidebar:
|
46 |
+
init_text = st.text_input("Initializer", "init token name")
|
47 |
+
steps = st.slider("Steps", 1, 100, 30, step=1)
|
48 |
+
learning_rate = st.text_input("Learning rate", "0.005")
|
49 |
+
learning_rate = float(learning_rate)
|
50 |
+
|
51 |
+
# case 1: init_text is not a single token
|
52 |
+
token = tokenizer.tokenize(init_text)
|
53 |
+
if len(token) > 1:
|
54 |
+
st.warning("init_text must be a single token")
|
55 |
+
st.stop()
|
56 |
+
|
57 |
+
# case 2: init_text already exists in the tokenizer
|
58 |
+
num_added_tokens = tokenizer.add_tokens(emb_name)
|
59 |
+
if num_added_tokens == 0:
|
60 |
+
st.warning(f"The tokenizer already contains the token {emb_name}")
|
61 |
+
st.stop()
|
62 |
+
|
63 |
+
cmd = """
|
64 |
+
accelerate launch textual_inversion.py \
|
65 |
+
--pretrained_model_name_or_path="Linaqruf/anything-v3.0" \
|
66 |
+
--train_data_dir="dataset" \
|
67 |
+
--learnable_property="style" \
|
68 |
+
--placeholder_token="{emb_name}" \
|
69 |
+
--initializer_token="{init}" \
|
70 |
+
--resolution=128 \
|
71 |
+
--train_batch_size=1 \
|
72 |
+
--repeats=1 \
|
73 |
+
--gradient_accumulation_steps=1 \
|
74 |
+
--max_train_steps={steps} \
|
75 |
+
--learning_rate={lr} \
|
76 |
+
--output_dir="output" \
|
77 |
+
--only_save_embeds
|
78 |
+
""".strip()
|
79 |
+
|
80 |
+
cmd = dedent(cmd).format(
|
81 |
+
emb_name=emb_name, init=init_text, lr=learning_rate, steps=steps
|
82 |
+
)
|
83 |
+
|
84 |
+
if st.button("Start"):
|
85 |
+
with st.spinner("Training..."):
|
86 |
+
subprocess.run(cmd, shell=True)
|
87 |
+
|
88 |
+
result_path = Path("output") / "learned_embeds.bin"
|
89 |
+
if not result_path.exists():
|
90 |
+
st.stop()
|
91 |
+
|
92 |
+
# fix unknown error
|
93 |
+
trained_emb = torch.load(result_path, map_location="cpu")
|
94 |
+
for k, v in trained_emb.items():
|
95 |
+
trained_emb[k] = torch.from_numpy(v.numpy())
|
96 |
+
torch.save(trained_emb, result_path)
|
97 |
+
|
98 |
+
file = result_path.read_bytes()
|
99 |
+
st.download_button("Download", file, f"{emb_name}.pt")
|