mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Model saving for colab mode
This commit is contained in:
@ -738,7 +738,7 @@ parser.add_argument("--nobreakmodel", action='store_true', help="Disables Breakm
|
||||
parser.add_argument("--unblock", action='store_true', default=False, help="Unblocks the KoboldAI port to be accessible from other machines without optimizing for remote play (It is recommended to use --host instead)")
|
||||
parser.add_argument("--quiet", action='store_true', default=False, help="If present will suppress any story related text from showing on the console")
|
||||
parser.add_argument("--lowmem", action='store_true', help="Extra Low Memory loading for the GPU, slower but memory does not peak to twice the usage")
|
||||
|
||||
parser.add_argument("--savemodel", action='store_true', help="Saves the model to the models folder even if --colab is used (Allows you to save models to Google Drive)")
|
||||
args: argparse.Namespace = None
|
||||
if(os.environ.get("KOBOLDAI_ARGS") is not None):
|
||||
import shlex
|
||||
@ -1420,7 +1420,7 @@ if(not vars.use_colab_tpu and vars.model not in ["InferKit", "Colab", "OAI", "Go
|
||||
except Exception as e:
|
||||
model = GPTNeoForCausalLM.from_pretrained(vars.model, cache_dir="cache", **lowmem)
|
||||
|
||||
if not args.colab:
|
||||
if not args.colab or args.savemodel:
|
||||
import shutil
|
||||
model = model.half()
|
||||
model.save_pretrained("models/{}".format(vars.model.replace('/', '_')))
|
||||
|
Reference in New Issue
Block a user