Model: Formatting fixes

This commit is contained in:
somebody
2023-02-26 14:14:29 -06:00
parent 10842e964b
commit f771ae38cf

View File

@@ -278,19 +278,25 @@ class use_core_manipulations:
use_core_manipulations.old_get_logits_processor use_core_manipulations.old_get_logits_processor
) )
else: else:
assert not use_core_manipulations.get_logits_processor, "Patch leak: THE MONKEYS HAVE ESCAPED" assert (
not use_core_manipulations.get_logits_processor
), "Patch leak: THE MONKEYS HAVE ESCAPED"
if use_core_manipulations.old_sample: if use_core_manipulations.old_sample:
transformers.GenerationMixin.sample = use_core_manipulations.old_sample transformers.GenerationMixin.sample = use_core_manipulations.old_sample
else: else:
assert not use_core_manipulations.sample, "Patch leak: THE MONKEYS HAVE ESCAPED" assert (
not use_core_manipulations.sample
), "Patch leak: THE MONKEYS HAVE ESCAPED"
if use_core_manipulations.old_get_stopping_criteria: if use_core_manipulations.old_get_stopping_criteria:
transformers.GenerationMixin._get_stopping_criteria = ( transformers.GenerationMixin._get_stopping_criteria = (
use_core_manipulations.old_get_stopping_criteria use_core_manipulations.old_get_stopping_criteria
) )
else: else:
assert not use_core_manipulations.get_stopping_criteria, "Patch leak: THE MONKEYS HAVE ESCAPED" assert (
not use_core_manipulations.get_stopping_criteria
), "Patch leak: THE MONKEYS HAVE ESCAPED"
def patch_transformers_download(): def patch_transformers_download():
@@ -1798,9 +1804,7 @@ class HFTorchInferenceModel(InferenceModel):
# (the folder doesn't contain any subfolders so os.remove will do just fine) # (the folder doesn't contain any subfolders so os.remove will do just fine)
for filename in os.listdir("accelerate-disk-cache"): for filename in os.listdir("accelerate-disk-cache"):
try: try:
os.remove( os.remove(os.path.join("accelerate-disk-cache", filename))
os.path.join("accelerate-disk-cache", filename)
)
except OSError: except OSError:
pass pass
os.makedirs("accelerate-disk-cache", exist_ok=True) os.makedirs("accelerate-disk-cache", exist_ok=True)
@@ -2019,7 +2023,10 @@ class HFTorchInferenceModel(InferenceModel):
breakmodel.gpu_blocks = [0] * n_layers breakmodel.gpu_blocks = [0] * n_layers
return return
elif utils.args.breakmodel_gpulayers is not None or utils.args.breakmodel_disklayers is not None: elif (
utils.args.breakmodel_gpulayers is not None
or utils.args.breakmodel_disklayers is not None
):
try: try:
if not utils.args.breakmodel_gpulayers: if not utils.args.breakmodel_gpulayers:
breakmodel.gpu_blocks = [] breakmodel.gpu_blocks = []
@@ -2632,7 +2639,7 @@ class HordeInferenceModel(InferenceModel):
client_agent = "KoboldAI:2.0.0:koboldai.org" client_agent = "KoboldAI:2.0.0:koboldai.org"
cluster_headers = { cluster_headers = {
"apikey": utils.koboldai_vars.horde_api_key, "apikey": utils.koboldai_vars.horde_api_key,
"Client-Agent": client_agent "Client-Agent": client_agent,
} }
try: try:
@@ -2671,15 +2678,13 @@ class HordeInferenceModel(InferenceModel):
# We've sent the request and got the ID back, now we need to watch it to see when it finishes # We've sent the request and got the ID back, now we need to watch it to see when it finishes
finished = False finished = False
cluster_agent_headers = { cluster_agent_headers = {"Client-Agent": client_agent}
"Client-Agent": client_agent
}
while not finished: while not finished:
try: try:
req = requests.get( req = requests.get(
f"{utils.koboldai_vars.colaburl[:-8]}/api/v2/generate/text/status/{request_id}", f"{utils.koboldai_vars.colaburl[:-8]}/api/v2/generate/text/status/{request_id}",
headers=cluster_agent_headers headers=cluster_agent_headers,
) )
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
errmsg = f"Horde unavailable. Please try again later" errmsg = f"Horde unavailable. Please try again later"
@@ -2725,7 +2730,9 @@ class HordeInferenceModel(InferenceModel):
return GenerationResult( return GenerationResult(
model=self, model=self,
out_batches=np.array([self.tokenizer.encode(cgen["text"]) for cgen in generations]), out_batches=np.array(
[self.tokenizer.encode(cgen["text"]) for cgen in generations]
),
prompt=prompt_tokens, prompt=prompt_tokens,
is_whole_generation=True, is_whole_generation=True,
single_line=single_line, single_line=single_line,