Use torch.no_grad() and more garbage collection
This commit is contained in:
parent
fae15b8a17
commit
3c9ce2c541
33
aiserver.py
33
aiserver.py
|
@ -1094,21 +1094,22 @@ def generate(txt, min, max):
|
|||
else:
|
||||
gen_in = txt
|
||||
|
||||
genout = generator(
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
min_length=min,
|
||||
max_length=max,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
tfs=tfs,
|
||||
temperature=vars.temp,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
return_full_text=False,
|
||||
num_return_sequences=vars.numseqs
|
||||
)
|
||||
with torch.no_grad():
|
||||
genout = generator(
|
||||
gen_in,
|
||||
do_sample=True,
|
||||
min_length=min,
|
||||
max_length=max,
|
||||
repetition_penalty=vars.rep_pen,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
tfs=tfs,
|
||||
temperature=vars.temp,
|
||||
bad_words_ids=vars.badwordsids,
|
||||
use_cache=True,
|
||||
return_full_text=False,
|
||||
num_return_sequences=vars.numseqs
|
||||
)
|
||||
except Exception as e:
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
||||
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
||||
|
@ -1126,6 +1127,8 @@ def generate(txt, min, max):
|
|||
|
||||
# Clear CUDA cache again if using GPU
|
||||
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
||||
del genout
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
set_aibusy(0)
|
||||
|
|
Loading…
Reference in New Issue