ginipick commited on
Commit
036dfc6
·
verified ·
1 Parent(s): f416411

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -260
app.py CHANGED
@@ -1,295 +1,226 @@
1
- import spaces
2
- import argparse
3
  import os
4
- import time
5
- from os import path
6
- import shutil
7
  from datetime import datetime
8
- from safetensors.torch import load_file
9
- from huggingface_hub import hf_hub_download
10
  import gradio as gr
 
 
11
  import torch
12
- from diffusers import FluxPipeline
13
  from PIL import Image
14
- from transformers import pipeline
15
-
16
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
17
-
18
- # Hugging Face 토큰 설정
19
- HF_TOKEN = os.getenv("HF_TOKEN")
20
- if HF_TOKEN is None:
21
- raise ValueError("HF_TOKEN environment variable is not set")
22
-
23
- # Setup and initialization code
24
- cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
25
- PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
26
- gallery_path = path.join(PERSISTENT_DIR, "gallery")
27
 
28
- os.environ["TRANSFORMERS_CACHE"] = cache_path
29
- os.environ["HF_HUB_CACHE"] = cache_path
30
- os.environ["HF_HOME"] = cache_path
 
31
 
32
- torch.backends.cuda.matmul.allow_tf32 = True
 
 
33
 
34
- # Create gallery directory if it doesn't exist
35
- if not path.exists(gallery_path):
36
- os.makedirs(gallery_path, exist_ok=True)
37
 
38
- class timer:
39
- def __init__(self, method_name="timed process"):
40
- self.method = method_name
41
- def __enter__(self):
42
- self.start = time.time()
43
- print(f"{self.method} starts")
44
- def __exit__(self, exc_type, exc_val, exc_tb):
45
- end = time.time()
46
- print(f"{self.method} took {str(round(end - self.start, 2))}s")
47
 
48
- # Model initialization
49
- if not path.exists(cache_path):
50
- os.makedirs(cache_path, exist_ok=True)
51
-
52
- # 인증된 모델 로드
53
- pipe = FluxPipeline.from_pretrained(
54
- "black-forest-labs/FLUX.1-dev",
55
- torch_dtype=torch.bfloat16,
56
- use_auth_token=HF_TOKEN
57
- )
58
-
59
- # Hyper-SD LoRA 로드
60
- pipe.load_lora_weights(
61
- hf_hub_download(
62
- "ByteDance/Hyper-SD",
63
- "Hyper-FLUX.1-dev-8steps-lora.safetensors",
64
- use_auth_token=HF_TOKEN
65
- )
66
- )
67
- pipe.fuse_lora(lora_scale=0.125)
68
- pipe.to(device="cuda", dtype=torch.bfloat16)
69
 
70
- def save_image(image):
71
- """Save the generated image and return the path"""
72
- try:
73
- if not os.path.exists(gallery_path):
74
- try:
75
- os.makedirs(gallery_path, exist_ok=True)
76
- except Exception as e:
77
- print(f"Failed to create gallery directory: {str(e)}")
78
- return None
79
-
80
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
81
- random_suffix = os.urandom(4).hex()
82
- filename = f"generated_{timestamp}_{random_suffix}.png"
83
- filepath = os.path.join(gallery_path, filename)
84
-
85
- try:
86
- if isinstance(image, Image.Image):
87
- image.save(filepath, "PNG", quality=100)
88
- else:
89
- image = Image.fromarray(image)
90
- image.save(filepath, "PNG", quality=100)
91
-
92
- return filepath
93
- except Exception as e:
94
- print(f"Failed to save image: {str(e)}")
95
- return None
96
-
97
- except Exception as e:
98
- print(f"Error in save_image: {str(e)}")
99
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- # 예시 프롬프트 정의
102
  examples = [
103
- ["A 3D Star Wars Darth Vader helmet, highly detailed metallic finish"],
104
- ["A 3D Iron Man mask with glowing eyes and metallic red-gold finish"],
105
- ["A detailed 3D Pokemon Pikachu figure with glossy surface"],
106
- ["A 3D geometric abstract cube transforming into a sphere, metallic finish"],
107
- ["A 3D steampunk mechanical heart with brass and copper details"],
108
- ["A 3D crystal dragon with transparent iridescent scales"],
109
- ["A 3D futuristic hovering drone with neon light accents"],
110
- ["A 3D ancient Greek warrior helmet with ornate details"],
111
- ["A 3D robotic butterfly with mechanical wings and metallic finish"],
112
- ["A 3D floating magical crystal orb with internal energy swirls"]
113
- ]
114
-
115
- @spaces.GPU
116
- def process_and_save_image(height=1024, width=1024, steps=8, scales=3.5, prompt="", seed=None):
117
- global pipe
118
 
