diff --git a/aiserver.py b/aiserver.py index 19bd9be9..e78211ae 100644 --- a/aiserver.py +++ b/aiserver.py @@ -1028,7 +1028,7 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme # Lazy loader import torch_lazy_loader - def get_lazy_load_callback(n_layers): + def get_lazy_load_callback(n_layers, convert_to_float16=True): if not vars.lazy_load: return @@ -1072,7 +1072,10 @@ if(not vars.model in ["InferKit", "Colab", "OAI", "ReadOnly", "TPUMeshTransforme f.seek(model_dict[key].seek_offset - current_offset, 1) device = device_map[key] #print(f"Transferring <{key}> to {'(CPU)' if device == 'cpu' else '[device ' + str(device) + ']'} ... ", end="", flush=True) - model_dict[key] = model_dict[key].materialize(f, map_location=torch.device(device)) + model_dict[key] = model_dict[key].materialize(f, map_location="cpu") + if convert_to_float16 and model_dict[key].dtype is torch.float32: + model_dict[key] = model_dict[key].to(torch.float16) + model_dict[key] = model_dict[key].to(device) #print("OK", flush=True) finally: if isinstance(f, zipfile.ZipExtFile):