diff --git a/prompt_tuner.py b/prompt_tuner.py index af9e5443..c6db1bfb 100644 --- a/prompt_tuner.py +++ b/prompt_tuner.py @@ -692,7 +692,7 @@ class TrainerBase(abc.ABC): if breakmodel_gpulayers is None: breakmodel_gpulayers = [] if breakmodel_primary_device is None: - breakmodel_primary_device = 0 if breakmodel_gpulayers else "cpu" + breakmodel_primary_device = 0 if sum(x if x >= 0 else 0 for x in breakmodel_gpulayers) else "cpu" if self.data.params is not None and "max_batch_size" not in self.data.params: self.data.params["max_batch_size"] = 2048 @@ -730,6 +730,8 @@ class TrainerBase(abc.ABC): model_config = self._get_model_config() n_layers = utils.num_layers(model_config) + breakmodel_gpulayers = [x if x >= 0 else n_layers for x in breakmodel_gpulayers] + convert_to_float16 = True hascuda = torch.cuda.is_available() usegpu = hascuda and not breakmodel_disklayers and len(breakmodel_gpulayers) == 1 and breakmodel_gpulayers[0] == n_layers