mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Reformat and clean up
This commit is contained in:
@@ -360,9 +360,8 @@ class InferenceModel:
|
||||
# amount (controlled by halt), or Dynamic WI has not told us to
|
||||
# stop temporarily to insert WI, we can assume that we are done
|
||||
# generating. We shall break.
|
||||
if (
|
||||
self.gen_state.get("halt")
|
||||
or not self.gen_state.get("regeneration_required")
|
||||
if self.gen_state.get("halt") or not self.gen_state.get(
|
||||
"regeneration_required"
|
||||
):
|
||||
break
|
||||
|
||||
|
@@ -40,7 +40,7 @@ class APIInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
|
@@ -36,7 +36,7 @@ class BasicAPIInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
|
@@ -53,6 +53,7 @@ except ModuleNotFoundError as e:
|
||||
# scores for each token.
|
||||
LOG_SAMPLER_NO_EFFECT = False
|
||||
|
||||
|
||||
class HFTorchInferenceModel(HFInferenceModel):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -249,9 +250,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
ret = []
|
||||
|
||||
for alt_phrase in [phrase, f" {phrase}"]:
|
||||
ret.append(
|
||||
model_self.tokenizer.encode(alt_phrase)
|
||||
)
|
||||
ret.append(model_self.tokenizer.encode(alt_phrase))
|
||||
|
||||
return ret
|
||||
|
||||
@@ -433,13 +432,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# sampler_order = utils.koboldai_vars.sampler_order[:]
|
||||
# if (
|
||||
# len(sampler_order) < 7
|
||||
# ): # Add repetition penalty at beginning if it's not present
|
||||
# sampler_order = [6] + sampler_order
|
||||
# for k in sampler_order:
|
||||
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||
scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||
visualize_probabilities(model_self, scores)
|
||||
return scores
|
||||
@@ -469,7 +461,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
if not isinstance(prompt_tokens, torch.Tensor):
|
||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||
|
@@ -42,7 +42,7 @@ class HordeInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||
|
||||
|
@@ -18,6 +18,7 @@ class OpenAIAPIError(Exception):
|
||||
|
||||
class OpenAIAPIInferenceModel(InferenceModel):
|
||||
"""InferenceModel for interfacing with OpenAI's generation API."""
|
||||
|
||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||
self.tokenizer = self._get_tokenizer("gpt2")
|
||||
|
||||
@@ -28,7 +29,7 @@ class OpenAIAPIInferenceModel(InferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
# Taken mainly from oairequest()
|
||||
|
||||
|
Reference in New Issue
Block a user