Spaces:
No application file
No application file
import time | |
import torch | |
import contextlib | |
from ldm_patched.modules import model_management | |
from ldm_patched.modules.ops import use_patched_ops | |
def automatic_memory_management(): | |
model_management.free_memory( | |
memory_required=3 * 1024 * 1024 * 1024, | |
device=model_management.get_torch_device() | |
) | |
module_list = [] | |
original_init = torch.nn.Module.__init__ | |
original_to = torch.nn.Module.to | |
def patched_init(self, *args, **kwargs): | |
module_list.append(self) | |
return original_init(self, *args, **kwargs) | |
def patched_to(self, *args, **kwargs): | |
module_list.append(self) | |
return original_to(self, *args, **kwargs) | |
try: | |
torch.nn.Module.__init__ = patched_init | |
torch.nn.Module.to = patched_to | |
yield | |
finally: | |
torch.nn.Module.__init__ = original_init | |
torch.nn.Module.to = original_to | |
start = time.perf_counter() | |
module_list = set(module_list) | |
for module in module_list: | |
module.cpu() | |
model_management.soft_empty_cache() | |
end = time.perf_counter() | |
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.') | |
return | |