kbora's picture
Upload 51 files
6af7294
raw
history blame
985 Bytes
from .diffusion_utils import build_pipeline
NAME_TO_MODEL = {
"stable-diffusion-v1-4":
{
"model" : "CompVis/stable-diffusion-v1-4",
"unet" : "CompVis/stable-diffusion-v1-4",
"tokenizer" : "openai/clip-vit-large-patch14",
"text_encoder" : "openai/clip-vit-large-patch14",
},
"stable_diffusion_v2_1":
{
"model" : "stabilityai/stable-diffusion-2-1",
"unet" : "stabilityai/stable-diffusion-2-1",
"tokenizer" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
"text_encoder" : "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
}
}
def get_model(model_name):
model = NAME_TO_MODEL.get(model_name)
if model is None:
raise ValueError(f"Model name {model_name} not found. Available models: {list(NAME_TO_MODEL.keys())}")
vae, tokenizer, text_encoder, unet = build_pipeline(model["model"], model["tokenizer"], model["text_encoder"], model["unet"])
return vae, tokenizer, text_encoder, unet