mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Reformat and clean up
This commit is contained in:
@@ -360,9 +360,8 @@ class InferenceModel:
|
|||||||
# amount (controlled by halt), or Dynamic WI has not told us to
|
# amount (controlled by halt), or Dynamic WI has not told us to
|
||||||
# stop temporarily to insert WI, we can assume that we are done
|
# stop temporarily to insert WI, we can assume that we are done
|
||||||
# generating. We shall break.
|
# generating. We shall break.
|
||||||
if (
|
if self.gen_state.get("halt") or not self.gen_state.get(
|
||||||
self.gen_state.get("halt")
|
"regeneration_required"
|
||||||
or not self.gen_state.get("regeneration_required")
|
|
||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@@ -40,7 +40,7 @@ class APIInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
@@ -36,7 +36,7 @@ class BasicAPIInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
@@ -53,6 +53,7 @@ except ModuleNotFoundError as e:
|
|||||||
# scores for each token.
|
# scores for each token.
|
||||||
LOG_SAMPLER_NO_EFFECT = False
|
LOG_SAMPLER_NO_EFFECT = False
|
||||||
|
|
||||||
|
|
||||||
class HFTorchInferenceModel(HFInferenceModel):
|
class HFTorchInferenceModel(HFInferenceModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -249,9 +250,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
ret = []
|
ret = []
|
||||||
|
|
||||||
for alt_phrase in [phrase, f" {phrase}"]:
|
for alt_phrase in [phrase, f" {phrase}"]:
|
||||||
ret.append(
|
ret.append(model_self.tokenizer.encode(alt_phrase))
|
||||||
model_self.tokenizer.encode(alt_phrase)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@@ -433,13 +432,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# sampler_order = utils.koboldai_vars.sampler_order[:]
|
|
||||||
# if (
|
|
||||||
# len(sampler_order) < 7
|
|
||||||
# ): # Add repetition penalty at beginning if it's not present
|
|
||||||
# sampler_order = [6] + sampler_order
|
|
||||||
# for k in sampler_order:
|
|
||||||
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
|
||||||
scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
|
scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||||
visualize_probabilities(model_self, scores)
|
visualize_probabilities(model_self, scores)
|
||||||
return scores
|
return scores
|
||||||
@@ -469,7 +461,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
if not isinstance(prompt_tokens, torch.Tensor):
|
if not isinstance(prompt_tokens, torch.Tensor):
|
||||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||||
|
@@ -42,7 +42,7 @@ class HordeInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
decoded_prompt = utils.decodenewlines(self.tokenizer.decode(prompt_tokens))
|
||||||
|
|
||||||
|
@@ -18,6 +18,7 @@ class OpenAIAPIError(Exception):
|
|||||||
|
|
||||||
class OpenAIAPIInferenceModel(InferenceModel):
|
class OpenAIAPIInferenceModel(InferenceModel):
|
||||||
"""InferenceModel for interfacing with OpenAI's generation API."""
|
"""InferenceModel for interfacing with OpenAI's generation API."""
|
||||||
|
|
||||||
def _load(self, save_model: bool, initial_load: bool) -> None:
|
def _load(self, save_model: bool, initial_load: bool) -> None:
|
||||||
self.tokenizer = self._get_tokenizer("gpt2")
|
self.tokenizer = self._get_tokenizer("gpt2")
|
||||||
|
|
||||||
@@ -28,7 +29,7 @@ class OpenAIAPIInferenceModel(InferenceModel):
|
|||||||
gen_settings: GenerationSettings,
|
gen_settings: GenerationSettings,
|
||||||
single_line: bool = False,
|
single_line: bool = False,
|
||||||
batch_count: int = 1,
|
batch_count: int = 1,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> GenerationResult:
|
) -> GenerationResult:
|
||||||
# Taken mainly from oairequest()
|
# Taken mainly from oairequest()
|
||||||
|
|
||||||
|
@@ -131,7 +131,7 @@ class Stoppers:
|
|||||||
# a newline.
|
# a newline.
|
||||||
if "newline_in_sequence" not in model.gen_state:
|
if "newline_in_sequence" not in model.gen_state:
|
||||||
model.gen_state["newline_in_sequence"] = [False] * len(input_ids)
|
model.gen_state["newline_in_sequence"] = [False] * len(input_ids)
|
||||||
|
|
||||||
for sequence_idx, batch_sequence in enumerate(input_ids):
|
for sequence_idx, batch_sequence in enumerate(input_ids):
|
||||||
if model.tokenizer.decode(batch_sequence[-1]) == "\n":
|
if model.tokenizer.decode(batch_sequence[-1]) == "\n":
|
||||||
model.gen_state["newline_in_sequence"][sequence_idx] = True
|
model.gen_state["newline_in_sequence"][sequence_idx] = True
|
||||||
|
Reference in New Issue
Block a user