mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Automatically calculate badwords and pad_token_id
This commit is contained in:
@ -1018,7 +1018,12 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||
|
||||
|
||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, **kwargs) -> None:
|
||||
global thread_resources_env, seq, tokenizer, network, params
|
||||
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
|
||||
|
||||
if "pad_token_id" in kwargs:
|
||||
pad_token_id = kwargs["pad_token_id"]
|
||||
elif "eos_token_id" in kwargs:
|
||||
pad_token_id = kwargs["eos_token_id"]
|
||||
|
||||
if not hasattr(vars, "sampler_order") or not vars.sampler_order:
|
||||
vars.sampler_order = utils.default_sampler_order.copy()
|
||||
|
Reference in New Issue
Block a user