|
#!/usr/bin/env python |
|
|
|
import os |
|
import shutil |
|
import sys |
|
|
|
from diffusers import StableDiffusionPipeline |
|
from diffusers.pipelines.stable_diffusion.safety_checker import \ |
|
StableDiffusionSafetyChecker |
|
|
|
# append project directory to path so predict.py can be imported |
|
sys.path.append('.') |
|
|
|
from predict import MODEL_CACHE, MODEL_ID, SAFETY_MODEL_ID |
|
|
|
if os.path.exists(MODEL_CACHE): |
|
shutil.rmtree(MODEL_CACHE) |
|
os.makedirs(MODEL_CACHE, exist_ok=True) |
|
|
|
saftey_checker = StableDiffusionSafetyChecker.from_pretrained( |
|
SAFETY_MODEL_ID, |
|
cache_dir=MODEL_CACHE, |
|
) |
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
MODEL_ID, |
|
cache_dir=MODEL_CACHE, |
|
) |
|
|