diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 6f5500b7..03f4cdde 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1045,9 +1045,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo except ValueError as e: tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") try: - model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") except ValueError as e: - model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") else: try: tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")