mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-17 12:10:49 +01:00
AutoTokenizer
Futureproofing for future tokenizers, for now this is not needed since everything uses GPT2. But when that changes we want to be prepared. Not all models have a proper tokenizer config, so if we can't find one we fall back to GPT2.
This commit is contained in:
parent
00a0cea077
commit
e1cd34268b
17
aiserver.py
17
aiserver.py
@ -615,7 +615,7 @@ print("{0}OK!{1}".format(colors.GREEN, colors.END))
|
|||||||
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransformerGPTJ"]):
|
||||||
if(not vars.noai):
|
if(not vars.noai):
|
||||||
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing transformers, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM
|
from transformers import StoppingCriteria, GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoForCausalLM, GPTNeoModel, AutoModelForCausalLM, AutoTokenizer
|
||||||
try:
|
try:
|
||||||
from transformers import GPTJModel
|
from transformers import GPTJModel
|
||||||
except:
|
except:
|
||||||
@ -885,21 +885,30 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
|||||||
# Download model from Huggingface if it does not exist, otherwise load locally
|
# Download model from Huggingface if it does not exist, otherwise load locally
|
||||||
if(os.path.isdir(vars.custmodpth)):
|
if(os.path.isdir(vars.custmodpth)):
|
||||||
with(maybe_use_float16()):
|
with(maybe_use_float16()):
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
|
except ValueError as e:
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem)
|
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem)
|
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem)
|
||||||
elif(os.path.isdir(vars.model.replace('/', '_'))):
|
elif(os.path.isdir(vars.model.replace('/', '_'))):
|
||||||
with(maybe_use_float16()):
|
with(maybe_use_float16()):
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/")
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
|
except ValueError as e:
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
|
model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
model = GPTNeoForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
|
model = GPTNeoForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
|
||||||
else:
|
else:
|
||||||
print("Model does not exist locally, attempting to download from Huggingface...")
|
print("Model does not exist locally, attempting to download from Huggingface...")
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
|
except ValueError as e:
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/")
|
||||||
with(maybe_use_float16()):
|
with(maybe_use_float16()):
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
|
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/")
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user