From 381e1c9275972074b1aa3486b631c75c210f05ce Mon Sep 17 00:00:00 2001 From: somebody Date: Sat, 24 Sep 2022 20:51:03 -0500 Subject: [PATCH] Add RWKV-4 code --- models/RWKV4/LICENSE.txt | 204 +++++++++ models/RWKV4/cuda/wkv_cuda.cu | 125 ++++++ models/RWKV4/cuda/wkv_op.cpp | 21 + .../src/__pycache__/model_run.cpython-37.pyc | Bin 0 -> 11284 bytes .../src/__pycache__/model_run.cpython-38.pyc | Bin 0 -> 11173 bytes .../src/__pycache__/model_run.cpython-39.pyc | Bin 0 -> 11139 bytes .../src/__pycache__/utils.cpython-38.pyc | Bin 0 -> 5131 bytes .../src/__pycache__/utils.cpython-39.pyc | Bin 0 -> 5145 bytes models/RWKV4/src/model.py | 416 ++++++++++++++++++ models/RWKV4/src/model_run.py | 392 +++++++++++++++++ models/RWKV4/src/utils.py | 95 ++++ 11 files changed, 1253 insertions(+) create mode 100644 models/RWKV4/LICENSE.txt create mode 100644 models/RWKV4/cuda/wkv_cuda.cu create mode 100644 models/RWKV4/cuda/wkv_op.cpp create mode 100644 models/RWKV4/src/__pycache__/model_run.cpython-37.pyc create mode 100644 models/RWKV4/src/__pycache__/model_run.cpython-38.pyc create mode 100644 models/RWKV4/src/__pycache__/model_run.cpython-39.pyc create mode 100644 models/RWKV4/src/__pycache__/utils.cpython-38.pyc create mode 100644 models/RWKV4/src/__pycache__/utils.cpython-39.pyc create mode 100644 models/RWKV4/src/model.py create mode 100644 models/RWKV4/src/model_run.py create mode 100644 models/RWKV4/src/utils.py diff --git a/models/RWKV4/LICENSE.txt b/models/RWKV4/LICENSE.txt new file mode 100644 index 00000000..72cd2da2 --- /dev/null +++ b/models/RWKV4/LICENSE.txt @@ -0,0 +1,204 @@ +Code in this directory is taken from https://github.com/BlinkDL/RWKV-LM. +The license for this code is as follows: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/models/RWKV4/cuda/wkv_cuda.cu b/models/RWKV4/cuda/wkv_cuda.cu new file mode 100644 index 00000000..6acd0f36 --- /dev/null +++ b/models/RWKV4/cuda/wkv_cuda.cu @@ -0,0 +1,125 @@ +#include +#include + +#define MIN_VALUE (-1e38) + +template +__global__ void kernel_forward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + F *__restrict__ const _y) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + F *__restrict__ const y = _y + _offset; + + F p = 0, q = 0, o = MIN_VALUE; + // p and q are running sums divided by exp(o) (to avoid overflows) + for (int i = 0; i < T; i++) { + const int ii = i * C; + + F no = max(o, u + k[ii]); + F A = exp(o - no); + F B = exp(u + k[ii] - no); + y[ii] = (A * p + B * v[ii]) / (A * q + B); + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } +} + +template +__global__ void kernel_backward(const int B, const int T, const int C, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, + F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + const int _b = idx / C; + const int _c = idx % C; + const int _offset = _b * T * C + _c; + + F u = _u[_c]; + F w = _w[_c]; + const F *__restrict__ const k = _k + _offset; + const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const gy = _gy + _offset; + + F *__restrict__ const gk = _gk + _offset; + F *__restrict__ const gv = _gv + _offset; + + F y[Tmax], z[Tmax], zexp[Tmax]; + + F gw = 0, gu = 0; + F p = 0, q = 0; + F dpdw = 0, dqdw = 0; + F o = MIN_VALUE; + for (int i = 0; i < T; i++) { + const int ii = i * C; + F no = max(o, k[ii] + u); + F A = exp(o - no); + F B = exp(k[ii] + u - no); + + F num = A * p + B * v[ii]; + F iden = 1 / (A * q + B); + + y[i] = num * iden; + z[i] = iden; + zexp[i] = k[ii] + u - no; + + gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; + gu += gy[ii] * (v[ii] - y[i]) * B * iden; + + no = max(w + o, k[ii]); + A = exp(w + o - no); + B = exp(k[ii] - no); + dpdw = A * (p + dpdw); + dqdw = A * (q + dqdw); + p = A * p + B * v[ii]; + q = A * q + B; + o = no; + } + + F gp = 0, gq = 0; + o = MIN_VALUE; + for (int i = T - 1; i >= 0; i--) { + const int ii = i * C; + F A = gy[ii] * z[i] * exp(zexp[i]); + F B = exp(k[ii] + o); + gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); + gv[ii] = A + B * gp; + + F no = max(w + o, zexp[i] - k[ii] - u); + A = exp(w + o - no); + B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); + gp = A * gp + B; + gq = A * gq - B * y[i]; + o = no; + } + + // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] += gw * _w[_c]; + _gu[_offsetBC] += gu; +} + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_forward<<>>(B, T, C, w, u, k, v, y); +} + +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { + dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance + assert(B * C % threadsPerBlock.x == 0); + dim3 numBlocks(B * C / threadsPerBlock.x); + kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); +} diff --git a/models/RWKV4/cuda/wkv_op.cpp b/models/RWKV4/cuda/wkv_op.cpp new file mode 100644 index 00000000..efe56d8d --- /dev/null +++ b/models/RWKV4/cuda/wkv_op.cpp @@ -0,0 +1,21 @@ +#include + +void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); + +void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { + cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); +} +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &forward, "wkv forward"); + m.def("backward", &backward, "wkv backward"); +} + +TORCH_LIBRARY(wkv, m) { + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/models/RWKV4/src/__pycache__/model_run.cpython-37.pyc b/models/RWKV4/src/__pycache__/model_run.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1e2973fd0f945d34887638336a493063f8a912 GIT binary patch literal 11284 zcmbtaTWllOd7c}G!{J2~m#fR}dLyqL$4s(X-{L0eX1(j(mF;9LJKnXOuyLmpXC#rL zNbZ@TEU}{|sGOuIx~ba&X)XvLA!&0d8lVLVw1r;^^s!G((T8~|&^$DKO8c0Hw%>Qo zkfOBOq(O!_KhOEkWzOZlpLxDmbTvHR`}(Ek!ynVMe`jIzov;N#WTD=-*`=<}I#thob`)APROyIO<(P9pmv;H|? z&ZU@nU>@?%19LvboB`&2{sJ%yDdsFN5BnE@xv*!{&-waot@g-IxMRyTJ#kkD-DcG4 zbsL>w#eMC~Prc#2c=Or~@6(_1ZoKsJLZZ7%$&^ID_WH7S#^o(u>q*ZCZG`; zk-4FV=OgQPb<~S(ydpn-SJ+2<;G2sK~2p_Y(Lm)HG`zM8Ep2%j9bvIhG!?=VgYIH~M?8#S&Lf430GeUz z{1x;wdPy(q4t}mtMVY^{K8IgPw`5IOALc>ji|I0&2K(bFB9)LNwSR}yuC=vY(0W_n z1+BLYP^vjW?M7@#YPXJ4JKEZxvvxDP_HGt5mTfzrJZse;v)wcHbD;IyiC8zqnrJOq zGi{xJ0KNcPcH4Q%Z}8Q>Muq(%Y5b6uW1Vk{bFmxe%0I_cRtxYjSM45=nF&Y(BR-Go80Z58u4Bv}_ zZrBr{m?thX3`8#!XW3|$V2)q`0349Ak2uGo84fnZ1=c)HSrKq z2tBGfdf}`frkMrc!+#d_?G#%JkQeoDuFT-G{ zQ)OmSp?-b6(d`DEms{Iofx^MKaQ=s11BkReEc-yyl_G`e+=uWupi<2sZ%5gE2P|ep zxwd;?N__qS_ySZET2&#SBC}m&J~n*I&)lP0E{$-+kFit4LSawh%94rf<(eh6lmtJa zzWT}k;t?Nbjz_vg<7lPT2y4?xCLHtwK@w^~P=jooPGct!UdvC6Za2x*!6~mc{HOfH zjas0waJ{t}CAn7{0u&boLNS!p>%x|Cq&aqd>4n7Zdco$(sLy6=+uKO{GcOYRR;wE{ zpyoD$ox~MEGw4T+E|g|wtI-(*H8ZioptCBjVz%|%ddv5Nt``n5@0{l~JB={(ya`Tv zj?IulYAQ{4b=xSAe+*Zzf|G`GYGUJ4T^y};e7v-hsK2ixgu_m^0^x2p^ zVl2yJs=0zQ5ib(FMDQsBDZK0uY%PMs5Dr6wiZCPsEoWfS(3m zSJWr7KrpzvlY6S5Ky3p2%ADV{O_ehnl- z4mtWa`mT#y7M_c~o0lF_{y>kQ4x}ZhweyfU1=@gnS;&k6+7@98>c|x)q(=$jvjrPx zywTz}v6u-#GTvNuj#Q&ke+o)b@fLvx0DFg`MBFAg>I#TE#Jvsx>9Qr>V2L6|T0V&` zAfDi498=sf(JpQ z`1L3Qy$|*-J;T8hNn8g&St0kBsjSH6u}dTD_uA*+JU`5F`d|IT)Cx**@&XV1m+jB>7R0& zssKt;)4Fv*3RS&qxW=#~_x66?Jp|i%M~3Prm^U_zDQXn?k!X1t{{9Kzr7k0U6BsPBg!N2I*{!86H5>A7F0CYaT`SIcfcC`wyYTd2nWK z7Iwh_g0Z8z)Ttd#y9`CPx@!xPAGu@L8~Ko|c`Bp}$rNko{Z?UDkDz?^^glB0TDO&S zgU;irA;M%VK9T6HLy;$`=x6TbPAiO>z0LlWf8bz>VTSLQ;uNnOv((~cNUugGyn@WA z$4KQ%v9b1zuk-o4=aO8r*WW=(AMTk;@%GK4gy5C8l5B~8r`)Zu~#T< zLlJOzyVLbx4~S2)V66uYKPk%{C+5*<4m`irj09!R$h3AD@4MwKUwawRI^gBXMaUd&5V>42V29M& zl?*RUUfL;tjg&^Z+%(oDBVG`HCF2ZimWtuP-B-zlShSIcESSFUD90*@j5n5!cR*aJ z}GvCDuC&; zG8*<{h=&p)Sb9`$Py5b2jJ2EXzBX#%t0+_Hcq>M4WlO}#)L6bN<5ADv*YcEViS<;u zo;V9FVRdPX$ujl{Y*@=L51+n{&_FL*NConT7gk{IBK*Gq!hXEfrJcHPWkIUSg_PbG zl_UN?^ucLkQyt*>>2pWITKVo^Q@OT5I1%1DEq)7aB{8X>);5d3O%N0O4#DRMz5tNq zaEH(h5q=GlJlPhJuhmwkBkr=rIMkKsOGy?kySLtplUro(z z;*EcWG*WX93u2C_ISR$AY(h_Ybd91eQ4Zrle#W@|5<-f&cau1;M)K`B zO>*efXbx@UcxtwLCgR;MM)M==9&T>J-(w5kp4G;0jSl;@G})(mh*9q*Hac@S4&VQ# zcK6nsRNH5%(xuwYE~hF_PYRojzDJ$O`bv!BJI!3mjO^ zVKQ~Qwb}0k%Lt!^{YEno?+0-xH;e*Y#=8>>aj{TEptA_~1Oc46GpJ468#kYS{RQvN zOVB)9M3~L~Ajv>>^mh=;>vvj_7)m@9ZMGHqt+f^mbLd+{i2|szNF>?yMyN7tDJIkG zARISx_voQLa2;yClmgNPqz)E{5%Qfz>$p)OLkItby6`H1Ca-lystz{4jA=U%1XORh z?t!SVfh!^JG8q*qDk=~nh~pWl0tAPRZJDthMSF_OEm@6%tVKprP?vc_LK7+ntR{~b z0qKldgwbdu#>z@!^!tgiz5%r@{oZ4C`aXu8$BRN2nZv~lOHwa$T5=o5N}2O9k>s2{ ziHr)Aw;3_gJtHkKPNaKgTB7pNAqr$jy$yClDU+6H=jxt4ntK-4-7wD*#}9%&bIK?G zHeMvYL-1z=e-4m3ABi0`1djZe%to*PS0oFE3!_KI->^7d7+d^;b?I_l5NUF@dRxW6 z)VlOG*F_O0T{XEb#>rJl+3$zz5_E>d_XzF~Ja|Ro`>a29Ok|@E{M^;hwrlca`0$@y zt|NtHvIxeK4A>Z}xMpsfNQms;qXA$n^Md#7ptZIhsSsI4{64xQu8Nv>=|R5$g!l^r z8ALM`Eya%Y z0S5T*7-RJf5QrOJt1!Lj> zmX^i`P(@Q7h9r_7qxbj#x(_;l6Y%(&;<9!E-aZNMBdUmF$g&OxfU!t|I>!Ynvi2$ z601P2fR?Z!U_Tcz;xzPZ5%zou`3mx-ITF#7EEUsIIi8C1m|r=@8qjYVq$mnRaQGYd>bMk0vhS~>#t}BUeQm5c$fuNLl8vaZaRmGTDYt8G z|8peenKUf_kaN*+LiS)VeRA~49z`Fn@nMe>-%y%(!xZ966YcL_kjAANz06rh)2WO7 zZy%T6oI})xQW>t83acIF- z1w?TX+T_D%h%l7Yq|z6bnwB=NioN_LbJ9OvXRbx?LxNR;Ho@NzY!Ivw$QbRRGfnlH zI@5#h$oQtpf~}1Ytuo=+7WVMB{kIu=E9(&F8>u_$mrsbwyv0 zxuLrMq&K_o$un=egABeunESJ2v#rjxo;lwk}yG~%lSa|G16^?B7} z5VblXKi2f{nIt~tfg`5oy?{oE(-=g(HGx>5vy^@{DnB$)cUphK!Is$0ep!ATBd;&y zwS|nR$|!-H`bjHJoXhHynkyL$A-~n(2Wj+LC<*Zynqk@a9L=Q9a?YrJ#K|~DcN;qX{4MeQY7S<8uIrqRkr?KrI>Tb8u8V`=SJat=1`fXf-~lFLhP zhAS;Q+d5e%Zr!SI3$)FF6uUhk1n7yj>G4a80{sIL6vh17{80C&v_BR8q3-ANX2>Pi zs~kW{yzjjCeaF1xd%t-tpSKnK-gZ~&=}#)kzq2s;%cAfslE0)XimSM4TdBrRwW`Xp zR@LNFuj=w?R1G|}cB*4m(`wu{)v>A;^VN)JySkV8rs5jk()U#0Qq`RGBC>t<`ggwx@UnoyQfu8yXp<4^3c27wY7>G*^NQB9yEL1T6}W&YnBhsofF6+o<;j-C+5t zM~5@#Q0F;~njbiwTCh3P&%OM_FnjLv{oqc`UuGeF?#fMc`&7krlO*11M+f%pDe)P*ALyEkNYBUD23Mk?13_Yw^O~{0Chx0qyN}bk+I%*?5WC$NPR9+Z$|3n zD22|KJ|7h>`MxJe+*gI@iAe7>yOGudop{~trs#E}RL~Rk%}Dp|^dr08>juq@L2uwk zsRoETh-UrTHfXUYob_6L>vm1JQOfJ|gPmyhu<>Xv7U`Z87p2idHa0eE?M7s+%h52! zOr+O?JEFuTpCBl6-RkW~9YpF@q;5y*&OJr^6tM3-d499k@y`1I>pgep{HtES=btAh zp1XEEes2Htd0*7ecY3bZcEq5&(%%th(f16JzXG6WrpkX#T~Z5bQMK^5wKB^57u5y) z3#uV&is~?n!B>*G*ERM}etA3#NRrI|fMl+;lwA;cOWg&Lw=|HbK0)SMXhar znqM?_Q@iGF8pM@uSs*#1p^?+>YWo=wdFEKGonUpemVBAC&fW)~0}R*H0 zexAgANW`Jaw}qL|4zpqIAg3zbM~-b(0Bap-iFeL~8CP@DcQrZRnb1y3_T+oP1{Yka zl<|8SX)r`GFE?h0?x@9n5%aMn%u1Leoa&L0EVMuF0R}@4!mNZjLiUi%Wlu1d^7PXkhUUiV_R}b z3_i{_j|0FYQ2dVpOcVZ`I+u|8!`dStc(OtX=`7MPN1`81a}7e?avW+E$BAsm0dEc3 z%;y~E)}Yo-YI1&13!3$g7i{)i$#e4k6LTUD3pp{`BwW2u#bK$MYARLsYqeZD&zu(@8C1(1^*wKmEa-xw%4S)Hy zE0ex#E*00i@8c&fG6#!DSp`=a z>&=>9nT=BZpzjHiRI}R*Kt@)(w&Mw>=|)<&8)d5CoL6e@<8EXJP0*ab*=z(+=9QWN z1qPmo8O-Q)Vcb~K3cI}ed}MYVud_bt(`nvuwvztLi^#mz?0PlmyDe`gvV~Xo`a!J= z?U~xHwFh2Bj||^yH^f!Uwwl>&x~|uC{2}I@ah!U)=KGE_!C6<>49Tb7QdC zcrC#j=&6S|^xafo>@%5R>VPAqP?KhU*96u=31$>)LKDZI4#_XtnmnqCEndJR#C3u- zf>#KnKr?-?y6_@Hc>fB@# zDS9p-_XQ-MbZM!%ydb8IH1#9b@%NLkb573qCW!VZFH>-gdHKIlIF6Tt6nVI1$-_~c z4Eg&I7ufq*a(fx}FL*SE+PtgDJoq(k2@#{q5n}EK#~zC>-UmMm4ll7k)0pp19E&X< zi#>TPb}o?;8&AmJcb^fT!WKx9TNRg(t;`FuKzxwoK(R{jA_1vSe45}Tf@=T?W4m6x zwi6ljh$6k++epxjrtkytJR6Uxc|sIy$CUjt>m_-UmL&;CW^WMm2f;l}lCl&9LK2Q7 zAxS+XwE)SdQkQ%Wem@iQKDeJ+bxs?eoFs(fS)Zl`JtfKde}W8*FXce_qQ0x`>M$Lt z=wkLAF?>o2f)T@;QV{461TkJu1VLKjVQMCLBf*pL;n_e8WRRnWqwdSO4OrD`saxIA^_ycws@T-3KePpM5=&5f{$@TZA)SB6}0kc zr73CwVwrY5#WVieSv-g3sY6PHi*Pu$JntiNqsHNw61g!ei1~k!8>#2oz53RaW}v{J z2x9Wz061P?v~Zvh2I-iglM;@)1+9TV%#|rT4nDze1}P|ZIND+}$cAvV=i%K`sC4ZB zF0afNSR->Yl6{Resomepa@)=|8r@ff*ZdF6uTngYd85oVYV|evq!Wy*x4VylSG$iB zJQ~@tn-=WAAfObE%y?Y6wUJ&6f=CBOY@+TsPW&C<{W3tTfoa9sIP9`o)NE~7klT7Z z`*D9wq{~Qt3jlkA&=_{43__1tkBk@NE(Tglm-}^*?6pt!Op!egROn>2Tb5XJ8Y^)O zDp-dtLk61kcqc^{udQTaF(Z(V_lYznC}KpPU_e0AK->bTB%(<&$|03X(R3E|e~jeQ zpW%i`Qv{n)R)_g`D}E%AchK{UR#~C{_Xjh~dM1B^nN(lBC6RNK4np=kWUR)ED$2BKioNnu+Gxvd-c^7Bg zns?@K(hIH?EbcEMoeD9-1E}W%%nf0kUu4#rVUMN@_!k}Rke1M61DDz7n$pnslMoOc(dJgpxXtN zeq?NVH8(2C^-Y9zBa?XNR@L^gN>4f`32rNzK{ys`hL^HFXf(z?NFbVRX`4ib2R9;` z;i))&oJCZ)6$UjNM&x)-GMYDt3XFD;K}W+%FN(vJr0(9Ak6*1@Qi(016~q z9KezgrGmBu`=g{;@ZaME)_>pBO_ad;vmZFfDMdpD91FsmAhN^`(P)P-TH8`(q!tQg zPertrZU*fVia@Ysl1s(8z=V)A5aL-yZJIcSWxj&jg-|DCzM(HNYD=+r|9Utti0fwP zG9PF~+%N>@DL6oMSN37GAc&QW$Z;DQ?j4}aU=3@#Y1a&LVE(iWa(y3SqJY4Z8WdZz zu5}kGzs!ZZ|u(WPZ2%h)Hd zVkNsaeBv^^{9dq}$l#AJufu#rSbrJB{YbM5yR)){*huG-D1iuV0B6z*&PR4bqvm*6Y>=LT= zBa%p(Y~Mu_|6u?Hq1aMltl8=^{tN0^tpMw~h_Ee!w3^iLF%wfr{YcGr8-xBi(nzyC zD2S;eP5mfYpGES?jiVco`2OQCg5#!K6PHpcH|<(@nr_Cm@l5lI3r~xeTzF=9(S@h& z&VmF=NwDJ8wY8}j7aakLiOYQJJ{gc24CD={LZk7z1ndG$2o2;&hY)cu+I&_k104bj z1jakeb}*ZU*eh8CR~lE8chuKk*tMX0tVLXGpbv5va7htbi2iOT`1}}eNc?YNEU*Ln z@R}v=6U^XBFYfo9pcszEL|?{M$t={sxWqW&GR{CdGremE<=|wP*@M`)qP+eOyIHsr zSka4 z-O~}Pej!*KVfS#|;(wVfTytI-zqNbVuc^pBSE{hVGC zAr1?ym+RE}4z(ihQ-j8_e_@iTT$)s;B3a(Y7MPm6n8xWNAPqnR7cPQUG~AB$nl`E0uiM*Ykr*BNHD2-8&R{7fA$XNaj3hd z*AdTGTBh{bsAq*Ic9_+gM-30@>i-w&{O14^c_TAYXtrv-Z-Uu*H^aocnI#AV6Qu(3 zHj`G8!k`2(QASGC9GGituS{7?;44@)374Hkrb!-(p(^v5ggO)pxQyp47&)(`gj|eP zt3}%SdZhLHk+!)7vx>IYQ5p7a;1=;J-$v$eQNx1N!JM4j&M5`VS?NfU|38gP9OG^= zw4*v&Qer?yb@Ze}MWaF%$Y^*AJcm*$Dbco79do>bdud!v!yt<+*YoI>$<;By9hYLu{cv@HhLQLy0-C!qknRw9=A%D!vNYD}-1^av`LQjughBx8N5)gM z2JV4dIuau8cd5~}HQvD9_L>`;fglZoQhp1M$d04gT{^zfr27uAaSJ^r8WL2@EYnkt zv}NM`|+@$Eoq^XMEV~ zYbUjFSW%M9_*nCyDG!^6TEG<$4jG54wPuh98N~WNq*+hX9BV?3!9%P9xg1);W`J3o z$B47gta+H}1>{S}7ZylHGqRLVO2u#{%wm4U5NklcS(I}uBhNls>%@4pJi?cw(MJ8_ z(X8V~i)V-(9FG~#6qo{X@$RO>S-h`wO4&Vux0SKZ6Gaj*VqhfTjsN_${ZD@3Vr4a^ z5_SV49829}nwEjlD2?xVnjKsYpJIn~f(?L5`ThjoVj~$MJ*4^yaZ<#O zCsWmKY8Yy!_S7u)GlEmIv7 z6BW(;9DGt+n_{FJm!=d1#WdxaWTgKgBSAvp`b%q3lEJ)Pg!#6M4=I2*q$&B&^r>Vp z?pc^+AUCu{5N8pG2U5W246^P`w6|Ox0lhcLnz;0uZfhe@H0qaiF(SW2J3t_GSMNS5 zdobcXIr?OeybF)_u*b1)C``Oz2H~TL_P3VDV;4E+U^ela{|&P)xv`PEjQO3w9_U5* zlF3a^sr%vJWDB7t1g?(M;@dj;@dZ?8iS?~{7k)8v3(_MGWFM5WFk;8PxNL#;T5Bl;GYO&VD%Pr+XQbBNS~ONP~sC0x+5E#Y6>PbzJE&jn>B>1 z-(ah^3Fz{Q-y`^a02og_hR4JoqEJc4y!sK=${AOgqfIQnM(}llBRr?$t0k9A_cSVL zL>6HS(+V!Bk86)2@irreYbHaTQ?&Oj*uSU2V)%AR+RVI+Dr&>{E|MlEhFhy-kG(Vf zHedfY1Y|=Yw@&Vr+@XIV?r?H4CrztwpyCHe^4=ETPd#F#id@VVl7 zKZr(=RT~7o4S~>}wVHgsDL(;;@0)&?gT2Id=5_h`io9=-_Y5*XDx(2%>c_1(vM$6Q wyBTqKO3WaLivY}#WcgjM`Jpj6HO^2`c{6@N?BTIsq|3kROy-j1Ip`&y#N3J literal 0 HcmV?d00001 diff --git a/models/RWKV4/src/__pycache__/model_run.cpython-39.pyc b/models/RWKV4/src/__pycache__/model_run.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b15dc60ea8e8164e4e57a2b8cb3a68ae9ded4b4 GIT binary patch literal 11139 zcmb7KS&$pYd7cY{!2sCBg5X}H=oMv4f^9B!Sc>h?lqiy`!{ka9DIbWGF%~ns3oI`7 z3|3rZHkQ4zl_+7#NmUXbv+Tb!|86=C(99@6NchZo!>%i|z?8|83t!{vH+$_`Q17&>&3@5- z`E#HCtn=K}OIMuFeA>D4{PmScwb!CqNq+g2b?3^}&pv4L*v*$kF4Q^Nc6;{&cuH3|JpDG)X(XO_mb@))Hw%M$1`8B@qZR008b%%R5VX_RabjcIZ!JGaNRW| zo^n%IG&9n%JT(jbM^^?2Ir1`!zY3lU)lhMiPypYGgO9ZR@pbg6!HC3Xv%*xOw0dt`{PMZ zd#FtHPU`ZG5zf0tcsf)D*%0G~pbiaH3DS)mW~YVvrtYRt%e&?secuXZ2D34J)H}3> zl#R*xz8hLMl+pYt6<=3HF|-DSp)x_YB>lsn&ynt;lr!tX#L@|@oDyk|KbElC+O-!i z-E^+Mc;%`{p{1-vrt9t2YhILZd99Avb2d8y+z}a@-AA8{jE&7lpDY(6^_fV$8L5|} z6mGus`ABco+mY4*UwG}^y6Ci{RL~K%?MU}-cO$#jX$SSKUZ>|rsZB6(5Y743UGQE< zI2+a4&g)g-Mk%k=4fdkB!_K2;TjXesJ`C!+Pn_TGw7l~^z(&X2JO8rR?fB;@dFQTOkNY$>SyECydq z=3diye{$vVKZ7FK`uEA!N<-NPQ#aIoFm*!%OX?GBt%ZhUYvU+eqi^a}V?VWT?x(>_ zW+M$2Gd49!)*WqNfvGdcQtgDQqqmg9q<8in^cj{?B@D?SBh zn&;=$`Gm_K)E)+}QwoYGXHoh&^809->(J+x-@?PX0Bq}I{^ql@_1YVg$qWdOSo>)C;Z z?KLD!NoW!dvB`JU^|CHUi&E8Yx49Q9$MjFIq7$>_L?esF{PHVTF1?JJFGXJ!p9Ee* znY#3(Hf1K?QZ`j-C@pwV(nwZ}x8M#`XeLvM62H7%ZMVJV_4@58-NFI60Q~pA0T3t$ zc>STG#>NxoasWLtq#;cKH)-M^Z5rihs7reG9`qc{5_*-Os{*}|Cmw39;im4;6wi!k zQX)sxSyEx7DWxM?LCTZ4e%FweNrHE1quzNBm-rMhI6ul71j^W`SN-x_l=6FBPmrbR z?Ro$4$#_1n%)azAkqnb+!VuPP#a$J>i+;nlouP;J9drgp2%o>$f*!}pq- z;tFP4$!ynM*K0d|AM?&QPOVw>eaD%Qtjp|%;?ra)s;!z@j`F10Y6+6nUyLU<-r3dB z+fHpUxu15Mqxri4@HViyLlq_j+uBe|@Q_@IEh;5tL4i?WO&im-gl^zYJ;bK(rvhWZ zVuq<9M@pe3O?=-3)j};38f*zoY=1hWylCrktID=`7LyRy2(A;nNFWuO=|j|o7a789 z_QYpEhc{W@YIW+acnP(#B`6}1dOOl?vp^xiVSR`+ZD+Tli=FqWCz|9f6rXCAix@9p zL7L+J!qFv64Del3d<@;i69i8ZaQURdiSo%_)x2hNv|m!|++Y`}dVUtz7f^iirKRTb zf|MG{)TJl2pGKT>a>h5ov`1x`ieoIx|BcFVvK*u+!wpMrj^bn}--ooo9%L!)CAhkf z(HvUyt|oEFYupnmMwcVR(hrF}mR`6AeGU>{wR9}?`h*lEF77gUeN6YcQwh%QWXfvIERLdGYn9C+H)!QkoPmO z>_hr#Rp+(-$w@{yUi9f@Feu`H{+|HA?^1@!tNOmSufuDkp^L?LB=9Mz2u1>LN=2YU z5yW&oQ3Pp8hnJbqjf76YM-T%wkO9VEM%}l8r4gp6``N&PTF4I70CqrngGwU{MU$f+ zc#wvo$f0i@{-6phcS8}cV9MngD9t)No$)#g59A5eF<#qv8$ZuhT2E0Wa0xa6;Pp^> z2+HMABTX$UehvWoWLLb(+Cv2NetAtAgtVz#9+_$oU1bkG!a25OluJ=HUQ+FAVf zOH-SaC>Qbb=sz*qq2Bl+8hn;SZ_Ei&YA92ep45ICy^(sR*{SVJSq3T$4#ncX0dTy+ zXkbI34AQYcCpBz!16Bh$mn&0rY1e80f9V~O4|pz&ehDD9!1Q8mY<5X4Xtvg$k;{6#_;Gzrl*=f7 z0|0A-EEra#1V)cVkAfG|E(Tgdm+N(r;x(XnrYN366*gIImn7Akrb-&~2i9TFkU1qi zUP;lyW6POX&D;VWuah7~MiU(ZDjU%w=mL}z)g%Q)YNbQfL~Hs-C_Y0OE{JqRuoxw^ zpO2T~$1-^ncm5W~NSM4JNU5PrUF1q$KMj+|uG+IN-JHq`(F39y!(=gx7NWY5+5Gp#Hjc%Ycd|Ym`;FgGdZ%tB*u#eL}U#_{P4_touQ& z)9POQM|A%m6s7+msTuLmvFxl~fQG6z{fj_G{bSQDmBQAyzRBxvpNcZIPInLaNPp*D zW3?M>mp?F}9A3cSk$spR73B5E-o?-x4h>0^ZaXp!egf+%yf zwc)v38)TvWc*7nn`8Dw$fke0cvA3#Oqk&jY?*+#*& z?qIC_bo=E|4>!&k%^dB;q^s7?lke+UjCn1KrdK8FSRIIA zIlJC}{4yf^POy^b-;b_rz;#9Teg(|^aJ>zmv%GR~MH<7Egy&b|So;4MMqk7}(FEf9 z;Zui6DjX?f+r3sCoqGO6vZ^Gu(KlwO17cqzc#Ghd3En376@Vy%?0?%w0@aJMlv?Dm zHtWr%_yzVDC$1uOElMNccD6e~Y^5aYo@J*82!07*#Or$9oy_FYR7L8oI1*HAM--6s z)V_x<{(}GtlCQ$PaYShEs-Rkoaxo)2%w3QWTEf2d{V*Gk{2^tuSK zS&+J1tJ-yF5P2NfH&w#b(1b-WLG+p=I`c@rYHOwjUHy*52Afohk)sK;_gp#T?kt`d9Wnbq!4$ao| zd&KjV4k<%5npok9+eEecQGY{5`2U4A{}}*9p23W4nXOvynGklK#<1`-W*N%BL~RDR z&7u{hGAKe#lu!~22jLpKDpQUUguh=qO0E-=Q(r)_KPIx>w)Y4#Xw&1*iNb{vtb0Vz(LZN^kDFpvFD)m}$sRv9L@} zInu7J^F%y3UFgdI+acFnTd&c}A6<--#`F9`RHTch?5g5yfTe}{rOU>d5=SePDd(R; zGDz0xDzQG!8M75x;Hg$9TbXG_d^;K?P~os?4H}+Q0R{q@WOZPpWI@e=YLE&#d*xA2 z!-2A3xU3S-t`YBk6tj&L;!MnjOC->zjTcKQ7uw!9cO>&fFR?0D@w3D}M=(yU~O$`3z%!*mPo1n z0Y(^aLHk}?a11@gtvQBn9!GbP@xw->*@i8_HcV_7S<{-@&dib4%*1wTJnq`t_&TmSj1gHL?nVtFm* z-eAohRk~u<*`VuM(T#xVaKE5apTfOiU-s5v3M|hdRn89NHIasnT!2&X_jdY z3WwZ>BX zBn0JJV>!VQ`qS-YwrX%6K0M(b6e;q-97Mwm^x#ZrCJz1?>F zxd3xD{ z{RVrzO~8;>{4T-o0l<0cFf%6p0F`n&memikRnAzt$G2Ji2EjK8j>w#j@045$-IHjf z6Ip;SOfR^oKBhf_!b6N4PL~XI@jVOP?`a4azEzSQGf$ujT0cI3r0ame*2~#r4@&8+ zioYSC5DK|)a;4<@{0nJ^Qgfb%yVQO=k8fta zneX>KMmH;!62tQgvDAM26k}gfW&h)$at3euTM)qn@3AHu&4xT=T}xPtY=bxLTP&X8 zjDEsy+8?vn`Y7jhnvNOcHeK`eng#Rqn?Alyuh=g&r!a~Mw^#00nibA2GU18B111V# zWy5Y(Q1?X<_2P!rtcsGDdcc~~qAV)-&Io>;)vJ#n8*4Nx*KS? z#yZ!f&Mde7laGYF7b84FlugYt}nVoGHzC&TG&@D0A^$*yvh*Dd5{2@YQs z1^BzZa5iktREc=IQyW~k#~J#|==ac8{vFQf_x1+o>{mRl3jff&fBbLe-5{C1V#fGS zSfxm(y=g&)imIyV4D!-%>_8^RjPyl$>_%+tWzI7DbZ%V83YnL=rS%3Eyg-uauF z5mhlQX2g+A`D|%PH2(~7*xje5F}`&?I;>Ov%B{5VzTzjH57+&bY~bYwQ2U9$3}tB zc^z*$4-x@CpH@Gy#(Zo6nf%w*9p^fG$mBFSAKKbldR6mCTkYGX#gNCZoxGj&!;>k< zN+Rx`ycQ0U^dyof1)7wUhqf^c%D`IS~&(i;bR} zXkjS`qBv4Puym@2qwl4sm-Ytz54+%*mG5FtVmX85_EqPo3f~5h9gwBIpbKra9`wRk z&tg;U+f-atw}VauxM+utTHC3T+EMohp|+z~X)8%}A&l=v5?aeO7!YiE{1?qC|lH?swC1* zq>{8=L@n%fAw$|61?)v(bm9FN)=0w0w)C&)ti1v3haFmYaO%TRC z`dPDlhFAG4pT|q3GM}@`D3v^iJKV>wgxWLIj^AkJuu2c8j{r$%H!hK44a}J-ejKQL z$QyMlm)_s;+7G#W8BB+TEd->SAb1OXDNV=tXnJp5)7Lib)Q2aB9Ajhy>7PLq>+&sl z-=DLtjn@&@0z1aGt+6F+@|0*ROg!=jREGZrwsts$vuSS?RcV`#?f9B-4~7CddYVqJSlr$hMia!ktt`VJ_DPK&&QQ9t2m)p&<7oe<0TBKyL42cwBP zmoI?>+B$-d=BfM+zPXafykPN!Xk+oqsHZfsvG)^uN?HvSFy234^1GuR{TmOwoh!)H z>Q!yU1MSIHERw!1&=uI1#b~ZDVrnwNe@^pQaX9C9-{6dkXF|fd%)H zc(v~3;37vLXF(9zqjk;kcR$~Z*_PSSR1oY6{5Duq;%Clr*hCHQaq9)^7(d2GN55i& zjk-tX6}S3f5a?16^b;}cQMnuhcZRKA-XmYdjFLR0yg-D!@IGSv79Aw5YXPj z(+;#9cnIic@TM<Ds#6@4zn+Z zpwZc=g8X2iR_m+)2Qfatf2nGROhDq|+xuzK%B~ zgqx^6yA!n)PI+SzxeZw0F$wSw?TKx!X@2)GiYWShO7rcD#4qeB{{M`{hh@tpI~b3z zU{QIM$R4;2u1TUF@0ZA9^&@iEs8&ms| zpmFx{d3lQ_qXo?>%_$kC!=5s_J-i*VWa3Oo3Wn(D;Vtu7#LYQgwLH*yt8C4Kj%J@c z_yeJf06=Dc4sSXD$@muAW?NivgdeUf1%skcg_OI7L=ItUkr@sIo^9|78zzC)Q|!)h z1CV~kk@-|{ldMPnA&7`(!R(e?y-GrCX}8k(1eB*&MZefF4F4&ZYqpc{Pp zBp!UinvG#rL+5%1@_GBv5o!Nauhm}>tkoZuNGmD>(v}bhe>xBIaaWnRh?>E{MWOtq_3S@3ZjekG^_JRC^Zp`C?;$5h5 z3%+lg6k;LP!^?GDc?Q(W3X4!8RPuJ_jf;1dGk08!5%I@GWh2kP%}}_T7Cb^}JcSlN zE9G-+6}Q>AjN4gBO>LF4a#mb(w(}J$xHIDJl~vXlZfA}ttl4r4cd?2nWX{tJz{33z z{?DiIuxSZzk#11PE-3avp9qwF`7z2jP$`KpZR0M38%pjQNu}magU%^rtwGxA4|-u* zpU#6Zzn}|g(oy}^x;#s~?-C&d<(BABUaYaVTph^>bfJTDLhdEcpeADw?89hiPdt)L`zi^Hs=q*OlG@8L z%3m)U?N`AcW`HgJIbn78%4AY`tKD_sIL}@K6HT!g@KkqBOuj^fsTNh zwbLGomY&|%6tsHy*OB}QhVDDuTzaI~X4r4gnUXhTxEz)7%Fj_oNF%))Baq*u((6P{ z5Fts-S?v=9hk>MZDUK!SL3Tm(Ox~7Xt_7~7CjT?qGv{O$+9%hp(&2IQyz{h+c?-ykyZWj(%lcwr0Lpz2H0l0e-WU!vFvP literal 0 HcmV?d00001 diff --git a/models/RWKV4/src/__pycache__/utils.cpython-39.pyc b/models/RWKV4/src/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e250a3366d6b8e4e79ef0b6e1136f84ff9adb66 GIT binary patch literal 5145 zcmbVQ-H#hr6~A}BJoeh&&1SosO-nH-q;+XGNhz%;X=s|SmNZUC(uP_sBd_mxGhTbf zxpSRttTC#(K~zX9CLtu$0^TC(@C(0CIiJq8l8CKs$}2`(&bV+&h2css%s9^S4f zU|U~0nA#??{S;C0QO=>Hzk~ppm9cy4ylx>sOU&?9 zQDA=87tXrPnJUrT?$kOL?oo#EGR8ghm4Ayf`klSbIr{~Vj|u+(-#`90eAh{)FW?ye z3A+^Ov^OlsP*GJiok3pujUCA3n3cXLkKKrkz05h!KAIaBvO?x%?rD~}qVgIWyLbMo zW<*s?iy3if!~XfPv7h;yhtzEMaOR0wad^Ypg8a9D00%$%-gx#Dz9&6>(%^}+5Ae>< zU`N(D`g%H{jx`Vb0Y6WP`3FqRkyPi=gKv6DEIiRSFBsHzq@ic8Txkv7iAA^;iU}T? zrT4F%JNs5}ipGXLt7ghM4UAVhz<{ z9Nig)OKsx5{n)4wW^bUR^B@sW^-=XhYs|+M@XCK>-EnTPhfGdm@S&}(rIVUR+G^i6 zJ%&6!b>eo?4^N~Z%Za#u;(9np(i3EWFTHyuf8TxOL@L`ShAQf%ON0BmbXr7h)lT|@ zGp~U2KTz1{*yQ||f*^_`6$DGCdpPu7dS+>N`u`w5&#YYjU0LM}Hrmsiqb7VE`!-eR z1nLXA&{k_fFO2mpKBRq{sw?Vt(1`#M?a=XRJ5^FU>i!_qb`&dZC8;ii@!d#5Q@Pdx zf{JzrlE!WDXif#oQ5-A}JDpH!1Ce}=mnKpXYk~q;he|ui0Q={pvP90Ac@u61p#tu@`XwX4fr&42oK7J*%UZH;cMl%PMxleHfLPEH4 zjudrZ!c6g_K;lE*s9U-8{m@=vEbKAja?@qU3_QV22RDA_L5GA@a5&0mIaLMQw*IJc?6?=vWFGp5yo^v zEaQvpGvg7AChA;%6EmQ#LwISPO0wl#Nn}{CazeDR^7Ax8{6^31DrhBCz;}P2q3^7C z^lv=yZmu9Rt5>xZ540y+u}J#5KyMh9{5tV9ct1+Dr-_gaXdBG6*Gg%ze1>{lmB{vO z?J2m+1lHS2;+49W1B-wzXF(9(qczR(cR&9S^GW7wrh;Hc;MXxVrTNS`3VW!b9JP*H zNB9vwI{XD6*rVrV?jAy7m_&ni zP6@~`yJNncCm(3;dkExjqNLA*>|`+|0cbx!q9%~Rm^^06BvpKp$&_{R=GDs=8i;+@ z|4A(V_Q`si-rJAa7L9W#>Aygb#lUp)i~u=|6qM&6b_jbBh!M^rTeakQq)l$m?mJ_f zQx0TCg||WR*^O%$Enw6eI~aBLu3aZ?t}4JmxOtute=2<1K@~Bt2)jaNBFe}po*BEU zv^kZzTPBm)mr2Z=V(IVqVL1a-J{0Jqt z)8jiZ_&?+EL8)??_Q&LFwA~Mg>_Xk3oMidYwk-8Sa>u&0-@e_FR}v96%L-l~?IaTM zFqKy@TNif@S2k$%*NA+T$Z;ZM31s*!(Jb9dqzF`NxrZQ4UMMxrOuk4{%I`HUH?H1l z1n1s9dwt>~bn({J>*s^U*((?14~ZFVXm)8%$uJ%El+owG4c_WgYotRMl-rz6&jkC zh9t-5AAkIqSN36YJE$8>($yz}Z&XW{Z^k@PrxbmYr~{R_=z?C;-ev?? ze_T{Hat+)Qg}Y(F3zWuF=<%~s&SSH<#l~gaol0tIvz(Q);;OTi?^wYN5;w7|vdVB* zb3|d)mYcYlRYW0ko@M|RZjOz`yRKK+*-@-IWhazthF_!31Zmhlbu%8b5 zCSB2P2L`O<8<>LtFITWs+k@nup6bI0qd4hDxZYQD4cp03$qU4ek_0^kLq|T)Qn)Xn zkqm-CFia7DbtP}a;k_Q<*7SSu*5}}9D1%RyQ_3DNxmnJ0^}uvC3n^|YxH>jxAmFgq zQqkQoA0m+zlC6VSujJn!^pcfGr9H@$l3!+a*_wkhs3Ot!EXSHHAhcGrcMGj$5!@?|Wl z7vved=n@qQ11*C_K&ll39RYP~r#%!cJ-w$ZX!Y=qB>7`3-E+9P^hmSKvR`8`C2z=Z zK`NujPf$ll`wCvh2qc}T{1y=kP(~7SR(k}&VIX;3iepK7kX;ZxllSG9YJqF1$pek{ r%sH8j_Qo6|2*S@j{%_1Ex0p}mn^!;(XY(Q*35UR7k%eH1&f;+ literal 0 HcmV?d00001 diff --git a/models/RWKV4/src/model.py b/models/RWKV4/src/model.py new file mode 100644 index 00000000..085a0bcb --- /dev/null +++ b/models/RWKV4/src/model.py @@ -0,0 +1,416 @@ +# File from RWKV-v4 Repo - Small changes made for compatibility + +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import math, os +import numpy as np +import logging +import torch +import torch.nn as nn +from torch.nn import functional as F +try: + from deepspeed.ops.adam import FusedAdam +except: + pass # some poor windows users cant install deepspeed + +logger = logging.getLogger(__name__) + +RWKV_HEAD_QK_DIM = 0 +print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') + +class L2Wrap(torch.autograd.Function): + @staticmethod + def forward(ctx, loss, y): + ctx.save_for_backward(y) + return loss + @staticmethod + def backward(ctx, grad_output): + y = ctx.saved_tensors[0] + # to encourage the logits to be close to 0 + factor = 1e-4 / (y.shape[0] * y.shape[1]) + maxx, ids = torch.max(y, -1, keepdim=True) + gy = torch.zeros_like(y) + gy.scatter_(-1, ids, maxx * factor) + return (grad_output, gy) + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + +T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] +# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice + +from torch.utils.cpp_extension import load +wkv_cuda = load(name="wkv", sources=["models/RWKV4/cuda/wkv_op.cpp", "models/RWKV4/cuda/wkv_cuda.cu"], + verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) + +class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + if '32' in os.environ['RWKV_FLOAT_MODE']: + w = -torch.exp(w.contiguous()) + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + else: + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + if '32' in os.environ['RWKV_FLOAT_MODE']: + return y + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + return y.half() + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + return y.bfloat16() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda').contiguous() + gu = torch.zeros((B, C), device='cuda').contiguous() + gk = torch.zeros((B, T, C), device='cuda').contiguous() + gv = torch.zeros((B, T, C), device='cuda').contiguous() + if '32' in os.environ['RWKV_FLOAT_MODE']: + wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) + else: + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + if '32' in os.environ['RWKV_FLOAT_MODE']: + return (None, None, None, gw, gu, gk, gv) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + +def RUN_CUDA(B, T, C, w, u, k, v): + return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) + +######################################################################################################## +# RWKV: RWKV Time-mix + RWKV Channel-mix +######################################################################################################## + +def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model + print("\n[--> first run, init model params (very slow for large models) <--]") + print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n") + + for mm in model.modules(): + if "RecursiveScriptModule" in str(type(mm)): + if mm.original_name not in ["Linear"]: + continue + ww = None + for name, param in mm.named_parameters(): + if name == "weight": + ww = param + else: + m = mm + if not isinstance(m, (nn.Linear, nn.Embedding)): + continue + ww = m.weight + with torch.no_grad(): + name = "[unknown weight]" + for name, parameter in model.named_parameters(): # find the name of the weight + if id(ww) == id(parameter): + break + + shape = ww.shape + gain = 1.0 + scale = 1.0 # extra scale for gain + + if isinstance(m, nn.Embedding): + gain = math.sqrt(max(shape[0], shape[1])) + if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb? + scale = 1e-4 + else: + scale = 0 + + if isinstance(m, nn.Linear): + if shape[0] > shape[1]: + gain = math.sqrt(shape[0] / shape[1]) + if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection? + scale = 0.5 + + if hasattr(m, "scale_init"): + scale = m.scale_init + + # print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}") + + gain *= scale + if scale == -999: + nn.init.eye_(ww) + elif gain == 0: + # zero init is great for some RWKV matrices + nn.init.zeros_(ww) + elif gain > 0: + nn.init.orthogonal_(ww, gain=gain) + else: + nn.init.normal_(ww, mean=0.0, std=-scale) + + +class RWKV_TimeMix(torch.jit.ScriptModule): + def __init__(self, config, layer_id): + super().__init__() + self.layer_id = layer_id + self.ctx_len = config.ctx_len + self.n_embd = config.n_embd + + attn_sz = config.n_embd + + with torch.no_grad(): # fancy init + ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1 + ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 + + # fancy time_decay + decay_speed = torch.ones(attn_sz) + for h in range(attn_sz): + decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1) + self.time_decay = nn.Parameter(decay_speed) + # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy()) + + # fancy time_first + zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5) + self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag) + + # fancy time_mix + x = torch.ones(1, 1, config.n_embd) + for i in range(config.n_embd): + x[0, 0, i] = i / config.n_embd + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + self.key = nn.Linear(config.n_embd, attn_sz, bias=False) + self.value = nn.Linear(config.n_embd, attn_sz, bias=False) + self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False) + + self.output = nn.Linear(attn_sz, config.n_embd, bias=False) + + self.key.scale_init = 0 + self.receptance.scale_init = 0 + self.output.scale_init = 0 + + @torch.jit.script_method + def jit_func(self, x): + + # Mix x with the previous timestep to produce xk, xv, xr + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + # Use xk, xv, xr to produce k, v, r + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + sr = torch.sigmoid(r) + + return sr, k, v + + def forward(self, x): + B, T, C = x.size() # x = (Batch,Time,Channel) + + sr, k, v = self.jit_func(x) + + rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + rwkv = self.output(rwkv) + return rwkv + + +class RWKV_ChannelMix(torch.jit.ScriptModule): + def __init__(self, config, layer_id): + super().__init__() + self.layer_id = layer_id + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0 + + x = torch.ones(1, 1, config.n_embd) + for i in range(config.n_embd): + x[0, 0, i] = i / config.n_embd + + self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) + + hidden_sz = 4 * config.n_embd + self.key = nn.Linear(config.n_embd, hidden_sz, bias=False) + self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False) + self.value = nn.Linear(hidden_sz, config.n_embd, bias=False) + + self.value.scale_init = 0 + self.receptance.scale_init = 0 + + @torch.jit.script_method + def forward(self, x): + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + k = self.key(xk) + k = torch.square(torch.relu(k)) + kv = self.value(k) + + rkv = torch.sigmoid(self.receptance(xr)) * kv + return rkv + +######################################################################################################## +# The GPT Model with our blocks +######################################################################################################## + + +class GPTConfig: + def __init__(self, vocab_size, ctx_len, **kwargs): + self.vocab_size = vocab_size + self.ctx_len = ctx_len + for k, v in kwargs.items(): + setattr(self, k, v) + + +class Block(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.config = config + self.layer_id = layer_id + + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + + if self.layer_id == 0: + self.ln0 = nn.LayerNorm(config.n_embd) + + if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': + self.ffnPre = RWKV_ChannelMix(config, 0) + else: + self.att = RWKV_TimeMix(config, layer_id) + + self.ffn = RWKV_ChannelMix(config, layer_id) + + def forward(self, x): + if self.layer_id == 0: + x = self.ln0(x) + if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre': + x = x + self.ffnPre(self.ln1(x)) # better in some cases + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + return x + + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + self.step = 0 + self.config = config + + self.emb = nn.Embedding(config.vocab_size, config.n_embd) + + self.blocks = nn.Sequential(*[Block(config, i) + for i in range(config.n_layer)]) + + self.ln_out = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + if RWKV_HEAD_QK_DIM > 0: + self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_q.scale_init = 0 + self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_k.scale_init = 0.1 + self.register_buffer("copy_mask", torch.tril( + torch.ones(config.ctx_len, config.ctx_len))) + + self.ctx_len = config.ctx_len + + try: + if os.environ['RWKV_LOAD_MODEL'] == str(False): + RWKV_Init(self, config) + except: + pass + + logger.info("number of parameters: %e", sum(p.numel() + for p in self.parameters())) + + def get_ctx_len(self): + return self.ctx_len + + def _init_weights(self, module): + if isinstance(module, (nn.Linear)): + module.weight.data.normal_(mean=0.0, std=0.01) + if isinstance(module, (nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=1e-5) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def configure_optimizers(self, train_config): + no_decay = set() + + for mn, m in self.named_modules(): # here we disable weight_decay + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + no_decay.add(fpn) + + param_dict = {pn: p for pn, p in self.named_parameters()} + optim_groups = [ + {"params": [param_dict[pn] + for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + + try: + optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False) + except: + print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n') + optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps) + + return optimizer + + def forward(self, idx, targets=None): + idx = idx.to(self.emb.weight.device) + + self.step += 1 + B, T = idx.size() + assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." + + x = self.emb(idx) + x = self.blocks(x) + x = self.ln_out(x) + + if RWKV_HEAD_QK_DIM > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + + if '32' in os.environ['RWKV_FLOAT_MODE']: + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half() + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16() + + x = self.head(x) + c + else: + x = self.head(x) + + loss = None + if targets is not None: + loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1)) + + return L2Wrap.apply(loss, x) diff --git a/models/RWKV4/src/model_run.py b/models/RWKV4/src/model_run.py new file mode 100644 index 00000000..0373dda5 --- /dev/null +++ b/models/RWKV4/src/model_run.py @@ -0,0 +1,392 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import types +import copy +import torch +import math, os +from torch.nn import functional as F +import torch.nn as nn + +RWKV_HEAD_QK_DIM = 0 +print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n') + +DEBUG_TIME = False # True False - show trained time-coeffs + +######################################################################################################## +# CUDA Kernel +######################################################################################################## + +if os.environ['RWKV_RUN_DEVICE'] == 'cuda': + T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!] + # it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice + + from torch.utils.cpp_extension import load + wkv_cuda = load(name="wkv", sources=["models/RWKV4/cuda/wkv_op.cpp", "models/RWKV4/cuda/wkv_cuda.cu"], + verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}']) + + class WKV(torch.autograd.Function): + @staticmethod + def forward(ctx, B, T, C, w, u, k, v): + ctx.B = B + ctx.T = T + ctx.C = C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + if '32' in os.environ['RWKV_FLOAT_MODE']: + w = -torch.exp(w.contiguous()) + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + else: + w = -torch.exp(w.float().contiguous()) + u = u.float().contiguous() + k = k.float().contiguous() + v = v.float().contiguous() + ctx.save_for_backward(w, u, k, v) + y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format) + wkv_cuda.forward(B, T, C, w, u, k, v, y) + if '32' in os.environ['RWKV_FLOAT_MODE']: + return y + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + return y.half() + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + return y.bfloat16() + + @staticmethod + def backward(ctx, gy): + B = ctx.B + T = ctx.T + C = ctx.C + assert T <= T_MAX + assert B * C % min(C, 1024) == 0 + w, u, k, v = ctx.saved_tensors + gw = torch.zeros((B, C), device='cuda').contiguous() + gu = torch.zeros((B, C), device='cuda').contiguous() + gk = torch.zeros((B, T, C), device='cuda').contiguous() + gv = torch.zeros((B, T, C), device='cuda').contiguous() + if '32' in os.environ['RWKV_FLOAT_MODE']: + wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) + else: + wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + gw = torch.sum(gw, dim=0) + gu = torch.sum(gu, dim=0) + if '32' in os.environ['RWKV_FLOAT_MODE']: + return (None, None, None, gw, gu, gk, gv) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half()) + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16()) + + def RUN_CUDA(B, T, C, w, u, k, v): + return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda()) + +############################################################################################################ + +RWKV_CFG = types.SimpleNamespace() + +class RWKV_ChannelMix(nn.Module): + def __init__(self, layer_id): + super().__init__() + self.layer_id = layer_id + + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) + self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd)) + + hidden_sz = 4 * RWKV_CFG.n_embd + self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False) + self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False) + + def forward(self, x): + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + k = self.key(xk) + k = torch.square(torch.relu(k)) + kv = self.value(k) + + rkv = torch.sigmoid(self.receptance(xr)) * kv + return rkv + +class RWKV_TimeMix(nn.Module): + def __init__(self, layer_id): + super().__init__() + self.layer_id = layer_id + self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd)) + self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3)) + + self.time_shift = nn.ZeroPad2d((0,0,1,-1)) + self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd)) + + self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + + self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False) + + def forward(self, x): + B, T, C = x.size() + + xx = self.time_shift(x) + xk = x * self.time_mix_k + xx * (1 - self.time_mix_k) + xv = x * self.time_mix_v + xx * (1 - self.time_mix_v) + xr = x * self.time_mix_r + xx * (1 - self.time_mix_r) + + k = self.key(xk) + v = self.value(xv) + r = self.receptance(xr) + + rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v) + + rwkv = self.output(rwkv) + return rwkv + +class Block(nn.Module): + def __init__(self, layer_id): + super().__init__() + self.layer_id = layer_id + + self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd) + self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd) + if self.layer_id == 0: + self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd) + + if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': + self.ffnPre = RWKV_ChannelMix(layer_id+1000) + else: + self.att = RWKV_TimeMix(layer_id) + + self.ffn = RWKV_ChannelMix(layer_id) + + def forward(self, x): + if self.layer_id == 0: + x = self.ln0(x) + if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre': + x = x + self.ffnPre(self.ln1(x)) + else: + x = x + self.att(self.ln1(x)) + x = x + self.ffn(self.ln2(x)) + return x + +class RWKV_GPT(nn.Module): + def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len): + global RWKV_CFG + super().__init__() + + RWKV_CFG.RUN_DEVICE = RUN_DEVICE + RWKV_CFG.model_type = model_type + RWKV_CFG.vocab_size = vocab_size + RWKV_CFG.n_layer = n_layer + RWKV_CFG.n_embd = n_embd + RWKV_CFG.ctx_len = ctx_len + + print('\nloading RWKV-GPT', MODEL_NAME) + + self.emb = nn.Embedding(vocab_size, n_embd) + + self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)]) + + self.ln_out = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, vocab_size, bias=False) + + if RWKV_HEAD_QK_DIM > 0: + self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_q.scale_init = 0 + self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False) + self.head_k.scale_init = 0.1 + self.register_buffer("copy_mask", torch.tril( + torch.ones(ctx_len, ctx_len))) + + self.ctx_len = ctx_len + self.eval() + self.load_state_dict(torch.load(MODEL_NAME + '.pth')) + self.eval() + + def forward(self, idx): + B, T = idx.size() + assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len." + + x = self.emb(idx) + x = self.blocks(x) + x = self.ln_out(x) + + if RWKV_HEAD_QK_DIM > 0: + q = self.head_q(x)[:, :T, :] + k = self.head_k(x)[:, :T, :] + c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM) + c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) + + if '32' in os.environ['RWKV_FLOAT_MODE']: + c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size) + elif os.environ['RWKV_FLOAT_MODE'] == 'fp16': + c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half() + elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': + c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16() + + x = self.head(x) + c + else: + x = self.head(x) + + return x + +############################################################################################################ + +class RWKV_RNN(): # this is running in FP32 at this moment + def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): + self.RUN_DEVICE = RUN_DEVICE + self.model_type = model_type + self.n_layer = n_layer + self.n_embd = n_embd + self.ctx_len = ctx_len + + self.w = types.SimpleNamespace() + + w = torch.load(MODEL_NAME + '.pth', + map_location=torch.device(RUN_DEVICE)) + for x in w.keys(): + w[x] = w[x].float() + if '.time_' in x: + w[x] = w[x].squeeze() + if '.time_decay' in x: + w[x] = -torch.exp(w[x]) + if DEBUG_TIME and '.time_' in x: + print(x, w[x].squeeze().cpu().numpy()) + + xx = x.split('.') + here = self.w + for i in range(len(xx)): + if xx[i].isdigit(): + ii = int(xx[i]) + if ii not in here: + here[ii] = types.SimpleNamespace() + here = here[ii] + else: + if i == len(xx) - 1: + setattr(here, xx[i], w[x]) + elif not hasattr(here, xx[i]): + if xx[i+1].isdigit(): + setattr(here, xx[i], {}) + else: + setattr(here, xx[i], types.SimpleNamespace()) + here = getattr(here, xx[i]) + + self.clear() + + def clear(self): + self.xx = {} + self.aa = {} + self.bb = {} + self.pp = {} + self.hk = None + + def save(self, target): + target.xx = copy.deepcopy(self.xx) + target.aa = copy.deepcopy(self.aa) + target.bb = copy.deepcopy(self.bb) + target.pp = copy.deepcopy(self.pp) + target.hk = copy.deepcopy(self.hk) + + def load(self, target): + self.xx = copy.deepcopy(target.xx) + self.aa = copy.deepcopy(target.aa) + self.bb = copy.deepcopy(target.bb) + self.pp = copy.deepcopy(target.pp) + self.hk = copy.deepcopy(target.hk) + + def LN(self, xx, w): + return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias) + + def FF(self, xx, w, name): + if name not in self.xx: + self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) + xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) + self.xx[name] = xx + + r = torch.sigmoid(w.receptance.weight @ xr) + k = torch.square(torch.relu(w.key.weight @ xk)) + kv = w.value.weight @ k + + return r * kv + + def SA(self, xx, w, name): + if name not in self.xx: + self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) + self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30 + + xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k) + xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v) + xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r) + self.xx[name] = xx + + r = torch.sigmoid(w.receptance.weight @ xr) + + k = w.key.weight @ xk + v = w.value.weight @ xv + + pp = self.pp[name] + aa = self.aa[name] + bb = self.bb[name] + ww = w.time_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + a = e1 * aa + e2 * v + b = e1 * bb + e2 + ww = pp + w.time_decay + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + self.aa[name] = e1 * aa + e2 * v + self.bb[name] = e1 * bb + e2 + self.pp[name] = p + + rwkv = r * a / b + + return w.output.weight @ rwkv + + def run(self, ctx): + w = self.w + x = w.emb.weight[ctx[-1]] + + for i in range(self.n_layer): + if i == 0: + x = self.LN(x, w.blocks[i].ln0) + if i == 0 and self.model_type == 'RWKV-ffnPre': + x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}') + else: + x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}') + x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}') + + x = self.LN(x, w.ln_out) + + if RWKV_HEAD_QK_DIM > 0: + if self.hk == None: + self.hk = (w.head_k.weight @ x).unsqueeze(0) + else: + self.hk = torch.cat( + [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) + if self.hk.shape[0] > self.ctx_len: + self.hk = self.hk[-self.ctx_len:, :] + + q = w.head_q.weight @ x + + x = w.head.weight @ x + x = x.cpu().numpy().tolist() + + c = (self.hk @ q) / RWKV_HEAD_QK_DIM + for i in range(len(c)): + x[ctx[i]] += c[i] + else: + x = w.head.weight @ x + x = x.cpu().numpy().tolist() + + return x diff --git a/models/RWKV4/src/utils.py b/models/RWKV4/src/utils.py new file mode 100644 index 00000000..3393fd85 --- /dev/null +++ b/models/RWKV4/src/utils.py @@ -0,0 +1,95 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import os +try: + NUM_GPUS = int(os.environ['RWKV_NUM_GPUS']) +except: + NUM_GPUS = 1 + +import json +import random +import numpy as np +import torch +from torch.nn import functional as F + +class TOKENIZER(): + def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): + if 'list' in str(type(WORD_NAME)): + self.charMode = False + if WORD_NAME[0] == WORD_NAME[1]: + from transformers import PreTrainedTokenizerFast + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) + else: + from transformers import GPT2TokenizerFast + self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) + self.vocab_size = len(self.tokenizer) + else: + self.charMode = True + with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: + self.word_table = json.load(result_file) + + self.vocab_size = len(self.word_table) + + self.stoi = {v: int(k) for k, v in self.word_table.items()} + self.itos = {int(k): v for k, v in self.word_table.items()} + + self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] + + def refine_context(self, context): + context = context.strip().split('\n') + for c in range(len(context)): + context[c] = context[c].strip().strip('\u3000').strip('\r') + context = list(filter(lambda c: c != '', context)) + context = '\n' + ('\n'.join(context)).strip() + if context == '': + context = '\n' + return context + + def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): + # out[self.UNKNOWN_CHAR] = -float('Inf') + + lastChar = int(x[-1]) + + probs = F.softmax(torch.tensor(out), dim=-1) + + if self.charMode: + if self.itos[lastChar] == '\n': + top_p = top_p_newline + else: + top_p = top_p_usual + else: + top_p = top_p_usual + + sorted_probs, s_index = torch.sort(probs, descending=True) + + # for j in range(30): + # pp = sorted_probs[j].item() + # if pp < 0.005: + # break + # ss = self.itos[int(s_index[j])].replace('\n','_') + # print(f'{math.floor(pp*100):>3.0f}{ss}', end='') + # print('') + + cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy() + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + + probs[probs < cutoff] = 0 + # print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "") + + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + + return torch.multinomial(probs, num_samples=1)[0] + + +def to_float(x): + return x.cpu().detach().numpy().flatten()[0].astype(float) + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed)