mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'united' of https://github.com/henk717/KoboldAI into fixing-time
This commit is contained in:
@@ -17,9 +17,12 @@ from transformers import (
|
||||
StoppingCriteria,
|
||||
GPTNeoForCausalLM,
|
||||
GPT2LMHeadModel,
|
||||
AutoModelForCausalLM,
|
||||
LogitsProcessorList,
|
||||
)
|
||||
try:
|
||||
from hf_bleeding_edge import AutoModelForCausalLM
|
||||
except ImportError:
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
import utils
|
||||
import modeling.lazy_loader as lazy_loader
|
||||
@@ -31,6 +34,7 @@ from modeling.stoppers import Stoppers
|
||||
from modeling.post_token_hooks import PostTokenHooks
|
||||
from modeling.inference_models.hf import HFInferenceModel
|
||||
from modeling.inference_model import (
|
||||
GenerationMode,
|
||||
GenerationResult,
|
||||
GenerationSettings,
|
||||
ModelCapabilities,
|
||||
@@ -127,8 +131,13 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
return ret
|
||||
|
||||
def get_auxilary_device(self) -> Union[str, int, torch.device]:
|
||||
return self.breakmodel_config.primary_device
|
||||
|
||||
if self.breakmodel:
|
||||
return self.breakmodel_config.primary_device
|
||||
if self.usegpu:
|
||||
return "cuda:0"
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
def _get_target_dtype(self) -> Union[torch.float16, torch.float32]:
|
||||
if self.breakmodel_config.primary_device == "cpu":
|
||||
return torch.float32
|
||||
@@ -229,9 +238,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
)
|
||||
|
||||
class KoboldLogitsWarperList(LogitsProcessorList):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
lw_self,
|
||||
input_ids: torch.LongTensor,
|
||||
@@ -248,17 +254,14 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
), f"Scores are None; processor '{processor}' is to blame"
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(
|
||||
beams: int = 1,
|
||||
) -> LogitsProcessorList:
|
||||
return KoboldLogitsWarperList()
|
||||
|
||||
def new_sample(self, *args, **kwargs):
|
||||
assert kwargs.pop("logits_warper", None) is not None
|
||||
kwargs["logits_warper"] = new_get_logits_warper(
|
||||
beams=1,
|
||||
)
|
||||
if utils.koboldai_vars.newlinemode in ["s", "ns"]:
|
||||
kwargs["logits_warper"] = KoboldLogitsWarperList()
|
||||
|
||||
if (
|
||||
utils.koboldai_vars.newlinemode in ["s", "ns"]
|
||||
and not m_self.gen_state["allow_eos"]
|
||||
):
|
||||
kwargs["eos_token_id"] = -1
|
||||
kwargs.setdefault("pad_token_id", 2)
|
||||
return new_sample.old_sample(self, *args, **kwargs)
|
||||
@@ -330,7 +333,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
with torch.no_grad():
|
||||
start_time = time.time()
|
||||
genout = self.model.generate(
|
||||
gen_in,
|
||||
input_ids=gen_in,
|
||||
do_sample=True,
|
||||
max_length=min(
|
||||
len(prompt_tokens) + max_new, utils.koboldai_vars.max_length
|
||||
@@ -609,3 +612,9 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
self.breakmodel = False
|
||||
self.usegpu = False
|
||||
return
|
||||
|
||||
def get_supported_gen_modes(self) -> List[GenerationMode]:
|
||||
# This changes a torch patch to disallow eos as a bad word.
|
||||
return super().get_supported_gen_modes() + [
|
||||
GenerationMode.UNTIL_EOS
|
||||
]
|
Reference in New Issue
Block a user