Fix soft prompt loading code
This commit is contained in:
parent
4e3cc93020
commit
248e0bd24b
|
@ -2295,9 +2295,10 @@ def spRequest(filename):
|
|||
|
||||
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
|
||||
assert isinstance(z, zipfile.ZipFile)
|
||||
z.close()
|
||||
|
||||
with z.open('tensor.npy') as f:
|
||||
tensor = np.load(f, allow_pickle=False)
|
||||
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
|
||||
tensor = f['tensor.npy']
|
||||
|
||||
# If the tensor is in bfloat16 format, convert it to float32
|
||||
if(tensor.dtype == 'V2'):
|
||||
|
@ -2305,7 +2306,8 @@ def spRequest(filename):
|
|||
tensor = np.uint32(tensor) << 16
|
||||
tensor.dtype = np.float32
|
||||
|
||||
tensor = np.float16(tensor)
|
||||
if(tensor.dtype != np.float16):
|
||||
tensor = np.float32(tensor)
|
||||
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
|
||||
|
||||
vars.sp = torch.from_numpy(tensor)
|
||||
|
|
Loading…
Reference in New Issue