diff --git a/RWKV4/LICENSE.txt b/RWKV4/LICENSE.txt deleted file mode 100644 index 72cd2da2..00000000 --- a/RWKV4/LICENSE.txt +++ /dev/null @@ -1,204 +0,0 @@ -Code in this directory is taken from https://github.com/BlinkDL/RWKV-LM. -The license for this code is as follows: - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/RWKV4/cuda/wkv_cuda.cu b/RWKV4/cuda/wkv_cuda.cu deleted file mode 100644 index 6acd0f36..00000000 --- a/RWKV4/cuda/wkv_cuda.cu +++ /dev/null @@ -1,125 +0,0 @@ -#include -#include - -#define MIN_VALUE (-1e38) - -template -__global__ void kernel_forward(const int B, const int T, const int C, - const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, - F *__restrict__ const _y) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - const int _b = idx / C; - const int _c = idx % C; - const int _offset = _b * T * C + _c; - - F u = _u[_c]; - F w = _w[_c]; - const F *__restrict__ const k = _k + _offset; - const F *__restrict__ const v = _v + _offset; - F *__restrict__ const y = _y + _offset; - - F p = 0, q = 0, o = MIN_VALUE; - // p and q are running sums divided by exp(o) (to avoid overflows) - for (int i = 0; i < T; i++) { - const int ii = i * C; - - F no = max(o, u + k[ii]); - F A = exp(o - no); - F B = exp(u + k[ii] - no); - y[ii] = (A * p + B * v[ii]) / (A * q + B); - - no = max(w + o, k[ii]); - A = exp(w + o - no); - B = exp(k[ii] - no); - p = A * p + B * v[ii]; - q = A * q + B; - o = no; - } -} - -template -__global__ void kernel_backward(const int B, const int T, const int C, - const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, - F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - const int _b = idx / C; - const int _c = idx % C; - const int _offset = _b * T * C + _c; - - F u = _u[_c]; - F w = _w[_c]; - const F *__restrict__ const k = _k + _offset; - const F *__restrict__ const v = _v + _offset; - const F *__restrict__ const gy = _gy + _offset; - - F *__restrict__ const gk = _gk + _offset; - F *__restrict__ const gv = _gv + _offset; - - F y[Tmax], z[Tmax], zexp[Tmax]; - - F gw = 0, gu = 0; - F p = 0, q = 0; - F dpdw = 0, dqdw = 0; - F o = MIN_VALUE; - for (int i = 0; i < T; i++) { - const int ii = i * C; - F no = max(o, k[ii] + u); - F A = exp(o - no); - F B = exp(k[ii] + u - no); - - F num = A * p + B * v[ii]; - F iden = 1 / (A * q + B); - - y[i] = num * iden; - z[i] = iden; - zexp[i] = k[ii] + u - no; - - gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; - gu += gy[ii] * (v[ii] - y[i]) * B * iden; - - no = max(w + o, k[ii]); - A = exp(w + o - no); - B = exp(k[ii] - no); - dpdw = A * (p + dpdw); - dqdw = A * (q + dqdw); - p = A * p + B * v[ii]; - q = A * q + B; - o = no; - } - - F gp = 0, gq = 0; - o = MIN_VALUE; - for (int i = T - 1; i >= 0; i--) { - const int ii = i * C; - F A = gy[ii] * z[i] * exp(zexp[i]); - F B = exp(k[ii] + o); - gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); - gv[ii] = A + B * gp; - - F no = max(w + o, zexp[i] - k[ii] - u); - A = exp(w + o - no); - B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); - gp = A * gp + B; - gq = A * gq - B * y[i]; - o = no; - } - - // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass - const int _offsetBC = _b * C + _c; - _gw[_offsetBC] += gw * _w[_c]; - _gu[_offsetBC] += gu; -} - -void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance - assert(B * C % threadsPerBlock.x == 0); - dim3 numBlocks(B * C / threadsPerBlock.x); - kernel_forward<<>>(B, T, C, w, u, k, v, y); -} - -void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance - assert(B * C % threadsPerBlock.x == 0); - dim3 numBlocks(B * C / threadsPerBlock.x); - kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); -} diff --git a/RWKV4/cuda/wkv_op.cpp b/RWKV4/cuda/wkv_op.cpp deleted file mode 100644 index efe56d8d..00000000 --- a/RWKV4/cuda/wkv_op.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include - -void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); -void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); - -void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { - cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); -} -void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { - cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &forward, "wkv forward"); - m.def("backward", &backward, "wkv backward"); -} - -TORCH_LIBRARY(wkv, m) { - m.def("forward", forward); - m.def("backward", backward); -} diff --git a/RWKV4/src/model.py b/RWKV4/src/model.py deleted file mode 100644 index 63f7315c..00000000 --- a/RWKV4/src/model.py +++ /dev/null @@ -1,416 +0,0 @@ -# File from RWKV-v4 Repo - Small changes made for compatibility - -######################################################################################################## -# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM -######################################################################################################## - -import math, os -import numpy as np -import logging -import torch -import torch.nn as nn -from torch.nn import functional as F -try: - from deepspeed.ops.adam import FusedAdam -except: - pass # some poor windows users cant install deepspeed - -logger = logging.getLogger(__name__) - -RWKV_HEAD_QK_DIM = 0 -print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') - -class L2Wrap(torch.autograd.Function): - @staticmethod - def forward(ctx, loss, y): - ctx.save_for_backward(y) - return loss - @staticmethod - def backward(ctx, grad_output): - y = ctx.saved_tensors[0] - # to encourage the logits to be close to 0 - factor = 1e-4 / (y.shape[0] * y.shape[1]) - maxx, ids = torch.max(y, -1, keepdim=True) - gy = torch.zeros_like(y) - gy.scatter_(-1, ids, maxx * factor) - return (grad_output, gy) - -######################################################################################################## -# CUDA Kernel -######################################################################################################## - -T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] -# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice - -from torch.utils.cpp_extension import load -wkv_cuda = load(name="wkv", sources=["RWKV4/cuda/wkv_op.cpp", "RWKV4/cuda/wkv_cuda.cu"], - verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) - -class WKV(torch.autograd.Function): - @staticmethod - def forward(ctx, B, T, C, w, u, k, v): - ctx.B = B - ctx.T = T - ctx.C = C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - if '32' in os.environ['RWKV_FLOAT_MODE']: - w = -torch.exp(w.contiguous()) - u = u.contiguous() - k = k.contiguous() - v = v.contiguous() - else: - w = -torch.exp(w.float().contiguous()) - u = u.float().contiguous() - k = k.float().contiguous() - v = v.float().contiguous() - ctx.save_for_backward(w, u, k, v) - y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) - wkv_cuda.forward(B, T, C, w, u, k, v, y) - if '32' in os.environ['RWKV_FLOAT_MODE']: - return y - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': - return y.half() - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - return y.bfloat16() - - @staticmethod - def backward(ctx, gy): - B = ctx.B - T = ctx.T - C = ctx.C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device='cuda').contiguous() - gu = torch.zeros((B, C), device='cuda').contiguous() - gk = torch.zeros((B, T, C), device='cuda').contiguous() - gv = torch.zeros((B, T, C), device='cuda').contiguous() - if '32' in os.environ['RWKV_FLOAT_MODE']: - wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) - else: - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) - gw = torch.sum(gw, dim=0) - gu = torch.sum(gu, dim=0) - if '32' in os.environ['RWKV_FLOAT_MODE']: - return (None, None, None, gw, gu, gk, gv) - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': - return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) - -def RUN_CUDA(B, T, C, w, u, k, v): - return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) - -######################################################################################################## -# RWKV: RWKV Time-mix + RWKV Channel-mix -######################################################################################################## - -def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model - print("\n[--> first run, init model params (very slow for large models) <--]") - print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n") - - for mm in model.modules(): - if "RecursiveScriptModule" in str(type(mm)): - if mm.original_name not in ["Linear"]: - continue - ww = None - for name, param in mm.named_parameters(): - if name == "weight": - ww = param - else: - m = mm - if not isinstance(m, (nn.Linear, nn.Embedding)): - continue - ww = m.weight - with torch.no_grad(): - name = "[unknown weight]" - for name, parameter in model.named_parameters(): # find the name of the weight - if id(ww) == id(parameter): - break - - shape = ww.shape - gain = 1.0 - scale = 1.0 # extra scale for gain - - if isinstance(m, nn.Embedding): - gain = math.sqrt(max(shape[0], shape[1])) - if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb? - scale = 1e-4 - else: - scale = 0 - - if isinstance(m, nn.Linear): - if shape[0] > shape[1]: - gain = math.sqrt(shape[0] / shape[1]) - if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection? - scale = 0.5 - - if hasattr(m, "scale_init"): - scale = m.scale_init - - # print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}") - - gain *= scale - if scale == -999: - nn.init.eye_(ww) - elif gain == 0: - # zero init is great for some RWKV matrices - nn.init.zeros_(ww) - elif gain > 0: - nn.init.orthogonal_(ww, gain=gain) - else: - nn.init.normal_(ww, mean=0.0, std=-scale) - - -class RWKV_TimeMix(torch.jit.ScriptModule): - def __init__(self, config, layer_id): - super().__init__() - self.layer_id = layer_id - self.ctx_len = config.ctx_len - self.n_embd = config.n_embd - - attn_sz = config.n_embd - - with torch.no_grad(): # fancy init - ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1 - ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 - - # fancy time_decay - decay_speed = torch.ones(attn_sz) - for h in range(attn_sz): - decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1) - self.time_decay = nn.Parameter(decay_speed) - # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) - - # fancy time_first - zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5) - self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) - - # fancy time_mix - x = torch.ones(1, 1, config.n_embd) - for i in range(config.n_embd): - x[0, 0, i] = i / config.n_embd - self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) - self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) - - self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - - self.key = nn.Linear(config.n_embd, attn_sz, bias=False) - self.value = nn.Linear(config.n_embd, attn_sz, bias=False) - self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False) - - self.output = nn.Linear(attn_sz, config.n_embd, bias=False) - - self.key.scale_init = 0 - self.receptance.scale_init = 0 - self.output.scale_init = 0 - - @torch.jit.script_method - def jit_func(self, x): - - # Mix x with the previous timestep to produce xk, xv, xr - xx = self.time_shift(x) - xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) - xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) - xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) - - # Use xk, xv, xr to produce k, v, r - k = self.key(xk) - v = self.value(xv) - r = self.receptance(xr) - sr = torch.sigmoid(r) - - return sr, k, v - - def forward(self, x): - B, T, C = x.size() # x = (Batch,Time,Channel) - - sr, k, v = self.jit_func(x) - - rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) - rwkv = self.output(rwkv) - return rwkv - - -class RWKV_ChannelMix(torch.jit.ScriptModule): - def __init__(self, config, layer_id): - super().__init__() - self.layer_id = layer_id - - self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - - with torch.no_grad(): # fancy init of time_mix - ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 - - x = torch.ones(1, 1, config.n_embd) - for i in range(config.n_embd): - x[0, 0, i] = i / config.n_embd - - self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) - - hidden_sz = 4 * config.n_embd - self.key = nn.Linear(config.n_embd, hidden_sz, bias=False) - self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False) - self.value = nn.Linear(hidden_sz, config.n_embd, bias=False) - - self.value.scale_init = 0 - self.receptance.scale_init = 0 - - @torch.jit.script_method - def forward(self, x): - xx = self.time_shift(x) - xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) - xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) - - k = self.key(xk) - k = torch.square(torch.relu(k)) - kv = self.value(k) - - rkv = torch.sigmoid(self.receptance(xr)) * kv - return rkv - -######################################################################################################## -# The GPT Model with our blocks -######################################################################################################## - - -class GPTConfig: - def __init__(self, vocab_size, ctx_len, **kwargs): - self.vocab_size = vocab_size - self.ctx_len = ctx_len - for k, v in kwargs.items(): - setattr(self, k, v) - - -class Block(nn.Module): - def __init__(self, config, layer_id): - super().__init__() - self.config = config - self.layer_id = layer_id - - self.ln1 = nn.LayerNorm(config.n_embd) - self.ln2 = nn.LayerNorm(config.n_embd) - - if self.layer_id == 0: - self.ln0 = nn.LayerNorm(config.n_embd) - - if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': - self.ffnPre = RWKV_ChannelMix(config, 0) - else: - self.att = RWKV_TimeMix(config, layer_id) - - self.ffn = RWKV_ChannelMix(config, layer_id) - - def forward(self, x): - if self.layer_id == 0: - x = self.ln0(x) - if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': - x = x + self.ffnPre(self.ln1(x)) # better in some cases - else: - x = x + self.att(self.ln1(x)) - x = x + self.ffn(self.ln2(x)) - return x - - -class GPT(nn.Module): - def __init__(self, config): - super().__init__() - self.step = 0 - self.config = config - - self.emb = nn.Embedding(config.vocab_size, config.n_embd) - - self.blocks = nn.Sequential(*[Block(config, i) - for i in range(config.n_layer)]) - - self.ln_out = nn.LayerNorm(config.n_embd) - self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - - if RWKV_HEAD_QK_DIM > 0: - self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) - self.head_q.scale_init = 0 - self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) - self.head_k.scale_init = 0.1 - self.register_buffer("copy_mask", torch.tril( - torch.ones(config.ctx_len, config.ctx_len))) - - self.ctx_len = config.ctx_len - - try: - if os.environ['RWKV_LOAD_MODEL'] == str(False): - RWKV_Init(self, config) - except: - pass - - logger.info("number of parameters: %e", sum(p.numel() - for p in self.parameters())) - - def get_ctx_len(self): - return self.ctx_len - - def _init_weights(self, module): - if isinstance(module, (nn.Linear)): - module.weight.data.normal_(mean=0.0, std=0.01) - if isinstance(module, (nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=1e-5) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - def configure_optimizers(self, train_config): - no_decay = set() - - for mn, m in self.named_modules(): # here we disable weight_decay - for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name - no_decay.add(fpn) - - param_dict = {pn: p for pn, p in self.named_parameters()} - optim_groups = [ - {"params": [param_dict[pn] - for pn in sorted(list(no_decay))], "weight_decay": 0.0}, - ] - - try: - optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) - except: - print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n') - optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) - - return optimizer - - def forward(self, idx, targets=None): - idx = idx.to(self.emb.weight.device) - - self.step += 1 - B, T = idx.size() - assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." - - x = self.emb(idx) - x = self.blocks(x) - x = self.ln_out(x) - - if RWKV_HEAD_QK_DIM > 0: - q = self.head_q(x)[:, :T, :] - k = self.head_k(x)[:, :T, :] - c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) - c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) - - if '32' in os.environ['RWKV_FLOAT_MODE']: - c = c @ F.one_hot(idx, num_classes=self.config.vocab_size) - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': - c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half() - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16() - - x = self.head(x) + c - else: - x = self.head(x) - - loss = None - if targets is not None: - loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1)) - - return L2Wrap.apply(loss, x) diff --git a/RWKV4/src/model_run.py b/RWKV4/src/model_run.py deleted file mode 100644 index 74c719d3..00000000 --- a/RWKV4/src/model_run.py +++ /dev/null @@ -1,392 +0,0 @@ -######################################################################################################## -# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM -######################################################################################################## - -import types -import copy -import torch -import math, os -from torch.nn import functional as F -import torch.nn as nn - -RWKV_HEAD_QK_DIM = 0 -print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') - -DEBUG_TIME = False # True False - show trained time-coeffs - -######################################################################################################## -# CUDA Kernel -######################################################################################################## - -if os.environ['RWKV_RUN_DEVICE'] == 'cuda': - T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] - # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice - - from torch.utils.cpp_extension import load - wkv_cuda = load(name="wkv", sources=["RWKV4/cuda/wkv_op.cpp", "RWKV4/cuda/wkv_cuda.cu"], - verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) - - class WKV(torch.autograd.Function): - @staticmethod - def forward(ctx, B, T, C, w, u, k, v): - ctx.B = B - ctx.T = T - ctx.C = C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - if '32' in os.environ['RWKV_FLOAT_MODE']: - w = -torch.exp(w.contiguous()) - u = u.contiguous() - k = k.contiguous() - v = v.contiguous() - else: - w = -torch.exp(w.float().contiguous()) - u = u.float().contiguous() - k = k.float().contiguous() - v = v.float().contiguous() - ctx.save_for_backward(w, u, k, v) - y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) - wkv_cuda.forward(B, T, C, w, u, k, v, y) - if '32' in os.environ['RWKV_FLOAT_MODE']: - return y - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': - return y.half() - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - return y.bfloat16() - - @staticmethod - def backward(ctx, gy): - B = ctx.B - T = ctx.T - C = ctx.C - assert T <= T_MAX - assert B * C % min(C, 1024) == 0 - w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device='cuda').contiguous() - gu = torch.zeros((B, C), device='cuda').contiguous() - gk = torch.zeros((B, T, C), device='cuda').contiguous() - gv = torch.zeros((B, T, C), device='cuda').contiguous() - if '32' in os.environ['RWKV_FLOAT_MODE']: - wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) - else: - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) - gw = torch.sum(gw, dim=0) - gu = torch.sum(gu, dim=0) - if '32' in os.environ['RWKV_FLOAT_MODE']: - return (None, None, None, gw, gu, gk, gv) - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': - return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) - - def RUN_CUDA(B, T, C, w, u, k, v): - return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) - -############################################################################################################ - -RWKV_CFG = types.SimpleNamespace() - -class RWKV_ChannelMix(nn.Module): - def __init__(self, layer_id): - super().__init__() - self.layer_id = layer_id - - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) - self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) - self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) - - hidden_sz = 4 * RWKV_CFG.n_embd - self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False) - self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) - self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False) - - def forward(self, x): - xx = self.time_shift(x) - xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) - xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) - - k = self.key(xk) - k = torch.square(torch.relu(k)) - kv = self.value(k) - - rkv = torch.sigmoid(self.receptance(xr)) * kv - return rkv - -class RWKV_TimeMix(nn.Module): - def __init__(self, layer_id): - super().__init__() - self.layer_id = layer_id - self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd)) - self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3)) - - self.time_shift = nn.ZeroPad2d((0,0,1,-1)) - self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) - self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) - self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) - - self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) - self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) - self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) - - self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) - - def forward(self, x): - B, T, C = x.size() - - xx = self.time_shift(x) - xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) - xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) - xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) - - k = self.key(xk) - v = self.value(xv) - r = self.receptance(xr) - - rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) - - rwkv = self.output(rwkv) - return rwkv - -class Block(nn.Module): - def __init__(self, layer_id): - super().__init__() - self.layer_id = layer_id - - self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd) - self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd) - if self.layer_id == 0: - self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd) - - if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': - self.ffnPre = RWKV_ChannelMix(layer_id+1000) - else: - self.att = RWKV_TimeMix(layer_id) - - self.ffn = RWKV_ChannelMix(layer_id) - - def forward(self, x): - if self.layer_id == 0: - x = self.ln0(x) - if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': - x = x + self.ffnPre(self.ln1(x)) - else: - x = x + self.att(self.ln1(x)) - x = x + self.ffn(self.ln2(x)) - return x - -class RWKV_GPT(nn.Module): - def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len): - global RWKV_CFG - super().__init__() - - RWKV_CFG.RUN_DEVICE = RUN_DEVICE - RWKV_CFG.model_type = model_type - RWKV_CFG.vocab_size = vocab_size - RWKV_CFG.n_layer = n_layer - RWKV_CFG.n_embd = n_embd - RWKV_CFG.ctx_len = ctx_len - - print('\nloading RWKV-GPT', MODEL_NAME) - - self.emb = nn.Embedding(vocab_size, n_embd) - - self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)]) - - self.ln_out = nn.LayerNorm(n_embd) - self.head = nn.Linear(n_embd, vocab_size, bias=False) - - if RWKV_HEAD_QK_DIM > 0: - self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False) - self.head_q.scale_init = 0 - self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False) - self.head_k.scale_init = 0.1 - self.register_buffer("copy_mask", torch.tril( - torch.ones(ctx_len, ctx_len))) - - self.ctx_len = ctx_len - self.eval() - self.load_state_dict(torch.load(MODEL_NAME + '.pth')) - self.eval() - - def forward(self, idx): - B, T = idx.size() - assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." - - x = self.emb(idx) - x = self.blocks(x) - x = self.ln_out(x) - - if RWKV_HEAD_QK_DIM > 0: - q = self.head_q(x)[:, :T, :] - k = self.head_k(x)[:, :T, :] - c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) - c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) - - if '32' in os.environ['RWKV_FLOAT_MODE']: - c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size) - elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': - c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half() - elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': - c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16() - - x = self.head(x) + c - else: - x = self.head(x) - - return x - -############################################################################################################ - -class RWKV_RNN(): # this is running in FP32 at this moment - def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): - self.RUN_DEVICE = RUN_DEVICE - self.model_type = model_type - self.n_layer = n_layer - self.n_embd = n_embd - self.ctx_len = ctx_len - - self.w = types.SimpleNamespace() - - w = torch.load(MODEL_NAME + '.pth', - map_location=torch.device(RUN_DEVICE)) - for x in w.keys(): - w[x] = w[x].float() - if '.time_' in x: - w[x] = w[x].squeeze() - if '.time_decay' in x: - w[x] = -torch.exp(w[x]) - if DEBUG_TIME and '.time_' in x: - print(x, w[x].squeeze().cpu().numpy()) - - xx = x.split('.') - here = self.w - for i in range(len(xx)): - if xx[i].isdigit(): - ii = int(xx[i]) - if ii not in here: - here[ii] = types.SimpleNamespace() - here = here[ii] - else: - if i == len(xx) - 1: - setattr(here, xx[i], w[x]) - elif not hasattr(here, xx[i]): - if xx[i+1].isdigit(): - setattr(here, xx[i], {}) - else: - setattr(here, xx[i], types.SimpleNamespace()) - here = getattr(here, xx[i]) - - self.clear() - - def clear(self): - self.xx = {} - self.aa = {} - self.bb = {} - self.pp = {} - self.hk = None - - def save(self, target): - target.xx = copy.deepcopy(self.xx) - target.aa = copy.deepcopy(self.aa) - target.bb = copy.deepcopy(self.bb) - target.pp = copy.deepcopy(self.pp) - target.hk = copy.deepcopy(self.hk) - - def load(self, target): - self.xx = copy.deepcopy(target.xx) - self.aa = copy.deepcopy(target.aa) - self.bb = copy.deepcopy(target.bb) - self.pp = copy.deepcopy(target.pp) - self.hk = copy.deepcopy(target.hk) - - def LN(self, xx, w): - return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias) - - def FF(self, xx, w, name): - if name not in self.xx: - self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) - xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) - self.xx[name] = xx - - r = torch.sigmoid(w.receptance.weight @ xr) - k = torch.square(torch.relu(w.key.weight @ xk)) - kv = w.value.weight @ k - - return r * kv - - def SA(self, xx, w, name): - if name not in self.xx: - self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30 - - xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) - xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v) - xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) - self.xx[name] = xx - - r = torch.sigmoid(w.receptance.weight @ xr) - - k = w.key.weight @ xk - v = w.value.weight @ xv - - pp = self.pp[name] - aa = self.aa[name] - bb = self.bb[name] - ww = w.time_first + k - p = torch.maximum(pp, ww) - e1 = torch.exp(pp - p) - e2 = torch.exp(ww - p) - a = e1 * aa + e2 * v - b = e1 * bb + e2 - ww = pp + w.time_decay - p = torch.maximum(ww, k) - e1 = torch.exp(ww - p) - e2 = torch.exp(k - p) - self.aa[name] = e1 * aa + e2 * v - self.bb[name] = e1 * bb + e2 - self.pp[name] = p - - rwkv = r * a / b - - return w.output.weight @ rwkv - - def run(self, ctx): - w = self.w - x = w.emb.weight[ctx[-1]] - - for i in range(self.n_layer): - if i == 0: - x = self.LN(x, w.blocks[i].ln0) - if i == 0 and self.model_type == 'RWKV-ffnPre': - x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}') - else: - x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}') - x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}') - - x = self.LN(x, w.ln_out) - - if RWKV_HEAD_QK_DIM > 0: - if self.hk == None: - self.hk = (w.head_k.weight @ x).unsqueeze(0) - else: - self.hk = torch.cat( - [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) - if self.hk.shape[0] > self.ctx_len: - self.hk = self.hk[-self.ctx_len:, :] - - q = w.head_q.weight @ x - - x = w.head.weight @ x - x = x.cpu().numpy().tolist() - - c = (self.hk @ q) / RWKV_HEAD_QK_DIM - for i in range(len(c)): - x[ctx[i]] += c[i] - else: - x = w.head.weight @ x - x = x.cpu().numpy().tolist() - - return x diff --git a/RWKV4/src/utils.py b/RWKV4/src/utils.py deleted file mode 100644 index 3393fd85..00000000 --- a/RWKV4/src/utils.py +++ /dev/null @@ -1,95 +0,0 @@ -######################################################################################################## -# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM -######################################################################################################## - -import os -try: - NUM_GPUS = int(os.environ['RWKV_NUM_GPUS']) -except: - NUM_GPUS = 1 - -import json -import random -import numpy as np -import torch -from torch.nn import functional as F - -class TOKENIZER(): - def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): - if 'list' in str(type(WORD_NAME)): - self.charMode = False - if WORD_NAME[0] == WORD_NAME[1]: - from transformers import PreTrainedTokenizerFast - self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) - else: - from transformers import GPT2TokenizerFast - self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) - self.vocab_size = len(self.tokenizer) - else: - self.charMode = True - with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: - self.word_table = json.load(result_file) - - self.vocab_size = len(self.word_table) - - self.stoi = {v: int(k) for k, v in self.word_table.items()} - self.itos = {int(k): v for k, v in self.word_table.items()} - - self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] - - def refine_context(self, context): - context = context.strip().split('\n') - for c in range(len(context)): - context[c] = context[c].strip().strip('\u3000').strip('\r') - context = list(filter(lambda c: c != '', context)) - context = '\n' + ('\n'.join(context)).strip() - if context == '': - context = '\n' - return context - - def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): - # out[self.UNKNOWN_CHAR] = -float('Inf') - - lastChar = int(x[-1]) - - probs = F.softmax(torch.tensor(out), dim=-1) - - if self.charMode: - if self.itos[lastChar] == '\n': - top_p = top_p_newline - else: - top_p = top_p_usual - else: - top_p = top_p_usual - - sorted_probs, s_index = torch.sort(probs, descending=True) - - # for j in range(30): - # pp = sorted_probs[j].item() - # if pp < 0.005: - # break - # ss = self.itos[int(s_index[j])].replace('\n','_') - # print(f'{math.floor(pp*100):>3.0f}{ss}', end='') - # print('') - - cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy() - cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) - - probs[probs < cutoff] = 0 - # print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "") - - if temperature != 1.0: - probs = probs.pow(1.0 / temperature) - - return torch.multinomial(probs, num_samples=1)[0] - - -def to_float(x): - return x.cpu().detach().numpy().flatten()[0].astype(float) - - -def set_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed)