mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model Fix bugs and introduce hack for visualization
Hopefully I remove that attrocity before the PR
This commit is contained in:
114
model.py
114
model.py
@@ -18,7 +18,7 @@ import json
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
from typing import Dict, Iterable, List, Optional, Set, Union
|
||||
import zipfile
|
||||
from tqdm.auto import tqdm
|
||||
from logger import logger
|
||||
@@ -48,6 +48,8 @@ import utils
|
||||
import breakmodel
|
||||
import koboldai_settings
|
||||
|
||||
HACK_currentmodel = None
|
||||
|
||||
try:
|
||||
import tpu_mtj_backend
|
||||
except ModuleNotFoundError as e:
|
||||
@@ -147,10 +149,35 @@ class Stoppers:
|
||||
return model.gen_state["regeneration_required"] or model.gen_state["halt"]
|
||||
|
||||
@staticmethod
|
||||
def wi_scanner(
|
||||
def dynamic_wi_scanner(
|
||||
model: InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
|
||||
return False
|
||||
|
||||
if not utils.koboldai_vars.dynamicscan:
|
||||
return False
|
||||
|
||||
if len(model.gen_state["wi_scanner_excluded_keys"]) != input_ids.shape[0]:
|
||||
model.gen_state["wi_scanner_excluded_keys"]
|
||||
print(model.tokenizer.decode(model.gen_state["wi_scanner_excluded_keys"]))
|
||||
print(model.tokenizer.decode(input_ids.shape[0]))
|
||||
|
||||
assert len(model.gen_state["wi_scanner_excluded_keys"]) == input_ids.shape[0]
|
||||
|
||||
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
|
||||
for i, t in enumerate(tail):
|
||||
decoded = utils.decodenewlines(model.tokenizer.decode(t))
|
||||
_, _, _, found = utils.koboldai_vars.calc_ai_text(
|
||||
submitted_text=decoded, send_context=False
|
||||
)
|
||||
found = list(
|
||||
set(found) - set(model.gen_state["wi_scanner_excluded_keys"][i])
|
||||
)
|
||||
if found:
|
||||
print("FOUNDWI", found)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@@ -342,7 +369,6 @@ def patch_transformers():
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
TemperatureLogitsWarper,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
)
|
||||
from warpers import (
|
||||
AdvancedRepetitionPenaltyLogitsProcessor,
|
||||
@@ -370,6 +396,7 @@ def patch_transformers():
|
||||
|
||||
cls.__call__ = new_call
|
||||
|
||||
# TODO: Make samplers generic
|
||||
dynamic_processor_wrap(
|
||||
AdvancedRepetitionPenaltyLogitsProcessor,
|
||||
("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
|
||||
@@ -579,7 +606,10 @@ def patch_transformers():
|
||||
|
||||
from torch.nn import functional as F
|
||||
|
||||
def visualize_probabilities(scores: torch.FloatTensor) -> None:
|
||||
def visualize_probabilities(
|
||||
model: InferenceModel,
|
||||
scores: torch.FloatTensor,
|
||||
) -> None:
|
||||
assert scores.ndim == 2
|
||||
|
||||
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
|
||||
@@ -620,7 +650,9 @@ def patch_transformers():
|
||||
token_prob_info.append(
|
||||
{
|
||||
"tokenId": token_id,
|
||||
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
|
||||
"decoded": utils.decodenewlines(
|
||||
model.tokenizer.decode(token_id)
|
||||
),
|
||||
"score": float(score),
|
||||
}
|
||||
)
|
||||
@@ -680,7 +712,7 @@ def patch_transformers():
|
||||
sampler_order = [6] + sampler_order
|
||||
for k in sampler_order:
|
||||
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||
visualize_probabilities(scores)
|
||||
visualize_probabilities(HACK_currentmodel, scores)
|
||||
return scores
|
||||
|
||||
def new_get_logits_warper(
|
||||
@@ -714,45 +746,6 @@ def patch_transformers():
|
||||
)
|
||||
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
||||
|
||||
# Sets up dynamic world info scanner
|
||||
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
excluded_world_info: List[Set],
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.excluded_world_info = excluded_world_info
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor,
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
|
||||
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
|
||||
return False
|
||||
|
||||
if not utils.koboldai_vars.dynamicscan:
|
||||
return False
|
||||
|
||||
if len(self.excluded_world_info) != input_ids.shape[0]:
|
||||
print(tokenizer.decode(self.excluded_world_info))
|
||||
print(tokenizer.decode(input_ids.shape[0]))
|
||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
||||
|
||||
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
|
||||
for i, t in enumerate(tail):
|
||||
decoded = utils.decodenewlines(tokenizer.decode(t))
|
||||
_, _, _, found = utils.koboldai_vars.calc_ai_text(
|
||||
submitted_text=decoded, send_context=False
|
||||
)
|
||||
found = list(set(found) - set(self.excluded_world_info[i]))
|
||||
if found:
|
||||
print("FOUNDWI", found)
|
||||
return True
|
||||
return False
|
||||
|
||||
class GenerationResult:
|
||||
def __init__(
|
||||
@@ -811,6 +804,9 @@ class InferenceModel:
|
||||
self._load(save_model=save_model)
|
||||
self._post_load()
|
||||
|
||||
global HACK_currentmodel
|
||||
HACK_currentmodel = self
|
||||
|
||||
def _post_load(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -981,8 +977,8 @@ class InferenceModel:
|
||||
# stop temporarily to insert WI, we can assume that we are done
|
||||
# generating. We shall break.
|
||||
if (
|
||||
model.core_stopper.halt
|
||||
or not model.core_stopper.regeneration_required
|
||||
self.gen_state["halt"]
|
||||
or not self.gen_state["regeneration_required"]
|
||||
):
|
||||
break
|
||||
|
||||
@@ -1141,21 +1137,17 @@ class InferenceModel:
|
||||
if utils.koboldai_vars.model == "ReadOnly":
|
||||
raise NotImplementedError("No loaded model")
|
||||
|
||||
result: GenerationResult
|
||||
time_start = time.time()
|
||||
|
||||
with use_core_manipulations():
|
||||
self._raw_generate(
|
||||
result = self._raw_generate(
|
||||
prompt_tokens=prompt_tokens,
|
||||
max_new=max_new,
|
||||
batch_count=batch_count,
|
||||
gen_settings=gen_settings,
|
||||
single_line=single_line,
|
||||
)
|
||||
# if i_vars.use_colab_tpu or koboldai_vars.model in (
|
||||
# "TPUMeshTransformerGPTJ",
|
||||
# "TPUMeshTransformerGPTNeoX",
|
||||
# ):
|
||||
|
||||
time_end = round(time.time() - time_start, 2)
|
||||
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
|
||||
|
||||
@@ -1250,7 +1242,7 @@ class HFMTJInferenceModel:
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
):
|
||||
) -> GenerationResult:
|
||||
soft_tokens = self.get_soft_tokens()
|
||||
|
||||
genout = tpool.execute(
|
||||
@@ -1297,7 +1289,7 @@ class HFTorchInferenceModel(InferenceModel):
|
||||
self.post_token_hooks = [
|
||||
Stoppers.core_stopper,
|
||||
PostTokenHooks.stream_tokens,
|
||||
Stoppers.wi_scanner,
|
||||
Stoppers.dynamic_wi_scanner,
|
||||
Stoppers.chat_mode_stopper,
|
||||
]
|
||||
|
||||
@@ -1338,7 +1330,7 @@ class HFTorchInferenceModel(InferenceModel):
|
||||
**kwargs,
|
||||
):
|
||||
stopping_criteria = old_gsc(hf_self, *args, **kwargs)
|
||||
stopping_criteria.insert(0, PTHStopper)
|
||||
stopping_criteria.insert(0, PTHStopper())
|
||||
return stopping_criteria
|
||||
|
||||
use_core_manipulations.get_stopping_criteria = _get_stopping_criteria
|
||||
@@ -1350,7 +1342,7 @@ class HFTorchInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
):
|
||||
) -> GenerationResult:
|
||||
if not isinstance(prompt_tokens, torch.Tensor):
|
||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||
else:
|
||||
@@ -1379,7 +1371,13 @@ class HFTorchInferenceModel(InferenceModel):
|
||||
"torch_raw_generate: run generator {}s".format(time.time() - start_time)
|
||||
)
|
||||
|
||||
return genout
|
||||
return GenerationResult(
|
||||
self,
|
||||
out_batches=genout,
|
||||
prompt=prompt_tokens,
|
||||
is_whole_generation=False,
|
||||
output_includes_prompt=True,
|
||||
)
|
||||
|
||||
def _get_model(self, location: str, tf_kwargs: Dict):
|
||||
try:
|
||||
|
Reference in New Issue
Block a user