diff --git a/aiserver.py b/aiserver.py index e6f9d798..0bbe17e4 100644 --- a/aiserver.py +++ b/aiserver.py @@ -207,13 +207,16 @@ def device_config(model): if(args.layers is not None): try: breakmodel.gpu_blocks = list(map(int, args.layers.split(','))) - assert len(gpu_blocks) <= torch.cuda.device_count() - assert sum(gpu_blocks) <= n_layers + assert len(breakmodel.gpu_blocks) <= torch.cuda.device_count() + assert sum(breakmodel.gpu_blocks) <= n_layers + n_layers -= sum(breakmodel.gpu_blocks) except: print("WARNING: --layers is malformatted. Please use the --help option to see correct usage of --layers. Defaulting to all layers on device 0.", file=sys.stderr) breakmodel.gpu_blocks = [n_layers] + n_layers = 0 elif(args.breakmodel_layers is not None): breakmodel.gpu_blocks = [n_layers - max(0, min(n_layers, args.breakmodel_layers))] + n_layers -= sum(breakmodel.gpu_blocks) else: device_count = torch.cuda.device_count() if(device_count > 1):