Don't patch lazyload on TPU

This commit is contained in:
somebody
2023-07-03 16:52:18 -05:00
parent 39049e8a46
commit 686c3d1592
2 changed files with 6 additions and 9 deletions

View File

@@ -27,9 +27,6 @@ from ansi2html import Ansi2HTMLConverter
logging.getLogger("urllib3").setLevel(logging.ERROR)
from modeling import patches
patches.patch_transformers_for_lazyload()
import attention_bias
attention_bias.do_patches()
@@ -10809,7 +10806,7 @@ def run():
Session(app)
logger.init_ok("Flask", status="OK")
logger.init("Webserver", status="Starting")
patch_transformers()
patch_transformers(use_tpu=koboldai_vars.use_colab_tpu)
# Start Flask/SocketIO (Blocking, so this must be last method!)
port = args.port if "port" in args and args.port is not None else 5000
@@ -10906,7 +10903,7 @@ else:
logger.init("Flask", status="Starting")
Session(app)
logger.init_ok("Flask", status="OK")
patch_transformers()
patch_transformers(use_tpu=koboldai_vars.use_colab_tpu)
startup(command_line_backend)
koboldai_settings.port = args.port if "port" in args and args.port is not None else 5000
print("{0}\nServer started in WSGI mode!{1}".format(colors.GREEN, colors.END), flush=True)

View File

@@ -164,7 +164,6 @@ def patch_transformers_for_lazyload() -> None:
# both for short term compatibility
load_in_8bit=False,
is_quantized=False,
is_safetensors=False,
keep_in_fp32_modules=None,
):
@@ -303,9 +302,10 @@ def patch_transformers_for_lazyload() -> None:
)
def patch_transformers() -> None:
def patch_transformers(use_tpu: bool) -> None:
patch_transformers_download()
patch_transformers_loader()
# Doesn't do anything for TPU
patch_transformers_generation()
if not use_tpu:
patch_transformers_generation()
patch_transformers_for_lazyload()