Merge pull request #82 from VE-FORBRYDERNE/tpu-config
Allow TPU models to specify settings/config in config.json
This commit is contained in:
commit
8e9d9faa97
21
aiserver.py
21
aiserver.py
|
@ -134,6 +134,7 @@ class vars:
|
||||||
wifolders_d = {} # Dictionary of World Info folder UID-info pairs
|
wifolders_d = {} # Dictionary of World Info folder UID-info pairs
|
||||||
wifolders_l = [] # List of World Info folder UIDs
|
wifolders_l = [] # List of World Info folder UIDs
|
||||||
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
|
wifolders_u = {} # Dictionary of pairs of folder UID - list of WI UID
|
||||||
|
modelconfig = {} # Raw contents of the model's config.json, or empty dictionary if none found
|
||||||
lua_state = None # Lua state of the Lua scripting system
|
lua_state = None # Lua state of the Lua scripting system
|
||||||
lua_koboldbridge = None # `koboldbridge` from bridge.lua
|
lua_koboldbridge = None # `koboldbridge` from bridge.lua
|
||||||
lua_kobold = None # `kobold` from` bridge.lua
|
lua_kobold = None # `kobold` from` bridge.lua
|
||||||
|
@ -417,10 +418,16 @@ def loadmodelsettings():
|
||||||
js = json.loads(model_js_config)
|
js = json.loads(model_js_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
try:
|
try:
|
||||||
model_js_config = open(vars.custmodpth + "/config.json", "r")
|
try:
|
||||||
|
model_js_config = open(vars.custmodpth + "/config.json", "r")
|
||||||
|
except Exception as e:
|
||||||
|
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
|
||||||
|
js = json.load(model_js_config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_js_config = open(vars.custmodpth.replace('/', '_') + "/config.json", "r")
|
js = {}
|
||||||
js = json.load(model_js_config)
|
if vars.model_type == "xglm" or js.get("modelcompat", "j") == "fairseq_lm":
|
||||||
|
vars.newlinemode = "s" # Default to </s> newline mode if using XGLM
|
||||||
|
vars.modelconfig = js
|
||||||
if("badwordsids" in js):
|
if("badwordsids" in js):
|
||||||
vars.badwordsids = js["badwordsids"]
|
vars.badwordsids = js["badwordsids"]
|
||||||
if("nobreakmodel" in js):
|
if("nobreakmodel" in js):
|
||||||
|
@ -1192,7 +1199,8 @@ else:
|
||||||
# Load the TPU backend if requested
|
# Load the TPU backend if requested
|
||||||
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
elif(vars.model == "TPUMeshTransformerGPTJ"):
|
||||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
if not vars.custmodpth or not os.path.isdir(vars.custmodpth):
|
||||||
|
raise FileNotFoundError(f"The specified model path {repr(vars.custmodpth)} is not the path to a valid folder")
|
||||||
import tpu_mtj_backend
|
import tpu_mtj_backend
|
||||||
tpu_mtj_backend.vars = vars
|
tpu_mtj_backend.vars = vars
|
||||||
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||||
|
@ -1200,7 +1208,8 @@ else:
|
||||||
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
|
tpu_mtj_backend.compiling_callback = tpumtjgenerate_compiling_callback
|
||||||
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
|
tpu_mtj_backend.stopped_compiling_callback = tpumtjgenerate_stopped_compiling_callback
|
||||||
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
|
tpu_mtj_backend.settings_callback = tpumtjgenerate_settings_callback
|
||||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
loadmodelsettings()
|
||||||
|
tpu_mtj_backend.load_model(vars.custmodpth, **vars.modelconfig)
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||||
tokenizer = tpu_mtj_backend.tokenizer
|
tokenizer = tpu_mtj_backend.tokenizer
|
||||||
|
@ -4921,7 +4930,7 @@ if(vars.model in ("TPUMeshTransformerGPTJ",)):
|
||||||
def send_debug():
|
def send_debug():
|
||||||
if vars.debug:
|
if vars.debug:
|
||||||
debug_info = ""
|
debug_info = ""
|
||||||
for variable in [["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata], ["Newline Mode", vars.newlinemode]]:
|
for variable in [["Newline Mode", vars.newlinemode], ["Action Length", len(vars.actions)], ["Actions Metadata Length", len(vars.actions_metadata)], ["Actions Metadata", vars.actions_metadata]]:
|
||||||
debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1])
|
debug_info = "{}{}: {}\n".format(debug_info, variable[0], variable[1])
|
||||||
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)
|
emit('from_server', {'cmd': 'debug_info', 'data': debug_info}, broadcast=True)
|
||||||
|
|
||||||
|
|
|
@ -791,12 +791,24 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
||||||
"pe_rotary_dims": 64,
|
"pe_rotary_dims": 64,
|
||||||
"seq": 2048,
|
"seq": 2048,
|
||||||
"cores_per_replica": 8,
|
"cores_per_replica": 8,
|
||||||
|
"tokenizer_class": "GPT2TokenizerFast",
|
||||||
|
"tokenizer": "gpt2",
|
||||||
}
|
}
|
||||||
params = kwargs
|
params = kwargs
|
||||||
|
if "compat" in params:
|
||||||
|
default_params["compat"] = params["compat"]
|
||||||
|
if default_params["compat"] == "fairseq_lm":
|
||||||
|
default_params["tokenizer"] = "KoboldAI/fairseq-dense-125M"
|
||||||
for param in default_params:
|
for param in default_params:
|
||||||
if param not in params:
|
if param not in params:
|
||||||
params[param] = default_params[param]
|
params[param] = default_params[param]
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
if not isinstance(params["tokenizer_class"], str) or not any(params["tokenizer_class"].endswith(s) for s in ("Tokenizer", "TokenizerFast")):
|
||||||
|
raise ValueError("`tokenizer_class` must be a string ending in 'Tokenizer' or 'TokenizerFast'")
|
||||||
|
tokenizer_class = getattr(__import__("transformers"), params["tokenizer_class"])
|
||||||
|
tokenizer = tokenizer_class.from_pretrained(params["tokenizer"])
|
||||||
|
|
||||||
# Disable JAX warnings about these two functions having been renamed
|
# Disable JAX warnings about these two functions having been renamed
|
||||||
jax.host_count = jax.process_count
|
jax.host_count = jax.process_count
|
||||||
jax.host_id = jax.process_index
|
jax.host_id = jax.process_index
|
||||||
|
@ -819,7 +831,6 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
||||||
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
devices = np.array(jax.devices()[:cores_per_replica]).reshape(mesh_shape)
|
||||||
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
|
||||||
|
|
||||||
global shard_xmap, batch_xmap
|
global shard_xmap, batch_xmap
|
||||||
shard_xmap = __shard_xmap()
|
shard_xmap = __shard_xmap()
|
||||||
|
|
Loading…
Reference in New Issue