# modules/model.py import os import torch from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from transformers import AutoencoderKL def get_checkpoints(folder): checkpoints = [] for file in os.listdir(folder): if file.endswith(('.safetensors', '.ckpt', '.pt', '.pth')): checkpoints.append(file) return checkpoints def load_model(checkpoint, vae, checkpoint_folder, vae_folder): # Memilih pipeline yang sesuai if "sdxl" in checkpoint.lower(): pipeline_class = StableDiffusionXLPipeline else: pipeline_class = StableDiffusionPipeline # Load checkpoint if checkpoint in get_checkpoints(checkpoint_folder): checkpoint_path = os.path.join(checkpoint_folder, checkpoint) try: model = pipeline_class.from_single_file(checkpoint_path, torch_dtype=torch.float16) except Exception as e: model = pipeline_class.from_pretrained(checkpoint_path, torch_dtype=torch.float16) else: if checkpoint.startswith("http"): try: model = pipeline_class.from_single_file(checkpoint, torch_dtype=torch.float16) except Exception as e: model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16) else: model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16) # Load VAE if vae != "none": if vae in get_checkpoints(vae_folder): vae_path = os.path.join(vae_folder, vae) vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16) else: vae_model = AutoencoderKL.from_pretrained(vae, torch_dtype=torch.float16) model.vae = vae_model return model def get_model_and_vae_options(): checkpoint_folder = "../models/checkpoint/" vae_folder = "../models/vae/" model_file = "../models/models.py" # Membaca model dan VAE dari models/model.py exec(open(model_file).read()) # Mendapatkan daftar checkpoint dan VAE dari folder checkpoints = get_checkpoints(checkpoint_folder) vae_files = get_checkpoints(vae_folder) # Menggabungkan daftar checkpoint, model Diffusers, dan VAE all_models = checkpoints + diffusers all_vaes = ["none"] + vae_files + vae # Mengubah format dropdown formatted_models = [os.path.basename(model) if not model.startswith("http") else model for model in all_models] formatted_vaes = [os.path.basename(vae) if not vae.startswith("http") else vae for vae in all_vaes] return formatted_models, formatted_vaes # Wrapper untuk fungsi generate_image di text2img def generate_image(text, checkpoint, vae): checkpoint_folder = "../models/checkpoint/" vae_folder = "../models/vae/" model = load_model(checkpoint, vae, checkpoint_folder, vae_folder) image = model([text])[0] return image