Model Fix bugs and introduce hack for visualization

Hopefully I remove that attrocity before the PR
This commit is contained in:
somebody
2023-02-25 18:12:49 -06:00
parent ffe4f25349
commit 465e22fa5c

114
model.py
View File

@@ -18,7 +18,7 @@ import json
import os
import time
import traceback
from typing import Dict, Iterable, List, Optional, Union
from typing import Dict, Iterable, List, Optional, Set, Union
import zipfile
from tqdm.auto import tqdm
from logger import logger
@@ -48,6 +48,8 @@ import utils
import breakmodel
import koboldai_settings
HACK_currentmodel = None
try:
import tpu_mtj_backend
except ModuleNotFoundError as e:
@@ -147,10 +149,35 @@ class Stoppers:
return model.gen_state["regeneration_required"] or model.gen_state["halt"]
@staticmethod
def wi_scanner(
def dynamic_wi_scanner(
model: InferenceModel,
input_ids: torch.LongTensor,
) -> bool:
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
return False
if not utils.koboldai_vars.dynamicscan:
return False
if len(model.gen_state["wi_scanner_excluded_keys"]) != input_ids.shape[0]:
model.gen_state["wi_scanner_excluded_keys"]
print(model.tokenizer.decode(model.gen_state["wi_scanner_excluded_keys"]))
print(model.tokenizer.decode(input_ids.shape[0]))
assert len(model.gen_state["wi_scanner_excluded_keys"]) == input_ids.shape[0]
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
for i, t in enumerate(tail):
decoded = utils.decodenewlines(model.tokenizer.decode(t))
_, _, _, found = utils.koboldai_vars.calc_ai_text(
submitted_text=decoded, send_context=False
)
found = list(
set(found) - set(model.gen_state["wi_scanner_excluded_keys"][i])
)
if found:
print("FOUNDWI", found)
return True
return False
@staticmethod
@@ -342,7 +369,6 @@ def patch_transformers():
TopKLogitsWarper,
TopPLogitsWarper,
TemperatureLogitsWarper,
RepetitionPenaltyLogitsProcessor,
)
from warpers import (
AdvancedRepetitionPenaltyLogitsProcessor,
@@ -370,6 +396,7 @@ def patch_transformers():
cls.__call__ = new_call
# TODO: Make samplers generic
dynamic_processor_wrap(
AdvancedRepetitionPenaltyLogitsProcessor,
("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
@@ -579,7 +606,10 @@ def patch_transformers():
from torch.nn import functional as F
def visualize_probabilities(scores: torch.FloatTensor) -> None:
def visualize_probabilities(
model: InferenceModel,
scores: torch.FloatTensor,
) -> None:
assert scores.ndim == 2
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
@@ -620,7 +650,9 @@ def patch_transformers():
token_prob_info.append(
{
"tokenId": token_id,
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
"decoded": utils.decodenewlines(
model.tokenizer.decode(token_id)
),
"score": float(score),
}
)
@@ -680,7 +712,7 @@ def patch_transformers():
sampler_order = [6] + sampler_order
for k in sampler_order:
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
visualize_probabilities(scores)
visualize_probabilities(HACK_currentmodel, scores)
return scores
def new_get_logits_warper(
@@ -714,45 +746,6 @@ def patch_transformers():
)
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
# Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria(StoppingCriteria):
def __init__(
self,
tokenizer,
excluded_world_info: List[Set],
):
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
return False
if not utils.koboldai_vars.dynamicscan:
return False
if len(self.excluded_world_info) != input_ids.shape[0]:
print(tokenizer.decode(self.excluded_world_info))
print(tokenizer.decode(input_ids.shape[0]))
assert len(self.excluded_world_info) == input_ids.shape[0]
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
for i, t in enumerate(tail):
decoded = utils.decodenewlines(tokenizer.decode(t))
_, _, _, found = utils.koboldai_vars.calc_ai_text(
submitted_text=decoded, send_context=False
)
found = list(set(found) - set(self.excluded_world_info[i]))
if found:
print("FOUNDWI", found)
return True
return False
class GenerationResult:
def __init__(
@@ -811,6 +804,9 @@ class InferenceModel:
self._load(save_model=save_model)
self._post_load()
global HACK_currentmodel
HACK_currentmodel = self
def _post_load(self) -> None:
pass
@@ -981,8 +977,8 @@ class InferenceModel:
# stop temporarily to insert WI, we can assume that we are done
# generating. We shall break.
if (
model.core_stopper.halt
or not model.core_stopper.regeneration_required
self.gen_state["halt"]
or not self.gen_state["regeneration_required"]
):
break
@@ -1141,21 +1137,17 @@ class InferenceModel:
if utils.koboldai_vars.model == "ReadOnly":
raise NotImplementedError("No loaded model")
result: GenerationResult
time_start = time.time()
with use_core_manipulations():
self._raw_generate(
result = self._raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings,
single_line=single_line,
)
# if i_vars.use_colab_tpu or koboldai_vars.model in (
# "TPUMeshTransformerGPTJ",
# "TPUMeshTransformerGPTNeoX",
# ):
time_end = round(time.time() - time_start, 2)
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
@@ -1250,7 +1242,7 @@ class HFMTJInferenceModel:
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
):
) -> GenerationResult:
soft_tokens = self.get_soft_tokens()
genout = tpool.execute(
@@ -1297,7 +1289,7 @@ class HFTorchInferenceModel(InferenceModel):
self.post_token_hooks = [
Stoppers.core_stopper,
PostTokenHooks.stream_tokens,
Stoppers.wi_scanner,
Stoppers.dynamic_wi_scanner,
Stoppers.chat_mode_stopper,
]
@@ -1338,7 +1330,7 @@ class HFTorchInferenceModel(InferenceModel):
**kwargs,
):
stopping_criteria = old_gsc(hf_self, *args, **kwargs)
stopping_criteria.insert(0, PTHStopper)
stopping_criteria.insert(0, PTHStopper())
return stopping_criteria
use_core_manipulations.get_stopping_criteria = _get_stopping_criteria
@@ -1350,7 +1342,7 @@ class HFTorchInferenceModel(InferenceModel):
gen_settings: GenerationSettings,
single_line: bool = False,
batch_count: int = 1,
):
) -> GenerationResult:
if not isinstance(prompt_tokens, torch.Tensor):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
else:
@@ -1379,7 +1371,13 @@ class HFTorchInferenceModel(InferenceModel):
"torch_raw_generate: run generator {}s".format(time.time() - start_time)
)
return genout
return GenerationResult(
self,
out_batches=genout,
prompt=prompt_tokens,
is_whole_generation=False,
output_includes_prompt=True,
)
def _get_model(self, location: str, tf_kwargs: Dict):
try: