commit
59e3a40496
121
aiserver.py
121
aiserver.py
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/python3
|
#!/usr/bin/python3
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
# KoboldAI
|
# KoboldAI
|
||||||
# Version: 1.19.0
|
# Version: 1.19.1
|
||||||
# By: The KoboldAI Community
|
# By: The KoboldAI Community
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
|
||||||
|
@ -377,6 +377,7 @@ class vars:
|
||||||
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
|
comregex_ai = re.compile(r'(?:\n<\|(?:.|\n)*?\|>(?=\n|$))|(?:<\|(?:.|\n)*?\|>\n?)') # Pattern for matching comments to remove them before sending them to the AI
|
||||||
comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor
|
comregex_ui = re.compile(r'(<\|(?:.|\n)*?\|>)') # Pattern for matching comments in the editor
|
||||||
sampler_order = utils.default_sampler_order.copy()
|
sampler_order = utils.default_sampler_order.copy()
|
||||||
|
rng_states = {} # Used by the POST /generate endpoint to store sampler RNG states
|
||||||
chatmode = False
|
chatmode = False
|
||||||
chatname = "You"
|
chatname = "You"
|
||||||
adventure = False
|
adventure = False
|
||||||
|
@ -630,7 +631,7 @@ tags = [
|
||||||
api_version = None # This gets set automatically so don't change this value
|
api_version = None # This gets set automatically so don't change this value
|
||||||
|
|
||||||
api_v1 = KoboldAPISpec(
|
api_v1 = KoboldAPISpec(
|
||||||
version="1.1.4",
|
version="1.2.0",
|
||||||
prefixes=["/api/v1", "/api/latest"],
|
prefixes=["/api/v1", "/api/latest"],
|
||||||
tags=tags,
|
tags=tags,
|
||||||
)
|
)
|
||||||
|
@ -2963,7 +2964,7 @@ def load_lua_scripts():
|
||||||
if(vars.serverstarted):
|
if(vars.serverstarted):
|
||||||
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True)
|
emit('from_server', {'cmd': 'errmsg', 'data': 'Lua script error; please check console.'}, broadcast=True)
|
||||||
sendUSStatItems()
|
sendUSStatItems()
|
||||||
logger.debug('LUA ERROR: ' + str(e).replace("\033", ""))
|
logger.error('LUA ERROR: ' + str(e).replace("\033", ""))
|
||||||
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
|
logger.warning("Lua engine stopped; please open 'Userscripts' and press Load to reinitialize scripts.")
|
||||||
if(vars.serverstarted):
|
if(vars.serverstarted):
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
|
@ -7450,6 +7451,13 @@ def story_load_validator(name: str):
|
||||||
raise ValidationError("Must be a valid story name.")
|
raise ValidationError("Must be a valid story name.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def permutation_validator(lst: list):
|
||||||
|
if any(not isinstance(e, int) for e in lst):
|
||||||
|
return
|
||||||
|
if min(lst) != 0 or max(lst) != len(lst) - 1 or len(set(lst)) != len(lst):
|
||||||
|
raise ValidationError("Must be a permutation of the first N non-negative integers, where N is the length of this array")
|
||||||
|
return True
|
||||||
|
|
||||||
class GenerationInputSchema(SamplerSettingsSchema):
|
class GenerationInputSchema(SamplerSettingsSchema):
|
||||||
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
|
prompt: str = fields.String(required=True, metadata={"description": "This is the submission."})
|
||||||
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
|
use_memory: bool = fields.Boolean(load_default=False, metadata={"description": "Whether or not to use the memory from the KoboldAI GUI when generating text."})
|
||||||
|
@ -7469,6 +7477,9 @@ class GenerationInputSchema(SamplerSettingsSchema):
|
||||||
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, all input formatting options default to `false` instead of the value in the KoboldAI GUI"})
|
disable_input_formatting: bool = fields.Boolean(load_default=True, metadata={"description": "When enabled, all input formatting options default to `false` instead of the value in the KoboldAI GUI"})
|
||||||
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action.\n\nIf `disable_input_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."})
|
frmtadsnsp: Optional[bool] = fields.Boolean(metadata={"description": "Input formatting option. When enabled, adds a leading space to your input if there is no trailing whitespace at the end of the previous action.\n\nIf `disable_input_formatting` is `true`, this defaults to `false` instead of the value in the KoboldAI GUI."})
|
||||||
quiet: Optional[bool] = fields.Boolean(metadata={"description": "When enabled, Generated output will not be displayed in the console."})
|
quiet: Optional[bool] = fields.Boolean(metadata={"description": "When enabled, Generated output will not be displayed in the console."})
|
||||||
|
sampler_order: Optional[List[int]] = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], metadata={"description": "Sampler order to be used. If N is the length of this array, then N must be greater than or equal to 6 and the array must be a permutation of the first N non-negative integers."})
|
||||||
|
sampler_seed: Optional[int] = fields.Integer(validate=validate.Range(min=0, max=2**64 - 1), metadata={"description": "RNG seed to use for sampling. If not specified, the global RNG will be used."})
|
||||||
|
sampler_full_determinism: Optional[bool] = fields.Boolean(metadata={"description": "If enabled, the generated text will always be the same as long as you use the same RNG seed, input and settings. If disabled, only the *sequence* of generated texts that you get when repeatedly generating text will be the same given the same RNG seed, input and settings."})
|
||||||
|
|
||||||
class GenerationResultSchema(KoboldSchema):
|
class GenerationResultSchema(KoboldSchema):
|
||||||
text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."})
|
text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."})
|
||||||
|
@ -7559,6 +7570,29 @@ def _generate_text(body: GenerationInputSchema):
|
||||||
"msg": "Server is busy; please try again later.",
|
"msg": "Server is busy; please try again later.",
|
||||||
"type": "service_unavailable",
|
"type": "service_unavailable",
|
||||||
}}), mimetype="application/json", status=503))
|
}}), mimetype="application/json", status=503))
|
||||||
|
if vars.use_colab_tpu:
|
||||||
|
import tpu_mtj_backend
|
||||||
|
if hasattr(body, "sampler_seed"):
|
||||||
|
# If a seed was specified, we need to save the global RNG state so we
|
||||||
|
# can restore it later
|
||||||
|
old_seed = vars.seed
|
||||||
|
old_rng_state = tpu_mtj_backend.get_rng_state() if vars.use_colab_tpu else torch.get_rng_state()
|
||||||
|
vars.seed = body.sampler_seed
|
||||||
|
# We should try to use a previously saved RNG state with the same seed
|
||||||
|
if body.sampler_seed in vars.rng_states:
|
||||||
|
if vars.use_colab_tpu:
|
||||||
|
tpu_mtj_backend.set_rng_state(vars.rng_states[body.sampler_seed])
|
||||||
|
else:
|
||||||
|
torch.set_rng_state(vars.rng_states[body.sampler_seed])
|
||||||
|
else:
|
||||||
|
if vars.use_colab_tpu:
|
||||||
|
tpu_mtj_backend.set_rng_state(tpu_mtj_backend.new_rng_state(body.sampler_seed))
|
||||||
|
else:
|
||||||
|
torch.manual_seed(body.sampler_seed)
|
||||||
|
vars.rng_states[body.sampler_seed] = tpu_mtj_backend.get_rng_state() if vars.use_colab_tpu else torch.get_rng_state()
|
||||||
|
if hasattr(body, "sampler_order"):
|
||||||
|
if len(body.sampler_order) < 7:
|
||||||
|
body.sampler_order = [6] + body.sampler_order
|
||||||
# This maps each property of the setting to use when sending the generate idempotently
|
# This maps each property of the setting to use when sending the generate idempotently
|
||||||
# To the object which typically contains it's value
|
# To the object which typically contains it's value
|
||||||
# This allows to set the property only for the API generation, and then revert the setting
|
# This allows to set the property only for the API generation, and then revert the setting
|
||||||
|
@ -7584,6 +7618,8 @@ def _generate_text(body: GenerationInputSchema):
|
||||||
"max_context_length": ("vars", "max_length", None),
|
"max_context_length": ("vars", "max_length", None),
|
||||||
"n": ("vars", "numseqs", None),
|
"n": ("vars", "numseqs", None),
|
||||||
"quiet": ("vars", "quiet", None),
|
"quiet": ("vars", "quiet", None),
|
||||||
|
"sampler_order": ("vars", "sampler_order", None),
|
||||||
|
"sampler_full_determinism": ("vars", "full_determinism", None),
|
||||||
}
|
}
|
||||||
saved_settings = {}
|
saved_settings = {}
|
||||||
set_aibusy(1)
|
set_aibusy(1)
|
||||||
|
@ -7633,6 +7669,12 @@ def _generate_text(body: GenerationInputSchema):
|
||||||
vars.output_streaming = output_streaming
|
vars.output_streaming = output_streaming
|
||||||
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
if vars.allowsp and getattr(body, "soft_prompt", None) is not None:
|
||||||
spRequest(old_spfilename)
|
spRequest(old_spfilename)
|
||||||
|
if hasattr(body, "sampler_seed"):
|
||||||
|
vars.seed = old_seed
|
||||||
|
if vars.use_colab_tpu:
|
||||||
|
tpu_mtj_backend.set_rng_state(old_rng_state)
|
||||||
|
else:
|
||||||
|
torch.set_rng_state(old_rng_state)
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@ -9838,6 +9880,60 @@ def put_config_soft_prompt(body: SoftPromptSettingSchema):
|
||||||
settingschanged()
|
settingschanged()
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
class SamplerSeedSettingSchema(KoboldSchema):
|
||||||
|
value: int = fields.Integer(validate=validate.Range(min=0, max=2**64 - 1), required=True)
|
||||||
|
|
||||||
|
@api_v1.get("/config/sampler_seed")
|
||||||
|
@api_schema_wrap
|
||||||
|
def get_config_sampler_seed():
|
||||||
|
"""---
|
||||||
|
get:
|
||||||
|
summary: Retrieve the current global sampler seed value
|
||||||
|
tags:
|
||||||
|
- config
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Successful request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: SamplerSeedSettingSchema
|
||||||
|
example:
|
||||||
|
value: 3475097509890965500
|
||||||
|
"""
|
||||||
|
return {"value": __import__("tpu_mtj_backend").get_rng_seed() if vars.use_colab_tpu else __import__("torch").initial_seed()}
|
||||||
|
|
||||||
|
@api_v1.put("/config/sampler_seed")
|
||||||
|
@api_schema_wrap
|
||||||
|
def put_config_sampler_seed(body: SamplerSeedSettingSchema):
|
||||||
|
"""---
|
||||||
|
put:
|
||||||
|
summary: Set the global sampler seed value
|
||||||
|
tags:
|
||||||
|
- config
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: SamplerSeedSettingSchema
|
||||||
|
example:
|
||||||
|
value: 3475097509890965500
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Successful request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema: EmptySchema
|
||||||
|
{api_validation_error_response}
|
||||||
|
"""
|
||||||
|
if vars.use_colab_tpu:
|
||||||
|
import tpu_mtj_backend
|
||||||
|
tpu_mtj_backend.set_rng_seed(body.value)
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
torch.manual_seed(body.value)
|
||||||
|
vars.seed = body.value
|
||||||
|
return {}
|
||||||
|
|
||||||
config_endpoint_schemas: List[Type[KoboldSchema]] = []
|
config_endpoint_schemas: List[Type[KoboldSchema]] = []
|
||||||
|
|
||||||
def config_endpoint_schema(c: Type[KoboldSchema]):
|
def config_endpoint_schema(c: Type[KoboldSchema]):
|
||||||
|
@ -10035,6 +10131,25 @@ class AddSentenceSpacingSettingsSchema(KoboldSchema):
|
||||||
name = "add sentence spacing (input formatting)"
|
name = "add sentence spacing (input formatting)"
|
||||||
example_yaml_value = "false"
|
example_yaml_value = "false"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class SamplerOrderSettingSchema(KoboldSchema):
|
||||||
|
value = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "sampler_order"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "sampler_order"
|
||||||
|
name = "sampler order"
|
||||||
|
example_yaml_value = "[6, 0, 1, 2, 3, 4, 5]"
|
||||||
|
|
||||||
|
@config_endpoint_schema
|
||||||
|
class SamplerFullDeterminismSettingSchema(KoboldSchema):
|
||||||
|
value = fields.Boolean(required=True)
|
||||||
|
class KoboldMeta:
|
||||||
|
route_name = "sampler_full_determinism"
|
||||||
|
obj = "vars"
|
||||||
|
var_name = "full_determinism"
|
||||||
|
name = "sampler full determinism"
|
||||||
|
example_yaml_value = "false"
|
||||||
|
|
||||||
|
|
||||||
for schema in config_endpoint_schemas:
|
for schema in config_endpoint_schemas:
|
||||||
|
|
|
@ -50,9 +50,12 @@ import itertools
|
||||||
import zipfile
|
import zipfile
|
||||||
import pickle
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
import _codecs
|
||||||
import utils
|
import utils
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
|
||||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||||
|
@ -111,8 +114,50 @@ class LazyTensor:
|
||||||
tensor._backward_hooks = self.backward_hooks
|
tensor._backward_hooks = self.backward_hooks
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
class RestrictedUnpickler(pickle.Unpickler):
|
||||||
|
def original_persistent_load(self, saved_id):
|
||||||
|
return super().persistent_load(saved_id)
|
||||||
|
|
||||||
class _LazyUnpickler(pickle.Unpickler):
|
def forced_persistent_load(self, saved_id):
|
||||||
|
if saved_id[0] != "storage":
|
||||||
|
raise pickle.UnpicklingError("`saved_id[0]` must be 'storage'")
|
||||||
|
return self.original_persistent_load(saved_id)
|
||||||
|
|
||||||
|
def find_class(self, module, name):
|
||||||
|
if module == "collections" and name == "OrderedDict":
|
||||||
|
return collections.OrderedDict
|
||||||
|
elif module == "torch._utils" and name == "_rebuild_tensor_v2":
|
||||||
|
return torch._utils._rebuild_tensor_v2
|
||||||
|
elif module == "torch" and name in (
|
||||||
|
"DoubleStorage",
|
||||||
|
"FloatStorage",
|
||||||
|
"HalfStorage",
|
||||||
|
"LongStorage",
|
||||||
|
"IntStorage",
|
||||||
|
"ShortStorage",
|
||||||
|
"CharStorage",
|
||||||
|
"ByteStorage",
|
||||||
|
"BoolStorage",
|
||||||
|
"BFloat16Storage",
|
||||||
|
):
|
||||||
|
return getattr(torch, name)
|
||||||
|
elif module == "numpy.core.multiarray" and name == "scalar":
|
||||||
|
return np.core.multiarray.scalar
|
||||||
|
elif module == "numpy" and name == "dtype":
|
||||||
|
return np.dtype
|
||||||
|
elif module == "_codecs" and name == "encode":
|
||||||
|
return _codecs.encode
|
||||||
|
else:
|
||||||
|
# Forbid everything else.
|
||||||
|
qualified_name = name if module == "__builtin__" else f"{module}.{name}"
|
||||||
|
raise pickle.UnpicklingError(f"`{qualified_name}` is forbidden; the model you are loading probably contains malicious code")
|
||||||
|
|
||||||
|
def load(self, *args, **kwargs):
|
||||||
|
self.original_persistent_load = getattr(self, "persistent_load", pickle.Unpickler.persistent_load)
|
||||||
|
self.persistent_load = self.forced_persistent_load
|
||||||
|
return super().load(*args, **kwargs)
|
||||||
|
|
||||||
|
class _LazyUnpickler(RestrictedUnpickler):
|
||||||
lazy_loaded_storages: Dict[str, LazyTensor]
|
lazy_loaded_storages: Dict[str, LazyTensor]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -127,7 +172,6 @@ class _LazyUnpickler(pickle.Unpickler):
|
||||||
return LazyTensor(storage_type, key, location)
|
return LazyTensor(storage_type, key, location)
|
||||||
|
|
||||||
def load(self, *args, **kwargs):
|
def load(self, *args, **kwargs):
|
||||||
self.persistent_load = self.forced_persistent_load
|
|
||||||
retval = super().load(*args, **kwargs)
|
retval = super().load(*args, **kwargs)
|
||||||
self.lazy_loaded_storages = {}
|
self.lazy_loaded_storages = {}
|
||||||
return retval
|
return retval
|
||||||
|
@ -213,16 +257,33 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss
|
||||||
unexpected_keys.append(key)
|
unexpected_keys.append(key)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def use_custom_unpickler(unpickler: Type[pickle.Unpickler] = RestrictedUnpickler):
|
||||||
|
try:
|
||||||
|
old_unpickler = pickle.Unpickler
|
||||||
|
pickle.Unpickler = unpickler
|
||||||
|
|
||||||
|
old_pickle_load = pickle.load
|
||||||
|
|
||||||
|
def new_pickle_load(*args, **kwargs):
|
||||||
|
return pickle.Unpickler(*args, **kwargs).load()
|
||||||
|
|
||||||
|
pickle.load = new_pickle_load
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
pickle.Unpickler = old_unpickler
|
||||||
|
pickle.load = old_pickle_load
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
|
def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, dematerialized_modules=False, use_accelerate_init_empty_weights=False):
|
||||||
if not enable:
|
if not enable:
|
||||||
yield False
|
with use_custom_unpickler(RestrictedUnpickler):
|
||||||
|
yield False
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
old_unpickler = pickle.Unpickler
|
|
||||||
pickle.Unpickler = _LazyUnpickler
|
|
||||||
|
|
||||||
old_rebuild_tensor = torch._utils._rebuild_tensor
|
old_rebuild_tensor = torch._utils._rebuild_tensor
|
||||||
torch._utils._rebuild_tensor = _rebuild_tensor
|
torch._utils._rebuild_tensor = _rebuild_tensor
|
||||||
|
|
||||||
|
@ -261,10 +322,10 @@ def use_lazy_torch_load(enable=True, callback: Optional[Callable] = None, demate
|
||||||
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
old_load_from_state_dict = torch.nn.Module._load_from_state_dict
|
||||||
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
torch.nn.Module._load_from_state_dict = _load_from_state_dict
|
||||||
|
|
||||||
yield True
|
with use_custom_unpickler(_LazyUnpickler):
|
||||||
|
yield True
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
pickle.Unpickler = old_unpickler
|
|
||||||
torch._utils._rebuild_tensor = old_rebuild_tensor
|
torch._utils._rebuild_tensor = old_rebuild_tensor
|
||||||
torch.load = old_torch_load
|
torch.load = old_torch_load
|
||||||
if dematerialized_modules:
|
if dematerialized_modules:
|
||||||
|
|
|
@ -55,7 +55,7 @@ from mesh_transformer.util import to_bf16
|
||||||
|
|
||||||
params: Dict[str, Any] = {}
|
params: Dict[str, Any] = {}
|
||||||
|
|
||||||
__seed = random.randrange(sys.maxsize)
|
__seed = random.randrange(2**64)
|
||||||
rng = random.Random(__seed)
|
rng = random.Random(__seed)
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,8 +69,17 @@ def set_rng_seed(seed: int):
|
||||||
return seed
|
return seed
|
||||||
|
|
||||||
def randomize_rng_seed():
|
def randomize_rng_seed():
|
||||||
return set_rng_seed(random.randrange(sys.maxsize))
|
return set_rng_seed(random.randrange(2**64))
|
||||||
|
|
||||||
|
def get_rng_state():
|
||||||
|
return rng
|
||||||
|
|
||||||
|
def set_rng_state(state):
|
||||||
|
global rng
|
||||||
|
rng = state
|
||||||
|
|
||||||
|
def new_rng_state(seed: int):
|
||||||
|
return random.Random(seed)
|
||||||
|
|
||||||
def warper_callback(logits) -> np.array:
|
def warper_callback(logits) -> np.array:
|
||||||
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
|
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
|
||||||
|
@ -946,6 +955,7 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.dlpack
|
import torch.utils.dlpack
|
||||||
|
import torch_lazy_loader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
move_xmap = jax.experimental.maps.xmap(
|
move_xmap = jax.experimental.maps.xmap(
|
||||||
|
@ -987,8 +997,9 @@ def read_neox_checkpoint(state, path, config, checkpoint_shards=2):
|
||||||
continue
|
continue
|
||||||
layer = checkpoint_layer - 2
|
layer = checkpoint_layer - 2
|
||||||
shards = []
|
shards = []
|
||||||
for checkpoint_shard in range(checkpoint_shards):
|
with torch_lazy_loader.use_custom_unpickler(torch_lazy_loader.RestrictedUnpickler):
|
||||||
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
|
for checkpoint_shard in range(checkpoint_shards):
|
||||||
|
shards.append(torch.load(path_template.format(layer=checkpoint_layer, shard=checkpoint_shard), map_location="cpu"))
|
||||||
for key in shards[0]:
|
for key in shards[0]:
|
||||||
if key == "attention.rotary_emb.inv_freq":
|
if key == "attention.rotary_emb.inv_freq":
|
||||||
continue
|
continue
|
||||||
|
|
Loading…
Reference in New Issue