Potential Colab URL fix

This commit is contained in:
ebolam
2022-10-28 14:19:50 -04:00
parent 46a870ef87
commit c3ccdb0b50
3 changed files with 12 additions and 8 deletions

View File

@@ -91,7 +91,6 @@ if lupa.LUA_VERSION[:2] != (5, 4):
logger.error(f"Please install lupa==1.10. You have lupa {lupa.__version__}.")
patch_causallm_patched = False
cloudflare = ""
# Make sure tqdm progress bars display properly in Colab
from tqdm.auto import tqdm
@@ -3094,8 +3093,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX")
and koboldai_vars.use_colab_tpu,
socketio_queue=koboldai_settings.queue,
initial_load=initial_load, logger=logger, cloudflare=cloudflare,
**koboldai_vars.modelconfig)
initial_load=initial_load, logger=logger, **koboldai_vars.modelconfig)
#tpool.execute(tpu_mtj_backend.load_model, koboldai_vars.custmodpth, hf_checkpoint=koboldai_vars.model not in ("TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX") and koboldai_vars.use_colab_tpu, **koboldai_vars.modelconfig)
koboldai_vars.modeldim = int(tpu_mtj_backend.params.get("d_embed", tpu_mtj_backend.params["d_model"]))
tokenizer = tpu_mtj_backend.tokenizer

View File

@@ -987,8 +987,14 @@ class user_settings(settings):
process_variable_changes(self.socketio, self.__class__.__name__.replace("_settings", ""), name, value, old_value)
class system_settings(settings):
local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai', 'comregex_ui', 'sp', '_horde_pid', 'inference_config', 'image_pipeline', 'summarizer', 'summary_tokenizer']
no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold', 'lua_koboldcore', 'sp', 'sp_length', '_horde_pid', 'horde_share', 'aibusy', 'serverstarted', 'inference_config', 'image_pipeline', 'summarizer', 'summary_tokenizer', 'use_colab_tpu', 'noai', 'disable_set_aibusy']
local_only_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold',
'lua_koboldcore', 'regex_sl', 'acregex_ai', 'acregex_ui', 'comregex_ai',
'comregex_ui', 'sp', '_horde_pid', 'inference_config', 'image_pipeline',
'summarizer', 'summary_tokenizer']
no_save_variables = ['socketio', 'lua_state', 'lua_logname', 'lua_koboldbridge', 'lua_kobold',
'lua_koboldcore', 'sp', 'sp_length', '_horde_pid', 'horde_share', 'aibusy',
'serverstarted', 'inference_config', 'image_pipeline', 'summarizer',
'summary_tokenizer', 'use_colab_tpu', 'noai', 'disable_set_aibusy', 'cloudflare_link']
settings_name = "system"
def __init__(self, socketio, koboldai_var):
self.socketio = socketio

View File

@@ -1080,7 +1080,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
koboldai_vars.status_message = ""
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, cloudflare="", **kwargs) -> None:
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpoint=False, socketio_queue=None, initial_load=False, logger=None, **kwargs) -> None:
global thread_resources_env, seq, tokenizer, network, params, pad_token_id
if "pad_token_id" in kwargs:
@@ -1251,8 +1251,8 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
thread_resources_env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')), ())
maps.thread_resources.env = thread_resources_env
if initial_load:
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {cloudflare}")
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {cloudflare}/new_ui")
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 1: {koboldai_vars.cloudflare_link}")
logger.message(f"KoboldAI has finished loading and is available at the following link for UI 2: {koboldai_vars.cloudflare_link}/new_ui")
global shard_xmap, batch_xmap