mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'model-structure-and-maybe-rwkv' of https://github.com/one-some/KoboldAI into model-structure-and-maybe-rwkv
This commit is contained in:
@@ -805,7 +805,7 @@ tags = [
|
||||
api_version = None # This gets set automatically so don't change this value
|
||||
|
||||
api_v1 = KoboldAPISpec(
|
||||
version="1.2.1",
|
||||
version="1.2.2",
|
||||
prefixes=["/api/v1", "/api/latest"],
|
||||
tags=tags,
|
||||
)
|
||||
@@ -1169,6 +1169,8 @@ def processsettings(js):
|
||||
koboldai_vars.nogenmod = js["nogenmod"]
|
||||
if("fulldeterminism" in js):
|
||||
koboldai_vars.full_determinism = js["fulldeterminism"]
|
||||
if("stop_sequence" in js):
|
||||
koboldai_vars.stop_sequence = js["stop_sequence"]
|
||||
if("autosave" in js):
|
||||
koboldai_vars.autosave = js["autosave"]
|
||||
if("newlinemode" in js):
|
||||
@@ -8281,6 +8283,7 @@ class GenerationInputSchema(SamplerSettingsSchema):
|
||||
sampler_order: Optional[List[int]] = fields.List(fields.Integer(), validate=[validate.Length(min=6), permutation_validator], metadata={"description": "Sampler order to be used. If N is the length of this array, then N must be greater than or equal to 6 and the array must be a permutation of the first N non-negative integers."})
|
||||
sampler_seed: Optional[int] = fields.Integer(validate=validate.Range(min=0, max=2**64 - 1), metadata={"description": "RNG seed to use for sampling. If not specified, the global RNG will be used."})
|
||||
sampler_full_determinism: Optional[bool] = fields.Boolean(metadata={"description": "If enabled, the generated text will always be the same as long as you use the same RNG seed, input and settings. If disabled, only the *sequence* of generated texts that you get when repeatedly generating text will be the same given the same RNG seed, input and settings."})
|
||||
stop_sequence: Optional[List[str]] = fields.List(fields.String(),metadata={"description": "An array of string sequences where the API will stop generating further tokens. The returned text WILL contain the stop sequence."}, validate=[validate.Length(max=10)])
|
||||
|
||||
class GenerationResultSchema(KoboldSchema):
|
||||
text: str = fields.String(required=True, metadata={"description": "Generated output as plain text."})
|
||||
@@ -8422,6 +8425,7 @@ def _generate_text(body: GenerationInputSchema):
|
||||
"quiet": ("koboldai_vars", "quiet", None),
|
||||
"sampler_order": ("koboldai_vars", "sampler_order", None),
|
||||
"sampler_full_determinism": ("koboldai_vars", "full_determinism", None),
|
||||
"stop_sequence": ("koboldai_vars", "stop_sequence", None),
|
||||
}
|
||||
saved_settings = {}
|
||||
set_aibusy(1)
|
||||
|
@@ -874,6 +874,7 @@ class story_settings(settings):
|
||||
self.chatmode = False
|
||||
self.chatname = "You"
|
||||
self.botname = "Bot"
|
||||
self.stop_sequence = [] #use for configuring stop sequences
|
||||
self.adventure = False
|
||||
self.actionmode = 0
|
||||
self.storymode = 0
|
||||
|
@@ -71,6 +71,7 @@ class HFTorchInferenceModel(HFInferenceModel):
|
||||
Stoppers.dynamic_wi_scanner,
|
||||
Stoppers.singleline_stopper,
|
||||
Stoppers.chat_mode_stopper,
|
||||
Stoppers.stop_sequence_stopper,
|
||||
]
|
||||
|
||||
self.capabilties = ModelCapabilities(
|
||||
|
@@ -72,6 +72,7 @@ class RWKVInferenceModel(InferenceModel):
|
||||
Stoppers.dynamic_wi_scanner,
|
||||
Stoppers.singleline_stopper,
|
||||
Stoppers.chat_mode_stopper,
|
||||
Stoppers.stop_sequence_stopper,
|
||||
]
|
||||
|
||||
self.capabilties = ModelCapabilities(
|
||||
|
@@ -116,6 +116,33 @@ class Stoppers:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def stop_sequence_stopper(
|
||||
model: InferenceModel,
|
||||
input_ids: torch.LongTensor,
|
||||
) -> bool:
|
||||
|
||||
data = [model.tokenizer.decode(x) for x in input_ids]
|
||||
# null_character = model.tokenizer.encode(chr(0))[0]
|
||||
if "completed" not in model.gen_state:
|
||||
model.gen_state["completed"] = [False] * len(input_ids)
|
||||
|
||||
#one issue is that the stop sequence may not actual align with the end of token
|
||||
#if its a subsection of a longer token
|
||||
for stopper in utils.koboldai_vars.stop_sequence:
|
||||
for i in range(len(input_ids)):
|
||||
if (
|
||||
data[i][-1 * (len(stopper)) :]
|
||||
== stopper
|
||||
):
|
||||
model.gen_state["completed"][i] = True
|
||||
|
||||
if all(model.gen_state["completed"]):
|
||||
utils.koboldai_vars.generated_tkns = utils.koboldai_vars.genamt
|
||||
del model.gen_state["completed"]
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def singleline_stopper(
|
||||
model: InferenceModel,
|
||||
|
Reference in New Issue
Block a user