From 686c3d1592dbb97198e6ba62ce33a825ba60e3a6 Mon Sep 17 00:00:00 2001 From: somebody Date: Mon, 3 Jul 2023 16:52:18 -0500 Subject: [PATCH] Don't patch lazyload on TPU --- aiserver.py | 7 ++----- modeling/patches.py | 8 ++++---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/aiserver.py b/aiserver.py index 613c735f..d5dbdfae 100644 --- a/aiserver.py +++ b/aiserver.py @@ -27,9 +27,6 @@ from ansi2html import Ansi2HTMLConverter logging.getLogger("urllib3").setLevel(logging.ERROR) -from modeling import patches -patches.patch_transformers_for_lazyload() - import attention_bias attention_bias.do_patches() @@ -10809,7 +10806,7 @@ def run(): Session(app) logger.init_ok("Flask", status="OK") logger.init("Webserver", status="Starting") - patch_transformers() + patch_transformers(use_tpu=koboldai_vars.use_colab_tpu) # Start Flask/SocketIO (Blocking, so this must be last method!) port = args.port if "port" in args and args.port is not None else 5000 @@ -10906,7 +10903,7 @@ else: logger.init("Flask", status="Starting") Session(app) logger.init_ok("Flask", status="OK") - patch_transformers() + patch_transformers(use_tpu=koboldai_vars.use_colab_tpu) startup(command_line_backend) koboldai_settings.port = args.port if "port" in args and args.port is not None else 5000 print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END), flush=True) diff --git a/modeling/patches.py b/modeling/patches.py index b7c5370a..d8990327 100644 --- a/modeling/patches.py +++ b/modeling/patches.py @@ -164,7 +164,6 @@ def patch_transformers_for_lazyload() -> None: # both for short term compatibility load_in_8bit=False, is_quantized=False, - is_safetensors=False, keep_in_fp32_modules=None, ): @@ -303,9 +302,10 @@ def patch_transformers_for_lazyload() -> None: ) -def patch_transformers() -> None: +def patch_transformers(use_tpu: bool) -> None: patch_transformers_download() patch_transformers_loader() - # Doesn't do anything for TPU - patch_transformers_generation() + if not use_tpu: + patch_transformers_generation() + patch_transformers_for_lazyload()