Merge branch 'united' of https://github.com/henk717/KoboldAI into fixing-time

This commit is contained in:
somebody
2023-08-07 16:22:04 -05:00
30 changed files with 1940 additions and 807 deletions

View File

@@ -17,9 +17,12 @@ from transformers import (
StoppingCriteria,
GPTNeoForCausalLM,
GPT2LMHeadModel,
AutoModelForCausalLM,
LogitsProcessorList,
)
try:
from hf_bleeding_edge import AutoModelForCausalLM
except ImportError:
from transformers import AutoModelForCausalLM
import utils
import modeling.lazy_loader as lazy_loader
@@ -31,6 +34,7 @@ from modeling.stoppers import Stoppers
from modeling.post_token_hooks import PostTokenHooks
from modeling.inference_models.hf import HFInferenceModel
from modeling.inference_model import (
GenerationMode,
GenerationResult,
GenerationSettings,
ModelCapabilities,
@@ -127,8 +131,13 @@ class HFTorchInferenceModel(HFInferenceModel):
return ret
def get_auxilary_device(self) -> Union[str, int, torch.device]:
return self.breakmodel_config.primary_device
if self.breakmodel:
return self.breakmodel_config.primary_device
if self.usegpu:
return "cuda:0"
else:
return "cpu"
def _get_target_dtype(self) -> Union[torch.float16, torch.float32]:
if self.breakmodel_config.primary_device == "cpu":
return torch.float32
@@ -229,9 +238,6 @@ class HFTorchInferenceModel(HFInferenceModel):
)
class KoboldLogitsWarperList(LogitsProcessorList):
def __init__(self):
pass
def __call__(
lw_self,
input_ids: torch.LongTensor,
@@ -248,17 +254,14 @@ class HFTorchInferenceModel(HFInferenceModel):
), f"Scores are None; processor '{processor}' is to blame"
return scores
def new_get_logits_warper(
beams: int = 1,
) -> LogitsProcessorList:
return KoboldLogitsWarperList()
def new_sample(self, *args, **kwargs):
assert kwargs.pop("logits_warper", None) is not None
kwargs["logits_warper"] = new_get_logits_warper(
beams=1,
)
if utils.koboldai_vars.newlinemode in ["s", "ns"]:
kwargs["logits_warper"] = KoboldLogitsWarperList()
if (
utils.koboldai_vars.newlinemode in ["s", "ns"]
and not m_self.gen_state["allow_eos"]
):
kwargs["eos_token_id"] = -1
kwargs.setdefault("pad_token_id", 2)
return new_sample.old_sample(self, *args, **kwargs)
@@ -330,7 +333,7 @@ class HFTorchInferenceModel(HFInferenceModel):
with torch.no_grad():
start_time = time.time()
genout = self.model.generate(
gen_in,
input_ids=gen_in,
do_sample=True,
max_length=min(
len(prompt_tokens) + max_new, utils.koboldai_vars.max_length
@@ -609,3 +612,9 @@ class HFTorchInferenceModel(HFInferenceModel):
self.breakmodel = False
self.usegpu = False
return
def get_supported_gen_modes(self) -> List[GenerationMode]:
# This changes a torch patch to disallow eos as a bad word.
return super().get_supported_gen_modes() + [
GenerationMode.UNTIL_EOS
]