diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 743e5fc0..e8594721 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -46,7 +46,7 @@ from jax.experimental import maps import jax.numpy as jnp import numpy as np import haiku as hk -from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM +from transformers import AutoTokenizer, GPT2Tokenizer, AutoModelForCausalLM, GPTNeoForCausalLM from tokenizers import Tokenizer from mesh_transformer.checkpoint import read_ckpt_lowmem from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor @@ -1061,7 +1061,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo "pe_rotary_dims": 64, "seq": 2048, "cores_per_replica": 8, - "tokenizer_class": "GPT2TokenizerFast", + "tokenizer_class": "GPT2Tokenizer", "tokenizer": "gpt2", } params = kwargs @@ -1079,7 +1079,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo "pe_rotary_dims": 24, "seq": 2048, "cores_per_replica": 8, - "tokenizer_class": "GPT2TokenizerFast", + "tokenizer_class": "GPT2Tokenizer", "tokenizer": "gpt2", } @@ -1357,9 +1357,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache") except Exception as e: try: - tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache") + tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache") except Exception as e: - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") + tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") try: model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache") except Exception as e: @@ -1372,9 +1372,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache") except Exception as e: try: - tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache") + tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache") except Exception as e: - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") + tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") try: model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache") except Exception as e: @@ -1387,9 +1387,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") except Exception as e: try: - tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") + tokenizer = GPT2Tokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") except Exception as e: - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") + tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache") try: model = AutoModelForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache") except Exception as e: