From 52095054a388d9039cd6d3372d2e65e1ce66e3d4 Mon Sep 17 00:00:00 2001 From: somebody Date: Thu, 9 Mar 2023 19:12:44 -0600 Subject: [PATCH] Model: Reformat and clean up --- modeling/inference_model.py | 5 ++--- modeling/inference_models/api.py | 2 +- modeling/inference_models/basic_api.py | 2 +- modeling/inference_models/hf_torch.py | 14 +++----------- modeling/inference_models/horde.py | 2 +- modeling/inference_models/openai.py | 3 ++- modeling/stoppers.py | 2 +- 7 files changed, 11 insertions(+), 19 deletions(-) diff --git a/modeling/inference_model.py b/modeling/inference_model.py index 9663f929..4a44ba56 100644 --- a/modeling/inference_model.py +++ b/modeling/inference_model.py @@ -360,9 +360,8 @@ class InferenceModel: # amount (controlled by halt), or Dynamic WI has not told us to # stop temporarily to insert WI, we can assume that we are done # generating. We shall break. - if ( - self.gen_state.get("halt") - or not self.gen_state.get("regeneration_required") + if self.gen_state.get("halt") or not self.gen_state.get( + "regeneration_required" ): break diff --git a/modeling/inference_models/api.py b/modeling/inference_models/api.py index 83fcd7ab..7f1f4ea8 100644 --- a/modeling/inference_models/api.py +++ b/modeling/inference_models/api.py @@ -40,7 +40,7 @@ class APIInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, - **kwargs + **kwargs, ): decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) diff --git a/modeling/inference_models/basic_api.py b/modeling/inference_models/basic_api.py index 9f1a147f..9e6a6713 100644 --- a/modeling/inference_models/basic_api.py +++ b/modeling/inference_models/basic_api.py @@ -36,7 +36,7 @@ class BasicAPIInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, - **kwargs + **kwargs, ): decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) diff --git a/modeling/inference_models/hf_torch.py b/modeling/inference_models/hf_torch.py index e101c6da..1509797d 100644 --- a/modeling/inference_models/hf_torch.py +++ b/modeling/inference_models/hf_torch.py @@ -53,6 +53,7 @@ except ModuleNotFoundError as e: # scores for each token. LOG_SAMPLER_NO_EFFECT = False + class HFTorchInferenceModel(HFInferenceModel): def __init__( self, @@ -249,9 +250,7 @@ class HFTorchInferenceModel(HFInferenceModel): ret = [] for alt_phrase in [phrase, f" {phrase}"]: - ret.append( - model_self.tokenizer.encode(alt_phrase) - ) + ret.append(model_self.tokenizer.encode(alt_phrase)) return ret @@ -433,13 +432,6 @@ class HFTorchInferenceModel(HFInferenceModel): *args, **kwargs, ): - # sampler_order = utils.koboldai_vars.sampler_order[:] - # if ( - # len(sampler_order) < 7 - # ): # Add repetition penalty at beginning if it's not present - # sampler_order = [6] + sampler_order - # for k in sampler_order: - # scores = self.__warper_list[k](input_ids, scores, *args, **kwargs) scores = model_self._apply_warpers(scores=scores, input_ids=input_ids) visualize_probabilities(model_self, scores) return scores @@ -469,7 +461,7 @@ class HFTorchInferenceModel(HFInferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, - **kwargs + **kwargs, ) -> GenerationResult: if not isinstance(prompt_tokens, torch.Tensor): gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] diff --git a/modeling/inference_models/horde.py b/modeling/inference_models/horde.py index 90f7a474..9bdc62b2 100644 --- a/modeling/inference_models/horde.py +++ b/modeling/inference_models/horde.py @@ -42,7 +42,7 @@ class HordeInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, - **kwargs + **kwargs, ) -> GenerationResult: decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) diff --git a/modeling/inference_models/openai.py b/modeling/inference_models/openai.py index 4a28d5f4..c6f07e0e 100644 --- a/modeling/inference_models/openai.py +++ b/modeling/inference_models/openai.py @@ -18,6 +18,7 @@ class OpenAIAPIError(Exception): class OpenAIAPIInferenceModel(InferenceModel): """InferenceModel for interfacing with OpenAI's generation API.""" + def _load(self, save_model: bool, initial_load: bool) -> None: self.tokenizer = self._get_tokenizer("gpt2") @@ -28,7 +29,7 @@ class OpenAIAPIInferenceModel(InferenceModel): gen_settings: GenerationSettings, single_line: bool = False, batch_count: int = 1, - **kwargs + **kwargs, ) -> GenerationResult: # Taken mainly from oairequest() diff --git a/modeling/stoppers.py b/modeling/stoppers.py index 3ec7a156..f4c4ff20 100644 --- a/modeling/stoppers.py +++ b/modeling/stoppers.py @@ -131,7 +131,7 @@ class Stoppers: # a newline. if "newline_in_sequence" not in model.gen_state: model.gen_state["newline_in_sequence"] = [False] * len(input_ids) - + for sequence_idx, batch_sequence in enumerate(input_ids): if model.tokenizer.decode(batch_sequence[-1]) == "\n": model.gen_state["newline_in_sequence"][sequence_idx] = True