119
- if seed is None:
120
- seed = torch.randint(0, 1000000, (1,)).item()
121
 
122
- # 한글 감지 번역
123
- def contains_korean(text):
124
- return any(ord('가') <= ord(c) <= ord('힣') for c in text)
125
 
126
- # 프롬프트 전처리
127
- if contains_korean(prompt):
128
- translated = translator(prompt)[0]['translation_text']
129
- prompt = translated
130
 
131
- formatted_prompt = f"wbgmsst, 3D, {prompt} ,white background"
132
 
133
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
134
- try:
135
- generated_image = pipe(
136
- prompt=[formatted_prompt],
137
- generator=torch.Generator().manual_seed(int(seed)),
138
- num_inference_steps=int(steps),
139
- guidance_scale=float(scales),
140
- height=int(height),
141
- width=int(width),
142
- max_sequence_length=256
143
- ).images[0]
144
-
145
- saved_path = save_image(generated_image)
146
- if saved_path is None:
147
- print("Warning: Failed to save generated image")
148
-
149
- return generated_image
150
- except Exception as e:
151
- print(f"Error in image generation: {str(e)}")
152
- return None
153
 
154
- def get_random_seed():
155
- return torch.randint(0, 1000000, (1,)).item()
 
 
 
156
 
157
-
158
- def process_example(prompt):
159
- return process_and_save_image(
160
- height=1024,
161
- width=1024,
162
- steps=8,
163
- scales=3.5,
164
- prompt=prompt,
165
- seed=get_random_seed()
166
- )
167
 
168
-
169
- # Gradio 인터페이스
170
- with gr.Blocks(
171
- theme=gr.themes.Soft(),
172
- css="""
173
- .container {
174
- background: linear-gradient(to bottom right, #1a1a1a, #4a4a4a);
175
- border-radius: 20px;
176
- padding: 20px;
177
- }
178
- .generate-btn {
179
- background: linear-gradient(45deg, #2196F3, #00BCD4);
180
- border: none;
181
- color: white;
182
- font-weight: bold;
183
- border-radius: 10px;
184
- }
185
- .output-image {
186
- border-radius: 15px;
187
- box-shadow: 0 8px 16px rgba(0,0,0,0.2);
188
- }
189
- .fixed-width {
190
- max-width: 1024px;
191
- margin: auto;
192
- }
193
- """
194
- ) as demo:
195
- gr.HTML(
196
- """
197
- <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
198
- <h1 style="font-size: 2.5rem; color: #2196F3;">3D Style Image Generator</h1>
199
- <p style="font-size: 1.2rem; color: #666;">Create amazing 3D-style images with AI</p>
200
- </div>
201
- """
202
- )
203
-
204
- with gr.Row(elem_classes="container"):
205
- with gr.Column(scale=3):
206
- prompt = gr.Textbox(
207
- label="Image Description",
208
- placeholder="Describe the 3D image you want to create...",
209
- lines=3
210
- )
211
-
212
- with gr.Accordion("Advanced Settings", open=False):
213
  with gr.Row():
