From 6da7a9629ad9c5ae2b25415e12174addc6b3b545 Mon Sep 17 00:00:00 2001 From: somebody Date: Wed, 19 Jul 2023 13:01:07 -0500 Subject: [PATCH] API: Fix /story/load --- aiserver.py | 10 +++++++--- koboldai_settings.py | 15 +++++++++------ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/aiserver.py b/aiserver.py index 2278015c..153e6d07 100644 --- a/aiserver.py +++ b/aiserver.py @@ -5130,9 +5130,13 @@ def load_story_v1(js, from_file=None): def load_story_v2(js, from_file=None): logger.debug("Loading V2 Story") logger.debug("Called from {}".format(inspect.stack()[1].function)) - leave_room(session['story']) - session['story'] = js['story_name'] - join_room(session['story']) + + new_story = js["story_name"] + # In socket context + if hasattr(request, "sid"): + leave_room(session['story']) + join_room(new_story) + session['story'] = new_story koboldai_vars.load_story(session['story'], js) diff --git a/koboldai_settings.py b/koboldai_settings.py index ebd8c019..3bc0eb86 100644 --- a/koboldai_settings.py +++ b/koboldai_settings.py @@ -6,7 +6,7 @@ import os, re, time, threading, json, pickle, base64, copy, tqdm, datetime, sys import shutil from typing import List, Union from io import BytesIO -from flask import has_request_context, session +from flask import has_request_context, session, request from flask_socketio import join_room, leave_room from collections import OrderedDict import multiprocessing @@ -130,11 +130,14 @@ class koboldai_vars(object): original_story_name = story_name if not multi_story: story_name = 'default' - #Leave the old room and join the new one - logger.debug("Leaving room {}".format(session['story'])) - leave_room(session['story']) - logger.debug("Joining room {}".format(story_name)) - join_room(story_name) + + # Leave the old room and join the new one if in socket context + if hasattr(request, "sid"): + logger.debug("Leaving room {}".format(session['story'])) + leave_room(session['story']) + logger.debug("Joining room {}".format(story_name)) + join_room(story_name) + session['story'] = story_name logger.debug("Sending story reset") self._story_settings[story_name]._socketio.emit("reset_story", {}, broadcast=True, room=story_name)