RWKV Work

This commit is contained in:
somebody
2022-09-24 20:46:52 -05:00
parent 42d4f49730
commit fa443487e3
2 changed files with 188 additions and 3 deletions

5
.gitignore vendored
View File

@@ -29,6 +29,11 @@ flask_session
accelerate-disk-cache
.ipynb_checkpoints
# Temporary until HF port
!models/RWKV-v4
models/RWKV-v4/20B_tokenizer.json
models/RWKV-v4/models
# Ignore PyCharm project files.
.idea

View File

@@ -140,6 +140,7 @@ model_menu = {
["Untuned Fairseq Dense", "fsdlist", "", True],
["Untuned Bloom", "bloomlist", "", True],
["Untuned XGLM", "xglmlist", "", True],
["Untuned RWKV-4", "rwkvlist", "", True],
["Untuned GPT2", "gpt2list", "", True],
["Online Services", "apilist", "", True],
["Read Only (No AI)", "ReadOnly", "", False]
@@ -244,6 +245,19 @@ model_menu = {
["XGLM 564M", "facebook/xglm-564M", "4GB", False],
["Return to Main Menu", "mainmenu", "", True],
],
'rwkvlist': [
["RWKV-4 7B (GPU)", "RWKV-7B-GPU", "??GB", False],
["RWKV-4 7B (CPU)", "RWKV-7B-CPU", "??GB", False],
["RWKV-4 3B (GPU)", "RWKV-3B-GPU", "?GB", False],
["RWKV-4 3B (CPU)", "RWKV-3B-CPU", "?GB", False],
["RWKV-4 1.5B (GPU)", "RWKV-1B5-GPU", "9GB", False],
["RWKV-4 1.5B (CPU)", "RWKV-1B5-CPU", "6GB", False],
["RWKV-4 340M (GPU)", "RWKV-340M-GPU", "?GB", False],
["RWKV-4 340M (CPU)", "RWKV-340M-CPU", "?GB", False],
["RWKV-4 169M (GPU)", "RWKV-169M-GPU", "?GB", False],
["RWKV-4 169M (CPU)", "RWKV-169M-CPU", "?GB", False],
["Return to Main Menu", "mainmenu", "", True],
],
'apilist': [
["GooseAI API (requires API key)", "GooseAI", "", False],
["OpenAI API (requires API key)", "OAI", "", False],
@@ -1464,6 +1478,8 @@ def get_model_info(model, directory=""):
print(":(")
pass
key = True
elif model.startswith("RWKV"):
pass
elif model == 'ReadOnly':
pass
elif not utils.HAS_ACCELERATE and not torch.cuda.is_available():
@@ -2351,7 +2367,7 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
# If transformers model was selected & GPU available, ask to use CPU or GPU
if(koboldai_vars.model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
if(koboldai_vars.model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"] and not koboldai_vars.model.startswith("RWKV")):
koboldai_vars.allowsp = True
# Test for GPU support
@@ -2443,7 +2459,16 @@ def load_model(use_gpu=True, gpu_layers=None, disk_layers=None, initial_load=Fal
koboldai_vars.noai = True
# Start transformers and create pipeline
if(not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
if koboldai_vars.model.startswith("RWKV"):
_, model_class, device = koboldai_vars.model.split("-")
model, tokenizer = rwkv_init(
model_class=model_class,
use_gpu=(device == "GPU")
)
global breakmodel
import breakmodel
elif (not koboldai_vars.use_colab_tpu and koboldai_vars.model not in ["InferKit", "Colab", "API", "CLUSTER", "OAI", "GooseAI" , "ReadOnly", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]):
if(not koboldai_vars.noai):
logger.init("Transformers", status='Starting')
for m in ("GPTJModel", "XGLMModel"):
@@ -5001,7 +5026,13 @@ def core_generate(text: list, min: int, max: int, found_entries: set):
genout = result.encoded
already_generated += len(genout[0]) - 1
assert already_generated <= koboldai_vars.genamt
try:
assert already_generated <= koboldai_vars.genamt
except AssertionError:
print("AlreadyGenerated", already_generated)
print("genamt", koboldai_vars.genamt)
raise
if result.is_whole_generation:
break
@@ -5165,6 +5196,16 @@ def raw_generate(
return GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True
)
elif koboldai_vars.model.startswith("RWKV"):
batch_encoded = rwkv_raw_generate(
prompt_tokens=prompt_tokens,
max_new=max_new,
batch_count=batch_count,
gen_settings=gen_settings
)
return GenerationResult(
out_batches=batch_encoded, prompt=prompt_tokens, is_whole_generation=True, output_includes_prompt=True
)
# Torch HF
batch_encoded = torch_raw_generate(
@@ -5555,6 +5596,145 @@ def api_raw_generate(
genout = [obj["text"] for obj in js["results"]]
return np.array([tokenizer.encode(x) for x in genout])
def rwkv_raw_generate(
prompt_tokens: List[int],
max_new: int,
batch_count: int,
gen_settings: GenerationSettings,
):
import types
model.clear()
context = list(prompt_tokens)
input_length = len(prompt_tokens)
# TODO: Not needed every run? I think this is creating that huge wait time
# between generations.
init_state = types.SimpleNamespace()
for i in range(input_length):
x = context[:i+1]
if i == input_length - 1:
init_state.out = model.run(x)
else:
model.run(x)
model.save(init_state)
for ni, i in enumerate(range(input_length, input_length + max_new)):
x = context[:i+1]
x = x[-model.ctx_len:]
if i == input_length:
out = copy.deepcopy(init_state.out)
else:
out = model.run(x)
# Don't generate EOS
out[0] = -9999999
char = tokenizer.sample_logits(
out=out,
x=x,
ctx_len=model.ctx_len,
temperature=gen_settings.temp,
top_p=gen_settings.top_p,
)
char = char.item()
context.append(char)
if koboldai_vars.output_streaming:
koboldai_vars.actions.stream_tokens([utils.decodenewlines(tokenizer.decode(char))])
# HACK
if ni > max_new:
break
return np.array([context])
@dataclass
class RWKVConfig:
n_layer: int
n_embed: int
ctx_len: int
def rwkv_init(model_class: str, use_gpu: bool = False):
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
os.environ["RWKV_FLOAT_MODE"] = "bf16"
logger.info("[RWKV] RWKV support is in super-duper-uber-schmoober alpha and will ignore many options.")
device = "cpu"
if use_gpu:
logger.warning("[RWKV] Using GPU. This may not work out of the box and may require significant setup.")
device = "cuda"
os.environ["RWKV_RUN_DEVICE"] = device
TOKENIZER_PATH = "models/RWKV4/20B_tokenizer.json"
MODEL_DIR = "models/RWKV4/models"
model_files = os.listdir(MODEL_DIR)
matching_models = [f for f in model_files if f.startswith(f"RWKV-4-Pile-{model_class}")]
if not matching_models:
raise RuntimeError(f"No models of class '{model_class}' found in '{MODEL_DIR}'. Did you rename the model?")
model_path = os.path.join(MODEL_DIR, sorted(matching_models)[-1])
model_config = {
"169M": RWKVConfig(n_layer=12, n_embed=768, ctx_len=1024),
"430M": RWKVConfig(n_layer=24, n_embed=1024, ctx_len=1024),
"1B5": RWKVConfig(n_layer=24, n_embed=2048, ctx_len=1024),
"3B": RWKVConfig(n_layer=32, n_embed=2560, ctx_len=1024),
"7B": RWKVConfig(n_layer=32, n_embed=4096, ctx_len=1024),
}.get(model_class)
if not model_config:
raise RuntimeError(f"No config for model '{model_class}' found!")
if not os.path.exists(TOKENIZER_PATH):
raise RuntimeError(f"Can't find tokenizer at '{TOKENIZER_PATH}'. Did you download it and put it at that location?")
# Model stuff
from models.RWKV4.src.model_run import RWKV_RNN
from transformers import PreTrainedTokenizerFast
from torch.nn import functional as F
model = RWKV_RNN(
model_path.split(".")[0],
device,
"RWKV",
model_config.n_layer,
model_config.n_embed,
model_config.ctx_len,
)
tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_PATH)
# We'll just patch tokenizer ourselves to make it easier
def _sample_logits(self, out, x, ctx_len, temperature, top_p):
last_char = int(x[-1])
probs = F.softmax(torch.tensor(out), dim=-1)
sorted_probs, s_index = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
probs[probs < cutoff] = 0
if temperature != 1.0:
probs = probs.pow(1.0 / temperature)
return torch.multinomial(probs, num_samples=1)[0]
tokenizer.sample_logits = _sample_logits.__get__(tokenizer, AutoTokenizer)
tokenizer._koboldai_header = []
tokenizer.add_bos_token = False
tokenizer.add_prefix_space = False
logger.info("[RWKV] Loaded :^)")
return model, tokenizer
#==================================================================#
# Send text to generator and deal with output