adr2432's picture
Upload 302 files
070b43a
raw
history blame
2.31 kB
import gc
import traceback
from queue import Queue
from threading import Thread
import torch
import transformers
import modules.shared as shared
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
return shared.stop_everything
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""
Transforms a function that takes a callback
into a lazy iterator (generator).
Adapted from: https://stackoverflow.com/a/9969000
"""
def __init__(self, func, args=None, kwargs=None, callback=None):
self.mfunc = func
self.c_callback = callback
self.q = Queue()
self.sentinel = object()
self.args = args or []
self.kwargs = kwargs or {}
self.stop_now = False
def _callback(val):
if self.stop_now or shared.stop_everything:
raise ValueError
self.q.put(val)
def gentask():
try:
ret = self.mfunc(callback=_callback, *args, **self.kwargs)
except ValueError:
pass
except:
traceback.print_exc()
pass
clear_torch_cache()
self.q.put(self.sentinel)
if self.c_callback:
self.c_callback(ret)
self.thread = Thread(target=gentask)
self.thread.start()
def __iter__(self):
return self
def __next__(self):
obj = self.q.get(True, None)
if obj is self.sentinel:
raise StopIteration
else:
return obj
def __del__(self):
clear_torch_cache()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True
clear_torch_cache()
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()