GPT2Tokenizer for TPU
This commit is contained in:
parent
60d09899ea
commit
6c32bc18d7
|
@ -46,7 +46,7 @@ from jax.experimental import maps
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import haiku as hk
|
import haiku as hk
|
||||||
from transformers import AutoTokenizer, GPT2TokenizerFast, AutoModelForCausalLM, GPTNeoForCausalLM
|
from transformers import AutoTokenizer, GPT2Tokenizer, AutoModelForCausalLM, GPTNeoForCausalLM
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
from mesh_transformer.checkpoint import read_ckpt_lowmem
|
||||||
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
|
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerShard, PlaceholderTensor
|
||||||
|
@ -1061,7 +1061,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
"pe_rotary_dims": 64,
|
"pe_rotary_dims": 64,
|
||||||
"seq": 2048,
|
"seq": 2048,
|
||||||
"cores_per_replica": 8,
|
"cores_per_replica": 8,
|
||||||
"tokenizer_class": "GPT2TokenizerFast",
|
"tokenizer_class": "GPT2Tokenizer",
|
||||||
"tokenizer": "gpt2",
|
"tokenizer": "gpt2",
|
||||||
}
|
}
|
||||||
params = kwargs
|
params = kwargs
|
||||||
|
@ -1079,7 +1079,7 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
"pe_rotary_dims": 24,
|
"pe_rotary_dims": 24,
|
||||||
"seq": 2048,
|
"seq": 2048,
|
||||||
"cores_per_replica": 8,
|
"cores_per_replica": 8,
|
||||||
"tokenizer_class": "GPT2TokenizerFast",
|
"tokenizer_class": "GPT2Tokenizer",
|
||||||
"tokenizer": "gpt2",
|
"tokenizer": "gpt2",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1357,9 +1357,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
tokenizer = GPT2Tokenizer.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, revision=vars.revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1372,9 +1372,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
tokenizer = GPT2Tokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), revision=vars.revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -1387,9 +1387,9 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
||||||
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
tokenizer = AutoTokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
tokenizer = GPT2Tokenizer.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
model = AutoModelForCausalLM.from_pretrained(vars.model, revision=vars.revision, cache_dir="cache")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Reference in New Issue