Use torch.no_grad() and more garbage collection

This commit is contained in:
Gnome Ann 2021-08-21 12:15:31 -04:00
parent fae15b8a17
commit 3c9ce2c541
1 changed files with 18 additions and 15 deletions

View File

@ -1094,21 +1094,22 @@ def generate(txt, min, max):
else:
gen_in = txt
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
tfs=tfs,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
return_full_text=False,
num_return_sequences=vars.numseqs
)
with torch.no_grad():
genout = generator(
gen_in,
do_sample=True,
min_length=min,
max_length=max,
repetition_penalty=vars.rep_pen,
top_p=top_p,
top_k=top_k,
tfs=tfs,
temperature=vars.temp,
bad_words_ids=vars.badwordsids,
use_cache=True,
return_full_text=False,
num_return_sequences=vars.numseqs
)
except Exception as e:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, e, colors.END))
@ -1126,6 +1127,8 @@ def generate(txt, min, max):
# Clear CUDA cache again if using GPU
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
del genout
gc.collect()
torch.cuda.empty_cache()
set_aibusy(0)