mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Monkey release detection fix
This commit is contained in:
24
model.py
24
model.py
@@ -246,12 +246,12 @@ class use_core_manipulations:
|
||||
get_stopping_criteria: callable = None
|
||||
|
||||
# We set these automatically
|
||||
old_get_logits_processor: callable
|
||||
old_sample: callable
|
||||
old_get_stopping_criteria: callable
|
||||
old_get_logits_processor: callable = None
|
||||
old_sample: callable = None
|
||||
old_get_stopping_criteria: callable = None
|
||||
|
||||
def __enter__(self):
|
||||
if self.get_logits_processor:
|
||||
if use_core_manipulations.get_logits_processor:
|
||||
use_core_manipulations.old_get_logits_processor = (
|
||||
transformers.GenerationMixin._get_logits_processor
|
||||
)
|
||||
@@ -259,11 +259,11 @@ class use_core_manipulations:
|
||||
use_core_manipulations.get_logits_processor
|
||||
)
|
||||
|
||||
if self.sample:
|
||||
if use_core_manipulations.sample:
|
||||
use_core_manipulations.old_sample = transformers.GenerationMixin.sample
|
||||
transformers.GenerationMixin.sample = use_core_manipulations.sample
|
||||
|
||||
if self.get_stopping_criteria:
|
||||
if use_core_manipulations.get_stopping_criteria:
|
||||
use_core_manipulations.old_get_stopping_criteria = (
|
||||
transformers.GenerationMixin._get_stopping_criteria
|
||||
)
|
||||
@@ -273,24 +273,24 @@ class use_core_manipulations:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
if self.old_get_logits_processor:
|
||||
if use_core_manipulations.old_get_logits_processor:
|
||||
transformers.GenerationMixin._get_logits_processor = (
|
||||
use_core_manipulations.old_get_logits_processor
|
||||
)
|
||||
else:
|
||||
assert not self.get_logits_processor, "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||||
assert not use_core_manipulations.get_logits_processor, "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||||
|
||||
if self.old_sample:
|
||||
if use_core_manipulations.old_sample:
|
||||
transformers.GenerationMixin.sample = use_core_manipulations.old_sample
|
||||
else:
|
||||
assert not self.sample, "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||||
assert not use_core_manipulations.sample, "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||||
|
||||
if self.old_get_stopping_criteria:
|
||||
if use_core_manipulations.old_get_stopping_criteria:
|
||||
transformers.GenerationMixin._get_stopping_criteria = (
|
||||
use_core_manipulations.old_get_stopping_criteria
|
||||
)
|
||||
else:
|
||||
assert not self.get_stopping_criteria, "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||||
assert not use_core_manipulations.get_stopping_criteria, "Patch leak: THE MONKEYS HAVE ESCAPED"
|
||||
|
||||
|
||||
def patch_transformers_download():
|
||||
|
Reference in New Issue
Block a user