API: Fix /story/load

This commit is contained in:
somebody
2023-07-19 13:01:07 -05:00
parent b9b3cd3aba
commit 6da7a9629a
2 changed files with 16 additions and 9 deletions

View File

@@ -5130,9 +5130,13 @@ def load_story_v1(js, from_file=None):
def load_story_v2(js, from_file=None): def load_story_v2(js, from_file=None):
logger.debug("Loading V2 Story") logger.debug("Loading V2 Story")
logger.debug("Called from {}".format(inspect.stack()[1].function)) logger.debug("Called from {}".format(inspect.stack()[1].function))
leave_room(session['story'])
session['story'] = js['story_name'] new_story = js["story_name"]
join_room(session['story']) # In socket context
if hasattr(request, "sid"):
leave_room(session['story'])
join_room(new_story)
session['story'] = new_story
koboldai_vars.load_story(session['story'], js) koboldai_vars.load_story(session['story'], js)

View File

@@ -6,7 +6,7 @@ import os, re, time, threading, json, pickle, base64, copy, tqdm, datetime, sys
import shutil import shutil
from typing import List, Union from typing import List, Union
from io import BytesIO from io import BytesIO
from flask import has_request_context, session from flask import has_request_context, session, request
from flask_socketio import join_room, leave_room from flask_socketio import join_room, leave_room
from collections import OrderedDict from collections import OrderedDict
import multiprocessing import multiprocessing
@@ -130,11 +130,14 @@ class koboldai_vars(object):
original_story_name = story_name original_story_name = story_name
if not multi_story: if not multi_story:
story_name = 'default' story_name = 'default'
#Leave the old room and join the new one
logger.debug("Leaving room {}".format(session['story'])) # Leave the old room and join the new one if in socket context
leave_room(session['story']) if hasattr(request, "sid"):
logger.debug("Joining room {}".format(story_name)) logger.debug("Leaving room {}".format(session['story']))
join_room(story_name) leave_room(session['story'])
logger.debug("Joining room {}".format(story_name))
join_room(story_name)
session['story'] = story_name session['story'] = story_name
logger.debug("Sending story reset") logger.debug("Sending story reset")
self._story_settings[story_name]._socketio.emit("reset_story", {}, broadcast=True, room=story_name) self._story_settings[story_name]._socketio.emit("reset_story", {}, broadcast=True, room=story_name)