Remove Lowmem from TPU

This commit is contained in:
henk717
2022-03-09 19:21:15 +01:00
parent 9dee9b5c6d
commit 68281184bf

View File

@ -1045,9 +1045,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
try:
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
else:
try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")