diff --git a/models/RWKV4/LICENSE.txt b/models/RWKV4/LICENSE.txt new file mode 100644 index 00000000..72cd2da2 --- /dev/null +++ b/models/RWKV4/LICENSE.txt @@ -0,0 +1,204 @@ +Code in this directory is taken from https://github.com/BlinkDL/RWKV-LM. +The license for this code is as follows: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/models/RWKV4/cuda/wkv_cuda.cu b/models/RWKV4/cuda/wkv_cuda.cu new file mode 100644 index 00000000..6acd0f36 --- /dev/null +++ b/models/RWKV4/cuda/wkv_cuda.cu @@ -0,0 +1,125 @@ +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + F p = 0, q = 0, o = MIN_VALUE; + // p and q are running sums divided by exp(o) (to avoid overflows) + for (int i = 0; i < T; i++) { + const int ii = i * C; + + F no = max(o, u + k[ii]); + F A = exp(o - no); + F B = exp(u + k[ii] - no); + y[ii] = (A * p + B * v[ii]) / (A * q + B); + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const gy = _gy + _offset; + + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F y[Tmax], z[Tmax], zexp[Tmax]; + + F gw = 0, gu = 0; + F p = 0, q = 0; + F dpdw = 0, dqdw = 0; + F o = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + F no = max(o, k[ii] + u); + F A = exp(o - no); + F B = exp(k[ii] + u - no); + + F num = A * p + B * v[ii]; + F iden = 1 / (A * q + B); + + y[i] = num * iden; + z[i] = iden; + zexp[i] = k[ii] + u - no; + + gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; + gu += gy[ii] * (v[ii] - y[i]) * B * iden; + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + dpdw = A * (p + dpdw); + dqdw = A * (q + dqdw); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } + + F gp = 0, gq = 0; + o = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + F A = gy[ii] * z[i] * exp(zexp[i]); + F B = exp(k[ii] + o); + gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); + gv[ii] = A + B * gp; + + F no = max(w + o, zexp[i] - k[ii] - u); + A = exp(w + o - no); + B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); + gp = A * gp + B; + gq = A * gq - B * y[i]; + o = no; + } + + // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] += gw * _w[_c]; + _gu[_offsetBC] += gu; +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); +} diff --git a/models/RWKV4/cuda/wkv_op.cpp b/models/RWKV4/cuda/wkv_op.cpp new file mode 100644 index 00000000..efe56d8d --- /dev/null +++ b/models/RWKV4/cuda/wkv_op.cpp @@ -0,0 +1,21 @@ +#include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/models/RWKV4/src/__pycache__/model_run.cpython-37.pyc b/models/RWKV4/src/__pycache__/model_run.cpython-37.pyc new file mode 100644 index 00000000..4e1e2973 Binary files /dev/null and b/models/RWKV4/src/__pycache__/model_run.cpython-37.pyc differ diff --git a/models/RWKV4/src/__pycache__/model_run.cpython-38.pyc b/models/RWKV4/src/__pycache__/model_run.cpython-38.pyc new file mode 100644 index 00000000..c025a550 Binary files /dev/null and b/models/RWKV4/src/__pycache__/model_run.cpython-38.pyc differ diff --git a/models/RWKV4/src/__pycache__/model_run.cpython-39.pyc b/models/RWKV4/src/__pycache__/model_run.cpython-39.pyc new file mode 100644 index 00000000..8b15dc60 Binary files /dev/null and b/models/RWKV4/src/__pycache__/model_run.cpython-39.pyc differ diff --git a/models/RWKV4/src/__pycache__/utils.cpython-38.pyc b/models/RWKV4/src/__pycache__/utils.cpython-38.pyc new file mode 100644 index 00000000..c2c9710f Binary files /dev/null and b/models/RWKV4/src/__pycache__/utils.cpython-38.pyc differ diff --git a/models/RWKV4/src/__pycache__/utils.cpython-39.pyc b/models/RWKV4/src/__pycache__/utils.cpython-39.pyc new file mode 100644 index 00000000..9e250a33 Binary files /dev/null and b/models/RWKV4/src/__pycache__/utils.cpython-39.pyc differ diff --git a/models/RWKV4/src/model.py b/models/RWKV4/src/model.py new file mode 100644 index 00000000..085a0bcb --- /dev/null +++ b/models/RWKV4/src/model.py @@ -0,0 +1,416 @@ +# File from RWKV-v4 Repo - Small changes made for compatibility + +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import math, os +import numpy as np +import logging +import torch +import torch.nn as nn +from torch.nn import functional as F +try: + from deepspeed.ops.adam import FusedAdam +except: + pass # some poor windows users cant install deepspeed + +logger = logging.getLogger(__name__) + +RWKV_HEAD_QK_DIM = 0 +print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') + +class L2Wrap(torch.autograd.Function): + @staticmethod + def forward(ctx, loss, y): + ctx.save_for_backward(y) + return loss + @staticmethod + def backward(ctx, grad_output): + y = ctx.saved_tensors[0] + # to encourage the logits to be close to 0 + factor = 1e-4 / (y.shape[0] * y.shape[1]) + maxx, ids = torch.max(y, -1, keepdim=True) + gy = torch.zeros_like(y) + gy.scatter_(-1, ids, maxx * factor) + return (grad_output, gy) + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + +T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] +# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice + +from torch.utils.cpp_extension import load +wkv_cuda = load(name="wkv", sources=["models/RWKV4/cuda/wkv_op.cpp", "models/RWKV4/cuda/wkv_cuda.cu"], + verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) + +class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + if '32' in os.environ['RWKV_FLOAT_MODE']: + w = -torch.exp(w.contiguous()) + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + else: + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + if '32' in os.environ['RWKV_FLOAT_MODE']: + return y + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + return y.half() + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + return y.bfloat16() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda').contiguous() + gu = torch.zeros((B, C), device='cuda').contiguous() + gk = torch.zeros((B, T, C), device='cuda').contiguous() + gv = torch.zeros((B, T, C), device='cuda').contiguous() + if '32' in os.environ['RWKV_FLOAT_MODE']: + wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) + else: + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + if '32' in os.environ['RWKV_FLOAT_MODE']: + return (None, None, None, gw, gu, gk, gv) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + +def RUN_CUDA(B, T, C, w, u, k, v): + return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) + +######################################################################################################## +# RWKV: RWKV Time-mix + RWKV Channel-mix +######################################################################################################## + +def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model + print("\n[--> first run, init model params (very slow for large models) <--]") + print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n") + + for mm in model.modules(): + if "RecursiveScriptModule" in str(type(mm)): + if mm.original_name not in ["Linear"]: + continue + ww = None + for name, param in mm.named_parameters(): + if name == "weight": + ww = param + else: + m = mm + if not isinstance(m, (nn.Linear, nn.Embedding)): + continue + ww = m.weight + with torch.no_grad(): + name = "[unknown weight]" + for name, parameter in model.named_parameters(): # find the name of the weight + if id(ww) == id(parameter): + break + + shape = ww.shape + gain = 1.0 + scale = 1.0 # extra scale for gain + + if isinstance(m, nn.Embedding): + gain = math.sqrt(max(shape[0], shape[1])) + if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb? + scale = 1e-4 + else: + scale = 0 + + if isinstance(m, nn.Linear): + if shape[0] > shape[1]: + gain = math.sqrt(shape[0] / shape[1]) + if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection? + scale = 0.5 + + if hasattr(m, "scale_init"): + scale = m.scale_init + + # print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}") + + gain *= scale + if scale == -999: + nn.init.eye_(ww) + elif gain == 0: + # zero init is great for some RWKV matrices + nn.init.zeros_(ww) + elif gain > 0: + nn.init.orthogonal_(ww, gain=gain) + else: + nn.init.normal_(ww, mean=0.0, std=-scale) + + +class RWKV_TimeMix(torch.jit.ScriptModule): + def __init__(self, config, layer_id): + super().__init__() + self.layer_id = layer_id + self.ctx_len = config.ctx_len + self.n_embd = config.n_embd + + attn_sz = config.n_embd + + with torch.no_grad(): # fancy init + ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1 + ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 + + # fancy time_decay + decay_speed = torch.ones(attn_sz) + for h in range(attn_sz): + decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1) + self.time_decay = nn.Parameter(decay_speed) + # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) + + # fancy time_first + zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5) + self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) + + # fancy time_mix + x = torch.ones(1, 1, config.n_embd) + for i in range(config.n_embd): + x[0, 0, i] = i / config.n_embd + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.key = nn.Linear(config.n_embd, attn_sz, bias=False) + self.value = nn.Linear(config.n_embd, attn_sz, bias=False) + self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False) + + self.output = nn.Linear(attn_sz, config.n_embd, bias=False) + + self.key.scale_init = 0 + self.receptance.scale_init = 0 + self.output.scale_init = 0 + + @torch.jit.script_method + def jit_func(self, x): + + # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + # Use xk, xv, xr to produce k, v, r + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + sr = torch.sigmoid(r) + + return sr, k, v + + def forward(self, x): + B, T, C = x.size() # x = (Batch,Time,Channel) + + sr, k, v = self.jit_func(x) + + rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + rwkv = self.output(rwkv) + return rwkv + + +class RWKV_ChannelMix(torch.jit.ScriptModule): + def __init__(self, config, layer_id): + super().__init__() + self.layer_id = layer_id + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 + + x = torch.ones(1, 1, config.n_embd) + for i in range(config.n_embd): + x[0, 0, i] = i / config.n_embd + + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + + hidden_sz = 4 * config.n_embd + self.key = nn.Linear(config.n_embd, hidden_sz, bias=False) + self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False) + self.value = nn.Linear(hidden_sz, config.n_embd, bias=False) + + self.value.scale_init = 0 + self.receptance.scale_init = 0 + + @torch.jit.script_method + def forward(self, x): + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + k = self.key(xk) + k = torch.square(torch.relu(k)) + kv = self.value(k) + + rkv = torch.sigmoid(self.receptance(xr)) * kv + return rkv + +######################################################################################################## +# The GPT Model with our blocks +######################################################################################################## + + +class GPTConfig: + def __init__(self, vocab_size, ctx_len, **kwargs): + self.vocab_size = vocab_size + self.ctx_len = ctx_len + for k, v in kwargs.items(): + setattr(self, k, v) + + +class Block(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.config = config + self.layer_id = layer_id + + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + + if self.layer_id == 0: + self.ln0 = nn.LayerNorm(config.n_embd) + + if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': + self.ffnPre = RWKV_ChannelMix(config, 0) + else: + self.att = RWKV_TimeMix(config, layer_id) + + self.ffn = RWKV_ChannelMix(config, layer_id) + + def forward(self, x): + if self.layer_id == 0: + x = self.ln0(x) + if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': + x = x + self.ffnPre(self.ln1(x)) # better in some cases + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + return x + + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + self.step = 0 + self.config = config + + self.emb = nn.Embedding(config.vocab_size, config.n_embd) + + self.blocks = nn.Sequential(*[Block(config, i) + for i in range(config.n_layer)]) + + self.ln_out = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + if RWKV_HEAD_QK_DIM > 0: + self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_q.scale_init = 0 + self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_k.scale_init = 0.1 + self.register_buffer("copy_mask", torch.tril( + torch.ones(config.ctx_len, config.ctx_len))) + + self.ctx_len = config.ctx_len + + try: + if os.environ['RWKV_LOAD_MODEL'] == str(False): + RWKV_Init(self, config) + except: + pass + + logger.info("number of parameters: %e", sum(p.numel() + for p in self.parameters())) + + def get_ctx_len(self): + return self.ctx_len + + def _init_weights(self, module): + if isinstance(module, (nn.Linear)): + module.weight.data.normal_(mean=0.0, std=0.01) + if isinstance(module, (nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=1e-5) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def configure_optimizers(self, train_config): + no_decay = set() + + for mn, m in self.named_modules(): # here we disable weight_decay + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + no_decay.add(fpn) + + param_dict = {pn: p for pn, p in self.named_parameters()} + optim_groups = [ + {"params": [param_dict[pn] + for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + + try: + optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + except: + print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n') + optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) + + return optimizer + + def forward(self, idx, targets=None): + idx = idx.to(self.emb.weight.device) + + self.step += 1 + B, T = idx.size() + assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." + + x = self.emb(idx) + x = self.blocks(x) + x = self.ln_out(x) + + if RWKV_HEAD_QK_DIM > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + + if '32' in os.environ['RWKV_FLOAT_MODE']: + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half() + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16() + + x = self.head(x) + c + else: + x = self.head(x) + + loss = None + if targets is not None: + loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1)) + + return L2Wrap.apply(loss, x) diff --git a/models/RWKV4/src/model_run.py b/models/RWKV4/src/model_run.py new file mode 100644 index 00000000..0373dda5 --- /dev/null +++ b/models/RWKV4/src/model_run.py @@ -0,0 +1,392 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import types +import copy +import torch +import math, os +from torch.nn import functional as F +import torch.nn as nn + +RWKV_HEAD_QK_DIM = 0 +print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') + +DEBUG_TIME = False # True False - show trained time-coeffs + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + +if os.environ['RWKV_RUN_DEVICE'] == 'cuda': + T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] + # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice + + from torch.utils.cpp_extension import load + wkv_cuda = load(name="wkv", sources=["models/RWKV4/cuda/wkv_op.cpp", "models/RWKV4/cuda/wkv_cuda.cu"], + verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) + + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + if '32' in os.environ['RWKV_FLOAT_MODE']: + w = -torch.exp(w.contiguous()) + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + else: + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + if '32' in os.environ['RWKV_FLOAT_MODE']: + return y + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + return y.half() + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + return y.bfloat16() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda').contiguous() + gu = torch.zeros((B, C), device='cuda').contiguous() + gk = torch.zeros((B, T, C), device='cuda').contiguous() + gv = torch.zeros((B, T, C), device='cuda').contiguous() + if '32' in os.environ['RWKV_FLOAT_MODE']: + wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) + else: + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + if '32' in os.environ['RWKV_FLOAT_MODE']: + return (None, None, None, gw, gu, gk, gv) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + + def RUN_CUDA(B, T, C, w, u, k, v): + return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) + +############################################################################################################ + +RWKV_CFG = types.SimpleNamespace() + +class RWKV_ChannelMix(nn.Module): + def __init__(self, layer_id): + super().__init__() + self.layer_id = layer_id + + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) + self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) + + hidden_sz = 4 * RWKV_CFG.n_embd + self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False) + self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False) + + def forward(self, x): + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + k = self.key(xk) + k = torch.square(torch.relu(k)) + kv = self.value(k) + + rkv = torch.sigmoid(self.receptance(xr)) * kv + return rkv + +class RWKV_TimeMix(nn.Module): + def __init__(self, layer_id): + super().__init__() + self.layer_id = layer_id + self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd)) + self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3)) + + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + + self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + + self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + + def forward(self, x): + B, T, C = x.size() + + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + + rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + + rwkv = self.output(rwkv) + return rwkv + +class Block(nn.Module): + def __init__(self, layer_id): + super().__init__() + self.layer_id = layer_id + + self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd) + self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd) + if self.layer_id == 0: + self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd) + + if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': + self.ffnPre = RWKV_ChannelMix(layer_id+1000) + else: + self.att = RWKV_TimeMix(layer_id) + + self.ffn = RWKV_ChannelMix(layer_id) + + def forward(self, x): + if self.layer_id == 0: + x = self.ln0(x) + if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': + x = x + self.ffnPre(self.ln1(x)) + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + return x + +class RWKV_GPT(nn.Module): + def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len): + global RWKV_CFG + super().__init__() + + RWKV_CFG.RUN_DEVICE = RUN_DEVICE + RWKV_CFG.model_type = model_type + RWKV_CFG.vocab_size = vocab_size + RWKV_CFG.n_layer = n_layer + RWKV_CFG.n_embd = n_embd + RWKV_CFG.ctx_len = ctx_len + + print('\nloading RWKV-GPT', MODEL_NAME) + + self.emb = nn.Embedding(vocab_size, n_embd) + + self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)]) + + self.ln_out = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, vocab_size, bias=False) + + if RWKV_HEAD_QK_DIM > 0: + self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_q.scale_init = 0 + self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_k.scale_init = 0.1 + self.register_buffer("copy_mask", torch.tril( + torch.ones(ctx_len, ctx_len))) + + self.ctx_len = ctx_len + self.eval() + self.load_state_dict(torch.load(MODEL_NAME + '.pth')) + self.eval() + + def forward(self, idx): + B, T = idx.size() + assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." + + x = self.emb(idx) + x = self.blocks(x) + x = self.ln_out(x) + + if RWKV_HEAD_QK_DIM > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + + if '32' in os.environ['RWKV_FLOAT_MODE']: + c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half() + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16() + + x = self.head(x) + c + else: + x = self.head(x) + + return x + +############################################################################################################ + +class RWKV_RNN(): # this is running in FP32 at this moment + def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): + self.RUN_DEVICE = RUN_DEVICE + self.model_type = model_type + self.n_layer = n_layer + self.n_embd = n_embd + self.ctx_len = ctx_len + + self.w = types.SimpleNamespace() + + w = torch.load(MODEL_NAME + '.pth', + map_location=torch.device(RUN_DEVICE)) + for x in w.keys(): + w[x] = w[x].float() + if '.time_' in x: + w[x] = w[x].squeeze() + if '.time_decay' in x: + w[x] = -torch.exp(w[x]) + if DEBUG_TIME and '.time_' in x: + print(x, w[x].squeeze().cpu().numpy()) + + xx = x.split('.') + here = self.w + for i in range(len(xx)): + if xx[i].isdigit(): + ii = int(xx[i]) + if ii not in here: + here[ii] = types.SimpleNamespace() + here = here[ii] + else: + if i == len(xx) - 1: + setattr(here, xx[i], w[x]) + elif not hasattr(here, xx[i]): + if xx[i+1].isdigit(): + setattr(here, xx[i], {}) + else: + setattr(here, xx[i], types.SimpleNamespace()) + here = getattr(here, xx[i]) + + self.clear() + + def clear(self): + self.xx = {} + self.aa = {} + self.bb = {} + self.pp = {} + self.hk = None + + def save(self, target): + target.xx = copy.deepcopy(self.xx) + target.aa = copy.deepcopy(self.aa) + target.bb = copy.deepcopy(self.bb) + target.pp = copy.deepcopy(self.pp) + target.hk = copy.deepcopy(self.hk) + + def load(self, target): + self.xx = copy.deepcopy(target.xx) + self.aa = copy.deepcopy(target.aa) + self.bb = copy.deepcopy(target.bb) + self.pp = copy.deepcopy(target.pp) + self.hk = copy.deepcopy(target.hk) + + def LN(self, xx, w): + return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias) + + def FF(self, xx, w, name): + if name not in self.xx: + self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) + xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) + self.xx[name] = xx + + r = torch.sigmoid(w.receptance.weight @ xr) + k = torch.square(torch.relu(w.key.weight @ xk)) + kv = w.value.weight @ k + + return r * kv + + def SA(self, xx, w, name): + if name not in self.xx: + self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30 + + xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) + xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v) + xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) + self.xx[name] = xx + + r = torch.sigmoid(w.receptance.weight @ xr) + + k = w.key.weight @ xk + v = w.value.weight @ xv + + pp = self.pp[name] + aa = self.aa[name] + bb = self.bb[name] + ww = w.time_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + a = e1 * aa + e2 * v + b = e1 * bb + e2 + ww = pp + w.time_decay + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + self.aa[name] = e1 * aa + e2 * v + self.bb[name] = e1 * bb + e2 + self.pp[name] = p + + rwkv = r * a / b + + return w.output.weight @ rwkv + + def run(self, ctx): + w = self.w + x = w.emb.weight[ctx[-1]] + + for i in range(self.n_layer): + if i == 0: + x = self.LN(x, w.blocks[i].ln0) + if i == 0 and self.model_type == 'RWKV-ffnPre': + x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}') + else: + x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}') + x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}') + + x = self.LN(x, w.ln_out) + + if RWKV_HEAD_QK_DIM > 0: + if self.hk == None: + self.hk = (w.head_k.weight @ x).unsqueeze(0) + else: + self.hk = torch.cat( + [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) + if self.hk.shape[0] > self.ctx_len: + self.hk = self.hk[-self.ctx_len:, :] + + q = w.head_q.weight @ x + + x = w.head.weight @ x + x = x.cpu().numpy().tolist() + + c = (self.hk @ q) / RWKV_HEAD_QK_DIM + for i in range(len(c)): + x[ctx[i]] += c[i] + else: + x = w.head.weight @ x + x = x.cpu().numpy().tolist() + + return x diff --git a/models/RWKV4/src/utils.py b/models/RWKV4/src/utils.py new file mode 100644 index 00000000..3393fd85 --- /dev/null +++ b/models/RWKV4/src/utils.py @@ -0,0 +1,95 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import os +try: + NUM_GPUS = int(os.environ['RWKV_NUM_GPUS']) +except: + NUM_GPUS = 1 + +import json +import random +import numpy as np +import torch +from torch.nn import functional as F + +class TOKENIZER(): + def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): + if 'list' in str(type(WORD_NAME)): + self.charMode = False + if WORD_NAME[0] == WORD_NAME[1]: + from transformers import PreTrainedTokenizerFast + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) + else: + from transformers import GPT2TokenizerFast + self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) + self.vocab_size = len(self.tokenizer) + else: + self.charMode = True + with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: + self.word_table = json.load(result_file) + + self.vocab_size = len(self.word_table) + + self.stoi = {v: int(k) for k, v in self.word_table.items()} + self.itos = {int(k): v for k, v in self.word_table.items()} + + self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] + + def refine_context(self, context): + context = context.strip().split('\n') + for c in range(len(context)): + context[c] = context[c].strip().strip('\u3000').strip('\r') + context = list(filter(lambda c: c != '', context)) + context = '\n' + ('\n'.join(context)).strip() + if context == '': + context = '\n' + return context + + def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + # out[self.UNKNOWN_CHAR] = -float('Inf') + + lastChar = int(x[-1]) + + probs = F.softmax(torch.tensor(out), dim=-1) + + if self.charMode: + if self.itos[lastChar] == '\n': + top_p = top_p_newline + else: + top_p = top_p_usual + else: + top_p = top_p_usual + + sorted_probs, s_index = torch.sort(probs, descending=True) + + # for j in range(30): + # pp = sorted_probs[j].item() + # if pp < 0.005: + # break + # ss = self.itos[int(s_index[j])].replace('\n','_') + # print(f'{math.floor(pp*100):>3.0f}{ss}', end='') + # print('') + + cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy() + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + + probs[probs < cutoff] = 0 + # print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "") + + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + + return torch.multinomial(probs, num_samples=1)[0] + + +def to_float(x): + return x.cpu().detach().numpy().flatten()[0].astype(float) + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed)