Model: Formatting fixes

This commit is contained in:
somebody
2023-02-26 14:14:29 -06:00
parent 10842e964b
commit f771ae38cf

View File

@@ -278,19 +278,25 @@ class use_core_manipulations:
use_core_manipulations.old_get_logits_processor
)
else:
assert not use_core_manipulations.get_logits_processor, "Patch leak: THE MONKEYS HAVE ESCAPED"
assert (
not use_core_manipulations.get_logits_processor
), "Patch leak: THE MONKEYS HAVE ESCAPED"
if use_core_manipulations.old_sample:
transformers.GenerationMixin.sample = use_core_manipulations.old_sample
else:
assert not use_core_manipulations.sample, "Patch leak: THE MONKEYS HAVE ESCAPED"
assert (
not use_core_manipulations.sample
), "Patch leak: THE MONKEYS HAVE ESCAPED"
if use_core_manipulations.old_get_stopping_criteria:
transformers.GenerationMixin._get_stopping_criteria = (
use_core_manipulations.old_get_stopping_criteria
)
else:
assert not use_core_manipulations.get_stopping_criteria, "Patch leak: THE MONKEYS HAVE ESCAPED"
assert (
not use_core_manipulations.get_stopping_criteria
), "Patch leak: THE MONKEYS HAVE ESCAPED"
def patch_transformers_download():
@@ -1798,9 +1804,7 @@ class HFTorchInferenceModel(InferenceModel):
# (the folder doesn't contain any subfolders so os.remove will do just fine)
for filename in os.listdir("accelerate-disk-cache"):
try:
os.remove(
os.path.join("accelerate-disk-cache", filename)
)
os.remove(os.path.join("accelerate-disk-cache", filename))
except OSError:
pass
os.makedirs("accelerate-disk-cache", exist_ok=True)
@@ -2019,7 +2023,10 @@ class HFTorchInferenceModel(InferenceModel):
breakmodel.gpu_blocks = [0] * n_layers
return
elif utils.args.breakmodel_gpulayers is not None or utils.args.breakmodel_disklayers is not None:
elif (
utils.args.breakmodel_gpulayers is not None
or utils.args.breakmodel_disklayers is not None
):
try:
if not utils.args.breakmodel_gpulayers:
breakmodel.gpu_blocks = []
@@ -2632,7 +2639,7 @@ class HordeInferenceModel(InferenceModel):
client_agent = "KoboldAI:2.0.0:koboldai.org"
cluster_headers = {
"apikey": utils.koboldai_vars.horde_api_key,
"Client-Agent": client_agent
"Client-Agent": client_agent,
}
try:
@@ -2671,15 +2678,13 @@ class HordeInferenceModel(InferenceModel):
# We've sent the request and got the ID back, now we need to watch it to see when it finishes
finished = False
cluster_agent_headers = {
"Client-Agent": client_agent
}
cluster_agent_headers = {"Client-Agent": client_agent}
while not finished:
try:
req = requests.get(
f"{utils.koboldai_vars.colaburl[:-8]}/api/v2/generate/text/status/{request_id}",
headers=cluster_agent_headers
headers=cluster_agent_headers,
)
except requests.exceptions.ConnectionError:
errmsg = f"Horde unavailable. Please try again later"
@@ -2725,7 +2730,9 @@ class HordeInferenceModel(InferenceModel):
return GenerationResult(
model=self,
out_batches=np.array([self.tokenizer.encode(cgen["text"]) for cgen in generations]),
out_batches=np.array(
[self.tokenizer.encode(cgen["text"]) for cgen in generations]
),
prompt=prompt_tokens,
is_whole_generation=True,
single_line=single_line,