214
- height = gr.Slider(
215
- label="Height",
216
- minimum=256,
217
- maximum=1152,
218
- step=64,
219
- value=1024
220
- )
221
- width = gr.Slider(
222
- label="Width",
223
- minimum=256,
224
- maximum=1152,
225
- step=64,
226
- value=1024
227
  )
228
-
229
- with gr.Row():
230
- steps = gr.Slider(
231
- label="Inference Steps",
232
- minimum=6,
233
- maximum=25,
 
 
 
234
  step=1,
235
- value=8
236
- )
237
- scales = gr.Slider(
238
- label="Guidance Scale",
239
- minimum=0.0,
240
- maximum=5.0,
241
- step=0.1,
242
- value=3.5
243
  )
244
-
245
- seed = gr.Number(
246
- label="Seed (random by default, set for reproducibility)",
247
- value=get_random_seed(),
248
- precision=0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  )
250
-
251
- randomize_seed = gr.Button("🎲 Randomize Seed", elem_classes=["generate-btn"])
252
-
253
- generate_btn = gr.Button(
254
- "✨ Generate Image",
255
- elem_classes=["generate-btn"]
256
- )
257
 
258
- with gr.Column(scale=4, elem_classes=["fixed-width"]):
259
- output = gr.Image(
260
- label="Generated Image",
261
- elem_id="output-image",
262
- elem_classes=["output-image", "fixed-width"],
263
- value="3d.webp"
 
 
 
264
  )
 
265
 
266
- # Examples 섹션
267
- gr.Examples(
268
- examples=examples,
269
- inputs=prompt,
270
- outputs=output,
271
- fn=process_example, # 수정된 함수 사용
272
- cache_examples=False,
273
- examples_per_page=5
274
- )
275
 
276
- def update_seed():
277
- return get_random_seed()
 
278
 
279
- # 이벤트 핸들러
280
- generate_btn.click(
281
- process_and_save_image,
282
- inputs=[height, width, steps, scales, prompt, seed],
283
- outputs=output
284
- ).then(
285
- update_seed,
286
- outputs=[seed]
287
  )
288
 
289
- randomize_seed.click(
290
- update_seed,
291
- outputs=[seed]
 
 
 
 
 
 
 
 
 
 
 
292
  )
293
 
294
- if __name__ == "__main__":
295
- demo.launch(allowed_paths=[PERSISTENT_DIR])
 
1
+ import random
 
2
  import os
3
+ import uuid
 
 
4
  from datetime import datetime
 
 
5
  import gradio as gr
6
+ import numpy as np
7
+ import spaces
8
  import torch
9
+ from diffusers import DiffusionPipeline
10
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Create permanent storage directory
13
+ SAVE_DIR = "saved_images" # Gradio will handle the persistence
14
+ if not os.path.exists(SAVE_DIR):
15
+ os.makedirs(SAVE_DIR, exist_ok=True)
16
 
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ repo_id = "black-forest-labs/FLUX.1-dev"
19
+ adapter_id = "ginipick/flux-lora-eric-cat"
20
 
21
+ pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
22
+ pipeline.load_lora_weights(adapter_id)
23
+ pipeline = pipeline.to(device)
24
 
25
+ MAX_SEED = np.iinfo(np.int32).max
26
+ MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
27
 
28
+ def save_generated_image(image, prompt):
29
+ # Generate unique filename with timestamp
30
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
31
+ unique_id = str(uuid.uuid4())[:8]
32
+ filename = f"{timestamp}_{unique_id}.png"
33
+ filepath = os.path.join(SAVE_DIR, filename)
34
+
35
+ # Save the image
36
+ image.save(filepath)
37
+
38
+ # Save metadata
39
+ metadata_file = os.path.join(SAVE_DIR, "metadata.txt")
40
+ with open(metadata_file, "a", encoding="utf-8") as f:
41
+ f.write(f"{filename}|{prompt}|{timestamp}\n")
42
+
43
+ return filepath
 
 
 
 
 
