"""k-diffusion transformer diffusion models, version 2. Codes adopted from https://github.com/crowsonkb/k-diffusion """ from contextlib import contextmanager from functools import update_wrapper import os import threading import torch def get_use_compile(): return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1" def get_use_flash_attention_2(): return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1" state = threading.local() state.checkpointing = False @contextmanager def checkpointing(enable=True): try: old_checkpointing, state.checkpointing = state.checkpointing, enable yield finally: state.checkpointing = old_checkpointing def get_checkpointing(): return getattr(state, "checkpointing", False) class compile_wrap: def __init__(self, function, *args, **kwargs): self.function = function self.args = args self.kwargs = kwargs self._compiled_function = None update_wrapper(self, function) @property def compiled_function(self): if self._compiled_function is not None: return self._compiled_function if get_use_compile(): try: self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs) except RuntimeError: self._compiled_function = self.function else: self._compiled_function = self.function return self._compiled_function def __call__(self, *args, **kwargs): return self.compiled_function(*args, **kwargs)