Fix soft prompt loading code
This commit is contained in:
parent
4e3cc93020
commit
248e0bd24b
|
@ -2295,9 +2295,10 @@ def spRequest(filename):
|
||||||
|
|
||||||
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
|
z, version, shape, fortran_order, dtype = fileops.checksp(filename, vars.modeldim)
|
||||||
assert isinstance(z, zipfile.ZipFile)
|
assert isinstance(z, zipfile.ZipFile)
|
||||||
|
z.close()
|
||||||
|
|
||||||
with z.open('tensor.npy') as f:
|
with np.load(fileops.sppath(filename), allow_pickle=False) as f:
|
||||||
tensor = np.load(f, allow_pickle=False)
|
tensor = f['tensor.npy']
|
||||||
|
|
||||||
# If the tensor is in bfloat16 format, convert it to float32
|
# If the tensor is in bfloat16 format, convert it to float32
|
||||||
if(tensor.dtype == 'V2'):
|
if(tensor.dtype == 'V2'):
|
||||||
|
@ -2305,7 +2306,8 @@ def spRequest(filename):
|
||||||
tensor = np.uint32(tensor) << 16
|
tensor = np.uint32(tensor) << 16
|
||||||
tensor.dtype = np.float32
|
tensor.dtype = np.float32
|
||||||
|
|
||||||
tensor = np.float16(tensor)
|
if(tensor.dtype != np.float16):
|
||||||
|
tensor = np.float32(tensor)
|
||||||
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
|
assert not np.isinf(tensor).any() and not np.isnan(tensor).any()
|
||||||
|
|
||||||
vars.sp = torch.from_numpy(tensor)
|
vars.sp = torch.from_numpy(tensor)
|
||||||
|
|
Loading…
Reference in New Issue