Spaces:
Runtime error
Runtime error
Update modules/model.py
Browse files- modules/model.py +0 -78
modules/model.py
CHANGED
@@ -1,78 +0,0 @@
|
|
1 |
-
# modules/model.py
|
2 |
-
import os
|
3 |
-
import torch
|
4 |
-
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
5 |
-
from transformers import AutoencoderKL
|
6 |
-
import spaces
|
7 |
-
|
8 |
-
def get_checkpoints(folder):
|
9 |
-
checkpoints = []
|
10 |
-
for file in os.listdir(folder):
|
11 |
-
if file.endswith(('.safetensors', '.ckpt', '.pt', '.pth')):
|
12 |
-
checkpoints.append(file)
|
13 |
-
return checkpoints
|
14 |
-
|
15 |
-
def load_model(checkpoint, vae, checkpoint_folder, vae_folder):
|
16 |
-
# Memilih pipeline yang sesuai
|
17 |
-
if "sdxl" in checkpoint.lower():
|
18 |
-
pipeline_class = StableDiffusionXLPipeline
|
19 |
-
else:
|
20 |
-
pipeline_class = StableDiffusionPipeline
|
21 |
-
|
22 |
-
# Load checkpoint
|
23 |
-
if checkpoint in get_checkpoints(checkpoint_folder):
|
24 |
-
checkpoint_path = os.path.join(checkpoint_folder, checkpoint)
|
25 |
-
try:
|
26 |
-
model = pipeline_class.from_single_file(checkpoint_path, torch_dtype=torch.float16)
|
27 |
-
except Exception as e:
|
28 |
-
model = pipeline_class.from_pretrained(checkpoint_path, torch_dtype=torch.float16)
|
29 |
-
else:
|
30 |
-
if checkpoint.startswith("http"):
|
31 |
-
try:
|
32 |
-
model = pipeline_class.from_single_file(checkpoint, torch_dtype=torch.float16)
|
33 |
-
except Exception as e:
|
34 |
-
model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16)
|
35 |
-
else:
|
36 |
-
model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16)
|
37 |
-
|
38 |
-
# Load VAE
|
39 |
-
if vae != "none":
|
40 |
-
if vae in get_checkpoints(vae_folder):
|
41 |
-
vae_path = os.path.join(vae_folder, vae)
|
42 |
-
vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16)
|
43 |
-
else:
|
44 |
-
vae_model = AutoencoderKL.from_pretrained(vae, torch_dtype=torch.float16)
|
45 |
-
model.vae = vae_model
|
46 |
-
|
47 |
-
return model
|
48 |
-
|
49 |
-
def get_model_and_vae_options():
|
50 |
-
checkpoint_folder = "../models/checkpoint/"
|
51 |
-
vae_folder = "../models/vae/"
|
52 |
-
model_file = "../models/model.py"
|
53 |
-
|
54 |
-
# Membaca model dan VAE dari models/model.py
|
55 |
-
exec(open(model_file).read())
|
56 |
-
|
57 |
-
# Mendapatkan daftar checkpoint dan VAE dari folder
|
58 |
-
checkpoints = get_checkpoints(checkpoint_folder)
|
59 |
-
vae_files = get_checkpoints(vae_folder)
|
60 |
-
|
61 |
-
# Menggabungkan daftar checkpoint, model Diffusers, dan VAE
|
62 |
-
all_models = checkpoints + diffusers
|
63 |
-
all_vaes = ["none"] + vae_files + vae
|
64 |
-
|
65 |
-
# Mengubah format dropdown
|
66 |
-
formatted_models = [os.path.basename(model) if not model.startswith("http") else model for model in all_models]
|
67 |
-
formatted_vaes = [os.path.basename(vae) if not vae.startswith("http") else vae for vae in all_vaes]
|
68 |
-
|
69 |
-
return formatted_models, formatted_vaes
|
70 |
-
|
71 |
-
@spaces.GPU()
|
72 |
-
def generate_image(text, neg_prompt, width, height, scheduler, num_steps, num_images, cfg_scale, seed, model, vae):
|
73 |
-
checkpoint_folder = "../models/checkpoint/"
|
74 |
-
vae_folder = "../models/vae/"
|
75 |
-
model = load_model(model, vae, checkpoint_folder, vae_folder)
|
76 |
-
images = model([text], height=height, width=width, num_inference_steps=num_steps, guidance_scale=cfg_scale, negative_prompt=neg_prompt)
|
77 |
-
return images.images
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|