Fix soft prompt loading code

This commit is contained in:
Gnome Ann 2021-10-28 00:29:42 -04:00
parent 4e3cc93020
commit 248e0bd24b
1 changed files with 5 additions and 3 deletions

View File

@ -2295,9 +2295,10 @@ def spRequest(filename):
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
assert isinstance(z, zipfile.ZipFile)
z.close()
with z.open('tensor.npy') as f:
tensor = np.load(f, allow_pickle=False)
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
tensor = f['tensor.npy']
# If the tensor is in bfloat16 format, convert it to float32
if(tensor.dtype == 'V2'):
@ -2305,7 +2306,8 @@ def spRequest(filename):
tensor = np.uint32(tensor) << 16
tensor.dtype = np.float32
tensor = np.float16(tensor)
if(tensor.dtype != np.float16):
tensor = np.float32(tensor)
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
vars.sp = torch.from_numpy(tensor)