mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge branch 'united' into xmap
This commit is contained in:
24
aiserver.py
24
aiserver.py
@ -69,7 +69,7 @@ modellist = [
|
||||
["C1 6B (Chatbot)", "hakurei/c1-6B", "12GB"],
|
||||
["Picard 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Picard", "6GB"],
|
||||
["Horni 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Horni", "6GB"],
|
||||
["Horni-LN 2.7B (Novel/NSFW)", "KoboldAI/GPT-Neo-2.7B-Horni-LN", "6GB"],
|
||||
["Horni-LN 2.7B (Novel)", "KoboldAI/GPT-Neo-2.7B-Horni-LN", "6GB"],
|
||||
["Shinen 2.7B (NSFW)", "KoboldAI/GPT-Neo-2.7B-Shinen", "6GB"],
|
||||
["GPT-J 6B", "EleutherAI/gpt-j-6B", "12GB"],
|
||||
["GPT-Neo 2.7B", "EleutherAI/gpt-neo-2.7B", "6GB"],
|
||||
@ -388,6 +388,7 @@ parser.add_argument("--breakmodel_gpulayers", type=str, help="If using a model t
|
||||
parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.")
|
||||
parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.")
|
||||
parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.")
|
||||
parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.")
|
||||
|
||||
args: argparse.Namespace = None
|
||||
if(os.environ.get("KOBOLDAI_ARGS") is not None):
|
||||
@ -395,8 +396,14 @@ if(os.environ.get("KOBOLDAI_ARGS") is not None):
|
||||
args = parser.parse_args(shlex.split(os.environ["KOBOLDAI_ARGS"]))
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
vars.model = args.model;
|
||||
|
||||
if args.colab:
|
||||
args.remote = True;
|
||||
args.override_rename = True;
|
||||
args.override_delete = True;
|
||||
|
||||
if args.remote:
|
||||
vars.remote = True;
|
||||
|
||||
@ -454,7 +461,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
vars.model_type = "gpt_neo"
|
||||
print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="")
|
||||
vars.hascuda = torch.cuda.is_available()
|
||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj")
|
||||
vars.bmsupported = vars.model_type in ("gpt_neo", "gptj") and not args.colab
|
||||
if(args.breakmodel is not None and args.breakmodel):
|
||||
print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --layers is used (see --help for details).", file=sys.stderr)
|
||||
if(args.breakmodel_layers is not None):
|
||||
@ -934,7 +941,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem)
|
||||
else:
|
||||
print("Model does not exist locally, attempting to download from Huggingface...")
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache/")
|
||||
except ValueError as e:
|
||||
@ -944,11 +950,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme
|
||||
model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem)
|
||||
except ValueError as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem)
|
||||
model = model.half()
|
||||
import shutil
|
||||
shutil.rmtree("cache/")
|
||||
model.save_pretrained(vars.model.replace('/', '_'))
|
||||
tokenizer.save_pretrained(vars.model.replace('/', '_'))
|
||||
|
||||
if not args.colab:
|
||||
model = model.half()
|
||||
import shutil
|
||||
shutil.rmtree("cache/")
|
||||
model.save_pretrained(vars.model.replace('/', '_'))
|
||||
tokenizer.save_pretrained(vars.model.replace('/', '_'))
|
||||
|
||||
if(vars.hascuda):
|
||||
if(vars.usegpu):
|
||||
|
Reference in New Issue
Block a user