Model: Ditch awful current_model hack

thanks to whjms for spotting that this could be zapped
This commit is contained in:
somebody
2023-03-09 19:08:08 -06:00
parent 885c226651
commit fb0b2f0467
2 changed files with 9 additions and 16 deletions

View File

@@ -22,10 +22,6 @@ except ModuleNotFoundError as e:
if utils.koboldai_vars.use_colab_tpu:
raise e
# I don't really like this way of pointing to the current model but I can't
# find a way around it in some areas.
current_model = None
# We only want to use logit manipulations and such on our core text model
class use_core_manipulations:
"""Use in a `with` block to patch functions for core story model sampling."""
@@ -176,9 +172,6 @@ class InferenceModel:
self._load(save_model=save_model, initial_load=initial_load)
self._post_load()
global current_model
current_model = self
def _post_load(self) -> None:
"""Post load hook. Called after `_load()`."""

View File

@@ -113,7 +113,7 @@ class HFTorchInferenceModel(HFInferenceModel):
pre = torch.Tensor(scores)
return scores
def _post_load(self) -> None:
def _post_load(model_self) -> None:
# Patch stopping_criteria
class PTHStopper(StoppingCriteria):
@@ -122,10 +122,10 @@ class HFTorchInferenceModel(HFInferenceModel):
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
) -> None:
self._post_token_gen(input_ids)
model_self._post_token_gen(input_ids)
for stopper in self.stopper_hooks:
do_stop = stopper(self, input_ids)
for stopper in model_self.stopper_hooks:
do_stop = stopper(model_self, input_ids)
if do_stop:
return True
return False
@@ -238,11 +238,11 @@ class HFTorchInferenceModel(HFInferenceModel):
# Handle direct phrases
if phrase.startswith("{") and phrase.endswith("}"):
no_brackets = phrase[1:-1]
return [inference_model.current_model.tokenizer.encode(no_brackets)]
return [model_self.tokenizer.encode(no_brackets)]
# Handle untamperable phrases
if not self._allow_leftwards_tampering(phrase):
return [inference_model.current_model.tokenizer.encode(phrase)]
return [model_self.tokenizer.encode(phrase)]
# Handle slight alterations to original phrase
phrase = phrase.strip(" ")
@@ -250,7 +250,7 @@ class HFTorchInferenceModel(HFInferenceModel):
for alt_phrase in [phrase, f" {phrase}"]:
ret.append(
inference_model.current_model.tokenizer.encode(alt_phrase)
model_self.tokenizer.encode(alt_phrase)
)
return ret
@@ -440,8 +440,8 @@ class HFTorchInferenceModel(HFInferenceModel):
# sampler_order = [6] + sampler_order
# for k in sampler_order:
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
scores = self._apply_warpers(scores=scores, input_ids=input_ids)
visualize_probabilities(inference_model.current_model, scores)
scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
visualize_probabilities(model_self, scores)
return scores
def new_get_logits_warper(