Model: Monkey release detection fix

This commit is contained in:
somebody
2023-02-26 13:35:29 -06:00
parent b99c16f562
commit 35bbd78326

View File

@@ -246,12 +246,12 @@ class use_core_manipulations:
get_stopping_criteria: callable = None get_stopping_criteria: callable = None
# We set these automatically # We set these automatically
old_get_logits_processor: callable old_get_logits_processor: callable = None
old_sample: callable old_sample: callable = None
old_get_stopping_criteria: callable old_get_stopping_criteria: callable = None
def __enter__(self): def __enter__(self):
if self.get_logits_processor: if use_core_manipulations.get_logits_processor:
use_core_manipulations.old_get_logits_processor = ( use_core_manipulations.old_get_logits_processor = (
transformers.GenerationMixin._get_logits_processor transformers.GenerationMixin._get_logits_processor
) )
@@ -259,11 +259,11 @@ class use_core_manipulations:
use_core_manipulations.get_logits_processor use_core_manipulations.get_logits_processor
) )
if self.sample: if use_core_manipulations.sample:
use_core_manipulations.old_sample = transformers.GenerationMixin.sample use_core_manipulations.old_sample = transformers.GenerationMixin.sample
transformers.GenerationMixin.sample = use_core_manipulations.sample transformers.GenerationMixin.sample = use_core_manipulations.sample
if self.get_stopping_criteria: if use_core_manipulations.get_stopping_criteria:
use_core_manipulations.old_get_stopping_criteria = ( use_core_manipulations.old_get_stopping_criteria = (
transformers.GenerationMixin._get_stopping_criteria transformers.GenerationMixin._get_stopping_criteria
) )
@@ -273,24 +273,24 @@ class use_core_manipulations:
return self return self
def __exit__(self, exc_type, exc_value, exc_traceback): def __exit__(self, exc_type, exc_value, exc_traceback):
if self.old_get_logits_processor: if use_core_manipulations.old_get_logits_processor:
transformers.GenerationMixin._get_logits_processor = ( transformers.GenerationMixin._get_logits_processor = (
use_core_manipulations.old_get_logits_processor use_core_manipulations.old_get_logits_processor
) )
else: else:
assert not self.get_logits_processor, "Patch leak: THE MONKEYS HAVE ESCAPED" assert not use_core_manipulations.get_logits_processor, "Patch leak: THE MONKEYS HAVE ESCAPED"
if self.old_sample: if use_core_manipulations.old_sample:
transformers.GenerationMixin.sample = use_core_manipulations.old_sample transformers.GenerationMixin.sample = use_core_manipulations.old_sample
else: else:
assert not self.sample, "Patch leak: THE MONKEYS HAVE ESCAPED" assert not use_core_manipulations.sample, "Patch leak: THE MONKEYS HAVE ESCAPED"
if self.old_get_stopping_criteria: if use_core_manipulations.old_get_stopping_criteria:
transformers.GenerationMixin._get_stopping_criteria = ( transformers.GenerationMixin._get_stopping_criteria = (
use_core_manipulations.old_get_stopping_criteria use_core_manipulations.old_get_stopping_criteria
) )
else: else:
assert not self.get_stopping_criteria, "Patch leak: THE MONKEYS HAVE ESCAPED" assert not use_core_manipulations.get_stopping_criteria, "Patch leak: THE MONKEYS HAVE ESCAPED"
def patch_transformers_download(): def patch_transformers_download():