PMRF / arch /hourglass /flags.py
ohayonguy
first commit fixed
b7f3942
raw
history blame
1.58 kB
"""k-diffusion transformer diffusion models, version 2.
Codes adopted from https://github.com/crowsonkb/k-diffusion
"""
from contextlib import contextmanager
from functools import update_wrapper
import os
import threading
import torch
def get_use_compile():
return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1"
def get_use_flash_attention_2():
return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1"
state = threading.local()
state.checkpointing = False
@contextmanager
def checkpointing(enable=True):
try:
old_checkpointing, state.checkpointing = state.checkpointing, enable
yield
finally:
state.checkpointing = old_checkpointing
def get_checkpointing():
return getattr(state, "checkpointing", False)
class compile_wrap:
def __init__(self, function, *args, **kwargs):
self.function = function
self.args = args
self.kwargs = kwargs
self._compiled_function = None
update_wrapper(self, function)
@property
def compiled_function(self):
if self._compiled_function is not None:
return self._compiled_function
if get_use_compile():
try:
self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs)
except RuntimeError:
self._compiled_function = self.function
else:
self._compiled_function = self.function
return self._compiled_function
def __call__(self, *args, **kwargs):
return self.compiled_function(*args, **kwargs)