|
from .diffusion_utils import build_pipeline |
|
|
|
NAME_TO_MODEL = { |
|
"stable-diffusion-v1-4": |
|
{ |
|
"model" : "CompVis/stable-diffusion-v1-4", |
|
"unet" : "CompVis/stable-diffusion-v1-4", |
|
"tokenizer" : "openai/clip-vit-large-patch14", |
|
"text_encoder" : "openai/clip-vit-large-patch14", |
|
}, |
|
"stable_diffusion_v2_1": |
|
{ |
|
"model" : "stabilityai/stable-diffusion-2-1", |
|
"unet" : "stabilityai/stable-diffusion-2-1", |
|
"tokenizer" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", |
|
"text_encoder" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", |
|
} |
|
} |
|
|
|
def get_model(model_name): |
|
model = NAME_TO_MODEL.get(model_name) |
|
if model is None: |
|
raise ValueError(f"Model name {model_name} not found. Available models: {list(NAME_TO_MODEL.keys())}") |
|
vae, tokenizer, text_encoder, unet = build_pipeline(model["model"], model["tokenizer"], model["text_encoder"], model["unet"]) |
|
return vae, tokenizer, text_encoder, unet |