From 68281184bf23123abf23ee77a9d56b45ad213364 Mon Sep 17 00:00:00 2001 From: henk717 Date: Wed, 9 Mar 2022 19:21:15 +0100 Subject: [PATCH] Remove Lowmem from TPU --- tpu_mtj_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 6f5500b7..03f4cdde 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1045,9 +1045,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo except ValueError as e: tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") try: - model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") except ValueError as e: - model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem) + model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") else: try: tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")