mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #473 from ebolam/united
Abort code moved to model classes
This commit is contained in:
20
aiserver.py
20
aiserver.py
@@ -2654,9 +2654,8 @@ def get_message(msg):
|
|||||||
if(koboldai_vars.mode == "play"):
|
if(koboldai_vars.mode == "play"):
|
||||||
if(koboldai_vars.aibusy):
|
if(koboldai_vars.aibusy):
|
||||||
if(msg.get('allowabort', False)):
|
if(msg.get('allowabort', False)):
|
||||||
koboldai_vars.abort = True
|
model.abort_generation()
|
||||||
return
|
return
|
||||||
koboldai_vars.abort = False
|
|
||||||
koboldai_vars.lua_koboldbridge.feedback = None
|
koboldai_vars.lua_koboldbridge.feedback = None
|
||||||
if(koboldai_vars.chatmode):
|
if(koboldai_vars.chatmode):
|
||||||
if(type(msg['chatname']) is not str):
|
if(type(msg['chatname']) is not str):
|
||||||
@@ -2676,9 +2675,8 @@ def get_message(msg):
|
|||||||
elif(msg['cmd'] == 'retry'):
|
elif(msg['cmd'] == 'retry'):
|
||||||
if(koboldai_vars.aibusy):
|
if(koboldai_vars.aibusy):
|
||||||
if(msg.get('allowabort', False)):
|
if(msg.get('allowabort', False)):
|
||||||
koboldai_vars.abort = True
|
model.abort_generation()
|
||||||
return
|
return
|
||||||
koboldai_vars.abort = False
|
|
||||||
if(koboldai_vars.chatmode):
|
if(koboldai_vars.chatmode):
|
||||||
if(type(msg['chatname']) is not str):
|
if(type(msg['chatname']) is not str):
|
||||||
raise ValueError("Chatname must be a string")
|
raise ValueError("Chatname must be a string")
|
||||||
@@ -3344,7 +3342,7 @@ def actionsubmit(
|
|||||||
# Clear the startup text from game screen
|
# Clear the startup text from game screen
|
||||||
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'updatescreen', 'gamestarted': False, 'data': 'Please wait, generating story...'}, broadcast=True, room="UI_1")
|
||||||
calcsubmit("", gen_mode=gen_mode) # Run the first action through the generator
|
calcsubmit("", gen_mode=gen_mode) # Run the first action through the generator
|
||||||
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
|
if(not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
disable_recentrng = True
|
disable_recentrng = True
|
||||||
@@ -3370,13 +3368,13 @@ def actionsubmit(
|
|||||||
refresh_story()
|
refresh_story()
|
||||||
if(len(koboldai_vars.actions) > 0):
|
if(len(koboldai_vars.actions) > 0):
|
||||||
emit('from_server', {'cmd': 'texteffect', 'data': koboldai_vars.actions.get_last_key() + 1}, broadcast=True, room="UI_1")
|
emit('from_server', {'cmd': 'texteffect', 'data': koboldai_vars.actions.get_last_key() + 1}, broadcast=True, room="UI_1")
|
||||||
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None):
|
if(not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None):
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
disable_recentrng = True
|
disable_recentrng = True
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and koboldai_vars.lua_koboldbridge.restart_sequence > 0):
|
if(not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and koboldai_vars.lua_koboldbridge.restart_sequence > 0):
|
||||||
genresult(genout[koboldai_vars.lua_koboldbridge.restart_sequence-1]["generated_text"], flash=False)
|
genresult(genout[koboldai_vars.lua_koboldbridge.restart_sequence-1]["generated_text"], flash=False)
|
||||||
refresh_story()
|
refresh_story()
|
||||||
data = ""
|
data = ""
|
||||||
@@ -3410,7 +3408,7 @@ def actionsubmit(
|
|||||||
if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating):
|
if(not no_generate and not koboldai_vars.noai and koboldai_vars.lua_koboldbridge.generating):
|
||||||
# Off to the tokenizer!
|
# Off to the tokenizer!
|
||||||
calcsubmit("", gen_mode=gen_mode)
|
calcsubmit("", gen_mode=gen_mode)
|
||||||
if(not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
|
if(not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and len(koboldai_vars.genseqs) == 0):
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
disable_recentrng = True
|
disable_recentrng = True
|
||||||
@@ -3431,13 +3429,13 @@ def actionsubmit(
|
|||||||
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
genout = [{"generated_text": x['text']} for x in koboldai_vars.actions.get_current_options()]
|
||||||
if(len(genout) == 1):
|
if(len(genout) == 1):
|
||||||
genresult(genout[0]["generated_text"])
|
genresult(genout[0]["generated_text"])
|
||||||
if(not no_generate and not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None):
|
if(not no_generate and not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None):
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
disable_recentrng = True
|
disable_recentrng = True
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
if(not no_generate and not koboldai_vars.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and koboldai_vars.lua_koboldbridge.restart_sequence > 0):
|
if(not no_generate and not model.abort and koboldai_vars.lua_koboldbridge.restart_sequence is not None and koboldai_vars.lua_koboldbridge.restart_sequence > 0):
|
||||||
genresult(genout[koboldai_vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
genresult(genout[koboldai_vars.lua_koboldbridge.restart_sequence-1]["generated_text"])
|
||||||
data = ""
|
data = ""
|
||||||
force_submit = True
|
force_submit = True
|
||||||
@@ -6204,7 +6202,7 @@ def UI_2_submit(data):
|
|||||||
def UI_2_abort(data):
|
def UI_2_abort(data):
|
||||||
if koboldai_vars.debug:
|
if koboldai_vars.debug:
|
||||||
print("got abort")
|
print("got abort")
|
||||||
koboldai_vars.abort = True
|
model.abort_generation()
|
||||||
|
|
||||||
|
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
|
@@ -171,6 +171,7 @@ class InferenceModel:
|
|||||||
"""Root class for all models."""
|
"""Root class for all models."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
self.abort = False
|
||||||
self.gen_state = {}
|
self.gen_state = {}
|
||||||
self.post_token_hooks = []
|
self.post_token_hooks = []
|
||||||
self.stopper_hooks = []
|
self.stopper_hooks = []
|
||||||
@@ -669,6 +670,9 @@ class InferenceModel:
|
|||||||
for hook in self.post_token_hooks:
|
for hook in self.post_token_hooks:
|
||||||
hook(self, input_ids)
|
hook(self, input_ids)
|
||||||
|
|
||||||
|
def abort_generation(self, abort=True):
|
||||||
|
self.abort=abort
|
||||||
|
|
||||||
def get_supported_gen_modes(self) -> List[GenerationMode]:
|
def get_supported_gen_modes(self) -> List[GenerationMode]:
|
||||||
"""Returns a list of compatible `GenerationMode`s for the current model.
|
"""Returns a list of compatible `GenerationMode`s for the current model.
|
||||||
|
|
||||||
|
@@ -29,19 +29,22 @@ class model_backend(InferenceModel):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.url = "https://horde.koboldai.net"
|
self.url = "https://horde.koboldai.net"
|
||||||
self.key = "0000000000"
|
self.key = "0000000000"
|
||||||
self.models = self.get_cluster_models()
|
self.models = []
|
||||||
self.model_name = "Horde"
|
self.model_name = "Horde"
|
||||||
self.model = []
|
self.model = []
|
||||||
|
self.request_id = None
|
||||||
|
|
||||||
|
|
||||||
# Do not allow API to be served over the API
|
# Do not allow API to be served over the API
|
||||||
self.capabilties = ModelCapabilities(api_host=False)
|
self.capabilties = ModelCapabilities(api_host=False)
|
||||||
|
|
||||||
def is_valid(self, model_name, model_path, menu_path):
|
def is_valid(self, model_name, model_path, menu_path):
|
||||||
|
self.models = self.get_cluster_models()
|
||||||
logger.debug("Horde Models: {}".format(self.models))
|
logger.debug("Horde Models: {}".format(self.models))
|
||||||
return model_name == "CLUSTER" or model_name in [x['value'] for x in self.models]
|
return model_name == "CLUSTER" or model_name in [x['value'] for x in self.models]
|
||||||
|
|
||||||
def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
|
def get_requested_parameters(self, model_name, model_path, menu_path, parameters = {}):
|
||||||
|
self.models = self.get_cluster_models()
|
||||||
if os.path.exists("settings/horde.model_backend.settings") and 'base_url' not in vars(self):
|
if os.path.exists("settings/horde.model_backend.settings") and 'base_url' not in vars(self):
|
||||||
with open("settings/horde.model_backend.settings", "r") as f:
|
with open("settings/horde.model_backend.settings", "r") as f:
|
||||||
temp = json.load(f)
|
temp = json.load(f)
|
||||||
@@ -222,18 +225,18 @@ class model_backend(InferenceModel):
|
|||||||
logger.error(errmsg)
|
logger.error(errmsg)
|
||||||
raise HordeException(errmsg)
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
request_id = req_status["id"]
|
self.request_id = req_status["id"]
|
||||||
logger.debug("Horde Request ID: {}".format(request_id))
|
logger.debug("Horde Request ID: {}".format(self.request_id))
|
||||||
|
|
||||||
# We've sent the request and got the ID back, now we need to watch it to see when it finishes
|
# We've sent the request and got the ID back, now we need to watch it to see when it finishes
|
||||||
finished = False
|
self.finished = False
|
||||||
|
|
||||||
cluster_agent_headers = {"Client-Agent": client_agent}
|
cluster_agent_headers = {"Client-Agent": client_agent}
|
||||||
|
|
||||||
while not finished:
|
while not self.finished:
|
||||||
try:
|
try:
|
||||||
req = requests.get(
|
req = requests.get(
|
||||||
f"{self.url}/api/v2/generate/text/status/{request_id}",
|
f"{self.url}/api/v2/generate/text/status/{self.request_id}",
|
||||||
headers=cluster_agent_headers,
|
headers=cluster_agent_headers,
|
||||||
)
|
)
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
@@ -260,15 +263,16 @@ class model_backend(InferenceModel):
|
|||||||
logger.error(errmsg)
|
logger.error(errmsg)
|
||||||
raise HordeException(errmsg)
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
finished = req_status["done"]
|
self.finished = req_status["done"]
|
||||||
utils.koboldai_vars.horde_wait_time = req_status["wait_time"]
|
utils.koboldai_vars.horde_wait_time = req_status["wait_time"]
|
||||||
utils.koboldai_vars.horde_queue_position = req_status["queue_position"]
|
utils.koboldai_vars.horde_queue_position = req_status["queue_position"]
|
||||||
utils.koboldai_vars.horde_queue_size = req_status["waiting"]
|
utils.koboldai_vars.horde_queue_size = req_status["waiting"]
|
||||||
|
|
||||||
if not finished:
|
if not self.finished:
|
||||||
logger.debug(req_status)
|
logger.debug(req_status)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
self.request_id = None
|
||||||
logger.debug("Last Horde Status Message: {}".format(req_status))
|
logger.debug("Last Horde Status Message: {}".format(req_status))
|
||||||
|
|
||||||
if req_status["faulted"]:
|
if req_status["faulted"]:
|
||||||
@@ -287,3 +291,33 @@ class model_backend(InferenceModel):
|
|||||||
is_whole_generation=True,
|
is_whole_generation=True,
|
||||||
single_line=single_line,
|
single_line=single_line,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def abort_generation(self, abort=True):
|
||||||
|
logger.info("Attempting to stop horde gen")
|
||||||
|
self.finished = True
|
||||||
|
try:
|
||||||
|
# Create request
|
||||||
|
client_agent = "KoboldAI:2.0.0:koboldai.org"
|
||||||
|
cluster_headers = {
|
||||||
|
"apikey": self.key,
|
||||||
|
"Client-Agent": client_agent,
|
||||||
|
}
|
||||||
|
req = requests.delete(
|
||||||
|
f"{self.url}/v2/generate/text/status/{self.request_id}",
|
||||||
|
headers=cluster_headers,
|
||||||
|
)
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
errmsg = f"Horde unavailable. Please try again later"
|
||||||
|
logger.error(errmsg)
|
||||||
|
raise HordeException(errmsg)
|
||||||
|
|
||||||
|
if req.status_code == 503:
|
||||||
|
errmsg = f"KoboldAI API Error: No available KoboldAI servers found in Horde to fulfil this request using the selected models or other properties."
|
||||||
|
logger.error(errmsg)
|
||||||
|
raise HordeException(errmsg)
|
||||||
|
elif not req.ok:
|
||||||
|
errmsg = f"KoboldAI API Error: Failed to get a standard reply from the Horde. Please check the console."
|
||||||
|
logger.error(req.url)
|
||||||
|
logger.error(errmsg)
|
||||||
|
logger.error(req.text)
|
||||||
|
raise HordeException(errmsg)
|
@@ -26,11 +26,11 @@ class Stoppers:
|
|||||||
f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})"
|
f"Inconsistency detected between KoboldAI Python and Lua backends ({utils.koboldai_vars.generated_tkns} != {utils.koboldai_vars.lua_koboldbridge.generated_cols})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if utils.koboldai_vars.abort or (
|
if model.abort or (
|
||||||
utils.koboldai_vars.inference_config.stop_at_genamt
|
utils.koboldai_vars.inference_config.stop_at_genamt
|
||||||
and utils.koboldai_vars.generated_tkns >= utils.koboldai_vars.genamt
|
and utils.koboldai_vars.generated_tkns >= utils.koboldai_vars.genamt
|
||||||
):
|
):
|
||||||
utils.koboldai_vars.abort = False
|
model.abort = False
|
||||||
model.gen_state["regeneration_required"] = False
|
model.gen_state["regeneration_required"] = False
|
||||||
model.gen_state["halt"] = False
|
model.gen_state["halt"] = False
|
||||||
return True
|
return True
|
||||||
|
@@ -100,7 +100,7 @@
|
|||||||
</div>
|
</div>
|
||||||
</div><br>
|
</div><br>
|
||||||
<div class="statusbar_outer_horde var_sync_alt_system_aibusy var_sync_alt_model_horde_wait_time" id="status_bar_horde">
|
<div class="statusbar_outer_horde var_sync_alt_system_aibusy var_sync_alt_model_horde_wait_time" id="status_bar_horde">
|
||||||
<div class="statusbar_inner_horde" style="width:100%">
|
<div class="statusbar_inner_horde" style="width:100%" onclick="socket.emit('abort','');">
|
||||||
<div> </div>
|
<div> </div>
|
||||||
<div>Queue <span class="var_sync_model_horde_queue_position"></span> of <span class="var_sync_model_horde_queue_size"></span></div>
|
<div>Queue <span class="var_sync_model_horde_queue_position"></span> of <span class="var_sync_model_horde_queue_size"></span></div>
|
||||||
<div><span class="var_sync_model_horde_wait_time"></span> sec left</div>
|
<div><span class="var_sync_model_horde_wait_time"></span> sec left</div>
|
||||||
|
Reference in New Issue
Block a user