mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Merge pull request #58 from VE-FORBRYDERNE/xmap
Dynamic TPU backend xmaps
This commit is contained in:
130
aiserver.py
130
aiserver.py
@ -22,7 +22,7 @@ import packaging
|
|||||||
import contextlib
|
import contextlib
|
||||||
import traceback
|
import traceback
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Callable, TypeVar, Union, Dict, Set, List
|
from typing import Any, Callable, TypeVar, Tuple, Union, Dict, Set, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import html
|
import html
|
||||||
@ -993,7 +993,7 @@ else:
|
|||||||
-1,
|
-1,
|
||||||
tpu_mtj_backend.params["d_model"],
|
tpu_mtj_backend.params["d_model"],
|
||||||
)
|
)
|
||||||
vars.sp = tensor
|
vars.sp = tpu_mtj_backend.shard_xmap(tensor)
|
||||||
soft_tokens = np.arange(
|
soft_tokens = np.arange(
|
||||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"],
|
||||||
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
tpu_mtj_backend.params["n_vocab"] + tpu_mtj_backend.params["n_vocab_padding"] + vars.sp_length,
|
||||||
@ -1001,6 +1001,49 @@ else:
|
|||||||
)
|
)
|
||||||
return soft_tokens
|
return soft_tokens
|
||||||
|
|
||||||
|
def tpumtjgenerate_warper_callback(scores) -> "np.array":
|
||||||
|
scores_shape = scores.shape
|
||||||
|
scores_list = scores.tolist()
|
||||||
|
vars.lua_koboldbridge.logits = vars.lua_state.table()
|
||||||
|
for r, row in enumerate(scores_list):
|
||||||
|
vars.lua_koboldbridge.logits[r+1] = vars.lua_state.table(*row)
|
||||||
|
vars.lua_koboldbridge.vocab_size = scores_shape[-1]
|
||||||
|
|
||||||
|
execute_genmod()
|
||||||
|
|
||||||
|
scores = np.array(
|
||||||
|
tuple(tuple(row.values()) for row in vars.lua_koboldbridge.logits.values()),
|
||||||
|
dtype=scores.dtype,
|
||||||
|
)
|
||||||
|
assert scores.shape == scores_shape
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def tpumtjgenerate_stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
|
||||||
|
vars.generated_tkns += 1
|
||||||
|
|
||||||
|
assert len(excluded_world_info) == len(generated)
|
||||||
|
regeneration_required = vars.lua_koboldbridge.regeneration_required
|
||||||
|
halt = not vars.lua_koboldbridge.generating or vars.generated_tkns >= vars.genamt
|
||||||
|
vars.lua_koboldbridge.regeneration_required = False
|
||||||
|
|
||||||
|
global past
|
||||||
|
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
vars.lua_koboldbridge.generated[i+1][vars.generated_tkns] = int(generated[i, tpu_mtj_backend.params["seq"] + n_generated - 1].item())
|
||||||
|
|
||||||
|
if(not vars.dynamicscan or halt):
|
||||||
|
return excluded_world_info, regeneration_required, halt
|
||||||
|
|
||||||
|
for i, t in enumerate(generated):
|
||||||
|
decoded = tokenizer.decode(past[i]) + tokenizer.decode(t[tpu_mtj_backend.params["seq"] : tpu_mtj_backend.params["seq"] + n_generated])
|
||||||
|
_, found = checkworldinfo(decoded, force_use_txt=True)
|
||||||
|
found -= excluded_world_info[i]
|
||||||
|
if(len(found) != 0):
|
||||||
|
regeneration_required = True
|
||||||
|
break
|
||||||
|
return excluded_world_info, regeneration_required, halt
|
||||||
|
|
||||||
# If we're running Colab or OAI, we still need a tokenizer.
|
# If we're running Colab or OAI, we still need a tokenizer.
|
||||||
if(vars.model == "Colab"):
|
if(vars.model == "Colab"):
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
@ -1013,6 +1056,8 @@ else:
|
|||||||
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
print("{0}Initializing Mesh Transformer JAX, please wait...{1}".format(colors.PURPLE, colors.END))
|
||||||
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
assert vars.model == "TPUMeshTransformerGPTJ" and vars.custmodpth and os.path.isdir(vars.custmodpth)
|
||||||
import tpu_mtj_backend
|
import tpu_mtj_backend
|
||||||
|
tpu_mtj_backend.warper_callback = tpumtjgenerate_warper_callback
|
||||||
|
tpu_mtj_backend.stopping_callback = tpumtjgenerate_stopping_callback
|
||||||
tpu_mtj_backend.load_model(vars.custmodpth)
|
tpu_mtj_backend.load_model(vars.custmodpth)
|
||||||
vars.allowsp = True
|
vars.allowsp = True
|
||||||
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
vars.modeldim = int(tpu_mtj_backend.params["d_model"])
|
||||||
@ -1020,12 +1065,14 @@ else:
|
|||||||
soft_tokens = tpumtjgetsofttokens()
|
soft_tokens = tpumtjgetsofttokens()
|
||||||
threading.Thread( # Compile backend code in background
|
threading.Thread( # Compile backend code in background
|
||||||
target=tpu_mtj_backend.infer,
|
target=tpu_mtj_backend.infer,
|
||||||
args=(np.uint32((23403, 727, 20185)),),
|
args=(np.tile(np.uint32((23403, 727, 20185)), (vars.numseqs, 1)),),
|
||||||
kwargs={
|
kwargs={
|
||||||
"soft_embeddings": vars.sp,
|
"soft_embeddings": vars.sp,
|
||||||
"soft_tokens": soft_tokens,
|
"soft_tokens": soft_tokens,
|
||||||
|
"use_callback": False,
|
||||||
"gen_len": 1,
|
"gen_len": 1,
|
||||||
"numseqs": vars.numseqs,
|
"numseqs": vars.numseqs,
|
||||||
|
"excluded_world_info": list(set() for _ in range(vars.numseqs)),
|
||||||
},
|
},
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
@ -2890,32 +2937,69 @@ def sendtocolab(txt, min, max):
|
|||||||
# Send text to TPU mesh transformer backend
|
# Send text to TPU mesh transformer backend
|
||||||
#==================================================================#
|
#==================================================================#
|
||||||
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
||||||
|
vars.generated_tkns = 0
|
||||||
|
|
||||||
if(found_entries is None):
|
if(found_entries is None):
|
||||||
found_entries = set()
|
found_entries = set()
|
||||||
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
found_entries = tuple(found_entries.copy() for _ in range(vars.numseqs))
|
||||||
|
|
||||||
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
print("{0}Min:{1}, Max:{2}, Txt:{3}{4}".format(colors.YELLOW, minimum, maximum, tokenizer.decode(txt), colors.END))
|
||||||
|
|
||||||
|
vars._actions = vars.actions
|
||||||
|
vars._prompt = vars.prompt
|
||||||
|
if(vars.dynamicscan):
|
||||||
|
vars._actions = vars._actions.copy()
|
||||||
|
|
||||||
# Submit input text to generator
|
# Submit input text to generator
|
||||||
try:
|
try:
|
||||||
if(vars.dynamicscan):
|
context = np.tile(np.uint32(txt), (vars.numseqs, 1))
|
||||||
raise ValueError("Dynamic world info scanning is not supported by the TPU backend yet")
|
|
||||||
|
|
||||||
soft_tokens = tpumtjgetsofttokens()
|
soft_tokens = tpumtjgetsofttokens()
|
||||||
|
|
||||||
genout = tpool.execute(
|
global past
|
||||||
tpu_mtj_backend.infer,
|
past = np.empty((vars.numseqs, 0), dtype=np.uint32)
|
||||||
np.uint32(txt),
|
|
||||||
gen_len = maximum-minimum+1,
|
while(True):
|
||||||
temp=vars.temp,
|
genout, n_generated, regeneration_required, halt = tpool.execute(
|
||||||
top_p=vars.top_p,
|
tpu_mtj_backend.infer,
|
||||||
top_k=vars.top_k,
|
context,
|
||||||
tfs=vars.tfs,
|
gen_len = maximum-minimum+1,
|
||||||
numseqs=vars.numseqs,
|
temp=vars.temp,
|
||||||
repetition_penalty=vars.rep_pen,
|
top_p=vars.top_p,
|
||||||
soft_embeddings=vars.sp,
|
top_k=vars.top_k,
|
||||||
soft_tokens=soft_tokens,
|
tfs=vars.tfs,
|
||||||
)
|
numseqs=vars.numseqs,
|
||||||
|
repetition_penalty=vars.rep_pen,
|
||||||
|
soft_embeddings=vars.sp,
|
||||||
|
soft_tokens=soft_tokens,
|
||||||
|
excluded_world_info=found_entries,
|
||||||
|
)
|
||||||
|
|
||||||
|
past = np.pad(past, ((0, 0), (0, n_generated)))
|
||||||
|
for r in range(vars.numseqs):
|
||||||
|
for c in range(vars.lua_koboldbridge.generated_cols):
|
||||||
|
assert vars.lua_koboldbridge.generated[r+1][c+1] is not None
|
||||||
|
past[r, c] = vars.lua_koboldbridge.generated[r+1][c+1]
|
||||||
|
|
||||||
|
if(halt or not regeneration_required):
|
||||||
|
break
|
||||||
|
print("(regeneration triggered)")
|
||||||
|
|
||||||
|
encoded = []
|
||||||
|
for i in range(vars.numseqs):
|
||||||
|
txt = tokenizer.decode(past[i])
|
||||||
|
winfo, mem, anotetxt, _found_entries = calcsubmitbudgetheader(txt, force_use_txt=True)
|
||||||
|
found_entries[i].update(_found_entries)
|
||||||
|
txt, _, _ = calcsubmitbudget(len(vars._actions), winfo, mem, anotetxt, vars._actions, submission=txt)
|
||||||
|
encoded.append(np.array(txt, dtype=np.uint32))
|
||||||
|
max_length = len(max(encoded, key=len))
|
||||||
|
encoded = np.stack(tuple(np.pad(e, (max_length - len(e), 0), constant_values=tpu_mtj_backend.pad_token_id) for e in encoded))
|
||||||
|
context = np.concatenate(
|
||||||
|
(
|
||||||
|
encoded,
|
||||||
|
past,
|
||||||
|
),
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if(issubclass(type(e), lupa.LuaError)):
|
if(issubclass(type(e), lupa.LuaError)):
|
||||||
@ -2931,10 +3015,10 @@ def tpumtjgenerate(txt, minimum, maximum, found_entries=None):
|
|||||||
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
|
print("{0}{1}{2}".format(colors.RED, traceback.format_exc().replace("\033", ""), colors.END), file=sys.stderr)
|
||||||
set_aibusy(0)
|
set_aibusy(0)
|
||||||
return
|
return
|
||||||
|
|
||||||
for i in range(vars.numseqs):
|
for i in range(vars.numseqs):
|
||||||
vars.lua_koboldbridge.generated[i+1] = vars.lua_state.table(*genout[i].tolist())
|
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(past[i])
|
||||||
vars.lua_koboldbridge.outputs[i+1] = tokenizer.decode(genout[i])
|
genout = past
|
||||||
|
|
||||||
execute_outmod()
|
execute_outmod()
|
||||||
if(vars.lua_koboldbridge.regeneration_required):
|
if(vars.lua_koboldbridge.regeneration_required):
|
||||||
@ -4005,7 +4089,7 @@ def spRequest(filename):
|
|||||||
-1,
|
-1,
|
||||||
tpu_mtj_backend.params["d_model"],
|
tpu_mtj_backend.params["d_model"],
|
||||||
)
|
)
|
||||||
vars.sp = np.float32(tensor)
|
vars.sp = tpu_mtj_backend.shard_xmap(np.float32(tensor))
|
||||||
else:
|
else:
|
||||||
vars.sp = torch.from_numpy(tensor)
|
vars.sp = torch.from_numpy(tensor)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
|
||||||
import progressbar
|
import progressbar
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
@ -20,6 +20,13 @@ from mesh_transformer.transformer_shard import CausalTransformer, CausalTransfor
|
|||||||
params: Dict[str, Any] = {}
|
params: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def warper_callback(logits) -> np.array:
|
||||||
|
raise NotImplementedError("`tpu_mtj_backend.warper_callback()` needs to be defined")
|
||||||
|
|
||||||
|
def stopping_callback(generated, n_generated, excluded_world_info) -> Tuple[List[set], bool, bool]:
|
||||||
|
raise NotImplementedError("`tpu_mtj_backend.stopping_callback()` needs to be defined")
|
||||||
|
|
||||||
|
|
||||||
def show_spinner():
|
def show_spinner():
|
||||||
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength, widgets=[progressbar.Timer(), ' ', progressbar.BouncingBar(left='[', right=']', marker='█')])
|
||||||
i = 0
|
i = 0
|
||||||
@ -28,6 +35,31 @@ def show_spinner():
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
__F = TypeVar("__F", bound=Callable)
|
||||||
|
__T = TypeVar("__T")
|
||||||
|
|
||||||
|
def __move_xmap(f: __F, out_axis: str) -> __F:
|
||||||
|
return maps.xmap(
|
||||||
|
f,
|
||||||
|
in_axes=(["shard", ...], ["batch", ...]),
|
||||||
|
out_axes=[out_axis, ...],
|
||||||
|
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __shard_xmap(batch_dim=1):
|
||||||
|
xmap = __move_xmap(lambda s, b: s, "shard")
|
||||||
|
def inner(x: __T) -> __T:
|
||||||
|
return xmap(x, np.empty(batch_dim))
|
||||||
|
return inner
|
||||||
|
|
||||||
|
def __batch_xmap(shard_dim=1):
|
||||||
|
xmap = __move_xmap(lambda s, b: b, "batch")
|
||||||
|
def inner(x: __T) -> __T:
|
||||||
|
return xmap(np.empty(shard_dim), x)
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
||||||
'''
|
'''
|
||||||
This gets called by generate_loop_fn to apply repetition penalty
|
This gets called by generate_loop_fn to apply repetition penalty
|
||||||
@ -38,19 +70,20 @@ def apply_repetition_penalty(logits, tokens, repetition_penalty):
|
|||||||
# logits array; e.g.
|
# logits array; e.g.
|
||||||
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
# if logits is [77, 5, 3, 98] and tokens is [0, 1, 2, 3, 2, 3, 1],
|
||||||
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
# then penalty_logits will be [77, 5, 3, 98, 3, 98, 5]
|
||||||
penalty_logits = jnp.take(logits, tokens)
|
penalty_logits = np.take(logits, tokens)
|
||||||
# Divide positive values by repetition_penalty and multiply negative
|
# Divide positive values by repetition_penalty and multiply negative
|
||||||
# values by repetition_penalty (the academic publication that described
|
# values by repetition_penalty (the academic publication that described
|
||||||
# this technique actually just only divided, but that would cause tokens
|
# this technique actually just only divided, but that would cause tokens
|
||||||
# with negative logits to become more likely, which is obviously wrong)
|
# with negative logits to become more likely, which is obviously wrong)
|
||||||
penalty_logits = jnp.where(
|
penalty_logits = np.where(
|
||||||
penalty_logits > 0,
|
penalty_logits > 0,
|
||||||
penalty_logits/repetition_penalty,
|
penalty_logits/repetition_penalty,
|
||||||
penalty_logits*repetition_penalty,
|
penalty_logits*repetition_penalty,
|
||||||
)
|
)
|
||||||
# Finally, put those penalized logit values back into their original
|
# Finally, put those penalized logit values back into their original
|
||||||
# positions in the logits array
|
# positions in the logits array
|
||||||
return logits.at[tokens].set(penalty_logits)
|
logits[tokens] = penalty_logits
|
||||||
|
return logits
|
||||||
|
|
||||||
def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
||||||
'''
|
'''
|
||||||
@ -66,15 +99,16 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
|||||||
# in the sorted logits array we want to remove and False for ones
|
# in the sorted logits array we want to remove and False for ones
|
||||||
# we want to keep, in this case the first top_k elements will be
|
# we want to keep, in this case the first top_k elements will be
|
||||||
# False and the rest will be True
|
# False and the rest will be True
|
||||||
sorted_indices_to_remove = jnp.arange(len(logits)) >= top_k
|
sorted_indices_to_remove = np.arange(len(logits)) >= top_k
|
||||||
# Unsort the logits array back to its original configuration and
|
# Unsort the logits array back to its original configuration and
|
||||||
# remove tokens we need to remove
|
# remove tokens we need to remove
|
||||||
_, indices_to_remove = jax.lax.sort_key_val(
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
jnp.argsort(-logits),
|
np.argsort(-logits),
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return np.where(indices_to_remove, -np.inf, logits)
|
||||||
logits = jax.lax.cond(top_k > 0, top_k_filter, lambda x: x, logits)
|
if top_k > 0:
|
||||||
|
logits = top_k_filter(logits)
|
||||||
# Top-p (after sorting the remaining tokens again in descending order of
|
# Top-p (after sorting the remaining tokens again in descending order of
|
||||||
# logit, remove the ones that have cumulative softmax probability
|
# logit, remove the ones that have cumulative softmax probability
|
||||||
# greater than p)
|
# greater than p)
|
||||||
@ -83,109 +117,167 @@ def kobold_sample(key, logits, top_p=0.9, temp=0.5, top_k=0, tfs=1.0):
|
|||||||
# with e (Euler's number) to the power of that element, and divide
|
# with e (Euler's number) to the power of that element, and divide
|
||||||
# each element of the new array by the sum of the elements in the
|
# each element of the new array by the sum of the elements in the
|
||||||
# new array
|
# new array
|
||||||
sorted_logits = -jnp.sort(-logits)
|
sorted_logits = -np.sort(-logits)
|
||||||
probabilities = jax.nn.softmax(sorted_logits)
|
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
|
||||||
# Calculate cumulative_probabilities as the prefix-sum array of
|
# Calculate cumulative_probabilities as the prefix-sum array of
|
||||||
# probabilities
|
# probabilities
|
||||||
cumulative_probabilities = jnp.cumsum(probabilities, axis=-1)
|
cumulative_probabilities = np.cumsum(probabilities, axis=-1)
|
||||||
# We want to remove tokens with cumulative probability higher
|
# We want to remove tokens with cumulative probability higher
|
||||||
# than top_p
|
# than top_p
|
||||||
sorted_indices_to_remove = cumulative_probabilities > top_p
|
sorted_indices_to_remove = cumulative_probabilities > top_p
|
||||||
# Don't ever remove the token with the highest logit, even if
|
# Don't ever remove the token with the highest logit, even if
|
||||||
# the probability is higher than top_p
|
# the probability is higher than top_p
|
||||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
sorted_indices_to_remove[0] = False
|
||||||
# Unsort and remove
|
# Unsort and remove
|
||||||
_, indices_to_remove = jax.lax.sort_key_val(
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
jnp.argsort(-logits),
|
np.argsort(-logits),
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return np.where(indices_to_remove, -np.inf, logits)
|
||||||
logits = jax.lax.cond(top_p < 1.0, top_p_filter, lambda x: x, logits)
|
if top_p < 1.0:
|
||||||
|
logits = top_p_filter(logits)
|
||||||
# Tail free sampling (basically top-p a second time on remaining tokens
|
# Tail free sampling (basically top-p a second time on remaining tokens
|
||||||
# except it's the "cumulative normalized absolute second finite
|
# except it's the "cumulative normalized absolute second finite
|
||||||
# differences of the softmax probabilities" instead of just the
|
# differences of the softmax probabilities" instead of just the
|
||||||
# cumulative softmax probabilities)
|
# cumulative softmax probabilities)
|
||||||
def tail_free_filter(logits):
|
def tail_free_filter(logits):
|
||||||
# Sort in descending order
|
# Sort in descending order
|
||||||
sorted_logits = -jnp.sort(-logits)
|
sorted_logits = -np.sort(-logits)
|
||||||
# Softmax again
|
# Softmax again
|
||||||
probabilities = jax.nn.softmax(sorted_logits)
|
probabilities = np.array(jax.nn.softmax(sorted_logits), copy=True)
|
||||||
# Calculate the second finite differences of that array (i.e.
|
# Calculate the second finite differences of that array (i.e.
|
||||||
# calculate the difference array and then calculate the difference
|
# calculate the difference array and then calculate the difference
|
||||||
# array of the difference array)
|
# array of the difference array)
|
||||||
d2 = jnp.diff(jnp.diff(probabilities))
|
d2 = np.diff(np.diff(probabilities))
|
||||||
# Get the absolute values of all those second finite differences
|
# Get the absolute values of all those second finite differences
|
||||||
d2 = jnp.abs(d2)
|
d2 = np.abs(d2)
|
||||||
# Normalize (all elements in the array are divided by the sum of the
|
# Normalize (all elements in the array are divided by the sum of the
|
||||||
# array's elements)
|
# array's elements)
|
||||||
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
d2 = d2 / d2.sum(axis=-1, keepdims=True)
|
||||||
# Get the prefix-sum array
|
# Get the prefix-sum array
|
||||||
cumulative_d2 = jnp.cumsum(d2, axis=-1)
|
cumulative_d2 = np.cumsum(d2, axis=-1)
|
||||||
# We will remove the tokens with a cumulative normalized absolute
|
# We will remove the tokens with a cumulative normalized absolute
|
||||||
# second finite difference larger than the TFS value
|
# second finite difference larger than the TFS value
|
||||||
sorted_indices_to_remove = cumulative_d2 > tfs
|
sorted_indices_to_remove = cumulative_d2 > tfs
|
||||||
# Don't remove the token with the highest logit
|
# Don't remove the token with the highest logit
|
||||||
sorted_indices_to_remove = sorted_indices_to_remove.at[0].set(False)
|
sorted_indices_to_remove[0] = False
|
||||||
# Since the d2 array has two fewer elements than the logits array,
|
# Since the d2 array has two fewer elements than the logits array,
|
||||||
# we'll add two extra Trues to the end
|
# we'll add two extra Trues to the end
|
||||||
sorted_indices_to_remove = jnp.pad(
|
sorted_indices_to_remove = np.pad(
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
(0, 2),
|
(0, 2),
|
||||||
constant_values=True,
|
constant_values=True,
|
||||||
)
|
)
|
||||||
# Unsort and remove
|
# Unsort and remove
|
||||||
_, indices_to_remove = jax.lax.sort_key_val(
|
_, indices_to_remove = jax.lax.sort_key_val(
|
||||||
jnp.argsort(-logits),
|
np.argsort(-logits),
|
||||||
sorted_indices_to_remove,
|
sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
return jnp.where(indices_to_remove, -jnp.inf, logits)
|
return np.where(indices_to_remove, -np.inf, logits)
|
||||||
logits = jax.lax.cond(tfs < 1.0, tail_free_filter, lambda x: x, logits)
|
if tfs < 1.0:
|
||||||
|
logits = tail_free_filter(logits)
|
||||||
# Temperature (just divide the logits by the temperature)
|
# Temperature (just divide the logits by the temperature)
|
||||||
def temp_filter(logits):
|
logits /= temp
|
||||||
return logits / temp
|
|
||||||
logits = jax.lax.cond(True, temp_filter, lambda x: x, logits)
|
|
||||||
# Finally, pick one token using the softmax thingy again (it gives
|
# Finally, pick one token using the softmax thingy again (it gives
|
||||||
# an array whose elements sum to 1 so it can be used nicely as a
|
# an array whose elements sum to 1 so it can be used nicely as a
|
||||||
# probability distribution)
|
# probability distribution)
|
||||||
return jax.random.categorical(key, logits, -1).astype(jnp.uint32)[jnp.newaxis]
|
return jax.random.categorical(key, logits, -1).astype(np.uint32)
|
||||||
|
|
||||||
pad_token_id = 50256
|
pad_token_id = 50256
|
||||||
|
|
||||||
|
def sample_func(data, key, numseqs_aux, badwords, repetition_penalty, sampler_options):
|
||||||
|
numseqs = numseqs_aux.shape[0]
|
||||||
|
gi = data[0][1]
|
||||||
|
def sample_loop_fn(carry):
|
||||||
|
generated, generated_index, logits, _ = carry[0][0]
|
||||||
|
sample_key = carry[1]
|
||||||
|
# Get the pseudo-random number generator key that will
|
||||||
|
# be used by kobold_sample to randomly pick a token
|
||||||
|
sample_key, new_key = jax.random.split(sample_key, num=2)
|
||||||
|
# Apply repetition penalty to all tokens that are
|
||||||
|
# currently inside the "generated" array
|
||||||
|
logits = apply_repetition_penalty(
|
||||||
|
logits,
|
||||||
|
generated,
|
||||||
|
repetition_penalty
|
||||||
|
)
|
||||||
|
# Remove any tokens in the badwords list by setting
|
||||||
|
# their logits to negative infinity which effectively
|
||||||
|
# makes their probabilities of being chosen zero
|
||||||
|
logits[badwords] = -np.inf
|
||||||
|
# Use the sampler (kobold_sample) to pick one token
|
||||||
|
# based on the logits array as a 0D uint32 array
|
||||||
|
# (higher logit means higher probability of being
|
||||||
|
# picked, non-linearly)
|
||||||
|
next_token = kobold_sample(
|
||||||
|
sample_key,
|
||||||
|
logits,
|
||||||
|
**sampler_options,
|
||||||
|
)
|
||||||
|
# Remember what token was picked
|
||||||
|
generated[generated_index] = next_token
|
||||||
|
generated_index += 1
|
||||||
|
# Re-pack the current sample_loop_fn's state so we can
|
||||||
|
# get back the same variables the next time
|
||||||
|
carry[0][0] = [generated, generated_index, logits, next_token]
|
||||||
|
carry[0].append(carry[0].pop(0))
|
||||||
|
return carry[0], new_key
|
||||||
|
# return jax.lax.while_loop(
|
||||||
|
# lambda carry: carry[0][0][1] == gi,
|
||||||
|
# sample_loop_fn,
|
||||||
|
# (data, key),
|
||||||
|
# )
|
||||||
|
carry = (data, key)
|
||||||
|
while carry[0][0][1] == gi:
|
||||||
|
carry = sample_loop_fn(carry)
|
||||||
|
return carry
|
||||||
|
|
||||||
class PenalizingCausalTransformer(CausalTransformer):
|
class PenalizingCausalTransformer(CausalTransformer):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
# Initialize
|
# Initialize
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
def generate(state, key, ctx, ctx_length, gen_length, numseqs_aux, sampler_options, soft_embeddings=None):
|
def generate_initial(state, key, ctx, ctx_length, numseqs_aux, soft_embeddings=None):
|
||||||
numseqs = numseqs_aux.shape[0]
|
numseqs = numseqs_aux.shape[0]
|
||||||
# These are the tokens that we don't want the AI to ever write
|
@hk.transform
|
||||||
self.badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
def generate_initial_inner(context, ctx_length):
|
||||||
def generate_sample(context, ctx_length):
|
|
||||||
# Give the initial context to the transformer
|
# Give the initial context to the transformer
|
||||||
transformer = CausalTransformerShard(config)
|
transformer = CausalTransformerShard(config)
|
||||||
def generate_initial_scan_fn(sequence_index, _):
|
def generate_initial_scan_fn(sequence_index, c):
|
||||||
_, initial_state = transformer.generate_initial(context, ctx_length, soft_embeddings=soft_embeddings)
|
_, initial_state = transformer.generate_initial(c, ctx_length, soft_embeddings=soft_embeddings)
|
||||||
# The "generated" array will contain the tokens from the
|
|
||||||
# context as well as the tokens picked by the sampler at
|
|
||||||
# each stage, padded with a bunch of 50256s, so we know
|
|
||||||
# which tokens have to be repetition penalized
|
|
||||||
generated = jnp.pad(context, (0, config["seq"]), constant_values=pad_token_id) # Let it start off with just the 2048 context tokens, plus some 50256s which will be eventually filled with sampler-chosen tokens
|
|
||||||
generated_index = config["seq"]
|
generated_index = config["seq"]
|
||||||
# Add that information to generate_loop_fn's starting state
|
# Add that information to generate_loop_fn's starting state
|
||||||
initial_state = (generated, generated_index, sequence_index) + initial_state
|
initial_state = (jnp.empty(config["n_vocab"], dtype=jnp.float32), generated_index, sequence_index) + initial_state
|
||||||
return sequence_index+1, initial_state
|
return sequence_index+1, initial_state
|
||||||
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, None, numseqs)
|
_, initial_states = jax.lax.scan(generate_initial_scan_fn, 0, context, numseqs)
|
||||||
sample_key = initial_states[-1][0]
|
sample_key = initial_states[-1][0]
|
||||||
initial_states = list(jax.tree_map(lambda x: x[i], initial_states[:-1]) for i in range(numseqs))
|
initial_states = list(list(jax.tree_map(lambda x: x[i], initial_states[:-1])) for i in range(numseqs))
|
||||||
# Get repetition penalty from the arguments
|
return initial_states, sample_key
|
||||||
repetition_penalty = sampler_options.pop('repetition_penalty', None)
|
return generate_initial_inner.apply(state["params"], key, ctx, ctx_length)
|
||||||
|
self.generate_initial_xmap = jax.experimental.maps.xmap(
|
||||||
|
fun=generate_initial,
|
||||||
|
in_axes=(
|
||||||
|
["shard", ...],
|
||||||
|
["batch", ...],
|
||||||
|
["batch", ...],
|
||||||
|
["batch", ...],
|
||||||
|
["batch", ...],
|
||||||
|
["shard", ...],
|
||||||
|
),
|
||||||
|
out_axes=["shard", "batch", ...],
|
||||||
|
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
||||||
|
)
|
||||||
|
def generate_once(data, state, numseqs_aux, soft_embeddings=None):
|
||||||
|
numseqs = numseqs_aux.shape[0]
|
||||||
|
@hk.without_apply_rng
|
||||||
|
@hk.transform
|
||||||
|
def generate_once_inner():
|
||||||
|
gi = data[0][1]
|
||||||
|
# Give the initial context to the transformer
|
||||||
|
transformer = CausalTransformerShard(config)
|
||||||
# This is the main generation loop
|
# This is the main generation loop
|
||||||
def generate_loop_fn(carry):
|
def generate_loop_fn(carry):
|
||||||
# Unpack current generate_loop_fn state
|
# Unpack current generate_loop_fn state
|
||||||
generated, generated_index, sequence_index, next_token, decode_state = carry[0][0]
|
_, generated_index, sequence_index, next_token, decode_state = carry[0][0]
|
||||||
sample_key = carry[1]
|
|
||||||
# Get the pseudo-random number generator key that will
|
|
||||||
# be used by kobold_sample to randomly pick a token
|
|
||||||
sample_key, new_key = jax.random.split(sample_key)
|
|
||||||
# Give the context to the model and get the logits it
|
# Give the context to the model and get the logits it
|
||||||
# spits out
|
# spits out
|
||||||
# (a 2D array with 1 row and 50400 columns representing
|
# (a 2D array with 1 row and 50400 columns representing
|
||||||
@ -196,75 +288,78 @@ class PenalizingCausalTransformer(CausalTransformer):
|
|||||||
# Verify that logits does indeed have that many rows and
|
# Verify that logits does indeed have that many rows and
|
||||||
# columns (if you get an error here, pray for mercy)
|
# columns (if you get an error here, pray for mercy)
|
||||||
assert logits.shape == (1, config["n_vocab"])
|
assert logits.shape == (1, config["n_vocab"])
|
||||||
|
assert logits.dtype == jnp.float32
|
||||||
# Flatten it into a 1D array to make it easier to use
|
# Flatten it into a 1D array to make it easier to use
|
||||||
logits = logits[0]
|
logits = logits[0]
|
||||||
# Apply repetition penalty to all tokens that are
|
|
||||||
# currently inside the "generated" array
|
|
||||||
if repetition_penalty is not None:
|
|
||||||
logits = apply_repetition_penalty(
|
|
||||||
logits,
|
|
||||||
generated,
|
|
||||||
repetition_penalty
|
|
||||||
)
|
|
||||||
# Remove any tokens in the badwords list by setting
|
|
||||||
# their logits to negative infinity which effectively
|
|
||||||
# makes their probabilities of being chosen zero
|
|
||||||
logits = logits.at[self.badwords].set(-jnp.inf)
|
|
||||||
# Use the sampler (kobold_sample) to pick one token
|
|
||||||
# based on the logits array as a 1D array with 1 element
|
|
||||||
# (higher logit means higher probability of being
|
|
||||||
# picked, non-linearly)
|
|
||||||
next_token = kobold_sample(
|
|
||||||
sample_key,
|
|
||||||
logits,
|
|
||||||
**sampler_options,
|
|
||||||
)
|
|
||||||
# Remember what token was picked
|
|
||||||
generated = generated.at[generated_index].set(next_token[0])
|
|
||||||
generated_index += 1
|
|
||||||
# Re-pack the current generate_loop_fn's state so we can
|
# Re-pack the current generate_loop_fn's state so we can
|
||||||
# get back the same variables the next time
|
# get back the same variables the next time
|
||||||
carry[0][0] = (generated, generated_index, sequence_index, next_token, new_state)
|
generated_index += 1
|
||||||
|
carry[0][0] = [logits, generated_index, sequence_index, next_token, new_state]
|
||||||
carry[0].append(carry[0].pop(0))
|
carry[0].append(carry[0].pop(0))
|
||||||
return carry[0], new_key
|
return carry[0],
|
||||||
final_state = jax.lax.while_loop(
|
return jax.lax.while_loop(
|
||||||
lambda carry: carry[0][0][1] - config["seq"] < gen_length,
|
lambda carry: carry[0][0][1] == gi,
|
||||||
generate_loop_fn,
|
generate_loop_fn,
|
||||||
(initial_states, sample_key),
|
(data,),
|
||||||
)
|
)
|
||||||
return final_state
|
return generate_once_inner.apply(state["params"])
|
||||||
generate_fn = hk.transform(generate_sample).apply
|
self.generate_once_xmap = jax.experimental.maps.xmap(
|
||||||
return generate_fn(state["params"], key, ctx, ctx_length)
|
fun=generate_once,
|
||||||
self.generate_xmap = jax.experimental.maps.xmap(
|
|
||||||
fun=generate,
|
|
||||||
in_axes=(
|
in_axes=(
|
||||||
|
["shard", "batch", ...],
|
||||||
["shard", ...],
|
["shard", ...],
|
||||||
["batch", ...],
|
["batch", ...],
|
||||||
["batch", ...],
|
|
||||||
["batch", ...],
|
|
||||||
["batch", ...],
|
|
||||||
["batch", ...],
|
|
||||||
["batch", ...],
|
|
||||||
["shard", ...],
|
["shard", ...],
|
||||||
),
|
),
|
||||||
out_axes=["shard", "batch", ...],
|
out_axes=["shard", "batch", ...],
|
||||||
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
axis_resources={'shard': 'mp', 'batch': 'dp'},
|
||||||
)
|
)
|
||||||
def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None):
|
def generate(self, ctx, ctx_length, gen_length, numseqs, sampler_options, return_logits=False, soft_embeddings=None, excluded_world_info=None, use_callback=True):
|
||||||
|
assert excluded_world_info is not None
|
||||||
assert not return_logits
|
assert not return_logits
|
||||||
|
assert gen_length.ndim == 1
|
||||||
|
assert soft_embeddings is not None
|
||||||
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
key = hk.PRNGSequence(random.randint(0, 2 ** 60))
|
||||||
batch_size = ctx.shape[0]
|
batch_size = ctx.shape[0]
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
return self.generate_xmap(
|
_numseqs_aux = jnp.empty((batch_size, numseqs), dtype=np.uint32)
|
||||||
self.state,
|
numseqs_aux = batch_xmap(_numseqs_aux)
|
||||||
jnp.array(key.take(batch_size)),
|
sample_data = [
|
||||||
ctx,
|
[
|
||||||
np.array(ctx_length, dtype=np.uint32),
|
np.pad(ctx[0][i], (0, params["seq"]), constant_values=pad_token_id),
|
||||||
np.array(gen_length, dtype=np.uint32),
|
params["seq"],
|
||||||
np.empty((batch_size, numseqs), dtype=np.uint8),
|
None,
|
||||||
sampler_options,
|
np.empty((), dtype=np.uint32),
|
||||||
soft_embeddings,
|
]
|
||||||
)
|
for i in range(numseqs)
|
||||||
|
]
|
||||||
|
repetition_penalty = sampler_options.pop("repetition_penalty", 1.0)
|
||||||
|
n_generated = 0
|
||||||
|
regeneration_required = False
|
||||||
|
halt = False
|
||||||
|
generate_data, sample_key = self.generate_initial_xmap(self.state, jnp.array(key.take(batch_size)), ctx, ctx_length, numseqs_aux, soft_embeddings)
|
||||||
|
sample_key = np.asarray(sample_key[0, 0])
|
||||||
|
while True:
|
||||||
|
generate_data, = self.generate_once_xmap(generate_data, self.state, numseqs_aux, soft_embeddings)
|
||||||
|
for i in range(numseqs):
|
||||||
|
sample_data[i][2] = np.array(generate_data[i][0][0, 0], copy=True)
|
||||||
|
if use_callback:
|
||||||
|
logits = np.float32(tuple(d[2] for d in sample_data))
|
||||||
|
logits = warper_callback(logits)
|
||||||
|
for i in range(numseqs):
|
||||||
|
sample_data[i][2] = logits[i]
|
||||||
|
sample_data, sample_key = sample_func(sample_data, sample_key, _numseqs_aux, badwords, repetition_penalty, sampler_options)
|
||||||
|
n_generated += 1
|
||||||
|
for i in range(numseqs):
|
||||||
|
generate_data[i][3] = np.tile(sample_data[i][0][sample_data[i][1]-1][np.newaxis, np.newaxis], (params["cores_per_replica"], 1, 1))
|
||||||
|
if use_callback:
|
||||||
|
generated = np.uint32(tuple(d[0] for d in sample_data))
|
||||||
|
excluded_world_info, regeneration_required, halt = stopping_callback(generated, n_generated, excluded_world_info)
|
||||||
|
if regeneration_required or halt:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return sample_data, n_generated, regeneration_required, halt
|
||||||
|
|
||||||
|
|
||||||
def infer(
|
def infer(
|
||||||
@ -278,35 +373,40 @@ def infer(
|
|||||||
gen_len=80,
|
gen_len=80,
|
||||||
soft_embeddings: Optional[np.array] = None,
|
soft_embeddings: Optional[np.array] = None,
|
||||||
soft_tokens: Optional[np.array] = None,
|
soft_tokens: Optional[np.array] = None,
|
||||||
) -> List[str]:
|
excluded_world_info = None,
|
||||||
|
use_callback=True,
|
||||||
|
) -> Tuple[List[np.array], int, bool, bool]:
|
||||||
|
assert excluded_world_info is not None
|
||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
total_batch = 1
|
total_batch = 1
|
||||||
tokens = context
|
tokens = context
|
||||||
if(soft_tokens is not None):
|
if(soft_tokens is not None):
|
||||||
tokens = np.uint32(np.concatenate((soft_tokens, tokens)))
|
tokens = np.uint32(np.concatenate((np.tile(soft_tokens, (tokens.shape[0], 1)), tokens), axis=-1))
|
||||||
provided_ctx = tokens.shape[0]
|
provided_ctx = tokens.shape[-1]
|
||||||
pad_amount = seq - provided_ctx
|
pad_amount = seq - provided_ctx
|
||||||
padded_tokens = np.pad(tokens, ((pad_amount, 0),), constant_values=pad_token_id)
|
padded_tokens = np.pad(tokens, ((0, 0), (pad_amount, 0)), constant_values=pad_token_id)
|
||||||
batched_tokens = np.array([padded_tokens] * total_batch)
|
batched_tokens = np.array([padded_tokens] * total_batch)
|
||||||
samples = []
|
samples = []
|
||||||
batched_generator_params = {
|
generator_params = {
|
||||||
"temp": temp * np.ones(total_batch),
|
"temp": float(temp),
|
||||||
"top_p": top_p * np.ones(total_batch),
|
"top_p": float(top_p),
|
||||||
"tfs": tfs * np.ones(total_batch),
|
"tfs": float(tfs),
|
||||||
"repetition_penalty": repetition_penalty * np.ones(total_batch),
|
"repetition_penalty": float(repetition_penalty),
|
||||||
"top_k": np.full(total_batch, top_k, dtype=np.uint32)
|
"top_k": int(top_k),
|
||||||
}
|
}
|
||||||
output = network.generate(
|
output = network.generate(
|
||||||
batched_tokens,
|
batched_tokens,
|
||||||
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
|
np.ones(total_batch, dtype=np.uint32) * provided_ctx,
|
||||||
np.ones(total_batch, dtype=np.uint32) * gen_len,
|
np.ones(total_batch, dtype=np.uint32) * gen_len,
|
||||||
numseqs,
|
numseqs,
|
||||||
batched_generator_params,
|
generator_params,
|
||||||
soft_embeddings=soft_embeddings,
|
soft_embeddings=soft_embeddings,
|
||||||
)[0]
|
excluded_world_info=excluded_world_info,
|
||||||
for o in output:
|
use_callback=use_callback,
|
||||||
samples.append(o[0][0, 0, params["seq"] : params["seq"] + gen_len])
|
)
|
||||||
return samples
|
for out in output[0]:
|
||||||
|
samples.append(out[0][params["seq"] : params["seq"] + gen_len])
|
||||||
|
return (samples,) + output[1:]
|
||||||
|
|
||||||
|
|
||||||
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None:
|
def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs) -> None:
|
||||||
@ -354,6 +454,14 @@ def load_model(path: str, driver_version="tpu_driver0.1_dev20210607", **kwargs)
|
|||||||
maps.thread_resources.env = thread_resources_env
|
maps.thread_resources.env = thread_resources_env
|
||||||
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
|
||||||
|
|
||||||
|
global shard_xmap, batch_xmap
|
||||||
|
shard_xmap = __shard_xmap()
|
||||||
|
batch_xmap = __batch_xmap(shard_dim=cores_per_replica)
|
||||||
|
|
||||||
|
global badwords
|
||||||
|
# These are the tokens that we don't want the AI to ever write
|
||||||
|
badwords = jnp.array([6880, 50256, 42496, 4613, 17414, 22039, 16410, 27, 29, 38430, 37922, 15913, 24618, 28725, 58, 47175, 36937, 26700, 12878, 16471, 37981, 5218, 29795, 13412, 45160, 3693, 49778, 4211, 20598, 36475, 33409, 44167, 32406, 29847, 29342, 42669, 685, 25787, 7359, 3784, 5320, 33994, 33490, 34516, 43734, 17635, 24293, 9959, 23785, 21737, 28401, 18161, 26358, 32509, 1279, 38155, 18189, 26894, 6927, 14610, 23834, 11037, 14631, 26933, 46904, 22330, 25915, 47934, 38214, 1875, 14692, 41832, 13163, 25970, 29565, 44926, 19841, 37250, 49029, 9609, 44438, 16791, 17816, 30109, 41888, 47527, 42924, 23984, 49074, 33717, 31161, 49082, 30138, 31175, 12240, 14804, 7131, 26076, 33250, 3556, 38381, 36338, 32756, 46581, 17912, 49146])
|
||||||
|
|
||||||
if not path.endswith("/"):
|
if not path.endswith("/"):
|
||||||
path += "/"
|
path += "/"
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user