44
 
45
+ def load_generated_images():
46
+ if not os.path.exists(SAVE_DIR):
47
+ return []
48
+
49
+ # Load all images from the directory
50
+ image_files = [os.path.join(SAVE_DIR, f) for f in os.listdir(SAVE_DIR)
51
+ if f.endswith(('.png', '.jpg', '.jpeg', '.webp'))]
52
+ # Sort by creation time (newest first)
53
+ image_files.sort(key=lambda x: os.path.getctime(x), reverse=True)
54
+ return image_files
55
+
56
+ def load_predefined_images():
57
+ # Return empty list since we're not using predefined images
58
+ return []
59
+
60
+ @spaces.GPU(duration=120)
61
+ def inference(
62
+ prompt: str,
63
+ seed: int,
64
+ randomize_seed: bool,
65
+ width: int,
66
+ height: int,
67
+ guidance_scale: float,
68
+ num_inference_steps: int,
69
+ lora_scale: float,
70
+ progress: gr.Progress = gr.Progress(track_tqdm=True),
71
+ ):
72
+ if randomize_seed:
73
+ seed = random.randint(0, MAX_SEED)
74
+ generator = torch.Generator(device=device).manual_seed(seed)
75
+
76
+ image = pipeline(
77
+ prompt=prompt,
78
+ guidance_scale=guidance_scale,
79
+ num_inference_steps=num_inference_steps,
80
+ width=width,
81
+ height=height,
82
+ generator=generator,
83
+ joint_attention_kwargs={"scale": lora_scale},
84
+ ).images[0]
85
+
86
+ # Save the generated image
87
+ filepath = save_generated_image(image, prompt)
88
+
89
+ # Return the image, seed, and updated gallery
90
+ return image, seed, load_generated_images()
91
 
 
92
  examples = [
93
+ "A fluffy white Persian cat in a tiny royal hanbok, sitting regally on a velvet cushion. The cat wears a miniature golden crown and a silk hanbok in pastel pink and mint green. Its long fur is perfectly groomed, and it has bright blue eyes that sparkle with nobility. The background features traditional Korean palace decorations and cherry blossoms floating in the air. [trigger]",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ "An adorable orange tabby cat as a master chef, wearing a white chef's hat and apron, standing on its hind legs at a tiny kitchen counter. The cat is delicately holding a miniature rolling pin, preparing cat-shaped cookies. The cozy kitchen background features tiny cooking utensils and a warm, golden afternoon light streaming through the window. [trigger]",
 
96
 
97
+ "A playful black and white tuxedo cat as a professional photographer, wearing tiny round glasses and a vintage camera around its neck. The cat is perfectly balanced on a tree branch, attempting to photograph a butterfly. It wears a cute brown leather camera bag and a mini beret, looking artistic and focused. [trigger]",
 
 
98
 
99
+ "A sleepy Scottish Fold cat in astronaut gear, floating inside a spaceship cabin. The cat wears a custom-fit space suit with cute patches, gently batting at floating star-shaped toys. Through the spaceship window, Earth and twinkling stars create a magical cosmic background. [trigger]",
 
 
 
100
 
101
+ "A graceful Siamese ballet dancer cat in a sparkly pink tutu, performing a perfect pirouette on a miniature stage. The cat wears tiny satin ballet slippers on its paws and a crystal tiara. The stage is lit with soft spotlights, and rose petals are scattered around its dancing feet. [trigger]",
102
 
103
+ "A adventurous calico cat explorer in safari gear, riding on top of a friendly elephant. The cat wears a tiny khaki vest with many pockets, a safari hat, and carries a miniature map. The background shows a beautiful sunset over the African savanna with acacia trees and colorful birds flying overhead. [trigger]"
104
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ css = """
107
+ footer {
108
+ visibility: hidden;
109
+ }
110
+ """
111
 
112
+ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css, analytics_enabled=False) as demo:
113
+ gr.HTML('<div class="title"> First CAT of Huggingface </div>')
114
+ gr.HTML('<div class="title">😄Image to Video Explore: <a href="https://huggingface.co/spaces/ginigen/theater" target="_blank">https://huggingface.co/spaces/ginigen/theater</a></div>')
 
 
 
 
 
 
 
