diff --git a/model.py b/model.py index 746c0571..8d53f934 100644 --- a/model.py +++ b/model.py @@ -246,12 +246,12 @@ class use_core_manipulations: get_stopping_criteria: callable = None # We set these automatically - old_get_logits_processor: callable - old_sample: callable - old_get_stopping_criteria: callable + old_get_logits_processor: callable = None + old_sample: callable = None + old_get_stopping_criteria: callable = None def __enter__(self): - if self.get_logits_processor: + if use_core_manipulations.get_logits_processor: use_core_manipulations.old_get_logits_processor = ( transformers.GenerationMixin._get_logits_processor ) @@ -259,11 +259,11 @@ class use_core_manipulations: use_core_manipulations.get_logits_processor ) - if self.sample: + if use_core_manipulations.sample: use_core_manipulations.old_sample = transformers.GenerationMixin.sample transformers.GenerationMixin.sample = use_core_manipulations.sample - if self.get_stopping_criteria: + if use_core_manipulations.get_stopping_criteria: use_core_manipulations.old_get_stopping_criteria = ( transformers.GenerationMixin._get_stopping_criteria ) @@ -273,24 +273,24 @@ class use_core_manipulations: return self def __exit__(self, exc_type, exc_value, exc_traceback): - if self.old_get_logits_processor: + if use_core_manipulations.old_get_logits_processor: transformers.GenerationMixin._get_logits_processor = ( use_core_manipulations.old_get_logits_processor ) else: - assert not self.get_logits_processor, "Patch leak: THE MONKEYS HAVE ESCAPED" + assert not use_core_manipulations.get_logits_processor, "Patch leak: THE MONKEYS HAVE ESCAPED" - if self.old_sample: + if use_core_manipulations.old_sample: transformers.GenerationMixin.sample = use_core_manipulations.old_sample else: - assert not self.sample, "Patch leak: THE MONKEYS HAVE ESCAPED" + assert not use_core_manipulations.sample, "Patch leak: THE MONKEYS HAVE ESCAPED" - if self.old_get_stopping_criteria: + if use_core_manipulations.old_get_stopping_criteria: transformers.GenerationMixin._get_stopping_criteria = ( use_core_manipulations.old_get_stopping_criteria ) else: - assert not self.get_stopping_criteria, "Patch leak: THE MONKEYS HAVE ESCAPED" + assert not use_core_manipulations.get_stopping_criteria, "Patch leak: THE MONKEYS HAVE ESCAPED" def patch_transformers_download():