Use torch.no_grad() and more garbage collection

This commit is contained in:
Gnome Ann 2021-08-21 12:15:31 -04:00
parent fae15b8a17
commit 3c9ce2c541
1 changed files with 18 additions and 15 deletions

View File

@ -1094,21 +1094,22 @@ def generate(txt, min, max):
else: else:
gen_in = txt gen_in = txt
genout = generator( with torch.no_grad():
gen_in, genout = generator(
do_sample=True, gen_in,
min_length=min, do_sample=True,
max_length=max, min_length=min,
repetition_penalty=vars.rep_pen, max_length=max,
top_p=top_p, repetition_penalty=vars.rep_pen,
top_k=top_k, top_p=top_p,
tfs=tfs, top_k=top_k,
temperature=vars.temp, tfs=tfs,
bad_words_ids=vars.badwordsids, temperature=vars.temp,
use_cache=True, bad_words_ids=vars.badwordsids,
return_full_text=False, use_cache=True,
num_return_sequences=vars.numseqs return_full_text=False,
) num_return_sequences=vars.numseqs
)
except Exception as e: except Exception as e:
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True) emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
print("{0}{1}{2}".format(colors.RED, e, colors.END)) print("{0}{1}{2}".format(colors.RED, e, colors.END))
@ -1126,6 +1127,8 @@ def generate(txt, min, max):
# Clear CUDA cache again if using GPU # Clear CUDA cache again if using GPU
if(vars.hascuda and (vars.usegpu or vars.breakmodel)): if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
del genout
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
set_aibusy(0) set_aibusy(0)