DamarJati commited on
Commit
ed85b56
·
verified ·
1 Parent(s): 202b8a5

Update modules/model.py

Browse files
Files changed (1) hide show
  1. modules/model.py +77 -0
modules/model.py CHANGED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/model.py
2
+ import os
3
+ import torch
4
+ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
5
+ from transformers import AutoencoderKL
6
+
7
+ def get_checkpoints(folder):
8
+ checkpoints = []
9
+ for file in os.listdir(folder):
10
+ if file.endswith(('.safetensors', '.ckpt', '.pt', '.pth')):
11
+ checkpoints.append(file)
12
+ return checkpoints
13
+
14
+ def load_model(checkpoint, vae, checkpoint_folder, vae_folder):
15
+ # Memilih pipeline yang sesuai
16
+ if "sdxl" in checkpoint.lower():
17
+ pipeline_class = StableDiffusionXLPipeline
18
+ else:
19
+ pipeline_class = StableDiffusionPipeline
20
+
21
+ # Load checkpoint
22
+ if checkpoint in get_checkpoints(checkpoint_folder):
23
+ checkpoint_path = os.path.join(checkpoint_folder, checkpoint)
24
+ try:
25
+ model = pipeline_class.from_single_file(checkpoint_path, torch_dtype=torch.float16)
26
+ except Exception as e:
27
+ model = pipeline_class.from_pretrained(checkpoint_path, torch_dtype=torch.float16)
28
+ else:
29
+ if checkpoint.startswith("http"):
30
+ try:
31
+ model = pipeline_class.from_single_file(checkpoint, torch_dtype=torch.float16)
32
+ except Exception as e:
33
+ model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16)
34
+ else:
35
+ model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16)
36
+
37
+ # Load VAE
38
+ if vae != "none":
39
+ if vae in get_checkpoints(vae_folder):
40
+ vae_path = os.path.join(vae_folder, vae)
41
+ vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16)
42
+ else:
43
+ vae_model = AutoencoderKL.from_pretrained(vae, torch_dtype=torch.float16)
44
+ model.vae = vae_model
45
+
46
+ return model
47
+
48
+ def get_model_and_vae_options():
49
+ checkpoint_folder = "../models/checkpoint/"
50
+ vae_folder = "../models/vae/"
51
+ model_file = "../models/models.py"
52
+
53
+ # Membaca model dan VAE dari models/model.py
54
+ exec(open(model_file).read())
55
+
56
+ # Mendapatkan daftar checkpoint dan VAE dari folder
57
+ checkpoints = get_checkpoints(checkpoint_folder)
58
+ vae_files = get_checkpoints(vae_folder)
59
+
60
+ # Menggabungkan daftar checkpoint, model Diffusers, dan VAE
61
+ all_models = checkpoints + diffusers
62
+ all_vaes = ["none"] + vae_files + vae
63
+
64
+ # Mengubah format dropdown
65
+ formatted_models = [os.path.basename(model) if not model.startswith("http") else model for model in all_models]
66
+ formatted_vaes = [os.path.basename(vae) if not vae.startswith("http") else vae for vae in all_vaes]
67
+
68
+ return formatted_models, formatted_vaes
69
+
70
+ # Wrapper untuk fungsi generate_image di text2img
71
+ def generate_image(text, checkpoint, vae):
72
+ checkpoint_folder = "../models/checkpoint/"
73
+ vae_folder = "../models/vae/"
74
+ model = load_model(checkpoint, vae, checkpoint_folder, vae_folder)
75
+ image = model([text])[0]
76
+ return image
77
+