mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #473 from ebolam/united
Abort code moved to model classes
This commit is contained in:
20
aiserver.py
20
aiserver.py
@@ -2654,9 +2654,8 @@ def get_message(msg):
|
||||
if(koboldai_vars.mode == "play"):
|
||||
if(koboldai_vars.aibusy):
|
||||
if(msg.get('allowabort', False)):
|
||||
koboldai_vars.abort = True
|
||||
model.abort_generation()
|
||||
return
|
||||
koboldai_vars.abort = False
|
||||
koboldai_vars.lua_koboldbridge.feedback = None
|
||||
if(koboldai_vars.chatmode):
|
||||
if(type(msg['chatname']) is not str):
|
||||
@@ -2676,9 +2675,8 @@ def get_message(msg):
|
||||
elif(msg['cmd'] == 'retry'):
|
||||
if(koboldai_vars.aibusy):
|
||||
if(msg.get('allowabort', False)):
|
||||
koboldai_vars.abort = True
|
||||
model.abort_generation()
|
||||
return
|
||||
koboldai_vars.abort = False
|
||||
if(koboldai_vars.chatmode):
|
||||
if(type(msg['chatname']) is not str):
|
||||
raise ValueError("Chatname must be a string")
|
||||
@@ -3344,7 +3342,7 @@ def actionsubmit(
|
||||
# Clear the startup text from game screen
|
||||
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True, room="UI_1")
|
||||
calcsubmit("", gen_mode=gen_mode) # Run the first action through the generator
|
||||
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
|
||||
if(not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
|
||||
data = ""
|
||||
force_submit = True
|
||||
disable_recentrng = True
|
||||
@@ -3370,13 +3368,13 @@ def actionsubmit(
|
||||
refresh_story()
|
||||
if(len(koboldai_vars.actions) > 0):
|
||||
emit('from_server', {'cmd': 'texteffect', 'data': koboldai_vars.actions.get_last_key() + 1}, broadcast=True, room="UI_1")
|
||||
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None):
|
||||
if(not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None):
|
||||
data = ""
|
||||
force_submit = True
|
||||
disable_recentrng = True
|
||||
continue
|
||||
else:
|
||||
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and koboldai_vars.lua_koboldbridge.restart_sequence > 0):
|
||||
if(not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and koboldai_vars.lua_koboldbridge.restart_sequence > 0):
|
||||
genresult(genout[koboldai_vars.lua_koboldbridge.restart_sequence-1]["generated_text"], flash=False)
|
||||
refresh_story()
|
||||
data = ""
|
||||
@@ -3410,7 +3408,7 @@ def actionsubmit(
|
||||
if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating):
|
||||
# Off to the tokenizer!
|
||||
calcsubmit("", gen_mode=gen_mode)
|
||||
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
|
||||
if(not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
|
||||
data = ""
|
||||
force_submit = True
|
||||
disable_recentrng = True
|
||||
@@ -3431,13 +3429,13 @@ def actionsubmit(
|
||||
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
||||
if(len(genout) == 1):
|
||||
genresult(genout[0]["generated_text"])
|
||||
if(not no_generate and not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None):
|
||||
if(not no_generate and not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None):
|
||||
data = ""
|
||||
force_submit = True
|
||||
disable_recentrng = True
|
||||
continue
|
||||
else:
|
||||
if(not no_generate and not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and koboldai_vars.lua_koboldbridge.restart_sequence > 0):
|
||||
if(not no_generate and not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and koboldai_vars.lua_koboldbridge.restart_sequence > 0):
|
||||
genresult(genout[koboldai_vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||
data = ""
|
||||
force_submit = True
|
||||
@@ -6204,7 +6202,7 @@ def UI_2_submit(data):
|
||||
def UI_2_abort(data):
|
||||
if koboldai_vars.debug:
|
||||
print("got abort")
|
||||
koboldai_vars.abort = True
|
||||
model.abort_generation()
|
||||
|
||||
|
||||
#==================================================================#
|
||||
|
@@ -171,6 +171,7 @@ class InferenceModel:
|
||||
"""Root class for all models."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.abort = False
|
||||
self.gen_state = {}
|
||||
self.post_token_hooks = []
|
||||
self.stopper_hooks = []
|
||||
@@ -669,6 +670,9 @@ class InferenceModel:
|
||||
for hook in self.post_token_hooks:
|
||||
hook(self, input_ids)
|
||||
|
||||
def abort_generation(self, abort=True):
|
||||
self.abort=abort
|
||||
|
||||
def get_supported_gen_modes(self) -> List[GenerationMode]:
|
||||
"""Returns a list of compatible `GenerationMode`s for the current model.
|
||||
|
||||
|
@@ -29,19 +29,22 @@ class model_backend(InferenceModel):
|
||||
super().__init__()
|
||||
self.url = "https://horde.koboldai.net"
|
||||
self.key = "0000000000"
|
||||
self.models = self.get_cluster_models()
|
||||
self.models = []
|
||||
self.model_name = "Horde"
|
||||
self.model = []
|
||||
self.request_id = None
|
||||
|
||||
|
||||
# Do not allow API to be served over the API
|
||||
self.capabilties = ModelCapabilities(api_host=False)
|
||||
|
||||
def is_valid(self, model_name, model_path, menu_path):
|
||||
self.models = self.get_cluster_models()
|
||||
logger.debug("Horde Models: {}".format(self.models))
|
||||
return model_name == "CLUSTER" or model_name in [x['value'] for x in self.models]
|
||||
|
||||
def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
|
||||
self.models = self.get_cluster_models()
|
||||
if os.path.exists("settings/horde.model_backend.settings") and 'base_url' not in vars(self):
|
||||
with open("settings/horde.model_backend.settings", "r") as f:
|
||||
temp = json.load(f)
|
||||
@@ -222,18 +225,18 @@ class model_backend(InferenceModel):
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
request_id = req_status["id"]
|
||||
logger.debug("Horde Request ID: {}".format(request_id))
|
||||
self.request_id = req_status["id"]
|
||||
logger.debug("Horde Request ID: {}".format(self.request_id))
|
||||
|
||||
# We've sent the request and got the ID back, now we need to watch it to see when it finishes
|
||||
finished = False
|
||||
self.finished = False
|
||||
|
||||
cluster_agent_headers = {"Client-Agent": client_agent}
|
||||
|
||||
while not finished:
|
||||
while not self.finished:
|
||||
try:
|
||||
req = requests.get(
|
||||
f"{self.url}/api/v2/generate/text/status/{request_id}",
|
||||
f"{self.url}/api/v2/generate/text/status/{self.request_id}",
|
||||
headers=cluster_agent_headers,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
@@ -260,15 +263,16 @@ class model_backend(InferenceModel):
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
finished = req_status["done"]
|
||||
self.finished = req_status["done"]
|
||||
utils.koboldai_vars.horde_wait_time = req_status["wait_time"]
|
||||
utils.koboldai_vars.horde_queue_position = req_status["queue_position"]
|
||||
utils.koboldai_vars.horde_queue_size = req_status["waiting"]
|
||||
|
||||
if not finished:
|
||||
if not self.finished:
|
||||
logger.debug(req_status)
|
||||
time.sleep(1)
|
||||
|
||||
self.request_id = None
|
||||
logger.debug("Last Horde Status Message: {}".format(req_status))
|
||||
|
||||
if req_status["faulted"]:
|
||||
@@ -287,3 +291,33 @@ class model_backend(InferenceModel):
|
||||
is_whole_generation=True,
|
||||
single_line=single_line,
|
||||
)
|
||||
|
||||
def abort_generation(self, abort=True):
|
||||
logger.info("Attempting to stop horde gen")
|
||||
self.finished = True
|
||||
try:
|
||||
# Create request
|
||||
client_agent = "KoboldAI:2.0.0:koboldai.org"
|
||||
cluster_headers = {
|
||||
"apikey": self.key,
|
||||
"Client-Agent": client_agent,
|
||||
}
|
||||
req = requests.delete(
|
||||
f"{self.url}/v2/generate/text/status/{self.request_id}",
|
||||
headers=cluster_headers,
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
errmsg = f"Horde unavailable. Please try again later"
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
|
||||
if req.status_code == 503:
|
||||
errmsg = f"KoboldAI API Error: No available KoboldAI servers found in Horde to fulfil this request using the selected models or other properties."
|
||||
logger.error(errmsg)
|
||||
raise HordeException(errmsg)
|
||||
elif not req.ok:
|
||||
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||
logger.error(req.url)
|
||||
logger.error(errmsg)
|
||||
logger.error(req.text)
|
||||
raise HordeException(errmsg)
|
@@ -26,11 +26,11 @@ class Stoppers:
|
||||
f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})"
|
||||
)
|
||||
|
||||
if utils.koboldai_vars.abort or (
|
||||
if model.abort or (
|
||||
utils.koboldai_vars.inference_config.stop_at_genamt
|
||||
and utils.koboldai_vars.generated_tkns >= utils.koboldai_vars.genamt
|
||||
):
|
||||
utils.koboldai_vars.abort = False
|
||||
model.abort = False
|
||||
model.gen_state["regeneration_required"] = False
|
||||
model.gen_state["halt"] = False
|
||||
return True
|
||||
|
@@ -100,7 +100,7 @@
|
||||
</div>
|
||||
</div><br>
|
||||
<div class="statusbar_outer_horde var_sync_alt_system_aibusy var_sync_alt_model_horde_wait_time" id="status_bar_horde">
|
||||
<div class="statusbar_inner_horde" style="width:100%">
|
||||
<div class="statusbar_inner_horde" style="width:100%" onclick="socket.emit('abort','');">
|
||||
<div> </div>
|
||||
<div>Queue <span class="var_sync_model_horde_queue_position"></span> of <span class="var_sync_model_horde_queue_size"></span></div>
|
||||
<div><span class="var_sync_model_horde_wait_time"></span> sec left</div>
|
||||
|
Reference in New Issue
Block a user