From 48d07adb5440c2a06e2a679e908b1a0c61119467 Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Sat, 12 Mar 2022 23:19:35 -0500 Subject: [PATCH] Also fallback to generic GPT2 tokenizer in Colab TPU instances --- tpu_mtj_backend.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index a565b578..a77eb8b9 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1035,29 +1035,38 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo if(os.path.isdir(vars.custmodpth)): try: tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") + except Exception as e: + try: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache") + except Exception as e: + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache") try: model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache") - except ValueError as e: + except Exception as e: model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache") elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))): try: tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + except Exception as e: + try: + tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") + except Exception as e: + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache") try: model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") - except ValueError as e: + except Exception as e: model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache") else: try: tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache") - except ValueError as e: - tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") + except Exception as e: + try: + tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache") + except Exception as e: + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", cache_dir="cache") try: model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache") - except ValueError as e: + except Exception as e: model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache") network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))