Fixed model downloading problem where models were downloaded multiple times

This commit is contained in:
ebolam
2022-02-06 13:42:46 -05:00
parent 8195360fcc
commit 9e17ea9636

View File

@ -970,43 +970,50 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
lowmem = {} lowmem = {}
# Download model from Huggingface if it does not exist, otherwise load locally # Download model from Huggingface if it does not exist, otherwise load locally
#If we specify a model and it's in the root directory, we need to move it to the models directory (legacy folder structure to new)
if os.path.isdir(vars.model.replace('/', '_')):
import shutil
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
if(os.path.isdir(vars.custmodpth)): if(os.path.isdir(vars.custmodpth)):
with(maybe_use_float16()): with(maybe_use_float16()):
try: try:
tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache/") tokenizer = AutoTokenizer.from_pretrained(vars.custmodpth, cache_dir="cache")
except ValueError as e: except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained(vars.custmodpth, cache_dir="cache")
try: try:
model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem) model = AutoModelForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
except ValueError as e: except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache/", **lowmem) model = GPTNeoForCausalLM.from_pretrained(vars.custmodpth, cache_dir="cache", **lowmem)
elif(os.path.isdir(vars.model.replace('/', '_'))): elif(os.path.isdir("models/{}".format(vars.model.replace('/', '_')))):
with(maybe_use_float16()): with(maybe_use_float16()):
try: try:
tokenizer = AutoTokenizer.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/") tokenizer = AutoTokenizer.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
except ValueError as e: except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache")
try: try:
model = AutoModelForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem) model = AutoModelForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
except ValueError as e: except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem) model = GPTNeoForCausalLM.from_pretrained("models/{}".format(vars.model.replace('/', '_')), cache_dir="cache", **lowmem)
else: else:
try: try:
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache/") tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache")
except ValueError as e: except ValueError as e:
tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache/") tokenizer = GPT2TokenizerFast.from_pretrained(vars.model, cache_dir="cache")
with(maybe_use_float16()): with(maybe_use_float16()):
try: try:
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem) model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
except ValueError as e: except ValueError as e:
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem) model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
if not args.colab: if not args.colab:
model = model.half() print("Trying to save model")
import shutil import shutil
shutil.rmtree("cache/") shutil.rmtree("cache/")
model.save_pretrained("/models/{}".format(vars.model.replace('/', '_'))) model = model.half()
tokenizer.save_pretrained("/models/{}".format(vars.model.replace('/', '_'))) model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
tokenizer.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
print("Saved")
if(vars.hascuda): if(vars.hascuda):
if(vars.usegpu): if(vars.usegpu):