Model: Monkey release detection fix

This commit is contained in:
somebody
2023-02-26 13:35:29 -06:00
parent b99c16f562
commit 35bbd78326

View File

@@ -246,12 +246,12 @@ class use_core_manipulations:
get_stopping_criteria: callable = None
# We set these automatically
old_get_logits_processor: callable
old_sample: callable
old_get_stopping_criteria: callable
old_get_logits_processor: callable = None
old_sample: callable = None
old_get_stopping_criteria: callable = None
def __enter__(self):
if self.get_logits_processor:
if use_core_manipulations.get_logits_processor:
use_core_manipulations.old_get_logits_processor = (
transformers.GenerationMixin._get_logits_processor
)
@@ -259,11 +259,11 @@ class use_core_manipulations:
use_core_manipulations.get_logits_processor
)
if self.sample:
if use_core_manipulations.sample:
use_core_manipulations.old_sample = transformers.GenerationMixin.sample
transformers.GenerationMixin.sample = use_core_manipulations.sample
if self.get_stopping_criteria:
if use_core_manipulations.get_stopping_criteria:
use_core_manipulations.old_get_stopping_criteria = (
transformers.GenerationMixin._get_stopping_criteria
)
@@ -273,24 +273,24 @@ class use_core_manipulations:
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if self.old_get_logits_processor:
if use_core_manipulations.old_get_logits_processor:
transformers.GenerationMixin._get_logits_processor = (
use_core_manipulations.old_get_logits_processor
)
else:
assert not self.get_logits_processor, "Patch leak: THE MONKEYS HAVE ESCAPED"
assert not use_core_manipulations.get_logits_processor, "Patch leak: THE MONKEYS HAVE ESCAPED"
if self.old_sample:
if use_core_manipulations.old_sample:
transformers.GenerationMixin.sample = use_core_manipulations.old_sample
else:
assert not self.sample, "Patch leak: THE MONKEYS HAVE ESCAPED"
assert not use_core_manipulations.sample, "Patch leak: THE MONKEYS HAVE ESCAPED"
if self.old_get_stopping_criteria:
if use_core_manipulations.old_get_stopping_criteria:
transformers.GenerationMixin._get_stopping_criteria = (
use_core_manipulations.old_get_stopping_criteria
)
else:
assert not self.get_stopping_criteria, "Patch leak: THE MONKEYS HAVE ESCAPED"
assert not use_core_manipulations.get_stopping_criteria, "Patch leak: THE MONKEYS HAVE ESCAPED"
def patch_transformers_download():