mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Ditch awful current_model hack
thanks to whjms for spotting that this could be zapped
This commit is contained in:
@@ -22,10 +22,6 @@ except ModuleNotFoundError as e:
|
||||
if utils.koboldai_vars.use_colab_tpu:
|
||||
raise e
|
||||
|
||||
# I don't really like this way of pointing to the current model but I can't
|
||||
# find a way around it in some areas.
|
||||
current_model = None
|
||||
|
||||
# We only want to use logit manipulations and such on our core text model
|
||||
class use_core_manipulations:
|
||||
"""Use in a `with` block to patch functions for core story model sampling."""
|
||||
@@ -176,9 +172,6 @@ class InferenceModel:
|
||||
self._load(save_model=save_model, initial_load=initial_load)
|
||||
self._post_load()
|
||||
|
||||
global current_model
|
||||
current_model = self
|
||||
|
||||
def _post_load(self) -> None:
|
||||
"""Post load hook. Called after `_load()`."""
|
||||
|
||||
|
@@ -113,7 +113,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
pre = torch.Tensor(scores)
|
||||
return scores
|
||||
|
||||
def _post_load(self) -> None:
|
||||
def _post_load(model_self) -> None:
|
||||
# Patch stopping_criteria
|
||||
|
||||
class PTHStopper(StoppingCriteria):
|
||||
@@ -122,10 +122,10 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
) -> None:
|
||||
self._post_token_gen(input_ids)
|
||||
model_self._post_token_gen(input_ids)
|
||||
|
||||
for stopper in self.stopper_hooks:
|
||||
do_stop = stopper(self, input_ids)
|
||||
for stopper in model_self.stopper_hooks:
|
||||
do_stop = stopper(model_self, input_ids)
|
||||
if do_stop:
|
||||
return True
|
||||
return False
|
||||
@@ -238,11 +238,11 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
# Handle direct phrases
|
||||
if phrase.startswith("{") and phrase.endswith("}"):
|
||||
no_brackets = phrase[1:-1]
|
||||
return [inference_model.current_model.tokenizer.encode(no_brackets)]
|
||||
return [model_self.tokenizer.encode(no_brackets)]
|
||||
|
||||
# Handle untamperable phrases
|
||||
if not self._allow_leftwards_tampering(phrase):
|
||||
return [inference_model.current_model.tokenizer.encode(phrase)]
|
||||
return [model_self.tokenizer.encode(phrase)]
|
||||
|
||||
# Handle slight alterations to original phrase
|
||||
phrase = phrase.strip(" ")
|
||||
@@ -250,7 +250,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
|
||||
for alt_phrase in [phrase, f" {phrase}"]:
|
||||
ret.append(
|
||||
inference_model.current_model.tokenizer.encode(alt_phrase)
|
||||
model_self.tokenizer.encode(alt_phrase)
|
||||
)
|
||||
|
||||
return ret
|
||||
@@ -440,8 +440,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
# sampler_order = [6] + sampler_order
|
||||
# for k in sampler_order:
|
||||
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||
scores = self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||
visualize_probabilities(inference_model.current_model, scores)
|
||||
scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||
visualize_probabilities(model_self, scores)
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(
|
||||
|
Reference in New Issue
Block a user