DamarJati commited on
Commit
c8b0cc9
·
verified ·
1 Parent(s): ac8cc1a

Update modules/model.py

Browse files
Files changed (1) hide show
  1. modules/model.py +0 -78
modules/model.py CHANGED
@@ -1,78 +0,0 @@
1
- # modules/model.py
2
- import os
3
- import torch
4
- from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
5
- from transformers import AutoencoderKL
6
- import spaces
7
-
8
- def get_checkpoints(folder):
9
- checkpoints = []
10
- for file in os.listdir(folder):
11
- if file.endswith(('.safetensors', '.ckpt', '.pt', '.pth')):
12
- checkpoints.append(file)
13
- return checkpoints
14
-
15
- def load_model(checkpoint, vae, checkpoint_folder, vae_folder):
16
- # Memilih pipeline yang sesuai
17
- if "sdxl" in checkpoint.lower():
18
- pipeline_class = StableDiffusionXLPipeline
19
- else:
20
- pipeline_class = StableDiffusionPipeline
21
-
22
- # Load checkpoint
23
- if checkpoint in get_checkpoints(checkpoint_folder):
24
- checkpoint_path = os.path.join(checkpoint_folder, checkpoint)
25
- try:
26
- model = pipeline_class.from_single_file(checkpoint_path, torch_dtype=torch.float16)
27
- except Exception as e:
28
- model = pipeline_class.from_pretrained(checkpoint_path, torch_dtype=torch.float16)
29
- else:
30
- if checkpoint.startswith("http"):
31
- try:
32
- model = pipeline_class.from_single_file(checkpoint, torch_dtype=torch.float16)
33
- except Exception as e:
34
- model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16)
35
- else:
36
- model = pipeline_class.from_pretrained(checkpoint, torch_dtype=torch.float16)
37
-
38
- # Load VAE
39
- if vae != "none":
40
- if vae in get_checkpoints(vae_folder):
41
- vae_path = os.path.join(vae_folder, vae)
42
- vae_model = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16)
43
- else:
44
- vae_model = AutoencoderKL.from_pretrained(vae, torch_dtype=torch.float16)
45
- model.vae = vae_model
46
-
47
- return model
48
-
49
- def get_model_and_vae_options():
50
- checkpoint_folder = "../models/checkpoint/"
51
- vae_folder = "../models/vae/"
52
- model_file = "../models/model.py"
53
-
54
- # Membaca model dan VAE dari models/model.py
55
- exec(open(model_file).read())
56
-
57
- # Mendapatkan daftar checkpoint dan VAE dari folder
58
- checkpoints = get_checkpoints(checkpoint_folder)
59
- vae_files = get_checkpoints(vae_folder)
60
-
61
- # Menggabungkan daftar checkpoint, model Diffusers, dan VAE
62
- all_models = checkpoints + diffusers
63
- all_vaes = ["none"] + vae_files + vae
64
-
65
- # Mengubah format dropdown
66
- formatted_models = [os.path.basename(model) if not model.startswith("http") else model for model in all_models]
67
- formatted_vaes = [os.path.basename(vae) if not vae.startswith("http") else vae for vae in all_vaes]
68
-
69
- return formatted_models, formatted_vaes
70
-
71
- @spaces.GPU()
72
- def generate_image(text, neg_prompt, width, height, scheduler, num_steps, num_images, cfg_scale, seed, model, vae):
73
- checkpoint_folder = "../models/checkpoint/"
74
- vae_folder = "../models/vae/"
75
- model = load_model(model, vae, checkpoint_folder, vae_folder)
76
- images = model([text], height=height, width=width, num_inference_steps=num_steps, guidance_scale=cfg_scale, negative_prompt=neg_prompt)
77
- return images.images
78
-