115
 
116
+ with gr.Tabs() as tabs:
117
+ with gr.Tab("Generation"):
118
+ with gr.Column(elem_id="col-container"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  with gr.Row():
120
+ prompt = gr.Text(
121
+ label="Prompt",
122
+ show_label=False,
123
+ max_lines=1,
124
+ placeholder="Enter your prompt",
125
+ container=False,
 
 
 
 
 
 
 
126
  )
127
+ run_button = gr.Button("Run", scale=0)
128
+
129
+ result = gr.Image(label="Result", show_label=False)
130
+
131
+ with gr.Accordion("Advanced Settings", open=False):
132
+ seed = gr.Slider(
133
+ label="Seed",
134
+ minimum=0,
135
+ maximum=MAX_SEED,
136
  step=1,
137
+ value=42,
 
 
 
 
 
 
 
138
  )
139
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
140
+
141
+ with gr.Row():
142
+ width = gr.Slider(
143
+ label="Width",
144
+ minimum=256,
145
+ maximum=MAX_IMAGE_SIZE,
146
+ step=32,
147
+ value=1024,
148
+ )
149
+ height = gr.Slider(
150
+ label="Height",
151
+ minimum=256,
152
+ maximum=MAX_IMAGE_SIZE,
153
+ step=32,
154
+ value=768,
155
+ )
156
+
157
+ with gr.Row():
158
+ guidance_scale = gr.Slider(
159
+ label="Guidance scale",
160
+ minimum=0.0,
161
+ maximum=10.0,
162
+ step=0.1,
163
+ value=3.5,
164
+ )
165
+ num_inference_steps = gr.Slider(
166
+ label="Number of inference steps",
167
+ minimum=1,
168
+ maximum=50,
169
+ step=1,
170
+ value=30,
171
+ )
172
+ lora_scale = gr.Slider(
173
+ label="LoRA scale",
174
+ minimum=0.0,
175
+ maximum=1.0,
176
+ step=0.1,
177
+ value=1.0,
178
+ )
179
+
180
+ gr.Examples(
181
+ examples=examples,
182
+ inputs=[prompt],
183
+ outputs=[result, seed],
184
  )
 
 
 
 
 
 
 
185
 
186
+ with gr.Tab("Gallery"):
187
+ gallery_header = gr.Markdown("### Generated Images Gallery")
188
+ generated_gallery = gr.Gallery(
189
+ label="Generated Images",
190
+ columns=6,
191
+ show_label=False,
192
+ value=load_generated_images(),
193
+ elem_id="generated_gallery",
194
+ height="auto"
195
  )
196
+ refresh_btn = gr.Button("🔄 Refresh Gallery")
197
 
 
 
 
 
 
 
 
 
 
198
 
199
+ # Event handlers
200
+ def refresh_gallery():
201
+ return load_generated_images()
202
 
203
+ refresh_btn.click(
204
+ fn=refresh_gallery,
205
+ inputs=None,
206
+ outputs=generated_gallery,
 
 
 
 
207
  )
208
 
209
+ gr.on(
210
+ triggers=[run_button.click, prompt.submit],
211
+ fn=inference,
212
+ inputs=[
213
+ prompt,
214
+ seed,
215
+ randomize_seed,
216
+ width,
217
+ height,
218
+ guidance_scale,
219
+ num_inference_steps,
220
+ lora_scale,
221
+ ],
222
+ outputs=[result, seed, generated_gallery],
223
  )
224
 
225
+ demo.queue()
226
+ demo.launch()