Dynamic TPU backend should support dynamic warpers and abort button
This commit is contained in:
parent
31735c4239
commit
3ba0e3f9d9
18
aiserver.py
18
aiserver.py
|
@ -1035,7 +1035,7 @@ else:
|
||||||
|
|
||||||
assert len(excluded_world_info) == len(generated)
|
assert len(excluded_world_info) == len(generated)
|
||||||
regeneration_required = vars.lua_koboldbridge.regeneration_required
|
regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||||
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
|
halt = vars.abort or not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
|
||||||
vars.lua_koboldbridge.regeneration_required = False
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
|
||||||
global past
|
global past
|
||||||
|
@ -1061,6 +1061,15 @@ else:
|
||||||
|
|
||||||
def tpumtjgenerate_stopped_compiling_callback() -> None:
|
def tpumtjgenerate_stopped_compiling_callback() -> None:
|
||||||
vars.compiling = False
|
vars.compiling = False
|
||||||
|
|
||||||
|
def tpumtjgenerate_settings_callback() -> dict:
|
||||||
|
return {
|
||||||
|
"top_p": float(vars.top_p),
|
||||||
|
"temp": float(vars.temp),
|
||||||
|
"top_k": int(vars.top_k),
|
||||||
|
"tfs": float(vars.tfs),
|
||||||
|
"repetition_penalty": float(vars.rep_pen),
|
||||||
|
}
|
||||||
|
|
||||||
# If we're running Colab or OAI, we still need a tokenizer.
|
# If we're running Colab or OAI, we still need a tokenizer.
|
||||||
if(vars.model == "Colab"):
|
if(vars.model == "Colab"):
|
||||||
|
@ -3009,12 +3018,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
tpu_mtj_backend.infer_dynamic,
|
tpu_mtj_backend.infer_dynamic,
|
||||||
context,
|
context,
|
||||||
gen_len = maximum-minimum+1,
|
gen_len = maximum-minimum+1,
|
||||||
temp=vars.temp,
|
|
||||||
top_p=vars.top_p,
|
|
||||||
top_k=vars.top_k,
|
|
||||||
tfs=vars.tfs,
|
|
||||||
numseqs=vars.numseqs,
|
numseqs=vars.numseqs,
|
||||||
repetition_penalty=vars.rep_pen,
|
|
||||||
soft_embeddings=vars.sp,
|
soft_embeddings=vars.sp,
|
||||||
soft_tokens=soft_tokens,
|
soft_tokens=soft_tokens,
|
||||||
excluded_world_info=found_entries,
|
excluded_world_info=found_entries,
|
||||||
|
@ -3026,7 +3030,7 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
||||||
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
||||||
|
|
||||||
if(halt or not regeneration_required):
|
if(vars.abort or halt or not regeneration_required):
|
||||||
break
|
break
|
||||||
print("(regeneration triggered)")
|
print("(regeneration triggered)")
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,15 @@ def warper_callback(logits) -> np.array:
|
||||||
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
|
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
|
||||||
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
|
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
|
||||||
|
|
||||||
|
def settings_callback() -> dict:
|
||||||
|
return {
|
||||||
|
"top_p": 0.9,
|
||||||
|
"temp": 0.5,
|
||||||
|
"top_k": 0,
|
||||||
|
"tfs": 1.0,
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
def started_compiling_callback() -> None:
|
def started_compiling_callback() -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -541,7 +550,7 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
out_axes=["shard", "batch", ...],
|
out_axes=["shard", "batch", ...],
|
||||||
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
||||||
)
|
)
|
||||||
def generate_dynamic(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True):
|
def generate_dynamic(self, ctx, ctx_length, gen_length, numseqs, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True):
|
||||||
assert excluded_world_info is not None
|
assert excluded_world_info is not None
|
||||||
assert not return_logits
|
assert not return_logits
|
||||||
assert gen_length.ndim == 1
|
assert gen_length.ndim == 1
|
||||||
|
@ -560,7 +569,6 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
]
|
]
|
||||||
for i in range(numseqs)
|
for i in range(numseqs)
|
||||||
]
|
]
|
||||||
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
|
|
||||||
n_generated = 0
|
n_generated = 0
|
||||||
regeneration_required = False
|
regeneration_required = False
|
||||||
halt = False
|
halt = False
|
||||||
|
@ -576,6 +584,8 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
logits = warper_callback(logits)
|
logits = warper_callback(logits)
|
||||||
for i in range(numseqs):
|
for i in range(numseqs):
|
||||||
sample_data[i][2] = logits[i]
|
sample_data[i][2] = logits[i]
|
||||||
|
sampler_options = settings_callback()
|
||||||
|
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
|
||||||
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
||||||
n_generated += 1
|
n_generated += 1
|
||||||
for i in range(numseqs):
|
for i in range(numseqs):
|
||||||
|
@ -611,11 +621,6 @@ class PenalizingCausalTransformer(CausalTransformer):
|
||||||
|
|
||||||
def infer_dynamic(
|
def infer_dynamic(
|
||||||
context: np.array,
|
context: np.array,
|
||||||
top_p=0.9,
|
|
||||||
temp=0.5,
|
|
||||||
top_k=0,
|
|
||||||
tfs=1.0,
|
|
||||||
repetition_penalty=1.0,
|
|
||||||
numseqs=1,
|
numseqs=1,
|
||||||
gen_len=80,
|
gen_len=80,
|
||||||
soft_embeddings: Optional[np.array] = None,
|
soft_embeddings: Optional[np.array] = None,
|
||||||
|
@ -634,19 +639,11 @@ def infer_dynamic(
|
||||||
padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id)
|
padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id)
|
||||||
batched_tokens = np.array([padded_tokens] * total_batch)
|
batched_tokens = np.array([padded_tokens] * total_batch)
|
||||||
samples = []
|
samples = []
|
||||||
generator_params = {
|
|
||||||
"temp": float(temp),
|
|
||||||
"top_p": float(top_p),
|
|
||||||
"tfs": float(tfs),
|
|
||||||
"repetition_penalty": float(repetition_penalty),
|
|
||||||
"top_k": int(top_k),
|
|
||||||
}
|
|
||||||
output = network.generate_dynamic(
|
output = network.generate_dynamic(
|
||||||
batched_tokens,
|
batched_tokens,
|
||||||
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
|
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
|
||||||
np.ones(total_batch, dtype=np.uint32) * gen_len,
|
np.ones(total_batch, dtype=np.uint32) * gen_len,
|
||||||
numseqs,
|
numseqs,
|
||||||
generator_params,
|
|
||||||
soft_embeddings=soft_embeddings,
|
soft_embeddings=soft_embeddings,
|
||||||
excluded_world_info=excluded_world_info,
|
excluded_world_info=excluded_world_info,
|
||||||
use_callback=use_callback,
|
use_callback=use_callback,
|
||||||
|
|
Loading…
Reference in New Issue