mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-01-23 13:50:25 +01:00
Bugfix for save function not appending .json extension by default
Bugfix for New Story function not clearing World Info from previous story Torch will not be initialized unless you select a local model, as there's no reason to invoke it for InferKit/Colab Changed JSON file writes to use indentation for readability
This commit is contained in:
parent
429c9b13f5
commit
2cef3bceaf
76
aiserver.py
76
aiserver.py
@ -9,7 +9,6 @@ from os import path, getcwd
|
||||
import tkinter as tk
|
||||
from tkinter import messagebox
|
||||
import json
|
||||
import torch
|
||||
|
||||
# KoboldAI
|
||||
import fileops
|
||||
@ -84,15 +83,10 @@ class vars:
|
||||
# Function to get model selection at startup
|
||||
#==================================================================#
|
||||
def getModelSelection():
|
||||
print(" # Model {0}\n ==================================="
|
||||
.format("VRAM" if vars.hascuda else " "))
|
||||
|
||||
print(" # Model V/RAM\n =========================================")
|
||||
i = 1
|
||||
for m in modellist:
|
||||
if(vars.hascuda):
|
||||
print(" {0} - {1}\t\t{2}".format(i, m[0].ljust(15), m[2]))
|
||||
else:
|
||||
print(" {0} - {1}".format(i, m[0]))
|
||||
print(" {0} - {1}\t\t{2}".format("{:<2}".format(i), m[0].ljust(15), m[2]))
|
||||
i += 1
|
||||
print(" ");
|
||||
modelsel = 0
|
||||
@ -123,36 +117,39 @@ def getModelSelection():
|
||||
# Startup
|
||||
#==================================================================#
|
||||
|
||||
# Test for GPU support
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
if(vars.hascuda):
|
||||
print("{0}FOUND!{1}".format(colors.GREEN, colors.END))
|
||||
else:
|
||||
print("{0}NOT FOUND!{1}".format(colors.YELLOW, colors.END))
|
||||
|
||||
# Select a model to run
|
||||
print("{0}Welcome to the KoboldAI Client!\nSelect an AI model to continue:{1}\n".format(colors.CYAN, colors.END))
|
||||
getModelSelection()
|
||||
|
||||
# If transformers model was selected & GPU available, ask to use CPU or GPU
|
||||
if((not vars.model in ["InferKit", "Colab"]) and vars.hascuda):
|
||||
if(not vars.model in ["InferKit", "Colab"]):
|
||||
# Test for GPU support
|
||||
import torch
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
if(vars.hascuda):
|
||||
print("{0}FOUND!{1}".format(colors.GREEN, colors.END))
|
||||
else:
|
||||
print("{0}NOT FOUND!{1}".format(colors.YELLOW, colors.END))
|
||||
|
||||
print("{0}Use GPU or CPU for generation?: (Default GPU){1}\n".format(colors.CYAN, colors.END))
|
||||
print(" 1 - GPU\n 2 - CPU\n")
|
||||
genselected = False
|
||||
while(genselected == False):
|
||||
genselect = input("Mode> ")
|
||||
if(genselect == ""):
|
||||
vars.usegpu = True
|
||||
genselected = True
|
||||
elif(genselect.isnumeric() and int(genselect) == 1):
|
||||
vars.usegpu = True
|
||||
genselected = True
|
||||
elif(genselect.isnumeric() and int(genselect) == 2):
|
||||
vars.usegpu = False
|
||||
genselected = True
|
||||
else:
|
||||
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
||||
|
||||
if(vars.hascuda):
|
||||
print(" 1 - GPU\n 2 - CPU\n")
|
||||
genselected = False
|
||||
while(genselected == False):
|
||||
genselect = input("Mode> ")
|
||||
if(genselect == ""):
|
||||
vars.usegpu = True
|
||||
genselected = True
|
||||
elif(genselect.isnumeric() and int(genselect) == 1):
|
||||
vars.usegpu = True
|
||||
genselected = True
|
||||
elif(genselect.isnumeric() and int(genselect) == 2):
|
||||
vars.usegpu = False
|
||||
genselected = True
|
||||
else:
|
||||
print("{0}Please enter a valid selection.{1}".format(colors.RED, colors.END))
|
||||
|
||||
# Ask for API key if InferKit was selected
|
||||
if(vars.model == "InferKit"):
|
||||
@ -163,7 +160,8 @@ if(vars.model == "InferKit"):
|
||||
# Write API key to file
|
||||
file = open("client.settings", "w")
|
||||
try:
|
||||
file.write("{\"apikey\": \""+vars.apikey+"\"}")
|
||||
js = {"apikey": vars.apikey}
|
||||
file.write(json.dumps(js, indent=3))
|
||||
finally:
|
||||
file.close()
|
||||
else:
|
||||
@ -183,7 +181,7 @@ if(vars.model == "InferKit"):
|
||||
# Write API key to file
|
||||
file = open("client.settings", "w")
|
||||
try:
|
||||
file.write(json.dumps(js))
|
||||
file.write(json.dumps(js, indent=3))
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
@ -456,7 +454,7 @@ def savesettings():
|
||||
# Write it
|
||||
file = open("client.settings", "w")
|
||||
try:
|
||||
file.write(json.dumps(js))
|
||||
file.write(json.dumps(js, indent=3))
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
@ -712,7 +710,7 @@ def generate(txt, min, max):
|
||||
#==================================================================#
|
||||
def sendtocolab(txt, min, max):
|
||||
# Log request to console
|
||||
print("{0}Len:{1}, Txt:{2}{3}".format(colors.YELLOW, len(txt), txt, colors.END))
|
||||
print("{0}Tokens:{1}, Txt:{2}{3}".format(colors.YELLOW, min-1, txt, colors.END))
|
||||
|
||||
# Build request JSON data
|
||||
reqdata = {
|
||||
@ -752,7 +750,7 @@ def sendtocolab(txt, min, max):
|
||||
elif("errors" in er):
|
||||
code = er["errors"][0]["extensions"]["code"]
|
||||
|
||||
errmsg = "InferKit API Error: {0} - {1}".format(req.status_code, code)
|
||||
errmsg = "Colab API Error: {0} - {1}".format(req.status_code, code)
|
||||
emit('from_server', {'cmd': 'errmsg', 'data': errmsg})
|
||||
set_aibusy(0)
|
||||
|
||||
@ -1148,7 +1146,7 @@ def saveRequest():
|
||||
# Write it
|
||||
file = open(savpath, "w")
|
||||
try:
|
||||
file.write(json.dumps(js))
|
||||
file.write(json.dumps(js, indent=3))
|
||||
finally:
|
||||
file.close()
|
||||
|
||||
@ -1303,8 +1301,10 @@ def newGameRequest():
|
||||
vars.actions = []
|
||||
vars.savedir = getcwd()+"\stories"
|
||||
vars.authornote = ""
|
||||
vars.worldinfo = []
|
||||
|
||||
# Refresh game screen
|
||||
sendwi()
|
||||
setStartState()
|
||||
|
||||
|
||||
|
@ -10,7 +10,8 @@ def getsavepath(dir, title, types):
|
||||
path = tk.filedialog.asksaveasfile(
|
||||
initialdir=dir,
|
||||
title=title,
|
||||
filetypes = types
|
||||
filetypes = types,
|
||||
defaultextension="*.*"
|
||||
)
|
||||
root.destroy()
|
||||
if(path != "" and path != None):
|
||||
|
Loading…
Reference in New Issue
Block a user