GPT2Tokenizer for TPU
This commit is contained in:
parent
60d09899ea
commit
6c32bc18d7
|
@ -46,7 +46,7 @@ from jax.experimental import maps
|
|||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import haiku as hk
|
||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
||||
from transformers import AutoTokenizer, GPT2Tokenizer, AutoModelForCausalLM, GPTNeoForCausalLM
|
||||
from tokenizers import Tokenizer
|
||||
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
||||
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
|
||||
|
@ -1061,7 +1061,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
"pe_rotary_dims": 64,
|
||||
"seq": 2048,
|
||||
"cores_per_replica": 8,
|
||||
"tokenizer_class": "GPT2TokenizerFast",
|
||||
"tokenizer_class": "GPT2Tokenizer",
|
||||
"tokenizer": "gpt2",
|
||||
}
|
||||
params = kwargs
|
||||
|
@ -1079,7 +1079,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
"pe_rotary_dims": 24,
|
||||
"seq": 2048,
|
||||
"cores_per_replica": 8,
|
||||
"tokenizer_class": "GPT2TokenizerFast",
|
||||
"tokenizer_class": "GPT2Tokenizer",
|
||||
"tokenizer": "gpt2",
|
||||
}
|
||||
|
||||
|
@ -1357,9 +1357,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
try:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
|
@ -1372,9 +1372,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
try:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
|
@ -1387,9 +1387,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
try:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||
try:
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
||||
except Exception as e:
|
||||
|
|
Loading…
Reference in New Issue