mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-02-10 16:50:45 +01:00
xmaps for moving things onto TPU
This commit is contained in:
parent
49e2bcab1a
commit
a3d6dc93e8
@ -993,7 +993,7 @@ else:
|
|||||||
-1,
|
-1,
|
||||||
tpu_mtj_backend.params["d_model"],
|
tpu_mtj_backend.params["d_model"],
|
||||||
)
|
)
|
||||||
vars.sp = tensor
|
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
|
||||||
soft_tokens = np.arange(
|
soft_tokens = np.arange(
|
||||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
||||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
||||||
@ -4005,7 +4005,7 @@ def spRequest(filename):
|
|||||||
-1,
|
-1,
|
||||||
tpu_mtj_backend.params["d_model"],
|
tpu_mtj_backend.params["d_model"],
|
||||||
)
|
)
|
||||||
vars.sp = np.float32(tensor)
|
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
|
||||||
else:
|
else:
|
||||||
vars.sp = torch.from_numpy(tensor)
|
vars.sp = torch.from_numpy(tensor)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||||
import progressbar
|
import progressbar
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
@ -28,6 +28,31 @@ def show_spinner():
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
__F = TypeVar("__F", bound=Callable)
|
||||||
|
__T = TypeVar("__T")
|
||||||
|
|
||||||
|
def __move_xmap(f: __F, out_axis: str) -> __F:
|
||||||
|
return maps.xmap(
|
||||||
|
f,
|
||||||
|
in_axes=(["shard", ...], ["batch", ...]),
|
||||||
|
out_axes=[out_axis, ...],
|
||||||
|
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __shard_xmap(batch_dim=1):
|
||||||
|
xmap = __move_xmap(lambda s, b: s, "shard")
|
||||||
|
def inner(x: __T) -> __T:
|
||||||
|
return xmap(x, np.empty(batch_dim))
|
||||||
|
return inner
|
||||||
|
|
||||||
|
def __batch_xmap(shard_dim=1):
|
||||||
|
xmap = __move_xmap(lambda s, b: b, "batch")
|
||||||
|
def inner(x: __T) -> __T:
|
||||||
|
return xmap(np.empty(shard_dim), x)
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply repetition penalty
|
This gets called by generate_loop_fn to apply repetition penalty
|
||||||
@ -255,16 +280,17 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||||
batch_size = ctx.shape[0]
|
batch_size = ctx.shape[0]
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
return self.generate_xmap(
|
xargs = (
|
||||||
self.state,
|
shard_xmap(self.state),
|
||||||
jnp.array(key.take(batch_size)),
|
batch_xmap(jnp.array(key.take(batch_size))),
|
||||||
ctx,
|
batch_xmap(ctx),
|
||||||
np.array(ctx_length, dtype=np.uint32),
|
batch_xmap(np.array(ctx_length, dtype=np.uint32)),
|
||||||
np.array(gen_length, dtype=np.uint32),
|
batch_xmap(np.array(gen_length, dtype=np.uint32)),
|
||||||
np.empty((batch_size, numseqs), dtype=np.uint8),
|
np.empty((batch_size, numseqs), dtype=np.uint8),
|
||||||
sampler_options,
|
batch_xmap(sampler_options),
|
||||||
soft_embeddings,
|
shard_xmap(soft_embeddings),
|
||||||
)
|
)
|
||||||
|
return self.generate_xmap(*xargs)
|
||||||
|
|
||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
@ -354,6 +380,10 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
||||||
|
|
||||||
|
global shard_xmap, batch_xmap
|
||||||
|
shard_xmap = __shard_xmap()
|
||||||
|
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
||||||
|
|
||||||
if not path.endswith("/"):
|
if not path.endswith("/"):
|
||||||
path += "/"
|
path += "/"
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user