Kaggle TPU support
This commit is contained in:
parent
1a1f2f6428
commit
707316de31
|
@ -320,7 +320,7 @@ class vars:
|
|||
quiet = False # If set will suppress any story text from being printed to the console (will only be seen on the client web page)
|
||||
debug = False # If set to true, will send debug information to the client for display
|
||||
lazy_load = True # Whether or not to use torch_lazy_loader.py for transformers models in order to reduce CPU memory usage
|
||||
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" # Whether or not we're in a Colab TPU instance and are going to use the TPU rather than the CPU
|
||||
use_colab_tpu = os.environ.get("COLAB_TPU_ADDR", "") != "" or os.environ.get("TPU_NAME", "") != "" # Whether or not we're in a Colab TPU instance or Kaggle TPU instance and are going to use the TPU rather than the CPU
|
||||
|
||||
utils.vars = vars
|
||||
|
||||
|
|
|
@ -1094,13 +1094,18 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", hf_checkpo
|
|||
print("Connecting to your Colab instance's TPU", flush=True)
|
||||
spinner = multiprocessing.Process(target=show_spinner, args=())
|
||||
spinner.start()
|
||||
colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
|
||||
url = f'http://{colab_tpu_addr}:8475/requestversion/{driver_version}'
|
||||
if os.environ.get('COLAB_TPU_ADDR', '') != '':
|
||||
tpu_address = os.environ['COLAB_TPU_ADDR'] # Colab
|
||||
else:
|
||||
tpu_address = os.environ['TPU_NAME'] # Kaggle
|
||||
tpu_address = tpu_address.replace("grpc://", "")
|
||||
tpu_address_without_port = tpu_address.split(':', 1)[0]
|
||||
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
|
||||
config.FLAGS.jax_xla_backend = "tpu_driver"
|
||||
config.FLAGS.jax_backend_target = "grpc://" + tpu_address
|
||||
requests.post(url)
|
||||
spinner.terminate()
|
||||
print()
|
||||
config.FLAGS.jax_xla_backend = "tpu_driver"
|
||||
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
|
||||
|
||||
cores_per_replica = params["cores_per_replica"]
|
||||
seq = params["seq"]
|
||||
|
|
Loading…
Reference in New Issue