From 707316de31c0dd6f67f06dd32bf8e1f21a18dd6f Mon Sep 17 00:00:00 2001 From: Gnome Ann <> Date: Tue, 31 May 2022 12:20:16 -0400 Subject: [PATCH] Kaggle TPU support --- aiserver.py | 2 +- tpu_mtj_backend.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/aiserver.py b/aiserver.py index 09e89a1a..cead5e44 100644 --- a/aiserver.py +++ b/aiserver.py @@ -320,7 +320,7 @@ class vars: quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page) debug = False # If set to true, will send debug information to the client for display lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage - use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" # Whether or not we're in a Colab TPU instance and are going to use the TPU rather than the CPU + use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU utils.vars = vars diff --git a/tpu_mtj_backend.py b/tpu_mtj_backend.py index 41746d37..fb2dc7ae 100644 --- a/tpu_mtj_backend.py +++ b/tpu_mtj_backend.py @@ -1094,13 +1094,18 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo print("Connecting to your Colab instance's TPU", flush=True) spinner = multiprocessing.Process(target=show_spinner, args=()) spinner.start() - colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0] - url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}' + if os.environ.get('COLAB_TPU_ADDR', '') != '': + tpu_address = os.environ['COLAB_TPU_ADDR'] # Colab + else: + tpu_address = os.environ['TPU_NAME'] # Kaggle + tpu_address = tpu_address.replace("grpc://", "") + tpu_address_without_port = tpu_address.split(':', 1)[0] + url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}' + config.FLAGS.jax_xla_backend = "tpu_driver" + config.FLAGS.jax_backend_target = "grpc://" + tpu_address requests.post(url) spinner.terminate() print() - config.FLAGS.jax_xla_backend = "tpu_driver" - config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR'] cores_per_replica = params["cores_per_replica"] seq = params["seq"]