Use torch.no_grad() and more garbage collection
This commit is contained in:
parent
fae15b8a17
commit
3c9ce2c541
33
aiserver.py
33
aiserver.py
|
@ -1094,21 +1094,22 @@ def generate(txt, min, max):
|
||||||
else:
|
else:
|
||||||
gen_in = txt
|
gen_in = txt
|
||||||
|
|
||||||
genout = generator(
|
with torch.no_grad():
|
||||||
gen_in,
|
genout = generator(
|
||||||
do_sample=True,
|
gen_in,
|
||||||
min_length=min,
|
do_sample=True,
|
||||||
max_length=max,
|
min_length=min,
|
||||||
repetition_penalty=vars.rep_pen,
|
max_length=max,
|
||||||
top_p=top_p,
|
repetition_penalty=vars.rep_pen,
|
||||||
top_k=top_k,
|
top_p=top_p,
|
||||||
tfs=tfs,
|
top_k=top_k,
|
||||||
temperature=vars.temp,
|
tfs=tfs,
|
||||||
bad_words_ids=vars.badwordsids,
|
temperature=vars.temp,
|
||||||
use_cache=True,
|
bad_words_ids=vars.badwordsids,
|
||||||
return_full_text=False,
|
use_cache=True,
|
||||||
num_return_sequences=vars.numseqs
|
return_full_text=False,
|
||||||
)
|
num_return_sequences=vars.numseqs
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
emit('from_server', {'cmd': 'errmsg', 'data': 'Error occured during generator call, please check console.'}, broadcast=True)
|
||||||
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
print("{0}{1}{2}".format(colors.RED, e, colors.END))
|
||||||
|
@ -1126,6 +1127,8 @@ def generate(txt, min, max):
|
||||||
|
|
||||||
# Clear CUDA cache again if using GPU
|
# Clear CUDA cache again if using GPU
|
||||||
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
if(vars.hascuda and (vars.usegpu or vars.breakmodel)):
|
||||||
|
del genout
|
||||||
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
|
|
Loading…
Reference in New Issue