mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Ditch awful current_model hack
thanks to whjms for spotting that this could be zapped
This commit is contained in:
@@ -22,10 +22,6 @@ except ModuleNotFoundError as e:
|
|||||||
if utils.koboldai_vars.use_colab_tpu:
|
if utils.koboldai_vars.use_colab_tpu:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# I don't really like this way of pointing to the current model but I can't
|
|
||||||
# find a way around it in some areas.
|
|
||||||
current_model = None
|
|
||||||
|
|
||||||
# We only want to use logit manipulations and such on our core text model
|
# We only want to use logit manipulations and such on our core text model
|
||||||
class use_core_manipulations:
|
class use_core_manipulations:
|
||||||
"""Use in a `with` block to patch functions for core story model sampling."""
|
"""Use in a `with` block to patch functions for core story model sampling."""
|
||||||
@@ -176,9 +172,6 @@ class InferenceModel:
|
|||||||
self._load(save_model=save_model, initial_load=initial_load)
|
self._load(save_model=save_model, initial_load=initial_load)
|
||||||
self._post_load()
|
self._post_load()
|
||||||
|
|
||||||
global current_model
|
|
||||||
current_model = self
|
|
||||||
|
|
||||||
def _post_load(self) -> None:
|
def _post_load(self) -> None:
|
||||||
"""Post load hook. Called after `_load()`."""
|
"""Post load hook. Called after `_load()`."""
|
||||||
|
|
||||||
|
@@ -113,7 +113,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
pre = torch.Tensor(scores)
|
pre = torch.Tensor(scores)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def _post_load(self) -> None:
|
def _post_load(model_self) -> None:
|
||||||
# Patch stopping_criteria
|
# Patch stopping_criteria
|
||||||
|
|
||||||
class PTHStopper(StoppingCriteria):
|
class PTHStopper(StoppingCriteria):
|
||||||
@@ -122,10 +122,10 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
scores: torch.FloatTensor,
|
scores: torch.FloatTensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._post_token_gen(input_ids)
|
model_self._post_token_gen(input_ids)
|
||||||
|
|
||||||
for stopper in self.stopper_hooks:
|
for stopper in model_self.stopper_hooks:
|
||||||
do_stop = stopper(self, input_ids)
|
do_stop = stopper(model_self, input_ids)
|
||||||
if do_stop:
|
if do_stop:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@@ -238,11 +238,11 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
# Handle direct phrases
|
# Handle direct phrases
|
||||||
if phrase.startswith("{") and phrase.endswith("}"):
|
if phrase.startswith("{") and phrase.endswith("}"):
|
||||||
no_brackets = phrase[1:-1]
|
no_brackets = phrase[1:-1]
|
||||||
return [inference_model.current_model.tokenizer.encode(no_brackets)]
|
return [model_self.tokenizer.encode(no_brackets)]
|
||||||
|
|
||||||
# Handle untamperable phrases
|
# Handle untamperable phrases
|
||||||
if not self._allow_leftwards_tampering(phrase):
|
if not self._allow_leftwards_tampering(phrase):
|
||||||
return [inference_model.current_model.tokenizer.encode(phrase)]
|
return [model_self.tokenizer.encode(phrase)]
|
||||||
|
|
||||||
# Handle slight alterations to original phrase
|
# Handle slight alterations to original phrase
|
||||||
phrase = phrase.strip(" ")
|
phrase = phrase.strip(" ")
|
||||||
@@ -250,7 +250,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
|
|
||||||
for alt_phrase in [phrase, f" {phrase}"]:
|
for alt_phrase in [phrase, f" {phrase}"]:
|
||||||
ret.append(
|
ret.append(
|
||||||
inference_model.current_model.tokenizer.encode(alt_phrase)
|
model_self.tokenizer.encode(alt_phrase)
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
@@ -440,8 +440,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
# sampler_order = [6] + sampler_order
|
# sampler_order = [6] + sampler_order
|
||||||
# for k in sampler_order:
|
# for k in sampler_order:
|
||||||
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||||
scores = self._apply_warpers(scores=scores, input_ids=input_ids)
|
scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||||
visualize_probabilities(inference_model.current_model, scores)
|
visualize_probabilities(model_self, scores)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def new_get_logits_warper(
|
def new_get_logits_warper(
|
||||||
|
Reference in New Issue
Block a user