diff --git a/aiserver.py b/aiserver.py index 0809be22..91156b14 100644 --- a/aiserver.py +++ b/aiserver.py @@ -386,6 +386,7 @@ parser.add_argument("--breakmodel_gpulayers", type=str, help="If using a model t parser.add_argument("--override_delete", action='store_true', help="Deleting stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow deleting stories if using --remote and prevent deleting stories otherwise.") parser.add_argument("--override_rename", action='store_true', help="Renaming stories from inside the browser is disabled if you are using --remote and enabled otherwise. Using this option will instead allow renaming stories if using --remote and prevent renaming stories otherwise.") parser.add_argument("--configname", help="Force a fixed configuration name to aid with config management.") +parser.add_argument("--colab", action='store_true', help="Optimize for Google Colab.") args: argparse.Namespace = None if(os.environ.get("KOBOLDAI_ARGS") is not None): @@ -393,8 +394,14 @@ if(os.environ.get("KOBOLDAI_ARGS") is not None): args = parser.parse_args(shlex.split(os.environ["KOBOLDAI_ARGS"])) else: args = parser.parse_args() + vars.model = args.model; +if args.colab: + args.remote = True; + args.override_rename = True; + args.override_delete = True; + if args.remote: vars.remote = True; @@ -452,7 +459,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme vars.model_type = "gpt_neo" print("{0}Looking for GPU support...{1}".format(colors.PURPLE, colors.END), end="") vars.hascuda = torch.cuda.is_available() - vars.bmsupported = vars.model_type in ("gpt_neo", "gptj") + vars.bmsupported = vars.model_type in ("gpt_neo", "gptj") and not args.colab if(args.breakmodel is not None and args.breakmodel): print("WARNING: --breakmodel is no longer supported. Breakmodel mode is now automatically enabled when --layers is used (see --help for details).", file=sys.stderr) if(args.breakmodel_layers is not None): @@ -932,7 +939,6 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme except ValueError as e: model = GPTNeoForCausalLM.from_pretrained(vars.model.replace('/', '_'), cache_dir="cache/", **lowmem) else: - print("Model does not exist locally, attempting to download from Huggingface...") try: tokenizer = AutoTokenizer.from_pretrained(vars.model, cache_dir="cache/") except ValueError as e: @@ -942,11 +948,13 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme model = AutoModelForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem) except ValueError as e: model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache/", **lowmem) - model = model.half() - import shutil - shutil.rmtree("cache/") - model.save_pretrained(vars.model.replace('/', '_')) - tokenizer.save_pretrained(vars.model.replace('/', '_')) + + if not args.colab: + model = model.half() + import shutil + shutil.rmtree("cache/") + model.save_pretrained(vars.model.replace('/', '_')) + tokenizer.save_pretrained(vars.model.replace('/', '_')) if(vars.hascuda): if(vars.usegpu): diff --git a/colabkobold.sh b/colabkobold.sh index 2aea408a..3c05d611 100644 --- a/colabkobold.sh +++ b/colabkobold.sh @@ -47,7 +47,7 @@ function launch else cd /content/KoboldAI-Client echo "Launching KoboldAI with the following options : python3 aiserver.py$model$kmpath$configname$ngrok --remote --override_delete --override_rename" - python3 aiserver.py$model$kmpath$configname$ngrok --remote --override_delete --override_rename + python3 aiserver.py$model$kmpath$configname$ngrok --remote --override_delete --override_rename --colab exit fi } @@ -106,9 +106,9 @@ if [ "$init" != "skip" ]; then cd /content/KoboldAI-Client - cp -rn stories/* /content/drive/MyDrive/KoboldAI/stories/ - cp -rn userscripts/* /content/drive/MyDrive/KoboldAI/userscripts/ - cp -rn softprompts/* /content/drive/MyDrive/KoboldAI/softprompts/ + cp -n stories/* /content/drive/MyDrive/KoboldAI/stories/ + cp -n userscripts/* /content/drive/MyDrive/KoboldAI/userscripts/ + cp -n softprompts/* /content/drive/MyDrive/KoboldAI/softprompts/ rm stories rm -rf stories/ rm userscripts