mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model: Reformat and clean up
This commit is contained in:
@@ -53,6 +53,7 @@ except ModuleNotFoundError as e:
|
||||
# scores for each token.
|
||||
LOG_SAMPLER_NO_EFFECT = False
|
||||
|
||||
|
||||
class HFTorchInferenceModel(HFInferenceModel):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -249,9 +250,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
ret = []
|
||||
|
||||
for alt_phrase in [phrase, f" {phrase}"]:
|
||||
ret.append(
|
||||
model_self.tokenizer.encode(alt_phrase)
|
||||
)
|
||||
ret.append(model_self.tokenizer.encode(alt_phrase))
|
||||
|
||||
return ret
|
||||
|
||||
@@ -433,13 +432,6 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# sampler_order = utils.koboldai_vars.sampler_order[:]
|
||||
# if (
|
||||
# len(sampler_order) < 7
|
||||
# ): # Add repetition penalty at beginning if it's not present
|
||||
# sampler_order = [6] + sampler_order
|
||||
# for k in sampler_order:
|
||||
# scores = self.__warper_list[k](input_ids, scores, *args, **kwargs)
|
||||
scores = model_self._apply_warpers(scores=scores, input_ids=input_ids)
|
||||
visualize_probabilities(model_self, scores)
|
||||
return scores
|
||||
@@ -469,7 +461,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
gen_settings: GenerationSettings,
|
||||
single_line: bool = False,
|
||||
batch_count: int = 1,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> GenerationResult:
|
||||
if not isinstance(prompt_tokens, torch.Tensor):
|
||||
gen_in = torch.tensor(prompt_tokens, dtype=torch.long)[None]
|
||||
|
Reference in New Issue
Block a user