From 465e22fa5ce5a062fe08b5cb4e7754286ed6c186 Mon Sep 17 00:00:00 2001 From: somebody Date: Sat, 25 Feb 2023 18:12:49 -0600 Subject: [PATCH] Model Fix bugs and introduce hack for visualization Hopefully I remove that attrocity before the PR --- model.py | 114 +++++++++++++++++++++++++++---------------------------- 1 file changed, 56 insertions(+), 58 deletions(-) diff --git a/model.py b/model.py index 4f560695..2cf50c0c 100644 --- a/model.py +++ b/model.py @@ -18,7 +18,7 @@ import json import os import time import traceback -from typing import Dict, Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Set, Union import zipfile from tqdm.auto import tqdm from logger import logger @@ -48,6 +48,8 @@ import utils import breakmodel import koboldai_settings +HACK_currentmodel = None + try: import tpu_mtj_backend except ModuleNotFoundError as e: @@ -147,10 +149,35 @@ class Stoppers: return model.gen_state["regeneration_required"] or model.gen_state["halt"] @staticmethod - def wi_scanner( + def dynamic_wi_scanner( model: InferenceModel, input_ids: torch.LongTensor, ) -> bool: + if not utils.koboldai_vars.inference_config.do_dynamic_wi: + return False + + if not utils.koboldai_vars.dynamicscan: + return False + + if len(model.gen_state["wi_scanner_excluded_keys"]) != input_ids.shape[0]: + model.gen_state["wi_scanner_excluded_keys"] + print(model.tokenizer.decode(model.gen_state["wi_scanner_excluded_keys"])) + print(model.tokenizer.decode(input_ids.shape[0])) + + assert len(model.gen_state["wi_scanner_excluded_keys"]) == input_ids.shape[0] + + tail = input_ids[..., -utils.koboldai_vars.generated_tkns :] + for i, t in enumerate(tail): + decoded = utils.decodenewlines(model.tokenizer.decode(t)) + _, _, _, found = utils.koboldai_vars.calc_ai_text( + submitted_text=decoded, send_context=False + ) + found = list( + set(found) - set(model.gen_state["wi_scanner_excluded_keys"][i]) + ) + if found: + print("FOUNDWI", found) + return True return False @staticmethod @@ -342,7 +369,6 @@ def patch_transformers(): TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper, - RepetitionPenaltyLogitsProcessor, ) from warpers import ( AdvancedRepetitionPenaltyLogitsProcessor, @@ -370,6 +396,7 @@ def patch_transformers(): cls.__call__ = new_call + # TODO: Make samplers generic dynamic_processor_wrap( AdvancedRepetitionPenaltyLogitsProcessor, ("penalty", "penalty_slope", "penalty_range", "use_alt_rep_pen"), @@ -579,7 +606,10 @@ def patch_transformers(): from torch.nn import functional as F - def visualize_probabilities(scores: torch.FloatTensor) -> None: + def visualize_probabilities( + model: InferenceModel, + scores: torch.FloatTensor, + ) -> None: assert scores.ndim == 2 if utils.koboldai_vars.numseqs > 1 or not utils.koboldai_vars.show_probs: @@ -620,7 +650,9 @@ def patch_transformers(): token_prob_info.append( { "tokenId": token_id, - "decoded": utils.decodenewlines(tokenizer.decode(token_id)), + "decoded": utils.decodenewlines( + model.tokenizer.decode(token_id) + ), "score": float(score), } ) @@ -680,7 +712,7 @@ def patch_transformers(): sampler_order = [6] + sampler_order for k in sampler_order: scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) - visualize_probabilities(scores) + visualize_probabilities(HACK_currentmodel, scores) return scores def new_get_logits_warper( @@ -714,45 +746,6 @@ def patch_transformers(): ) transformers.generation.logits_process.NoBadWordsLogitsProcessor.__init__ = new_init - # Sets up dynamic world info scanner - class DynamicWorldInfoScanCriteria(StoppingCriteria): - def __init__( - self, - tokenizer, - excluded_world_info: List[Set], - ): - self.tokenizer = tokenizer - self.excluded_world_info = excluded_world_info - - def __call__( - self, - input_ids: torch.LongTensor, - scores: torch.FloatTensor, - **kwargs, - ) -> bool: - - if not utils.koboldai_vars.inference_config.do_dynamic_wi: - return False - - if not utils.koboldai_vars.dynamicscan: - return False - - if len(self.excluded_world_info) != input_ids.shape[0]: - print(tokenizer.decode(self.excluded_world_info)) - print(tokenizer.decode(input_ids.shape[0])) - assert len(self.excluded_world_info) == input_ids.shape[0] - - tail = input_ids[..., -utils.koboldai_vars.generated_tkns :] - for i, t in enumerate(tail): - decoded = utils.decodenewlines(tokenizer.decode(t)) - _, _, _, found = utils.koboldai_vars.calc_ai_text( - submitted_text=decoded, send_context=False - ) - found = list(set(found) - set(self.excluded_world_info[i])) - if found: - print("FOUNDWI", found) - return True - return False class GenerationResult: def __init__( @@ -811,6 +804,9 @@ class InferenceModel: self._load(save_model=save_model) self._post_load() + global HACK_currentmodel + HACK_currentmodel = self + def _post_load(self) -> None: pass @@ -981,8 +977,8 @@ class InferenceModel: # stop temporarily to insert WI, we can assume that we are done # generating. We shall break. if ( - model.core_stopper.halt - or not model.core_stopper.regeneration_required + self.gen_state["halt"] + or not self.gen_state["regeneration_required"] ): break @@ -1141,21 +1137,17 @@ class InferenceModel: if utils.koboldai_vars.model == "ReadOnly": raise NotImplementedError("No loaded model") - result: GenerationResult time_start = time.time() with use_core_manipulations(): - self._raw_generate( + result = self._raw_generate( prompt_tokens=prompt_tokens, max_new=max_new, batch_count=batch_count, gen_settings=gen_settings, single_line=single_line, ) - # if i_vars.use_colab_tpu or koboldai_vars.model in ( - # "TPUMeshTransformerGPTJ", - # "TPUMeshTransformerGPTNeoX", - # ): + time_end = round(time.time() - time_start, 2) tokens_per_second = round(len(result.encoded[0]) / time_end, 2) @@ -1250,7 +1242,7 @@ class HFMTJInferenceModel: gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, - ): + ) -> GenerationResult: soft_tokens = self.get_soft_tokens() genout = tpool.execute( @@ -1297,7 +1289,7 @@ class HFTorchInferenceModel(InferenceModel): self.post_token_hooks = [ Stoppers.core_stopper, PostTokenHooks.stream_tokens, - Stoppers.wi_scanner, + Stoppers.dynamic_wi_scanner, Stoppers.chat_mode_stopper, ] @@ -1338,7 +1330,7 @@ class HFTorchInferenceModel(InferenceModel): **kwargs, ): stopping_criteria = old_gsc(hf_self, *args, **kwargs) - stopping_criteria.insert(0, PTHStopper) + stopping_criteria.insert(0, PTHStopper()) return stopping_criteria use_core_manipulations.get_stopping_criteria = _get_stopping_criteria @@ -1350,7 +1342,7 @@ class HFTorchInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, - ): + ) -> GenerationResult: if not isinstance(prompt_tokens, torch.Tensor): gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] else: @@ -1379,7 +1371,13 @@ class HFTorchInferenceModel(InferenceModel): "torch_raw_generate: run generator {}s".format(time.time() - start_time) ) - return genout + return GenerationResult( + self, + out_batches=genout, + prompt=prompt_tokens, + is_whole_generation=False, + output_includes_prompt=True, + ) def _get_model(self, location: str, tf_kwargs: Dict): try: