Spaces:
Runtime error
Runtime error
# modules/model.py | |
import os | |
import torch | |
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline | |
from transformers import AutoencoderKL | |
import spaces | |
def get_checkpoints(folder): | |
checkpoints = [] | |
for file in os.listdir(folder): | |
if file.endswith(('.safetensors', '.ckpt', '.pt', '.pth')): | |
checkpoints.append(file) | |
return checkpoints | |
def load_model(checkpoint, vae, checkpoint_folder, vae_folder): | |
# Memilih pipeline yang sesuai | |
if "sdxl" in checkpoint.lower(): | |
pipeline_class = StableDiffusionXLPipeline | |
else: | |
pipeline_class = StableDiffusionPipeline | |
# Load checkpoint | |
if checkpoint in get_checkpoints(checkpoint_folder): | |
checkpoint_path = os.path.join(checkpoint_folder, checkpoint) | |
try: | |
model = pipeline_class.from_single_file(checkpoint_path, torch_dtype=torch.float16) | |
except Exception as e: | |
model = pipeline_class.from_pretrained(checkpoint_path, torch_dtype=torch.float16) | |
else: | |
if checkpoint.startswith("http"): | |
try: | |
model = pipeline_class.from_single_file(checkpoint, torch_dtype=torch.float16) | |
except Exception as e: | |
model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16) | |
else: | |
model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16) | |
# Load VAE | |
if vae != "none": | |
if vae in get_checkpoints(vae_folder): | |
vae_path = os.path.join(vae_folder, vae) | |
vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16) | |
else: | |
vae_model = AutoencoderKL.from_pretrained(vae, torch_dtype=torch.float16) | |
model.vae = vae_model | |
return model | |
def get_model_and_vae_options(): | |
checkpoint_folder = "../models/checkpoint/" | |
vae_folder = "../models/vae/" | |
model_file = "../models/model.py" | |
# Membaca model dan VAE dari models/model.py | |
exec(open(model_file).read()) | |
# Mendapatkan daftar checkpoint dan VAE dari folder | |
checkpoints = get_checkpoints(checkpoint_folder) | |
vae_files = get_checkpoints(vae_folder) | |
# Menggabungkan daftar checkpoint, model Diffusers, dan VAE | |
all_models = checkpoints + diffusers | |
all_vaes = ["none"] + vae_files + vae | |
# Mengubah format dropdown | |
formatted_models = [os.path.basename(model) if not model.startswith("http") else model for model in all_models] | |
formatted_vaes = [os.path.basename(vae) if not vae.startswith("http") else vae for vae in all_vaes] | |
return formatted_models, formatted_vaes | |
def generate_image(text, neg_prompt, width, height, scheduler, num_steps, num_images, cfg_scale, seed, model, vae): | |
checkpoint_folder = "../models/checkpoint/" | |
vae_folder = "../models/vae/" | |
model = load_model(model, vae, checkpoint_folder, vae_folder) | |
images = model([text], height=height, width=width, num_inference_steps=num_steps, guidance_scale=cfg_scale, negative_prompt=neg_prompt) | |
return images.images | |