Spaces:
Running
Running
import json | |
import traceback | |
import sys | |
import os | |
SD_DIR = os.getcwd() | |
print('started in ', SD_DIR) | |
SD_UI_DIR = './ui' | |
sys.path.append(os.path.dirname(SD_UI_DIR)) | |
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '.', 'config')) | |
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '.', 'models')) | |
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder | |
from fastapi import FastAPI, HTTPException | |
from fastapi.staticfiles import StaticFiles | |
from starlette.responses import FileResponse, StreamingResponse | |
from pydantic import BaseModel | |
import logging | |
from sd_internal import Request, Response | |
app = FastAPI() | |
model_loaded = False | |
model_is_loading = False | |
modifiers_cache = None | |
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) | |
# don't show access log entries for URLs that start with the given prefix | |
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/modifier-thumbnails'] | |
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media") | |
# defaults from https://huggingface.co./blog/stable_diffusion | |
class ImageRequest(BaseModel): | |
session_id: str = "session" | |
prompt: str = "" | |
negative_prompt: str = "" | |
init_image: str = None # base64 | |
mask: str = None # base64 | |
num_outputs: int = 1 | |
num_inference_steps: int = 50 | |
guidance_scale: float = 7.5 | |
width: int = 512 | |
height: int = 512 | |
seed: int = 42 | |
prompt_strength: float = 0.8 | |
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" | |
# allow_nsfw: bool = False | |
save_to_disk_path: str = None | |
turbo: bool = True | |
use_cpu: bool = False | |
use_full_precision: bool = False | |
use_face_correction: str = None # or "GFPGANv1.3" | |
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" | |
use_stable_diffusion_model: str = "sd-v1-4" | |
show_only_filtered_image: bool = False | |
output_format: str = "jpeg" # or "png" | |
stream_progress_updates: bool = False | |
stream_image_progress: bool = False | |
class SetAppConfigRequest(BaseModel): | |
update_branch: str = "main" | |
def read_root(): | |
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} | |
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=headers) | |
async def ping(): | |
global model_loaded, model_is_loading | |
try: | |
if model_loaded: | |
return {'OK'} | |
if model_is_loading: | |
return {'ERROR'} | |
model_is_loading = True | |
from . import runtime | |
runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load()) | |
model_loaded = True | |
model_is_loading = False | |
return {'OK'} | |
except Exception as e: | |
model_is_loading = False | |
print(traceback.format_exc()) | |
return HTTPException(status_code=500, detail=str(e)) | |
# needs to support the legacy installations | |
def get_initial_model_to_load(): | |
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt') | |
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model" | |
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use) | |
config = getConfig() | |
if 'model' in config and 'stable-diffusion' in config['model']: | |
model_name = config['model']['stable-diffusion'] | |
model_path = resolve_model_to_use(model_name) | |
if os.path.exists(model_path + '.ckpt'): | |
ckpt_to_use = model_path | |
else: | |
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt') | |
return ckpt_to_use | |
def resolve_model_to_use(model_name): | |
if model_name in ('sd-v1-4', 'custom-model'): | |
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) | |
legacy_model_path = os.path.join(SD_DIR, model_name) | |
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'): | |
model_path = legacy_model_path | |
else: | |
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) | |
return model_path | |
def save_model_to_config(model_name): | |
config = getConfig() | |
if 'model' not in config: | |
config['model'] = {} | |
config['model']['stable-diffusion'] = model_name | |
setConfig(config) | |
def image(req : ImageRequest): | |
from . import runtime | |
r = Request() | |
r.session_id = req.session_id | |
r.prompt = req.prompt | |
r.negative_prompt = req.negative_prompt | |
r.init_image = req.init_image | |
r.mask = req.mask | |
r.num_outputs = req.num_outputs | |
r.num_inference_steps = req.num_inference_steps | |
r.guidance_scale = req.guidance_scale | |
r.width = req.width | |
r.height = req.height | |
r.seed = req.seed | |
r.prompt_strength = req.prompt_strength | |
r.sampler = req.sampler | |
# r.allow_nsfw = req.allow_nsfw | |
r.turbo = req.turbo | |
r.use_cpu = req.use_cpu | |
r.use_full_precision = req.use_full_precision | |
r.save_to_disk_path = req.save_to_disk_path | |
r.use_upscale: str = req.use_upscale | |
r.use_face_correction = req.use_face_correction | |
r.show_only_filtered_image = req.show_only_filtered_image | |
r.output_format = req.output_format | |
r.stream_progress_updates = True # the underlying implementation only supports streaming | |
r.stream_image_progress = req.stream_image_progress | |
r.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model) | |
save_model_to_config(req.use_stable_diffusion_model) | |
try: | |
if not req.stream_progress_updates: | |
r.stream_image_progress = False | |
res = runtime.mk_img(r) | |
if req.stream_progress_updates: | |
return StreamingResponse(res, media_type='application/json') | |
else: # compatibility mode: buffer the streaming responses, and return the last one | |
last_result = None | |
for result in res: | |
last_result = result | |
return json.loads(last_result) | |
except Exception as e: | |
print(traceback.format_exc()) | |
return HTTPException(status_code=500, detail=str(e)) | |
def stop(): | |
try: | |
if model_is_loading: | |
return {'ERROR'} | |
from . import runtime | |
runtime.stop_processing = True | |
return {'OK'} | |
except Exception as e: | |
print(traceback.format_exc()) | |
return HTTPException(status_code=500, detail=str(e)) | |
def get_image(session_id, img_id): | |
from . import runtime | |
buf = runtime.temp_images[session_id + '/' + img_id] | |
buf.seek(0) | |
return StreamingResponse(buf, media_type='image/jpeg') | |
async def setAppConfig(req : SetAppConfigRequest): | |
try: | |
config = { | |
'update_branch': req.update_branch | |
} | |
config_json_str = json.dumps(config) | |
config_bat_str = f'@set update_branch={req.update_branch}' | |
config_sh_str = f'export update_branch={req.update_branch}' | |
config_json_path = os.path.join(CONFIG_DIR, 'config.json') | |
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') | |
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') | |
with open(config_json_path, 'w') as f: | |
f.write(config_json_str) | |
with open(config_bat_path, 'w') as f: | |
f.write(config_bat_str) | |
with open(config_sh_path, 'w') as f: | |
f.write(config_sh_str) | |
return {'OK'} | |
except Exception as e: | |
print(traceback.format_exc()) | |
return HTTPException(status_code=500, detail=str(e)) | |
def getAppConfig(): | |
try: | |
config_json_path = os.path.join(CONFIG_DIR, 'config.json') | |
if not os.path.exists(config_json_path): | |
return HTTPException(status_code=500, detail="No config file") | |
with open(config_json_path, 'r') as f: | |
return json.load(f) | |
except Exception as e: | |
print(traceback.format_exc()) | |
return HTTPException(status_code=500, detail=str(e)) | |
def getConfig(): | |
try: | |
config_json_path = os.path.join(CONFIG_DIR, 'config.json') | |
if not os.path.exists(config_json_path): | |
return {} | |
with open(config_json_path, 'r') as f: | |
return json.load(f) | |
except Exception as e: | |
return {} | |
def setConfig(config): | |
try: | |
config_json_path = os.path.join(CONFIG_DIR, 'config.json') | |
with open(config_json_path, 'w') as f: | |
return json.dump(config, f) | |
except: | |
print(traceback.format_exc()) | |
def getModels(): | |
models = { | |
'active': { | |
'stable-diffusion': 'sd-v1-4', | |
}, | |
'options': { | |
'stable-diffusion': ['sd-v1-4'], | |
}, | |
} | |
# custom models | |
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion') | |
for file in os.listdir(sd_models_dir): | |
if file.endswith('.ckpt'): | |
model_name = os.path.splitext(file)[0] | |
models['options']['stable-diffusion'].append(model_name) | |
# legacy | |
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt') | |
if os.path.exists(custom_weight_path): | |
models['active']['stable-diffusion'] = 'custom-model' | |
models['options']['stable-diffusion'].append('custom-model') | |
config = getConfig() | |
if 'model' in config and 'stable-diffusion' in config['model']: | |
models['active']['stable-diffusion'] = config['model']['stable-diffusion'] | |
return models | |
def read_modifiers(): | |
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} | |
return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=headers) | |
def read_home_dir(): | |
return {outpath} | |
# don't log certain requests | |
class LogSuppressFilter(logging.Filter): | |
def filter(self, record: logging.LogRecord) -> bool: | |
path = record.getMessage() | |
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: | |
if path.find(prefix) != -1: | |
return False | |
return True | |
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) | |