Model: Reformat and clean up

This commit is contained in:
somebody
2023-03-09 19:12:44 -06:00
parent fb0b2f0467
commit 52095054a3
7 changed files with 11 additions and 19 deletions

View File

@@ -360,9 +360,8 @@ class InferenceModel:
# amount (controlled by halt), or Dynamic WI has not told us to # amount (controlled by halt), or Dynamic WI has not told us to
# stop temporarily to insert WI, we can assume that we are done # stop temporarily to insert WI, we can assume that we are done
# generating. We shall break. # generating. We shall break.
if ( if self.gen_state.get("halt") or not self.gen_state.get(
self.gen_state.get("halt") "regeneration_required"
or not self.gen_state.get("regeneration_required")
): ):
break break

View File

@@ -40,7 +40,7 @@ class APIInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs **kwargs,
): ):
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))

View File

@@ -36,7 +36,7 @@ class BasicAPIInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs **kwargs,
): ):
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))

View File

@@ -53,6 +53,7 @@ except ModuleNotFoundError as e:
# scores for each token. # scores for each token.
LOG_SAMPLER_NO_EFFECT = False LOG_SAMPLER_NO_EFFECT = False
class HFTorchInferenceModel(HFInferenceModel): class HFTorchInferenceModel(HFInferenceModel):
def __init__( def __init__(
self, self,
@@ -249,9 +250,7 @@ class HFTorchInferenceModel(HFInferenceModel):
ret = [] ret = []
for alt_phrase in [phrase, f" {phrase}"]: for alt_phrase in [phrase, f" {phrase}"]:
ret.append( ret.append(model_self.tokenizer.encode(alt_phrase))
model_self.tokenizer.encode(alt_phrase)
)
return ret return ret
@@ -433,13 +432,6 @@ class HFTorchInferenceModel(HFInferenceModel):
*args, *args,
**kwargs, **kwargs,
): ):
# sampler_order = utils.koboldai_vars.sampler_order[:]
# if (
# len(sampler_order) < 7
# ): # Add repetition penalty at beginning if it's not present
# sampler_order = [6] + sampler_order
# for k in sampler_order:
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
scores = model_self._apply_warpers(scores=scores, input_ids=input_ids) scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
visualize_probabilities(model_self, scores) visualize_probabilities(model_self, scores)
return scores return scores
@@ -469,7 +461,7 @@ class HFTorchInferenceModel(HFInferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs **kwargs,
) -> GenerationResult: ) -> GenerationResult:
if not isinstance(prompt_tokens, torch.Tensor): if not isinstance(prompt_tokens, torch.Tensor):
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None] gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]

View File

@@ -42,7 +42,7 @@ class HordeInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs **kwargs,
) -> GenerationResult: ) -> GenerationResult:
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens)) decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))

View File

@@ -18,6 +18,7 @@ class OpenAIAPIError(Exception):
class OpenAIAPIInferenceModel(InferenceModel): class OpenAIAPIInferenceModel(InferenceModel):
"""InferenceModel for interfacing with OpenAI's generation API.""" """InferenceModel for interfacing with OpenAI's generation API."""
def _load(self, save_model: bool, initial_load: bool) -> None: def _load(self, save_model: bool, initial_load: bool) -> None:
self.tokenizer = self._get_tokenizer("gpt2") self.tokenizer = self._get_tokenizer("gpt2")
@@ -28,7 +29,7 @@ class OpenAIAPIInferenceModel(InferenceModel):
gen_settings: GenerationSettings, gen_settings: GenerationSettings,
single_line: bool = False, single_line: bool = False,
batch_count: int = 1, batch_count: int = 1,
**kwargs **kwargs,
) -> GenerationResult: ) -> GenerationResult:
# Taken mainly from oairequest() # Taken mainly from oairequest()

View File

@@ -131,7 +131,7 @@ class Stoppers:
# a newline. # a newline.
if "newline_in_sequence" not in model.gen_state: if "newline_in_sequence" not in model.gen_state:
model.gen_state["newline_in_sequence"] = [False] * len(input_ids) model.gen_state["newline_in_sequence"] = [False] * len(input_ids)
for sequence_idx, batch_sequence in enumerate(input_ids): for sequence_idx, batch_sequence in enumerate(input_ids):
if model.tokenizer.decode(batch_sequence[-1]) == "\n": if model.tokenizer.decode(batch_sequence[-1]) == "\n":
model.gen_state["newline_in_sequence"][sequence_idx] = True model.gen_state["newline_in_sequence"][sequence_idx] = True