From fb0b2f0467e43e535fc57c005c54204f631a0080 Mon Sep 17 00:00:00 2001 From: somebody Date: Thu, 9 Mar 2023 19:08:08 -0600 Subject: [PATCH] Model: Ditch awful current_model hack thanks to whjms for spotting that this could be zapped --- modeling/inference_model.py | 7 ------- modeling/inference_models/hf_torch.py | 18 +++++++++--------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 5157acb0..9663f929 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -22,10 +22,6 @@ except ModuleNotFoundError as e: if utils.koboldai_vars.use_colab_tpu: raise e -# I don't really like this way of pointing to the current model but I can't -# find a way around it in some areas. -current_model = None - # We only want to use logit manipulations and such on our core text model class use_core_manipulations: """Use in a `with` block to patch functions for core story model sampling.""" @@ -176,9 +172,6 @@ class InferenceModel: self._load(save_model=save_model, initial_load=initial_load) self._post_load() - global current_model - current_model = self - def _post_load(self) -> None: """Post load hook. Called after `_load()`.""" diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index bf339a24..e101c6da 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -113,7 +113,7 @@ class HFTorchInferenceModel(HFInferenceModel): pre = torch.Tensor(scores) return scores - def _post_load(self) -> None: + def _post_load(model_self) -> None: # Patch stopping_criteria class PTHStopper(StoppingCriteria): @@ -122,10 +122,10 @@ class HFTorchInferenceModel(HFInferenceModel): input_ids: torch.LongTensor, scores: torch.FloatTensor, ) -> None: - self._post_token_gen(input_ids) + model_self._post_token_gen(input_ids) - for stopper in self.stopper_hooks: - do_stop = stopper(self, input_ids) + for stopper in model_self.stopper_hooks: + do_stop = stopper(model_self, input_ids) if do_stop: return True return False @@ -238,11 +238,11 @@ class HFTorchInferenceModel(HFInferenceModel): # Handle direct phrases if phrase.startswith("{") and phrase.endswith("}"): no_brackets = phrase[1:-1] - return [inference_model.current_model.tokenizer.encode(no_brackets)] + return [model_self.tokenizer.encode(no_brackets)] # Handle untamperable phrases if not self._allow_leftwards_tampering(phrase): - return [inference_model.current_model.tokenizer.encode(phrase)] + return [model_self.tokenizer.encode(phrase)] # Handle slight alterations to original phrase phrase = phrase.strip(" ") @@ -250,7 +250,7 @@ class HFTorchInferenceModel(HFInferenceModel): for alt_phrase in [phrase, f" {phrase}"]: ret.append( - inference_model.current_model.tokenizer.encode(alt_phrase) + model_self.tokenizer.encode(alt_phrase) ) return ret @@ -440,8 +440,8 @@ class HFTorchInferenceModel(HFInferenceModel): # sampler_order = [6] + sampler_order # for k in sampler_order: # scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) - scores = self._apply_warpers(scores=scores, input_ids=input_ids) - visualize_probabilities(inference_model.current_model, scores) + scores = model_self._apply_warpers(scores=scores, input_ids=input_ids) + visualize_probabilities(model_self, scores) return scores def new_get_logits_warper(