mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-09 08:18:40 +01:00
xmaps for moving things onto TPU
This commit is contained in:
parent
49e2bcab1a
commit
a3d6dc93e8
@ -993,7 +993,7 @@ else:
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
)
|
||||
vars.sp = tensor
|
||||
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
|
||||
soft_tokens = np.arange(
|
||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
||||
@ -4005,7 +4005,7 @@ def spRequest(filename):
|
||||
-1,
|
||||
tpu_mtj_backend.params["d_model"],
|
||||
)
|
||||
vars.sp = np.float32(tensor)
|
||||
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
|
||||
else:
|
||||
vars.sp = torch.from_numpy(tensor)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import multiprocessing
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
import progressbar
|
||||
import time
|
||||
import os
|
||||
@ -28,6 +28,31 @@ def show_spinner():
|
||||
time.sleep(0.1)
|
||||
i += 1
|
||||
|
||||
|
||||
__F = TypeVar("__F", bound=Callable)
|
||||
__T = TypeVar("__T")
|
||||
|
||||
def __move_xmap(f: __F, out_axis: str) -> __F:
|
||||
return maps.xmap(
|
||||
f,
|
||||
in_axes=(["shard", ...], ["batch", ...]),
|
||||
out_axes=[out_axis, ...],
|
||||
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
||||
)
|
||||
|
||||
def __shard_xmap(batch_dim=1):
|
||||
xmap = __move_xmap(lambda s, b: s, "shard")
|
||||
def inner(x: __T) -> __T:
|
||||
return xmap(x, np.empty(batch_dim))
|
||||
return inner
|
||||
|
||||
def __batch_xmap(shard_dim=1):
|
||||
xmap = __move_xmap(lambda s, b: b, "batch")
|
||||
def inner(x: __T) -> __T:
|
||||
return xmap(np.empty(shard_dim), x)
|
||||
return inner
|
||||
|
||||
|
||||
def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
||||
'''
|
||||
This gets called by generate_loop_fn to apply repetition penalty
|
||||
@ -255,16 +280,17 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||
batch_size = ctx.shape[0]
|
||||
self.batch_size = batch_size
|
||||
return self.generate_xmap(
|
||||
self.state,
|
||||
jnp.array(key.take(batch_size)),
|
||||
ctx,
|
||||
np.array(ctx_length, dtype=np.uint32),
|
||||
np.array(gen_length, dtype=np.uint32),
|
||||
xargs = (
|
||||
shard_xmap(self.state),
|
||||
batch_xmap(jnp.array(key.take(batch_size))),
|
||||
batch_xmap(ctx),
|
||||
batch_xmap(np.array(ctx_length, dtype=np.uint32)),
|
||||
batch_xmap(np.array(gen_length, dtype=np.uint32)),
|
||||
np.empty((batch_size, numseqs), dtype=np.uint8),
|
||||
sampler_options,
|
||||
soft_embeddings,
|
||||
batch_xmap(sampler_options),
|
||||
shard_xmap(soft_embeddings),
|
||||
)
|
||||
return self.generate_xmap(*xargs)
|
||||
|
||||
|
||||
def infer(
|
||||
@ -354,6 +380,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
||||
maps.thread_resources.env = thread_resources_env
|
||||
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
||||
|
||||
global shard_xmap, batch_xmap
|
||||
shard_xmap = __shard_xmap()
|
||||
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
||||
|
||||
if not path.endswith("/"):
|
||||
path += "/"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user