mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model Fix bugs and introduce hack for visualization
Hopefully I remove that attrocity before the PR
This commit is contained in:
114
model.py
114
model.py
@@ -18,7 +18,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, Iterable, List, Optional, Union
|
from typing import Dict, Iterable, List, Optional, Set, Union
|
||||||
import zipfile
|
import zipfile
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from logger import logger
|
from logger import logger
|
||||||
@@ -48,6 +48,8 @@ import utils
|
|||||||
import breakmodel
|
import breakmodel
|
||||||
import koboldai_settings
|
import koboldai_settings
|
||||||
|
|
||||||
|
HACK_currentmodel = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import tpu_mtj_backend
|
import tpu_mtj_backend
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
@@ -147,10 +149,35 @@ class Stoppers:
|
|||||||
return model.gen_state["regeneration_required"] or model.gen_state["halt"]
|
return model.gen_state["regeneration_required"] or model.gen_state["halt"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wi_scanner(
|
def dynamic_wi_scanner(
|
||||||
model: InferenceModel,
|
model: InferenceModel,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not utils.koboldai_vars.dynamicscan:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(model.gen_state["wi_scanner_excluded_keys"]) != input_ids.shape[0]:
|
||||||
|
model.gen_state["wi_scanner_excluded_keys"]
|
||||||
|
print(model.tokenizer.decode(model.gen_state["wi_scanner_excluded_keys"]))
|
||||||
|
print(model.tokenizer.decode(input_ids.shape[0]))
|
||||||
|
|
||||||
|
assert len(model.gen_state["wi_scanner_excluded_keys"]) == input_ids.shape[0]
|
||||||
|
|
||||||
|
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
|
||||||
|
for i, t in enumerate(tail):
|
||||||
|
decoded = utils.decodenewlines(model.tokenizer.decode(t))
|
||||||
|
_, _, _, found = utils.koboldai_vars.calc_ai_text(
|
||||||
|
submitted_text=decoded, send_context=False
|
||||||
|
)
|
||||||
|
found = list(
|
||||||
|
set(found) - set(model.gen_state["wi_scanner_excluded_keys"][i])
|
||||||
|
)
|
||||||
|
if found:
|
||||||
|
print("FOUNDWI", found)
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -342,7 +369,6 @@ def patch_transformers():
|
|||||||
TopKLogitsWarper,
|
TopKLogitsWarper,
|
||||||
TopPLogitsWarper,
|
TopPLogitsWarper,
|
||||||
TemperatureLogitsWarper,
|
TemperatureLogitsWarper,
|
||||||
RepetitionPenaltyLogitsProcessor,
|
|
||||||
)
|
)
|
||||||
from warpers import (
|
from warpers import (
|
||||||
AdvancedRepetitionPenaltyLogitsProcessor,
|
AdvancedRepetitionPenaltyLogitsProcessor,
|
||||||
@@ -370,6 +396,7 @@ def patch_transformers():
|
|||||||
|
|
||||||
cls.__call__ = new_call
|
cls.__call__ = new_call
|
||||||
|
|
||||||
|
# TODO: Make samplers generic
|
||||||
dynamic_processor_wrap(
|
dynamic_processor_wrap(
|
||||||
AdvancedRepetitionPenaltyLogitsProcessor,
|
AdvancedRepetitionPenaltyLogitsProcessor,
|
||||||
("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
|
("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"),
|
||||||
@@ -579,7 +606,10 @@ def patch_transformers():
|
|||||||
|
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
def visualize_probabilities(scores: torch.FloatTensor) -> None:
|
def visualize_probabilities(
|
||||||
|
model: InferenceModel,
|
||||||
|
scores: torch.FloatTensor,
|
||||||
|
) -> None:
|
||||||
assert scores.ndim == 2
|
assert scores.ndim == 2
|
||||||
|
|
||||||
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
|
if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs:
|
||||||
@@ -620,7 +650,9 @@ def patch_transformers():
|
|||||||
token_prob_info.append(
|
token_prob_info.append(
|
||||||
{
|
{
|
||||||
"tokenId": token_id,
|
"tokenId": token_id,
|
||||||
"decoded": utils.decodenewlines(tokenizer.decode(token_id)),
|
"decoded": utils.decodenewlines(
|
||||||
|
model.tokenizer.decode(token_id)
|
||||||
|
),
|
||||||
"score": float(score),
|
"score": float(score),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -680,7 +712,7 @@ def patch_transformers():
|
|||||||
sampler_order = [6] + sampler_order
|
sampler_order = [6] + sampler_order
|
||||||
for k in sampler_order:
|
for k in sampler_order:
|
||||||
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||||
visualize_probabilities(scores)
|
visualize_probabilities(HACK_currentmodel, scores)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def new_get_logits_warper(
|
def new_get_logits_warper(
|
||||||
@@ -714,45 +746,6 @@ def patch_transformers():
|
|||||||
)
|
)
|
||||||
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init
|
||||||
|
|
||||||
# Sets up dynamic world info scanner
|
|
||||||
class DynamicWorldInfoScanCriteria(StoppingCriteria):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tokenizer,
|
|
||||||
excluded_world_info: List[Set],
|
|
||||||
):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.excluded_world_info = excluded_world_info
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
scores: torch.FloatTensor,
|
|
||||||
**kwargs,
|
|
||||||
) -> bool:
|
|
||||||
|
|
||||||
if not utils.koboldai_vars.inference_config.do_dynamic_wi:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not utils.koboldai_vars.dynamicscan:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if len(self.excluded_world_info) != input_ids.shape[0]:
|
|
||||||
print(tokenizer.decode(self.excluded_world_info))
|
|
||||||
print(tokenizer.decode(input_ids.shape[0]))
|
|
||||||
assert len(self.excluded_world_info) == input_ids.shape[0]
|
|
||||||
|
|
||||||
tail = input_ids[..., -utils.koboldai_vars.generated_tkns :]
|
|
||||||
for i, t in enumerate(tail):
|
|
||||||
decoded = utils.decodenewlines(tokenizer.decode(t))
|
|
||||||
_, _, _, found = utils.koboldai_vars.calc_ai_text(
|
|
||||||
submitted_text=decoded, send_context=False
|
|
||||||
)
|
|
||||||
found = list(set(found) - set(self.excluded_world_info[i]))
|
|
||||||
if found:
|
|
||||||
print("FOUNDWI", found)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
class GenerationResult:
|
class GenerationResult:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -811,6 +804,9 @@ class InferenceModel:
|
|||||||
self._load(save_model=save_model)
|
self._load(save_model=save_model)
|
||||||
self._post_load()
|
self._post_load()
|
||||||
|
|
||||||
|
global HACK_currentmodel
|
||||||
|
HACK_currentmodel = self
|
||||||
|
|
||||||
def _post_load(self) -> None:
|
def _post_load(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -981,8 +977,8 @@ class InferenceModel:
|
|||||||
# stop temporarily to insert WI, we can assume that we are done
|
# stop temporarily to insert WI, we can assume that we are done
|
||||||
# generating. We shall break.
|
# generating. We shall break.
|
||||||
if (
|
if (
|
||||||
model.core_stopper.halt
|
self.gen_state["halt"]
|
||||||
or not model.core_stopper.regeneration_required
|
or not self.gen_state["regeneration_required"]
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -1141,21 +1137,17 @@ class InferenceModel:
|
|||||||
if utils.koboldai_vars.model == "ReadOnly":
|
if utils.koboldai_vars.model == "ReadOnly":
|
||||||
raise NotImplementedError("No loaded model")
|
raise NotImplementedError("No loaded model")
|
||||||
|
|
||||||
result: GenerationResult
|
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
|
|
||||||
with use_core_manipulations():
|
with use_core_manipulations():
|
||||||
self._raw_generate(
|
result = self._raw_generate(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
max_new=max_new,
|
max_new=max_new,
|
||||||
batch_count=batch_count,
|
batch_count=batch_count,
|
||||||
gen_settings=gen_settings,
|
gen_settings=gen_settings,
|
||||||
single_line=single_line,
|
single_line=single_line,
|
||||||
)
|
)
|
||||||
# if i_vars.use_colab_tpu or koboldai_vars.model in (
|
|
||||||
# "TPUMeshTransformerGPTJ",
|
|
||||||
# "TPUMeshTransformerGPTNeoX",
|
|
||||||
# ):
|
|
||||||
time_end = round(time.time() - time_start, 2)
|
time_end = round(time.time() - time_start, 2)
|
||||||
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
|
tokens_per_second = round(len(result.encoded[0]) / time_end, 2)
|
||||||
|
|
||||||
@@ -1250,7 +1242,7 @@ class HFMTJInferenceModel:
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
):
|
) -> GenerationResult:
|
||||||
soft_tokens = self.get_soft_tokens()
|
soft_tokens = self.get_soft_tokens()
|
||||||
|
|
||||||
genout = tpool.execute(
|
genout = tpool.execute(
|
||||||
@@ -1297,7 +1289,7 @@ class HFTorchInferenceModel(InferenceModel):
|
|||||||
self.post_token_hooks = [
|
self.post_token_hooks = [
|
||||||
Stoppers.core_stopper,
|
Stoppers.core_stopper,
|
||||||
PostTokenHooks.stream_tokens,
|
PostTokenHooks.stream_tokens,
|
||||||
Stoppers.wi_scanner,
|
Stoppers.dynamic_wi_scanner,
|
||||||
Stoppers.chat_mode_stopper,
|
Stoppers.chat_mode_stopper,
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -1338,7 +1330,7 @@ class HFTorchInferenceModel(InferenceModel):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
stopping_criteria = old_gsc(hf_self, *args, **kwargs)
|
stopping_criteria = old_gsc(hf_self, *args, **kwargs)
|
||||||
stopping_criteria.insert(0, PTHStopper)
|
stopping_criteria.insert(0, PTHStopper())
|
||||||
return stopping_criteria
|
return stopping_criteria
|
||||||
|
|
||||||
use_core_manipulations.get_stopping_criteria = _get_stopping_criteria
|
use_core_manipulations.get_stopping_criteria = _get_stopping_criteria
|
||||||
@@ -1350,7 +1342,7 @@ class HFTorchInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
):
|
) -> GenerationResult:
|
||||||
if not isinstance(prompt_tokens, torch.Tensor):
|
if not isinstance(prompt_tokens, torch.Tensor):
|
||||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||||
else:
|
else:
|
||||||
@@ -1379,7 +1371,13 @@ class HFTorchInferenceModel(InferenceModel):
|
|||||||
"torch_raw_generate: run generator {}s".format(time.time() - start_time)
|
"torch_raw_generate: run generator {}s".format(time.time() - start_time)
|
||||||
)
|
)
|
||||||
|
|
||||||
return genout
|
return GenerationResult(
|
||||||
|
self,
|
||||||
|
out_batches=genout,
|
||||||
|
prompt=prompt_tokens,
|
||||||
|
is_whole_generation=False,
|
||||||
|
output_includes_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_model(self, location: str, tf_kwargs: Dict):
|
def _get_model(self, location: str, tf_kwargs: Dict):
|
||||||
try:
|
try:
|
||||||
|
Reference in New Issue
Block a user