mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #128 from VE-FORBRYDERNE/opt
OPT breakmodel and TPU support
This commit is contained in:
102
aiserver.py
102
aiserver.py
@@ -240,6 +240,7 @@ class vars:
|
||||
# badwords = [] # Array of str/chr values that should be removed from output
|
||||
badwordsids = [[13460], [6880], [50256], [42496], [4613], [17414], [22039], [16410], [27], [29], [38430], [37922], [15913], [24618], [28725], [58], [47175], [36937], [26700], [12878], [16471], [37981], [5218], [29795], [13412], [45160], [3693], [49778], [4211], [20598], [36475], [33409], [44167], [32406], [29847], [29342], [42669], [685], [25787], [7359], [3784], [5320], [33994], [33490], [34516], [43734], [17635], [24293], [9959], [23785], [21737], [28401], [18161], [26358], [32509], [1279], [38155], [18189], [26894], [6927], [14610], [23834], [11037], [14631], [26933], [46904], [22330], [25915], [47934], [38214], [1875], [14692], [41832], [13163], [25970], [29565], [44926], [19841], [37250], [49029], [9609], [44438], [16791], [17816], [30109], [41888], [47527], [42924], [23984], [49074], [33717], [31161], [49082], [30138], [31175], [12240], [14804], [7131], [26076], [33250], [3556], [38381], [36338], [32756], [46581], [17912], [49146]] # Tokenized array of badwords used to prevent AI artifacting
|
||||
badwordsids_neox = [[0], [1], [44162], [9502], [12520], [31841], [36320], [49824], [34417], [6038], [34494], [24815], [26635], [24345], [3455], [28905], [44270], [17278], [32666], [46880], [7086], [43189], [37322], [17778], [20879], [49821], [3138], [14490], [4681], [21391], [26786], [43134], [9336], [683], [48074], [41256], [19181], [29650], [28532], [36487], [45114], [46275], [16445], [15104], [11337], [1168], [5647], [29], [27482], [44965], [43782], [31011], [42944], [47389], [6334], [17548], [38329], [32044], [35487], [2239], [34761], [7444], [1084], [12399], [18990], [17636], [39083], [1184], [35830], [28365], [16731], [43467], [47744], [1138], [16079], [40116], [45564], [18297], [42368], [5456], [18022], [42696], [34476], [23505], [23741], [39334], [37944], [45382], [38709], [33440], [26077], [43600], [34418], [36033], [6660], [48167], [48471], [15775], [19884], [41533], [1008], [31053], [36692], [46576], [20095], [20629], [31759], [46410], [41000], [13488], [30952], [39258], [16160], [27655], [22367], [42767], [43736], [49694], [13811], [12004], [46768], [6257], [37471], [5264], [44153], [33805], [20977], [21083], [25416], [14277], [31096], [42041], [18331], [33376], [22372], [46294], [28379], [38475], [1656], [5204], [27075], [50001], [16616], [11396], [7748], [48744], [35402], [28120], [41512], [4207], [43144], [14767], [15640], [16595], [41305], [44479], [38958], [18474], [22734], [30522], [46267], [60], [13976], [31830], [48701], [39822], [9014], [21966], [31422], [28052], [34607], [2479], [3851], [32214], [44082], [45507], [3001], [34368], [34758], [13380], [38363], [4299], [46802], [30996], [12630], [49236], [7082], [8795], [5218], [44740], [9686], [9983], [45301], [27114], [40125], [1570], [26997], [544], [5290], [49193], [23781], [14193], [40000], [2947], [43781], [9102], [48064], [42274], [18772], [49384], [9884], [45635], [43521], [31258], [32056], [47686], [21760], [13143], [10148], [26119], [44308], [31379], [36399], [23983], [46694], [36134], [8562], [12977], [35117], [28591], [49021], [47093], [28653], [29013], [46468], [8605], [7254], [25896], [5032], [8168], [36893], [38270], [20499], [27501], [34419], [29547], [28571], [36586], [20871], [30537], [26842], [21375], [31148], [27618], [33094], [3291], [31789], [28391], [870], [9793], [41361], [47916], [27468], [43856], [8850], [35237], [15707], [47552], [2730], [41449], [45488], [3073], [49806], [21938], [24430], [22747], [20924], [46145], [20481], [20197], [8239], [28231], [17987], [42804], [47269], [29972], [49884], [21382], [46295], [36676], [34616], [3921], [26991], [27720], [46265], [654], [9855], [40354], [5291], [34904], [44342], [2470], [14598], [880], [19282], [2498], [24237], [21431], [16369], [8994], [44524], [45662], [13663], [37077], [1447], [37786], [30863], [42854], [1019], [20322], [4398], [12159], [44072], [48664], [31547], [18736], [9259], [31], [16354], [21810], [4357], [37982], [5064], [2033], [32871], [47446], [62], [22158], [37387], [8743], [47007], [17981], [11049], [4622], [37916], [36786], [35138], [29925], [14157], [18095], [27829], [1181], [22226], [5709], [4725], [30189], [37014], [1254], [11380], [42989], [696], [24576], [39487], [30119], [1092], [8088], [2194], [9899], [14412], [21828], [3725], [13544], [5180], [44679], [34398], [3891], [28739], [14219], [37594], [49550], [11326], [6904], [17266], [5749], [10174], [23405], [9955], [38271], [41018], [13011], [48392], [36784], [24254], [21687], [23734], [5413], [41447], [45472], [10122], [17555], [15830], [47384], [12084], [31350], [47940], [11661], [27988], [45443], [905], [49651], [16614], [34993], [6781], [30803], [35869], [8001], [41604], [28118], [46462], [46762], [16262], [17281], [5774], [10943], [5013], [18257], [6750], [4713], [3951], [11899], [38791], [16943], [37596], [9318], [18413], [40473], [13208], [16375]]
|
||||
badwordsids_opt = [[44717], [46613], [48513], [49923], [50185], [48755], [8488], [43303], [49659], [48601], [49817], [45405], [48742], [49925], [47720], [11227], [48937], [48784], [50017], [42248], [49310], [48082], [49895], [50025], [49092], [49007], [8061], [44226], [0], [742], [28578], [15698], [49784], [46679], [39365], [49281], [49609], [48081], [48906], [46161], [48554], [49670], [48677], [49721], [49632], [48610], [48462], [47457], [10975], [46077], [28696], [48709], [43839], [49798], [49154], [48203], [49625], [48395], [50155], [47161], [49095], [48833], [49420], [49666], [48443], [22176], [49242], [48651], [49138], [49750], [40389], [48021], [21838], [49070], [45333], [40862], [1], [49915], [33525], [49858], [50254], [44403], [48992], [48872], [46117], [49853], [47567], [50206], [41552], [50068], [48999], [49703], [49940], [49329], [47620], [49868], [49962], [2], [44082], [50236], [31274], [50260], [47052], [42645], [49177], [17523], [48691], [49900], [49069], [49358], [48794], [47529], [46479], [48457], [646], [49910], [48077], [48935], [46386], [48902], [49151], [48759], [49803], [45587], [48392], [47789], [48654], [49836], [49230], [48188], [50264], [46844], [44690], [48505], [50161], [27779], [49995], [41833], [50154], [49097], [48520], [50018], [8174], [50084], [49366], [49526], [50193], [7479], [49982], [3]]
|
||||
deletewi = None # Temporary storage for UID to delete
|
||||
wirmvwhtsp = False # Whether to remove leading whitespace from WI entries
|
||||
widepth = 3 # How many historical actions to scan for WI hits
|
||||
@@ -274,7 +275,7 @@ class vars:
|
||||
recentrngm = None # If a new random game was recently generated without Submitting after, this is the memory used (as a string), otherwise this is None
|
||||
useprompt = False # Whether to send the full prompt with every submit action
|
||||
breakmodel = False # For GPU users, whether to use both system RAM and VRAM to conserve VRAM while offering speedup compared to CPU-only
|
||||
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM only, currently)
|
||||
bmsupported = False # Whether the breakmodel option is supported (GPT-Neo/GPT-J/XGLM/OPT only, currently)
|
||||
nobreakmodel = False # Something specifically requested Breakmodel to be disabled (For example a models config)
|
||||
smandelete = False # Whether stories can be deleted from inside the browser
|
||||
smanrename = False # Whether stories can be renamed from inside the browser
|
||||
@@ -391,7 +392,7 @@ def device_list(n_layers, primary=None, selected=None):
|
||||
def device_config(config):
|
||||
global breakmodel, generator
|
||||
import breakmodel
|
||||
n_layers = config.num_layers if hasattr(config, "num_layers") else config.n_layer
|
||||
n_layers = utils.num_layers(config)
|
||||
if(args.breakmodel_gpulayers is not None):
|
||||
try:
|
||||
breakmodel.gpu_blocks = list(map(int, args.breakmodel_gpulayers.split(',')))
|
||||
@@ -464,7 +465,7 @@ def device_config(config):
|
||||
# If all layers are on the same device, use the old GPU generation mode
|
||||
while(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] == 0):
|
||||
breakmodel.gpu_blocks.pop()
|
||||
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, config.num_layers if hasattr(config, "num_layers") else config.n_layer)):
|
||||
if(len(breakmodel.gpu_blocks) and breakmodel.gpu_blocks[-1] in (-1, utils.num_layers(config))):
|
||||
vars.breakmodel = False
|
||||
vars.usegpu = True
|
||||
vars.gpu_device = len(breakmodel.gpu_blocks)-1
|
||||
@@ -496,22 +497,33 @@ def move_model_to_devices(model):
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
if(hasattr(model.transformer, 'wpe')):
|
||||
model.transformer.wpe.to(breakmodel.primary_device)
|
||||
else:
|
||||
elif(not hasattr(model.model, "decoder")):
|
||||
model.model.embed_tokens.to(breakmodel.primary_device)
|
||||
model.model.layer_norm.to(breakmodel.primary_device)
|
||||
model.lm_head.to(breakmodel.primary_device)
|
||||
model.model.embed_positions.to(breakmodel.primary_device)
|
||||
else:
|
||||
model.model.decoder.embed_tokens.to(breakmodel.primary_device)
|
||||
if(model.model.decoder.project_in is not None):
|
||||
model.model.decoder.project_in.to(breakmodel.primary_device)
|
||||
if(model.model.decoder.project_out is not None):
|
||||
model.model.decoder.project_out.to(breakmodel.primary_device)
|
||||
model.model.decoder.embed_positions.to(breakmodel.primary_device)
|
||||
gc.collect()
|
||||
GPTNeoModel.forward = breakmodel.new_forward_neo
|
||||
if("GPTJModel" in globals()):
|
||||
GPTJModel.forward = breakmodel.new_forward_neo # type: ignore
|
||||
if("XGLMModel" in globals()):
|
||||
XGLMModel.forward = breakmodel.new_forward_xglm # type: ignore
|
||||
if("OPTDecoder" in globals()):
|
||||
OPTDecoder.forward = breakmodel.new_forward_opt # type: ignore
|
||||
generator = model.generate
|
||||
if(hasattr(model, "transformer")):
|
||||
breakmodel.move_hidden_layers(model.transformer)
|
||||
else:
|
||||
elif(not hasattr(model.model, "decoder")):
|
||||
breakmodel.move_hidden_layers(model.model, model.model.layers)
|
||||
else:
|
||||
breakmodel.move_hidden_layers(model.model.decoder, model.model.decoder.layers)
|
||||
|
||||
#==================================================================#
|
||||
# Allow the models to override some settings
|
||||
@@ -774,7 +786,7 @@ def spRequest(filename):
|
||||
tensor = tensor.reshape(
|
||||
tpu_mtj_backend.params["cores_per_replica"],
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
|
||||
)
|
||||
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
|
||||
else:
|
||||
@@ -908,12 +920,15 @@ if(vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMe
|
||||
print("WARNING: No model type detected, assuming Neo (If this is a GPT2 model use the other menu option or --model GPT2Custom)")
|
||||
vars.model_type = "gpt_neo"
|
||||
|
||||
if(vars.model_type == "opt"):
|
||||
vars.badwordsids = vars.badwordsids_opt
|
||||
|
||||
if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm") and not vars.nobreakmodel
|
||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj", "xglm", "opt") and not vars.nobreakmodel
|
||||
if(args.breakmodel is not None and args.breakmodel):
|
||||
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --breakmodel_gpulayers is used (see --help for details).", file=sys.stderr)
|
||||
if(args.breakmodel_layers is not None):
|
||||
@@ -1125,17 +1140,29 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
globals()[m] = getattr(__import__("transformers"), m)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
from transformers.models.opt.modeling_opt import OPTDecoder
|
||||
except:
|
||||
pass
|
||||
import transformers.generation_utils
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import modeling_utils
|
||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
@classmethod
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
utils.num_shards = None
|
||||
utils.current_shard = 0
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
|
||||
# Lazy loader
|
||||
import torch_lazy_loader
|
||||
@@ -1172,7 +1199,9 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
last_storage_key = None
|
||||
f = None
|
||||
current_offset = 0
|
||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors"):
|
||||
if utils.num_shards is not None:
|
||||
utils.current_shard += 1
|
||||
for key in tqdm(sorted(device_map.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset)), desc="Loading model tensors" + (f" (shard {utils.current_shard}/{utils.num_shards})" if utils.num_shards is not None else "")):
|
||||
storage_key = model_dict[key].key
|
||||
if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
|
||||
last_storage_key = storage_key
|
||||
@@ -1245,8 +1274,10 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
input_ids.clamp_(max=self.config.vocab_size-1)
|
||||
if(hasattr(self, "transformer")):
|
||||
inputs_embeds = self.transformer.wte(input_ids)
|
||||
else:
|
||||
elif(not hasattr(self.model, "decoder")):
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
inputs_embeds = self.model.decoder.embed_tokens(input_ids)
|
||||
if(vars.sp is not None):
|
||||
vars.sp = vars.sp.to(inputs_embeds.dtype).to(inputs_embeds.device)
|
||||
inputs_embeds = torch.where(
|
||||
@@ -1254,20 +1285,39 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
vars.sp[shifted_input_ids.clamp(min=0)],
|
||||
inputs_embeds,
|
||||
)
|
||||
if(not hasattr(self, "transformer")):
|
||||
if(hasattr(self, "model") and hasattr(self.model, "embed_scale")):
|
||||
inputs_embeds *= self.model.embed_scale
|
||||
kwargs['inputs_embeds'] = inputs_embeds
|
||||
return old_forward(self, *args, **kwargs)
|
||||
cls.forward = new_causallm_forward
|
||||
for cls in (GPT2LMHeadModel, GPTNeoForCausalLM):
|
||||
patch_causallm(cls)
|
||||
for c in ("GPTJForCausalLM", "XGLMForCausalLM"):
|
||||
for c in ("GPTJForCausalLM", "XGLMForCausalLM", "OPTForCausalLM"):
|
||||
try:
|
||||
patch_causallm(getattr(__import__("transformers"), c))
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
# Fix a bug in OPTForCausalLM where self.lm_head is the wrong size
|
||||
if(transformers_version == "4.19.0"):
|
||||
try:
|
||||
from transformers import OPTForCausalLM, OPTModel
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
# This is the same as the original __init__ but with
|
||||
# config.hidden_size
|
||||
# replaced with
|
||||
# config.word_embed_proj_dim
|
||||
def new_init(self, config):
|
||||
super(OPTForCausalLM, self).__init__(config)
|
||||
self.model = OPTModel(config)
|
||||
self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
OPTForCausalLM.__init__ = new_init
|
||||
|
||||
|
||||
# Patch transformers to use our custom logit warpers
|
||||
from transformers import LogitsProcessorList, LogitsWarper, LogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, RepetitionPenaltyLogitsProcessor
|
||||
from warpers import AdvancedRepetitionPenaltyLogitsProcessor, TailFreeLogitsWarper, TypicalLogitsWarper
|
||||
@@ -1422,12 +1472,18 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
|
||||
def get_hidden_size_from_model(model):
|
||||
try:
|
||||
return int(model.transformer.hidden_size)
|
||||
return int(model.model.decoder.project_in.in_features)
|
||||
except:
|
||||
try:
|
||||
return int(model.transformer.embed_dim)
|
||||
return int(model.model.decoder.embed_tokens.out_features)
|
||||
except:
|
||||
return int(model.lm_head.in_features)
|
||||
try:
|
||||
return int(model.transformer.hidden_size)
|
||||
except:
|
||||
try:
|
||||
return int(model.transformer.embed_dim)
|
||||
except:
|
||||
return int(model.lm_head.in_features)
|
||||
|
||||
def maybe_low_cpu_mem_usage() -> Dict[str, Any]:
|
||||
if(packaging.version.parse(transformers_version) < packaging.version.parse("4.11.0")):
|
||||
@@ -1482,7 +1538,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
import shutil
|
||||
shutil.move(vars.model.replace('/', '_'), "models/{}".format(vars.model.replace('/', '_')))
|
||||
print("\n", flush=True)
|
||||
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layer) if vars.lazy_load else None, dematerialized_modules=True):
|
||||
with maybe_use_float16(), torch_lazy_loader.use_lazy_torch_load(enable=vars.lazy_load, callback=get_lazy_load_callback(utils.num_layers(model_config)) if vars.lazy_load else None, dematerialized_modules=True):
|
||||
if(vars.lazy_load): # torch_lazy_loader.py and low_cpu_mem_usage can't be used at the same time
|
||||
lowmem = {}
|
||||
if(os.path.isdir(vars.custmodpth)):
|
||||
@@ -1562,13 +1618,21 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", revision=vars.revision, cache_dir="cache")
|
||||
else:
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import modeling_utils
|
||||
old_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
@classmethod
|
||||
def new_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
utils.num_shards = None
|
||||
utils.current_shard = 0
|
||||
if not args.no_aria2:
|
||||
utils.aria2_hook(pretrained_model_name_or_path, **kwargs)
|
||||
return old_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
PreTrainedModel.from_pretrained = new_from_pretrained
|
||||
old_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
|
||||
def new_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs):
|
||||
utils.num_shards = utils.get_num_shards(index_filename)
|
||||
return old_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
|
||||
modeling_utils.get_checkpoint_shard_files = new_get_checkpoint_shard_files
|
||||
|
||||
def tpumtjgetsofttokens():
|
||||
soft_tokens = None
|
||||
@@ -1576,14 +1640,14 @@ else:
|
||||
global np
|
||||
if 'np' not in globals():
|
||||
import numpy as np
|
||||
tensor = np.zeros((1, tpu_mtj_backend.params["d_model"]), dtype=np.float32)
|
||||
tensor = np.zeros((1, tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"])), dtype=np.float32)
|
||||
rows = tensor.shape[0]
|
||||
padding_amount = tpu_mtj_backend.params["seq"] - (tpu_mtj_backend.params["seq"] % -tpu_mtj_backend.params["cores_per_replica"]) - rows
|
||||
tensor = np.pad(tensor, ((0, padding_amount), (0, 0)))
|
||||
tensor = tensor.reshape(
|
||||
tpu_mtj_backend.params["cores_per_replica"],
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]),
|
||||
)
|
||||
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
|
||||
soft_tokens = np.arange(
|
||||
@@ -1672,7 +1736,7 @@ else:
|
||||
if vars.model in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and (not vars.custmodpth or not os.path.isdir(vars.custmodpth)):
|
||||
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
|
||||
import tpu_mtj_backend
|
||||
if(vars.model == "TPUMeshTransformerGPTNeoX"):
|
||||
if(vars.model == "TPUMeshTransformerGPTNeoX" or vars.model_type == "opt"):
|
||||
tpu_mtj_backend.pad_token_id = 1
|
||||
tpu_mtj_backend.vars = vars
|
||||
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||
@@ -1684,7 +1748,7 @@ else:
|
||||
loadmodelsettings()
|
||||
loadsettings()
|
||||
tpu_mtj_backend.load_model(vars.custmodpth, hf_checkpoint=vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and vars.use_colab_tpu, **vars.modelconfig)
|
||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||
vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]))
|
||||
tokenizer = tpu_mtj_backend.tokenizer
|
||||
else:
|
||||
loadsettings()
|
||||
|
Reference in New Issue
Block a user