diff --git a/aiserver.py b/aiserver.py index 5732128f..ad6e7fd6 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1149,8 +1149,12 @@ def get_model_info(model, directory=""): breakmodel = False else: breakmodel = True - if path.exists("settings/{}.breakmodel".format(model.replace("/", "_"))): - with open("settings/{}.breakmodel".format(model.replace("/", "_")), "r") as file: + if model in ["NeoCustom", "GPT2Custom"]: + filename = "settings/{}.breakmodel".format(os.path.basename(os.path.normpath(directory))) + else: + filename = "settings/{}.breakmodel".format(model.replace("/", "_")) + if path.exists(filename): + with open(filename, "r") as file: data = file.read().split("\n")[:2] if len(data) < 2: data.append("0") @@ -3245,7 +3249,11 @@ def get_message(msg): if gpu_layers == msg['gpu_layers'] and disk_layers == msg['disk_layers']: changed = False if changed: - f = open("settings/" + vars.model.replace('/', '_') + ".breakmodel", "w") + if vars.model in ["NeoCustom", "GPT2Custom"]: + filename = "settings/{}.breakmodel".format(os.path.basename(os.path.normpath(vars.custmodpth))) + else: + filename = "settings/{}.breakmodel".format(vars.model.replace('/', '_')) + f = open(filename, "w") f.write(msg['gpu_layers'] + '\n' + msg['disk_layers']) f.close() vars.colaburl = msg['url'] + "/request" diff --git a/test_aiserver.py b/test_aiserver.py index 855b71f5..b9bf4606 100644 --- a/test_aiserver.py +++ b/test_aiserver.py @@ -125,7 +125,7 @@ def test_load_model_from_web_ui(client_data, model, expected_load_options): assert response['url'] == expected_load_options['url'] #Now send the load - socketio_client.emit('message',{'cmd': 'load_model', 'use_gpu': True, 'key': '', 'gpu_layers': '', 'url': '', 'online_model': ''}) + socketio_client.emit('message',{'cmd': 'load_model', 'use_gpu': True, 'key': '', 'gpu_layers': str(expected_load_options['layer_count']), 'disk_layers': '0', 'url': '', 'online_model': ''}) #wait until the game state turns back to start state = 'wait' start_time = time.time() @@ -208,11 +208,17 @@ def test_back_redo(client_data): response = socketio_client.get_received()[0]['args'][0] assert response == {'cmd': 'errmsg', 'data': 'Cannot delete the prompt.'} socketio_client.emit('message',{'cmd': 'redo', 'data': ''}) - socketio_client.emit('message',{'cmd': 'redo', 'data': ''}) + response = socketio_client.get_received() + assert response == [{'name': 'from_server', 'args': [{'cmd': 'updatescreen', 'gamestarted': True, 'data': 'Niko the kobold stalked carefully down the alley, his small scaly figure obscured by a dusky cloak that fluttered lightly in the cold winter breeze. Holding up his tail to keep it from dragging in the dirty snow that covered the cobblestone, he waited patiently for the butcher to turn his attention from his stall so that he could pilfer his next meal: a tender-looking chicken. He crouched just slightly as he neared the stall to ensure that no one was watching, not that anyone would be dumb enough to hassle a small kobold. What else was there for a lowly kobold to'}], 'namespace': '/'}, + {'name': 'from_server', 'args': [{'cmd': 'texteffect', 'data': 1}], 'namespace': '/'}] socketio_client.emit('message',{'cmd': 'redo', 'data': ''}) response = socketio_client.get_received() - assert response == [{'name': 'from_server', 'args': [{'cmd': 'updatescreen', 'gamestarted': True, 'data': 'Niko the kobold stalked carefully down the alley, his small scaly figure obscured by a dusky cloak that fluttered lightly in the cold winter breeze. Holding up his tail to keep it from dragging in the dirty snow that covered the cobblestone, he waited patiently for the butcher to turn his attention from his stall so that he could pilfer his next meal: a tender-looking chicken. He crouched just slightly as he neared the stall to ensure that no one was watching, not that anyone would be dumb enough to hassle a small kobold. What else was there for a lowly kobold to'}], 'namespace': '/'}, {'name': 'from_server', 'args': [{'cmd': 'texteffect', 'data': 1}], 'namespace': '/'}, {'name': 'from_server', 'args': [{'cmd': 'updatechunk', 'data': {'index': 2, 'html': ' do in a city? All that Niko needed to know was'}}], 'namespace': '/'}, {'name': 'from_server', 'args': [{'cmd': 'texteffect', 'data': 2}], 'namespace': '/'}, {'name': 'from_server', 'args': [{'cmd': 'updatechunk', 'data': {'index': 3, 'html': ' where to find the chicken and then how to make off with it.

A soft thud caused Niko to quickly lift his head. Standing behind the stall where the butcher had been cutting his chicken,
'}}], 'namespace': '/'}, {'name': 'from_server', 'args': [{'cmd': 'texteffect', 'data': 3}], 'namespace': '/'}] - + assert response == [{'name': 'from_server', 'args': [{'cmd': 'updatechunk', 'data': {'index': 2, 'html': ' do in a city? All that Niko needed to know was'}}], 'namespace': '/'}, + {'name': 'from_server', 'args': [{'cmd': 'texteffect', 'data': 2}], 'namespace': '/'}] + socketio_client.emit('message',{'cmd': 'redo', 'data': ''}) + response = socketio_client.get_received() + assert response == [{'name': 'from_server', 'args': [{'cmd': 'updatechunk', 'data': {'index': 3, 'html': ' where to find the chicken and then how to make off with it.

A soft thud caused Niko to quickly lift his head. Standing behind the stall where the butcher had been cutting his chicken,
'}}], 'namespace': '/'}, + {'name': 'from_server', 'args': [{'cmd': 'texteffect', 'data': 3}], 'namespace': '/'}]