Use torch.no_grad() and more garbage collection

This commit is contained in:
Gnome Ann 2021-08-21 12:15:31 -04:00
parent fae15b8a17
commit 3c9ce2c541

View File

@ -1094,6 +1094,7 @@ def generate(txt, min, max):
else:
gen_in = txt
with torch.no_grad():
genout = generator(
gen_in,
do_sample=True,
@ -1126,6 +1127,8 @@ def generate(txt, min, max):
# Clear CUDA cache again if using GPU
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
del genout
gc.collect()
torch.cuda.empty_cache()
set_aibusy(0)