Merge branch 'united' of https://github.com/henk717/KoboldAI into model-structure-and-maybe-rwkv

This commit is contained in:
somebody
2023-03-17 16:20:13 -05:00
12 changed files with 153 additions and 64 deletions

View File

@@ -110,19 +110,43 @@ class HFTorchInferenceModel(HFInferenceModel):
pre = torch.Tensor(scores)
return scores
def _post_load(model_self) -> None:
# Patch stopping_criteria
def get_model_type(self) -> str:
if not self.model_config:
return "Read Only"
if not isinstance(self.model_config, dict):
return str(self.model_config.model_type)
model_type = self.model_config.get("model_type")
if model_type:
return model_type
if utils.koboldai_vars.mode.endswith("gpt2"):
return "gpt2"
else:
return "Unknown"
def _post_load(m_self) -> None:
if not utils.koboldai_vars.model_type:
utils.koboldai_vars.model_type = m_self.get_model_type()
# Model specific overrides if a model has bad defaults
if utils.koboldai_vars.model_type == "llama":
m_self.tokenizer.decode_with_prefix_space = True
m_self.tokenizer.add_bos_token = False
# Patch stopping_criteria
class PTHStopper(StoppingCriteria):
def __call__(
hf_self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
) -> None:
model_self._post_token_gen(input_ids)
m_self._post_token_gen(input_ids)
for stopper in model_self.stopper_hooks:
do_stop = stopper(model_self, input_ids)
for stopper in m_self.stopper_hooks:
do_stop = stopper(m_self, input_ids)
if do_stop:
return True
return False
@@ -235,18 +259,18 @@ class HFTorchInferenceModel(HFInferenceModel):
# Handle direct phrases
if phrase.startswith("{") and phrase.endswith("}"):
no_brackets = phrase[1:-1]
return [model_self.tokenizer.encode(no_brackets)]
return [m_self.tokenizer.encode(no_brackets)]
# Handle untamperable phrases
if not self._allow_leftwards_tampering(phrase):
return [model_self.tokenizer.encode(phrase)]
return [m_self.tokenizer.encode(phrase)]
# Handle slight alterations to original phrase
phrase = phrase.strip(" ")
ret = []
for alt_phrase in [phrase, f" {phrase}"]:
ret.append(model_self.tokenizer.encode(alt_phrase))
ret.append(m_self.tokenizer.encode(alt_phrase))
return ret
@@ -428,8 +452,8 @@ class HFTorchInferenceModel(HFInferenceModel):
*args,
**kwargs,
):
scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
visualize_probabilities(model_self, scores)
scores = m_self._apply_warpers(scores=scores, input_ids=input_ids)
visualize_probabilities(m_self, scores)
return scores
def new_get_logits_warper(