mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'united' of https://github.com/henk717/KoboldAI into model-structure-and-maybe-rwkv
This commit is contained in:
@@ -110,19 +110,43 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
pre = torch.Tensor(scores)
|
||||
return scores
|
||||
|
||||
def _post_load(model_self) -> None:
|
||||
# Patch stopping_criteria
|
||||
def get_model_type(self) -> str:
|
||||
if not self.model_config:
|
||||
return "Read Only"
|
||||
|
||||
if not isinstance(self.model_config, dict):
|
||||
return str(self.model_config.model_type)
|
||||
|
||||
model_type = self.model_config.get("model_type")
|
||||
|
||||
if model_type:
|
||||
return model_type
|
||||
|
||||
if utils.koboldai_vars.mode.endswith("gpt2"):
|
||||
return "gpt2"
|
||||
else:
|
||||
return "Unknown"
|
||||
|
||||
def _post_load(m_self) -> None:
|
||||
if not utils.koboldai_vars.model_type:
|
||||
utils.koboldai_vars.model_type = m_self.get_model_type()
|
||||
|
||||
# Model specific overrides if a model has bad defaults
|
||||
if utils.koboldai_vars.model_type == "llama":
|
||||
m_self.tokenizer.decode_with_prefix_space = True
|
||||
m_self.tokenizer.add_bos_token = False
|
||||
|
||||
# Patch stopping_criteria
|
||||
class PTHStopper(StoppingCriteria):
|
||||
def __call__(
|
||||
hf_self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
) -> None:
|
||||
model_self._post_token_gen(input_ids)
|
||||
m_self._post_token_gen(input_ids)
|
||||
|
||||
for stopper in model_self.stopper_hooks:
|
||||
do_stop = stopper(model_self, input_ids)
|
||||
for stopper in m_self.stopper_hooks:
|
||||
do_stop = stopper(m_self, input_ids)
|
||||
if do_stop:
|
||||
return True
|
||||
return False
|
||||
@@ -235,18 +259,18 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
# Handle direct phrases
|
||||
if phrase.startswith("{") and phrase.endswith("}"):
|
||||
no_brackets = phrase[1:-1]
|
||||
return [model_self.tokenizer.encode(no_brackets)]
|
||||
return [m_self.tokenizer.encode(no_brackets)]
|
||||
|
||||
# Handle untamperable phrases
|
||||
if not self._allow_leftwards_tampering(phrase):
|
||||
return [model_self.tokenizer.encode(phrase)]
|
||||
return [m_self.tokenizer.encode(phrase)]
|
||||
|
||||
# Handle slight alterations to original phrase
|
||||
phrase = phrase.strip(" ")
|
||||
ret = []
|
||||
|
||||
for alt_phrase in [phrase, f" {phrase}"]:
|
||||
ret.append(model_self.tokenizer.encode(alt_phrase))
|
||||
ret.append(m_self.tokenizer.encode(alt_phrase))
|
||||
|
||||
return ret
|
||||
|
||||
@@ -428,8 +452,8 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||
visualize_probabilities(model_self, scores)
|
||||
scores = m_self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||
visualize_probabilities(m_self, scores)
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(
|
||||
|
Reference in New Issue
Block a user