mirror of
https://github.com/KoboldAI/KoboldAI-Client.git
synced 2025-06-05 21:59:24 +02:00
Add RWKV-4 code
This commit is contained in:
204
models/RWKV4/LICENSE.txt
Normal file
204
models/RWKV4/LICENSE.txt
Normal file
@@ -0,0 +1,204 @@
|
||||
Code in this directory is taken from https://github.com/BlinkDL/RWKV-LM.
|
||||
The license for this code is as follows:
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
125
models/RWKV4/cuda/wkv_cuda.cu
Normal file
125
models/RWKV4/cuda/wkv_cuda.cu
Normal file
@@ -0,0 +1,125 @@
|
||||
#include <stdio.h>
|
||||
#include <assert.h>
|
||||
|
||||
#define MIN_VALUE (-1e38)
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_forward(const int B, const int T, const int C,
|
||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
|
||||
F *__restrict__ const _y) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
F *__restrict__ const y = _y + _offset;
|
||||
|
||||
F p = 0, q = 0, o = MIN_VALUE;
|
||||
// p and q are running sums divided by exp(o) (to avoid overflows)
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
|
||||
F no = max(o, u + k[ii]);
|
||||
F A = exp(o - no);
|
||||
F B = exp(u + k[ii] - no);
|
||||
y[ii] = (A * p + B * v[ii]) / (A * q + B);
|
||||
|
||||
no = max(w + o, k[ii]);
|
||||
A = exp(w + o - no);
|
||||
B = exp(k[ii] - no);
|
||||
p = A * p + B * v[ii];
|
||||
q = A * q + B;
|
||||
o = no;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__global__ void kernel_backward(const int B, const int T, const int C,
|
||||
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
|
||||
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int _b = idx / C;
|
||||
const int _c = idx % C;
|
||||
const int _offset = _b * T * C + _c;
|
||||
|
||||
F u = _u[_c];
|
||||
F w = _w[_c];
|
||||
const F *__restrict__ const k = _k + _offset;
|
||||
const F *__restrict__ const v = _v + _offset;
|
||||
const F *__restrict__ const gy = _gy + _offset;
|
||||
|
||||
F *__restrict__ const gk = _gk + _offset;
|
||||
F *__restrict__ const gv = _gv + _offset;
|
||||
|
||||
F y[Tmax], z[Tmax], zexp[Tmax];
|
||||
|
||||
F gw = 0, gu = 0;
|
||||
F p = 0, q = 0;
|
||||
F dpdw = 0, dqdw = 0;
|
||||
F o = MIN_VALUE;
|
||||
for (int i = 0; i < T; i++) {
|
||||
const int ii = i * C;
|
||||
F no = max(o, k[ii] + u);
|
||||
F A = exp(o - no);
|
||||
F B = exp(k[ii] + u - no);
|
||||
|
||||
F num = A * p + B * v[ii];
|
||||
F iden = 1 / (A * q + B);
|
||||
|
||||
y[i] = num * iden;
|
||||
z[i] = iden;
|
||||
zexp[i] = k[ii] + u - no;
|
||||
|
||||
gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
|
||||
gu += gy[ii] * (v[ii] - y[i]) * B * iden;
|
||||
|
||||
no = max(w + o, k[ii]);
|
||||
A = exp(w + o - no);
|
||||
B = exp(k[ii] - no);
|
||||
dpdw = A * (p + dpdw);
|
||||
dqdw = A * (q + dqdw);
|
||||
p = A * p + B * v[ii];
|
||||
q = A * q + B;
|
||||
o = no;
|
||||
}
|
||||
|
||||
F gp = 0, gq = 0;
|
||||
o = MIN_VALUE;
|
||||
for (int i = T - 1; i >= 0; i--) {
|
||||
const int ii = i * C;
|
||||
F A = gy[ii] * z[i] * exp(zexp[i]);
|
||||
F B = exp(k[ii] + o);
|
||||
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
|
||||
gv[ii] = A + B * gp;
|
||||
|
||||
F no = max(w + o, zexp[i] - k[ii] - u);
|
||||
A = exp(w + o - no);
|
||||
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
|
||||
gp = A * gp + B;
|
||||
gq = A * gq - B * y[i];
|
||||
o = no;
|
||||
}
|
||||
|
||||
// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
|
||||
const int _offsetBC = _b * C + _c;
|
||||
_gw[_offsetBC] += gw * _w[_c];
|
||||
_gu[_offsetBC] += gu;
|
||||
}
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
|
||||
}
|
||||
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
|
||||
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
|
||||
assert(B * C % threadsPerBlock.x == 0);
|
||||
dim3 numBlocks(B * C / threadsPerBlock.x);
|
||||
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
|
||||
}
|
21
models/RWKV4/cuda/wkv_op.cpp
Normal file
21
models/RWKV4/cuda/wkv_op.cpp
Normal file
@@ -0,0 +1,21 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
|
||||
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
|
||||
|
||||
void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
|
||||
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
|
||||
}
|
||||
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
|
||||
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("forward", &forward, "wkv forward");
|
||||
m.def("backward", &backward, "wkv backward");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(wkv, m) {
|
||||
m.def("forward", forward);
|
||||
m.def("backward", backward);
|
||||
}
|
BIN
models/RWKV4/src/__pycache__/model_run.cpython-37.pyc
Normal file
BIN
models/RWKV4/src/__pycache__/model_run.cpython-37.pyc
Normal file
Binary file not shown.
BIN
models/RWKV4/src/__pycache__/model_run.cpython-38.pyc
Normal file
BIN
models/RWKV4/src/__pycache__/model_run.cpython-38.pyc
Normal file
Binary file not shown.
BIN
models/RWKV4/src/__pycache__/model_run.cpython-39.pyc
Normal file
BIN
models/RWKV4/src/__pycache__/model_run.cpython-39.pyc
Normal file
Binary file not shown.
BIN
models/RWKV4/src/__pycache__/utils.cpython-38.pyc
Normal file
BIN
models/RWKV4/src/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
models/RWKV4/src/__pycache__/utils.cpython-39.pyc
Normal file
BIN
models/RWKV4/src/__pycache__/utils.cpython-39.pyc
Normal file
Binary file not shown.
416
models/RWKV4/src/model.py
Normal file
416
models/RWKV4/src/model.py
Normal file
@@ -0,0 +1,416 @@
|
||||
# File from RWKV-v4 Repo - Small changes made for compatibility
|
||||
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import math, os
|
||||
import numpy as np
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
try:
|
||||
from deepspeed.ops.adam import FusedAdam
|
||||
except:
|
||||
pass # some poor windows users cant install deepspeed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RWKV_HEAD_QK_DIM = 0
|
||||
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
||||
|
||||
class L2Wrap(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, loss, y):
|
||||
ctx.save_for_backward(y)
|
||||
return loss
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
y = ctx.saved_tensors[0]
|
||||
# to encourage the logits to be close to 0
|
||||
factor = 1e-4 / (y.shape[0] * y.shape[1])
|
||||
maxx, ids = torch.max(y, -1, keepdim=True)
|
||||
gy = torch.zeros_like(y)
|
||||
gy.scatter_(-1, ids, maxx * factor)
|
||||
return (grad_output, gy)
|
||||
|
||||
########################################################################################################
|
||||
# CUDA Kernel
|
||||
########################################################################################################
|
||||
|
||||
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
|
||||
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
wkv_cuda = load(name="wkv", sources=["models/RWKV4/cuda/wkv_op.cpp", "models/RWKV4/cuda/wkv_cuda.cu"],
|
||||
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
|
||||
|
||||
class WKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, B, T, C, w, u, k, v):
|
||||
ctx.B = B
|
||||
ctx.T = T
|
||||
ctx.C = C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 1024) == 0
|
||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
||||
w = -torch.exp(w.contiguous())
|
||||
u = u.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
else:
|
||||
w = -torch.exp(w.float().contiguous())
|
||||
u = u.float().contiguous()
|
||||
k = k.float().contiguous()
|
||||
v = v.float().contiguous()
|
||||
ctx.save_for_backward(w, u, k, v)
|
||||
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
|
||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
||||
return y
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
||||
return y.half()
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
||||
return y.bfloat16()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
B = ctx.B
|
||||
T = ctx.T
|
||||
C = ctx.C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 1024) == 0
|
||||
w, u, k, v = ctx.saved_tensors
|
||||
gw = torch.zeros((B, C), device='cuda').contiguous()
|
||||
gu = torch.zeros((B, C), device='cuda').contiguous()
|
||||
gk = torch.zeros((B, T, C), device='cuda').contiguous()
|
||||
gv = torch.zeros((B, T, C), device='cuda').contiguous()
|
||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
|
||||
else:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
|
||||
gw = torch.sum(gw, dim=0)
|
||||
gu = torch.sum(gu, dim=0)
|
||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
||||
return (None, None, None, gw, gu, gk, gv)
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
||||
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
||||
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
||||
|
||||
def RUN_CUDA(B, T, C, w, u, k, v):
|
||||
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
|
||||
|
||||
########################################################################################################
|
||||
# RWKV: RWKV Time-mix + RWKV Channel-mix
|
||||
########################################################################################################
|
||||
|
||||
def RWKV_Init(model, args): # fancy initialization of all lin & emb layer in the model
|
||||
print("\n[--> first run, init model params (very slow for large models) <--]")
|
||||
print("[so you shall only do it for 1 single GPU and save the checkpt and load it when using multiple GPU]\n")
|
||||
|
||||
for mm in model.modules():
|
||||
if "RecursiveScriptModule" in str(type(mm)):
|
||||
if mm.original_name not in ["Linear"]:
|
||||
continue
|
||||
ww = None
|
||||
for name, param in mm.named_parameters():
|
||||
if name == "weight":
|
||||
ww = param
|
||||
else:
|
||||
m = mm
|
||||
if not isinstance(m, (nn.Linear, nn.Embedding)):
|
||||
continue
|
||||
ww = m.weight
|
||||
with torch.no_grad():
|
||||
name = "[unknown weight]"
|
||||
for name, parameter in model.named_parameters(): # find the name of the weight
|
||||
if id(ww) == id(parameter):
|
||||
break
|
||||
|
||||
shape = ww.shape
|
||||
gain = 1.0
|
||||
scale = 1.0 # extra scale for gain
|
||||
|
||||
if isinstance(m, nn.Embedding):
|
||||
gain = math.sqrt(max(shape[0], shape[1]))
|
||||
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # token emb?
|
||||
scale = 1e-4
|
||||
else:
|
||||
scale = 0
|
||||
|
||||
if isinstance(m, nn.Linear):
|
||||
if shape[0] > shape[1]:
|
||||
gain = math.sqrt(shape[0] / shape[1])
|
||||
if shape[0] == args.vocab_size and shape[1] == args.n_embd: # final projection?
|
||||
scale = 0.5
|
||||
|
||||
if hasattr(m, "scale_init"):
|
||||
scale = m.scale_init
|
||||
|
||||
# print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {str(scale).ljust(4)} {name}")
|
||||
|
||||
gain *= scale
|
||||
if scale == -999:
|
||||
nn.init.eye_(ww)
|
||||
elif gain == 0:
|
||||
# zero init is great for some RWKV matrices
|
||||
nn.init.zeros_(ww)
|
||||
elif gain > 0:
|
||||
nn.init.orthogonal_(ww, gain=gain)
|
||||
else:
|
||||
nn.init.normal_(ww, mean=0.0, std=-scale)
|
||||
|
||||
|
||||
class RWKV_TimeMix(torch.jit.ScriptModule):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.ctx_len = config.ctx_len
|
||||
self.n_embd = config.n_embd
|
||||
|
||||
attn_sz = config.n_embd
|
||||
|
||||
with torch.no_grad(): # fancy init
|
||||
ratio_0_to_1 = (layer_id / (config.n_layer - 1)) # 0 to 1
|
||||
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
|
||||
|
||||
# fancy time_decay
|
||||
decay_speed = torch.ones(attn_sz)
|
||||
for h in range(attn_sz):
|
||||
decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
||||
self.time_decay = nn.Parameter(decay_speed)
|
||||
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
|
||||
|
||||
# fancy time_first
|
||||
zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5)
|
||||
self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
|
||||
|
||||
# fancy time_mix
|
||||
x = torch.ones(1, 1, config.n_embd)
|
||||
for i in range(config.n_embd):
|
||||
x[0, 0, i] = i / config.n_embd
|
||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
|
||||
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||
self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||
self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
|
||||
|
||||
self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
|
||||
|
||||
self.key.scale_init = 0
|
||||
self.receptance.scale_init = 0
|
||||
self.output.scale_init = 0
|
||||
|
||||
@torch.jit.script_method
|
||||
def jit_func(self, x):
|
||||
|
||||
# Mix x with the previous timestep to produce xk, xv, xr
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
|
||||
# Use xk, xv, xr to produce k, v, r
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
r = self.receptance(xr)
|
||||
sr = torch.sigmoid(r)
|
||||
|
||||
return sr, k, v
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # x = (Batch,Time,Channel)
|
||||
|
||||
sr, k, v = self.jit_func(x)
|
||||
|
||||
rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
||||
rwkv = self.output(rwkv)
|
||||
return rwkv
|
||||
|
||||
|
||||
class RWKV_ChannelMix(torch.jit.ScriptModule):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
|
||||
|
||||
with torch.no_grad(): # fancy init of time_mix
|
||||
ratio_1_to_almost0 = (1.0 - (layer_id / config.n_layer)) # 1 to ~0
|
||||
|
||||
x = torch.ones(1, 1, config.n_embd)
|
||||
for i in range(config.n_embd):
|
||||
x[0, 0, i] = i / config.n_embd
|
||||
|
||||
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
|
||||
|
||||
hidden_sz = 4 * config.n_embd
|
||||
self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
|
||||
self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
||||
self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
|
||||
|
||||
self.value.scale_init = 0
|
||||
self.receptance.scale_init = 0
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x):
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
|
||||
k = self.key(xk)
|
||||
k = torch.square(torch.relu(k))
|
||||
kv = self.value(k)
|
||||
|
||||
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
||||
return rkv
|
||||
|
||||
########################################################################################################
|
||||
# The GPT Model with our blocks
|
||||
########################################################################################################
|
||||
|
||||
|
||||
class GPTConfig:
|
||||
def __init__(self, vocab_size, ctx_len, **kwargs):
|
||||
self.vocab_size = vocab_size
|
||||
self.ctx_len = ctx_len
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.ln1 = nn.LayerNorm(config.n_embd)
|
||||
self.ln2 = nn.LayerNorm(config.n_embd)
|
||||
|
||||
if self.layer_id == 0:
|
||||
self.ln0 = nn.LayerNorm(config.n_embd)
|
||||
|
||||
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
||||
self.ffnPre = RWKV_ChannelMix(config, 0)
|
||||
else:
|
||||
self.att = RWKV_TimeMix(config, layer_id)
|
||||
|
||||
self.ffn = RWKV_ChannelMix(config, layer_id)
|
||||
|
||||
def forward(self, x):
|
||||
if self.layer_id == 0:
|
||||
x = self.ln0(x)
|
||||
if self.layer_id == 0 and self.config.model_type == 'RWKV-ffnPre':
|
||||
x = x + self.ffnPre(self.ln1(x)) # better in some cases
|
||||
else:
|
||||
x = x + self.att(self.ln1(x))
|
||||
x = x + self.ffn(self.ln2(x))
|
||||
return x
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.step = 0
|
||||
self.config = config
|
||||
|
||||
self.emb = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
|
||||
self.blocks = nn.Sequential(*[Block(config, i)
|
||||
for i in range(config.n_layer)])
|
||||
|
||||
self.ln_out = nn.LayerNorm(config.n_embd)
|
||||
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
if RWKV_HEAD_QK_DIM > 0:
|
||||
self.head_q = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
||||
self.head_q.scale_init = 0
|
||||
self.head_k = nn.Linear(config.n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
||||
self.head_k.scale_init = 0.1
|
||||
self.register_buffer("copy_mask", torch.tril(
|
||||
torch.ones(config.ctx_len, config.ctx_len)))
|
||||
|
||||
self.ctx_len = config.ctx_len
|
||||
|
||||
try:
|
||||
if os.environ['RWKV_LOAD_MODEL'] == str(False):
|
||||
RWKV_Init(self, config)
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.info("number of parameters: %e", sum(p.numel()
|
||||
for p in self.parameters()))
|
||||
|
||||
def get_ctx_len(self):
|
||||
return self.ctx_len
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.01)
|
||||
if isinstance(module, (nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=1e-5)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def configure_optimizers(self, train_config):
|
||||
no_decay = set()
|
||||
|
||||
for mn, m in self.named_modules(): # here we disable weight_decay
|
||||
for pn, p in m.named_parameters():
|
||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
||||
no_decay.add(fpn)
|
||||
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
optim_groups = [
|
||||
{"params": [param_dict[pn]
|
||||
for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
try:
|
||||
optimizer = FusedAdam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps, bias_correction=True, adam_w_mode=False, weight_decay=0, amsgrad=False)
|
||||
except:
|
||||
print('\n\nDeepSpeed not found. Using torch optimizer instead (probably slower)\n\n')
|
||||
optimizer = torch.optim.Adam(optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)
|
||||
|
||||
return optimizer
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
idx = idx.to(self.emb.weight.device)
|
||||
|
||||
self.step += 1
|
||||
B, T = idx.size()
|
||||
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
||||
|
||||
x = self.emb(idx)
|
||||
x = self.blocks(x)
|
||||
x = self.ln_out(x)
|
||||
|
||||
if RWKV_HEAD_QK_DIM > 0:
|
||||
q = self.head_q(x)[:, :T, :]
|
||||
k = self.head_k(x)[:, :T, :]
|
||||
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||||
|
||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
||||
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size)
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
||||
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).half()
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
||||
c = c @ F.one_hot(idx, num_classes=self.config.vocab_size).bfloat16()
|
||||
|
||||
x = self.head(x) + c
|
||||
else:
|
||||
x = self.head(x)
|
||||
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1))
|
||||
|
||||
return L2Wrap.apply(loss, x)
|
392
models/RWKV4/src/model_run.py
Normal file
392
models/RWKV4/src/model_run.py
Normal file
@@ -0,0 +1,392 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import types
|
||||
import copy
|
||||
import torch
|
||||
import math, os
|
||||
from torch.nn import functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
RWKV_HEAD_QK_DIM = 0
|
||||
print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM}\n')
|
||||
|
||||
DEBUG_TIME = False # True False - show trained time-coeffs
|
||||
|
||||
########################################################################################################
|
||||
# CUDA Kernel
|
||||
########################################################################################################
|
||||
|
||||
if os.environ['RWKV_RUN_DEVICE'] == 'cuda':
|
||||
T_MAX = 1024 # increase this if your ctx_len is long [NOTE: TAKES LOTS OF VRAM!]
|
||||
# it's possible to go beyond CUDA limitations if you slice the ctx and pass the hidden state in each slice
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
wkv_cuda = load(name="wkv", sources=["models/RWKV4/cuda/wkv_op.cpp", "models/RWKV4/cuda/wkv_cuda.cu"],
|
||||
verbose=True, extra_cuda_cflags=['-res-usage', '--maxrregcount 60', '--use_fast_math', '-O3', '-Xptxas -O3', f'-DTmax={T_MAX}'])
|
||||
|
||||
class WKV(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, B, T, C, w, u, k, v):
|
||||
ctx.B = B
|
||||
ctx.T = T
|
||||
ctx.C = C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 1024) == 0
|
||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
||||
w = -torch.exp(w.contiguous())
|
||||
u = u.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
else:
|
||||
w = -torch.exp(w.float().contiguous())
|
||||
u = u.float().contiguous()
|
||||
k = k.float().contiguous()
|
||||
v = v.float().contiguous()
|
||||
ctx.save_for_backward(w, u, k, v)
|
||||
y = torch.empty((B, T, C), device='cuda', memory_format=torch.contiguous_format)
|
||||
wkv_cuda.forward(B, T, C, w, u, k, v, y)
|
||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
||||
return y
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
||||
return y.half()
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
||||
return y.bfloat16()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gy):
|
||||
B = ctx.B
|
||||
T = ctx.T
|
||||
C = ctx.C
|
||||
assert T <= T_MAX
|
||||
assert B * C % min(C, 1024) == 0
|
||||
w, u, k, v = ctx.saved_tensors
|
||||
gw = torch.zeros((B, C), device='cuda').contiguous()
|
||||
gu = torch.zeros((B, C), device='cuda').contiguous()
|
||||
gk = torch.zeros((B, T, C), device='cuda').contiguous()
|
||||
gv = torch.zeros((B, T, C), device='cuda').contiguous()
|
||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
|
||||
else:
|
||||
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
|
||||
gw = torch.sum(gw, dim=0)
|
||||
gu = torch.sum(gu, dim=0)
|
||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
||||
return (None, None, None, gw, gu, gk, gv)
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
||||
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
||||
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
|
||||
|
||||
def RUN_CUDA(B, T, C, w, u, k, v):
|
||||
return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
|
||||
|
||||
############################################################################################################
|
||||
|
||||
RWKV_CFG = types.SimpleNamespace()
|
||||
|
||||
class RWKV_ChannelMix(nn.Module):
|
||||
def __init__(self, layer_id):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
||||
self.time_mix_k = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
|
||||
self.time_mix_r = nn.Parameter(torch.ones(1, 1, RWKV_CFG.n_embd))
|
||||
|
||||
hidden_sz = 4 * RWKV_CFG.n_embd
|
||||
self.key = nn.Linear(RWKV_CFG.n_embd, hidden_sz, bias=False)
|
||||
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
||||
self.value = nn.Linear(hidden_sz, RWKV_CFG.n_embd, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
|
||||
k = self.key(xk)
|
||||
k = torch.square(torch.relu(k))
|
||||
kv = self.value(k)
|
||||
|
||||
rkv = torch.sigmoid(self.receptance(xr)) * kv
|
||||
return rkv
|
||||
|
||||
class RWKV_TimeMix(nn.Module):
|
||||
def __init__(self, layer_id):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.time_decay = nn.Parameter(torch.ones(RWKV_CFG.n_embd))
|
||||
self.time_first = nn.Parameter(torch.ones(RWKV_CFG.n_embd) * math.log(0.3))
|
||||
|
||||
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
|
||||
self.time_mix_k = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
||||
self.time_mix_v = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
||||
self.time_mix_r = nn.Parameter(torch.ones(1,1,RWKV_CFG.n_embd))
|
||||
|
||||
self.key = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
||||
self.value = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
||||
self.receptance = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
||||
|
||||
self.output = nn.Linear(RWKV_CFG.n_embd, RWKV_CFG.n_embd, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size()
|
||||
|
||||
xx = self.time_shift(x)
|
||||
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
|
||||
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
|
||||
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
|
||||
|
||||
k = self.key(xk)
|
||||
v = self.value(xv)
|
||||
r = self.receptance(xr)
|
||||
|
||||
rwkv = torch.sigmoid(r) * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
|
||||
|
||||
rwkv = self.output(rwkv)
|
||||
return rwkv
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, layer_id):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.ln1 = nn.LayerNorm(RWKV_CFG.n_embd)
|
||||
self.ln2 = nn.LayerNorm(RWKV_CFG.n_embd)
|
||||
if self.layer_id == 0:
|
||||
self.ln0 = nn.LayerNorm(RWKV_CFG.n_embd)
|
||||
|
||||
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
|
||||
self.ffnPre = RWKV_ChannelMix(layer_id+1000)
|
||||
else:
|
||||
self.att = RWKV_TimeMix(layer_id)
|
||||
|
||||
self.ffn = RWKV_ChannelMix(layer_id)
|
||||
|
||||
def forward(self, x):
|
||||
if self.layer_id == 0:
|
||||
x = self.ln0(x)
|
||||
if self.layer_id == 0 and RWKV_CFG.model_type == 'RWKV-ffnPre':
|
||||
x = x + self.ffnPre(self.ln1(x))
|
||||
else:
|
||||
x = x + self.att(self.ln1(x))
|
||||
x = x + self.ffn(self.ln2(x))
|
||||
return x
|
||||
|
||||
class RWKV_GPT(nn.Module):
|
||||
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, vocab_size, n_layer, n_embd, ctx_len):
|
||||
global RWKV_CFG
|
||||
super().__init__()
|
||||
|
||||
RWKV_CFG.RUN_DEVICE = RUN_DEVICE
|
||||
RWKV_CFG.model_type = model_type
|
||||
RWKV_CFG.vocab_size = vocab_size
|
||||
RWKV_CFG.n_layer = n_layer
|
||||
RWKV_CFG.n_embd = n_embd
|
||||
RWKV_CFG.ctx_len = ctx_len
|
||||
|
||||
print('\nloading RWKV-GPT', MODEL_NAME)
|
||||
|
||||
self.emb = nn.Embedding(vocab_size, n_embd)
|
||||
|
||||
self.blocks = nn.Sequential(*[Block(i) for i in range(n_layer)])
|
||||
|
||||
self.ln_out = nn.LayerNorm(n_embd)
|
||||
self.head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||
|
||||
if RWKV_HEAD_QK_DIM > 0:
|
||||
self.head_q = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
||||
self.head_q.scale_init = 0
|
||||
self.head_k = nn.Linear(n_embd, RWKV_HEAD_QK_DIM, bias=False)
|
||||
self.head_k.scale_init = 0.1
|
||||
self.register_buffer("copy_mask", torch.tril(
|
||||
torch.ones(ctx_len, ctx_len)))
|
||||
|
||||
self.ctx_len = ctx_len
|
||||
self.eval()
|
||||
self.load_state_dict(torch.load(MODEL_NAME + '.pth'))
|
||||
self.eval()
|
||||
|
||||
def forward(self, idx):
|
||||
B, T = idx.size()
|
||||
assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
|
||||
|
||||
x = self.emb(idx)
|
||||
x = self.blocks(x)
|
||||
x = self.ln_out(x)
|
||||
|
||||
if RWKV_HEAD_QK_DIM > 0:
|
||||
q = self.head_q(x)[:, :T, :]
|
||||
k = self.head_k(x)[:, :T, :]
|
||||
c = (q @ k.transpose(-2, -1)) * (1.0 / RWKV_HEAD_QK_DIM)
|
||||
c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
|
||||
|
||||
if '32' in os.environ['RWKV_FLOAT_MODE']:
|
||||
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size)
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'fp16':
|
||||
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).half()
|
||||
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
|
||||
c = c @ F.one_hot(idx, num_classes=RWKV_CFG.vocab_size).bfloat16()
|
||||
|
||||
x = self.head(x) + c
|
||||
else:
|
||||
x = self.head(x)
|
||||
|
||||
return x
|
||||
|
||||
############################################################################################################
|
||||
|
||||
class RWKV_RNN(): # this is running in FP32 at this moment
|
||||
def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len):
|
||||
self.RUN_DEVICE = RUN_DEVICE
|
||||
self.model_type = model_type
|
||||
self.n_layer = n_layer
|
||||
self.n_embd = n_embd
|
||||
self.ctx_len = ctx_len
|
||||
|
||||
self.w = types.SimpleNamespace()
|
||||
|
||||
w = torch.load(MODEL_NAME + '.pth',
|
||||
map_location=torch.device(RUN_DEVICE))
|
||||
for x in w.keys():
|
||||
w[x] = w[x].float()
|
||||
if '.time_' in x:
|
||||
w[x] = w[x].squeeze()
|
||||
if '.time_decay' in x:
|
||||
w[x] = -torch.exp(w[x])
|
||||
if DEBUG_TIME and '.time_' in x:
|
||||
print(x, w[x].squeeze().cpu().numpy())
|
||||
|
||||
xx = x.split('.')
|
||||
here = self.w
|
||||
for i in range(len(xx)):
|
||||
if xx[i].isdigit():
|
||||
ii = int(xx[i])
|
||||
if ii not in here:
|
||||
here[ii] = types.SimpleNamespace()
|
||||
here = here[ii]
|
||||
else:
|
||||
if i == len(xx) - 1:
|
||||
setattr(here, xx[i], w[x])
|
||||
elif not hasattr(here, xx[i]):
|
||||
if xx[i+1].isdigit():
|
||||
setattr(here, xx[i], {})
|
||||
else:
|
||||
setattr(here, xx[i], types.SimpleNamespace())
|
||||
here = getattr(here, xx[i])
|
||||
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.xx = {}
|
||||
self.aa = {}
|
||||
self.bb = {}
|
||||
self.pp = {}
|
||||
self.hk = None
|
||||
|
||||
def save(self, target):
|
||||
target.xx = copy.deepcopy(self.xx)
|
||||
target.aa = copy.deepcopy(self.aa)
|
||||
target.bb = copy.deepcopy(self.bb)
|
||||
target.pp = copy.deepcopy(self.pp)
|
||||
target.hk = copy.deepcopy(self.hk)
|
||||
|
||||
def load(self, target):
|
||||
self.xx = copy.deepcopy(target.xx)
|
||||
self.aa = copy.deepcopy(target.aa)
|
||||
self.bb = copy.deepcopy(target.bb)
|
||||
self.pp = copy.deepcopy(target.pp)
|
||||
self.hk = copy.deepcopy(target.hk)
|
||||
|
||||
def LN(self, xx, w):
|
||||
return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias)
|
||||
|
||||
def FF(self, xx, w, name):
|
||||
if name not in self.xx:
|
||||
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
||||
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
||||
self.xx[name] = xx
|
||||
|
||||
r = torch.sigmoid(w.receptance.weight @ xr)
|
||||
k = torch.square(torch.relu(w.key.weight @ xk))
|
||||
kv = w.value.weight @ k
|
||||
|
||||
return r * kv
|
||||
|
||||
def SA(self, xx, w, name):
|
||||
if name not in self.xx:
|
||||
self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||
self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||
self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE)
|
||||
self.pp[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) - 1e30
|
||||
|
||||
xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
|
||||
xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
|
||||
xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)
|
||||
self.xx[name] = xx
|
||||
|
||||
r = torch.sigmoid(w.receptance.weight @ xr)
|
||||
|
||||
k = w.key.weight @ xk
|
||||
v = w.value.weight @ xv
|
||||
|
||||
pp = self.pp[name]
|
||||
aa = self.aa[name]
|
||||
bb = self.bb[name]
|
||||
ww = w.time_first + k
|
||||
p = torch.maximum(pp, ww)
|
||||
e1 = torch.exp(pp - p)
|
||||
e2 = torch.exp(ww - p)
|
||||
a = e1 * aa + e2 * v
|
||||
b = e1 * bb + e2
|
||||
ww = pp + w.time_decay
|
||||
p = torch.maximum(ww, k)
|
||||
e1 = torch.exp(ww - p)
|
||||
e2 = torch.exp(k - p)
|
||||
self.aa[name] = e1 * aa + e2 * v
|
||||
self.bb[name] = e1 * bb + e2
|
||||
self.pp[name] = p
|
||||
|
||||
rwkv = r * a / b
|
||||
|
||||
return w.output.weight @ rwkv
|
||||
|
||||
def run(self, ctx):
|
||||
w = self.w
|
||||
x = w.emb.weight[ctx[-1]]
|
||||
|
||||
for i in range(self.n_layer):
|
||||
if i == 0:
|
||||
x = self.LN(x, w.blocks[i].ln0)
|
||||
if i == 0 and self.model_type == 'RWKV-ffnPre':
|
||||
x = x + self.FF(self.LN(x, w.blocks[i].ln1), w.blocks[i].ffnPre, f'ffnPre.{i}')
|
||||
else:
|
||||
x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
|
||||
x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')
|
||||
|
||||
x = self.LN(x, w.ln_out)
|
||||
|
||||
if RWKV_HEAD_QK_DIM > 0:
|
||||
if self.hk == None:
|
||||
self.hk = (w.head_k.weight @ x).unsqueeze(0)
|
||||
else:
|
||||
self.hk = torch.cat(
|
||||
[self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0)
|
||||
if self.hk.shape[0] > self.ctx_len:
|
||||
self.hk = self.hk[-self.ctx_len:, :]
|
||||
|
||||
q = w.head_q.weight @ x
|
||||
|
||||
x = w.head.weight @ x
|
||||
x = x.cpu().numpy().tolist()
|
||||
|
||||
c = (self.hk @ q) / RWKV_HEAD_QK_DIM
|
||||
for i in range(len(c)):
|
||||
x[ctx[i]] += c[i]
|
||||
else:
|
||||
x = w.head.weight @ x
|
||||
x = x.cpu().numpy().tolist()
|
||||
|
||||
return x
|
95
models/RWKV4/src/utils.py
Normal file
95
models/RWKV4/src/utils.py
Normal file
@@ -0,0 +1,95 @@
|
||||
########################################################################################################
|
||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
|
||||
########################################################################################################
|
||||
|
||||
import os
|
||||
try:
|
||||
NUM_GPUS = int(os.environ['RWKV_NUM_GPUS'])
|
||||
except:
|
||||
NUM_GPUS = 1
|
||||
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
class TOKENIZER():
|
||||
def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'):
|
||||
if 'list' in str(type(WORD_NAME)):
|
||||
self.charMode = False
|
||||
if WORD_NAME[0] == WORD_NAME[1]:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0])
|
||||
else:
|
||||
from transformers import GPT2TokenizerFast
|
||||
self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1])
|
||||
self.vocab_size = len(self.tokenizer)
|
||||
else:
|
||||
self.charMode = True
|
||||
with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file:
|
||||
self.word_table = json.load(result_file)
|
||||
|
||||
self.vocab_size = len(self.word_table)
|
||||
|
||||
self.stoi = {v: int(k) for k, v in self.word_table.items()}
|
||||
self.itos = {int(k): v for k, v in self.word_table.items()}
|
||||
|
||||
self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR]
|
||||
|
||||
def refine_context(self, context):
|
||||
context = context.strip().split('\n')
|
||||
for c in range(len(context)):
|
||||
context[c] = context[c].strip().strip('\u3000').strip('\r')
|
||||
context = list(filter(lambda c: c != '', context))
|
||||
context = '\n' + ('\n'.join(context)).strip()
|
||||
if context == '':
|
||||
context = '\n'
|
||||
return context
|
||||
|
||||
def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None):
|
||||
# out[self.UNKNOWN_CHAR] = -float('Inf')
|
||||
|
||||
lastChar = int(x[-1])
|
||||
|
||||
probs = F.softmax(torch.tensor(out), dim=-1)
|
||||
|
||||
if self.charMode:
|
||||
if self.itos[lastChar] == '\n':
|
||||
top_p = top_p_newline
|
||||
else:
|
||||
top_p = top_p_usual
|
||||
else:
|
||||
top_p = top_p_usual
|
||||
|
||||
sorted_probs, s_index = torch.sort(probs, descending=True)
|
||||
|
||||
# for j in range(30):
|
||||
# pp = sorted_probs[j].item()
|
||||
# if pp < 0.005:
|
||||
# break
|
||||
# ss = self.itos[int(s_index[j])].replace('\n','_')
|
||||
# print(f'{math.floor(pp*100):>3.0f}{ss}', end='')
|
||||
# print('')
|
||||
|
||||
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
|
||||
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
||||
|
||||
probs[probs < cutoff] = 0
|
||||
# print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "")
|
||||
|
||||
if temperature != 1.0:
|
||||
probs = probs.pow(1.0 / temperature)
|
||||
|
||||
return torch.multinomial(probs, num_samples=1)[0]
|
||||
|
||||
|
||||
def to_float(x):
|
||||
return x.cpu().detach().numpy().flatten()[0].astype(float)
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
Reference in New Issue
Block a user