Model: Ditch awful current_model hack

thanks to whjms for spotting that this could be zapped
This commit is contained in:
somebody
2023-03-09 19:08:08 -06:00
parent 885c226651
commit fb0b2f0467
2 changed files with 9 additions and 16 deletions

View File

@@ -22,10 +22,6 @@ except ModuleNotFoundError as e:
if utils.koboldai_vars.use_colab_tpu: if utils.koboldai_vars.use_colab_tpu:
raise e raise e
# I don't really like this way of pointing to the current model but I can't
# find a way around it in some areas.
current_model = None
# We only want to use logit manipulations and such on our core text model # We only want to use logit manipulations and such on our core text model
class use_core_manipulations: class use_core_manipulations:
"""Use in a `with` block to patch functions for core story model sampling.""" """Use in a `with` block to patch functions for core story model sampling."""
@@ -176,9 +172,6 @@ class InferenceModel:
self._load(save_model=save_model, initial_load=initial_load) self._load(save_model=save_model, initial_load=initial_load)
self._post_load() self._post_load()
global current_model
current_model = self
def _post_load(self) -> None: def _post_load(self) -> None:
"""Post load hook. Called after `_load()`.""" """Post load hook. Called after `_load()`."""

View File

@@ -113,7 +113,7 @@ class HFTorchInferenceModel(HFInferenceModel):
pre = torch.Tensor(scores) pre = torch.Tensor(scores)
return scores return scores
def _post_load(self) -> None: def _post_load(model_self) -> None:
# Patch stopping_criteria # Patch stopping_criteria
class PTHStopper(StoppingCriteria): class PTHStopper(StoppingCriteria):
@@ -122,10 +122,10 @@ class HFTorchInferenceModel(HFInferenceModel):
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
scores: torch.FloatTensor, scores: torch.FloatTensor,
) -> None: ) -> None:
self._post_token_gen(input_ids) model_self._post_token_gen(input_ids)
for stopper in self.stopper_hooks: for stopper in model_self.stopper_hooks:
do_stop = stopper(self, input_ids) do_stop = stopper(model_self, input_ids)
if do_stop: if do_stop:
return True return True
return False return False
@@ -238,11 +238,11 @@ class HFTorchInferenceModel(HFInferenceModel):
# Handle direct phrases # Handle direct phrases
if phrase.startswith("{") and phrase.endswith("}"): if phrase.startswith("{") and phrase.endswith("}"):
no_brackets = phrase[1:-1] no_brackets = phrase[1:-1]
return [inference_model.current_model.tokenizer.encode(no_brackets)] return [model_self.tokenizer.encode(no_brackets)]
# Handle untamperable phrases # Handle untamperable phrases
if not self._allow_leftwards_tampering(phrase): if not self._allow_leftwards_tampering(phrase):
return [inference_model.current_model.tokenizer.encode(phrase)] return [model_self.tokenizer.encode(phrase)]
# Handle slight alterations to original phrase # Handle slight alterations to original phrase
phrase = phrase.strip(" ") phrase = phrase.strip(" ")
@@ -250,7 +250,7 @@ class HFTorchInferenceModel(HFInferenceModel):
for alt_phrase in [phrase, f" {phrase}"]: for alt_phrase in [phrase, f" {phrase}"]:
ret.append( ret.append(
inference_model.current_model.tokenizer.encode(alt_phrase) model_self.tokenizer.encode(alt_phrase)
) )
return ret return ret
@@ -440,8 +440,8 @@ class HFTorchInferenceModel(HFInferenceModel):
# sampler_order = [6] + sampler_order # sampler_order = [6] + sampler_order
# for k in sampler_order: # for k in sampler_order:
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) # scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
scores = self._apply_warpers(scores=scores, input_ids=input_ids) scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
visualize_probabilities(inference_model.current_model, scores) visualize_probabilities(model_self, scores)
return scores return scores
def new_get_logits_warper( def new_get_logits_warper(