Model Fix bugs and introduce hack for visualization

Hopefully I remove that attrocity before the PR
This commit is contained in:
somebody
2023-02-25 18:12:49 -06:00
parent ffe4f25349
commit 465e22fa5c

114
model.py
View File

@@ -18,7 +18,7 @@ import json
import os import os
import time import time
import traceback import traceback
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Set, Union
import zipfile import zipfile
from tqdm.auto import tqdm from tqdm.auto import tqdm
from logger import logger from logger import logger
@@ -48,6 +48,8 @@ import utils
import breakmodel import breakmodel
import koboldai_settings import koboldai_settings
HACK_currentmodel = None
try: try:
import tpu_mtj_backend import tpu_mtj_backend
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
@@ -147,10 +149,35 @@ class Stoppers:
return model.gen_state["regeneration_required"] or model.gen_state["halt"] return model.gen_state["regeneration_required"] or model.gen_state["halt"]
@staticmethod @staticmethod
def wi_scanner( def dynamic_wi_scanner(
model: InferenceModel, model: InferenceModel,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
) -> bool: ) -> bool:
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
return False
if not utils.koboldai_vars.dynamicscan:
return False
if len(model.gen_state["wi_scanner_excluded_keys"]) != input_ids.shape[0]:
model.gen_state["wi_scanner_excluded_keys"]
print(model.tokenizer.decode(model.gen_state["wi_scanner_excluded_keys"]))
print(model.tokenizer.decode(input_ids.shape[0]))
assert len(model.gen_state["wi_scanner_excluded_keys"]) == input_ids.shape[0]
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
for i, t in enumerate(tail):
decoded = utils.decodenewlines(model.tokenizer.decode(t))
_, _, _, found = utils.koboldai_vars.calc_ai_text(
submitted_text=decoded, send_context=False
)
found = list(
set(found) - set(model.gen_state["wi_scanner_excluded_keys"][i])
)
if found:
print("FOUNDWI", found)
return True
return False return False
@staticmethod @staticmethod
@@ -342,7 +369,6 @@ def patch_transformers():
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TemperatureLogitsWarper, TemperatureLogitsWarper,
RepetitionPenaltyLogitsProcessor,
) )
from warpers import ( from warpers import (
AdvancedRepetitionPenaltyLogitsProcessor, AdvancedRepetitionPenaltyLogitsProcessor,
@@ -370,6 +396,7 @@ def patch_transformers():
cls.__call__ = new_call cls.__call__ = new_call
# TODO: Make samplers generic
dynamic_processor_wrap( dynamic_processor_wrap(
AdvancedRepetitionPenaltyLogitsProcessor, AdvancedRepetitionPenaltyLogitsProcessor,
("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"), ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
@@ -579,7 +606,10 @@ def patch_transformers():
from torch.nn import functional as F from torch.nn import functional as F
def visualize_probabilities(scores: torch.FloatTensor) -> None: def visualize_probabilities(
model: InferenceModel,
scores: torch.FloatTensor,
) -> None:
assert scores.ndim == 2 assert scores.ndim == 2
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs: if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
@@ -620,7 +650,9 @@ def patch_transformers():
token_prob_info.append( token_prob_info.append(
{ {
"tokenId": token_id, "tokenId": token_id,
"decoded": utils.decodenewlines(tokenizer.decode(token_id)), "decoded": utils.decodenewlines(
model.tokenizer.decode(token_id)
),
"score": float(score), "score": float(score),
} }
) )
@@ -680,7 +712,7 @@ def patch_transformers():
sampler_order = [6] + sampler_order sampler_order = [6] + sampler_order
for k in sampler_order: for k in sampler_order:
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
visualize_probabilities(scores) visualize_probabilities(HACK_currentmodel, scores)
return scores return scores
def new_get_logits_warper( def new_get_logits_warper(
@@ -714,45 +746,6 @@ def patch_transformers():
) )
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
# Sets up dynamic world info scanner
class DynamicWorldInfoScanCriteria(StoppingCriteria):
def __init__(
self,
tokenizer,
excluded_world_info: List[Set],
):
self.tokenizer = tokenizer
self.excluded_world_info = excluded_world_info
def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs,
) -> bool:
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
return False
if not utils.koboldai_vars.dynamicscan:
return False
if len(self.excluded_world_info) != input_ids.shape[0]:
print(tokenizer.decode(self.excluded_world_info))
print(tokenizer.decode(input_ids.shape[0]))
assert len(self.excluded_world_info) == input_ids.shape[0]
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
for i, t in enumerate(tail):
decoded = utils.decodenewlines(tokenizer.decode(t))
_, _, _, found = utils.koboldai_vars.calc_ai_text(
submitted_text=decoded, send_context=False
)
found = list(set(found) - set(self.excluded_world_info[i]))
if found:
print("FOUNDWI", found)
return True
return False
class GenerationResult: class GenerationResult:
def __init__( def __init__(
@@ -811,6 +804,9 @@ class InferenceModel:
self._load(save_model=save_model) self._load(save_model=save_model)
self._post_load() self._post_load()
global HACK_currentmodel
HACK_currentmodel = self
def _post_load(self) -> None: def _post_load(self) -> None:
pass pass
@@ -981,8 +977,8 @@ class InferenceModel:
# stop temporarily to insert WI, we can assume that we are done # stop temporarily to insert WI, we can assume that we are done
# generating. We shall break. # generating. We shall break.
if ( if (
model.core_stopper.halt self.gen_state["halt"]
or not model.core_stopper.regeneration_required or not self.gen_state["regeneration_required"]
): ):
break break
@@ -1141,21 +1137,17 @@ class InferenceModel:
if utils.koboldai_vars.model == "ReadOnly": if utils.koboldai_vars.model == "ReadOnly":
raise NotImplementedError("No loaded model") raise NotImplementedError("No loaded model")
result: GenerationResult
time_start = time.time() time_start = time.time()
with use_core_manipulations(): with use_core_manipulations():
self._raw_generate( result = self._raw_generate(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
max_new=max_new, max_new=max_new,
batch_count=batch_count, batch_count=batch_count,
gen_settings=gen_settings, gen_settings=gen_settings,
single_line=single_line, single_line=single_line,
) )
# if i_vars.use_colab_tpu or koboldai_vars.model in (
# "TPUMeshTransformerGPTJ",
# "TPUMeshTransformerGPTNeoX",
# ):
time_end = round(time.time() - time_start, 2) time_end = round(time.time() - time_start, 2)
tokens_per_second = round(len(result.encoded[0]) / time_end, 2) tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
@@ -1250,7 +1242,7 @@ class HFMTJInferenceModel:
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
): ) -> GenerationResult:
soft_tokens = self.get_soft_tokens() soft_tokens = self.get_soft_tokens()
genout = tpool.execute( genout = tpool.execute(
@@ -1297,7 +1289,7 @@ class HFTorchInferenceModel(InferenceModel):
self.post_token_hooks = [ self.post_token_hooks = [
Stoppers.core_stopper, Stoppers.core_stopper,
PostTokenHooks.stream_tokens, PostTokenHooks.stream_tokens,
Stoppers.wi_scanner, Stoppers.dynamic_wi_scanner,
Stoppers.chat_mode_stopper, Stoppers.chat_mode_stopper,
] ]
@@ -1338,7 +1330,7 @@ class HFTorchInferenceModel(InferenceModel):
**kwargs, **kwargs,
): ):
stopping_criteria = old_gsc(hf_self, *args, **kwargs) stopping_criteria = old_gsc(hf_self, *args, **kwargs)
stopping_criteria.insert(0, PTHStopper) stopping_criteria.insert(0, PTHStopper())
return stopping_criteria return stopping_criteria
use_core_manipulations.get_stopping_criteria = _get_stopping_criteria use_core_manipulations.get_stopping_criteria = _get_stopping_criteria
@@ -1350,7 +1342,7 @@ class HFTorchInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
): ) -> GenerationResult:
if not isinstance(prompt_tokens, torch.Tensor): if not isinstance(prompt_tokens, torch.Tensor):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
else: else:
@@ -1379,7 +1371,13 @@ class HFTorchInferenceModel(InferenceModel):
"torch_raw_generate: run generator {}s".format(time.time() - start_time) "torch_raw_generate: run generator {}s".format(time.time() - start_time)
) )
return genout return GenerationResult(
self,
out_batches=genout,
prompt=prompt_tokens,
is_whole_generation=False,
output_includes_prompt=True,
)
def _get_model(self, location: str, tf_kwargs: Dict): def _get_model(self, location: str, tf_kwargs: Dict):
try: try: