diff --git a/modeling/inference_models/hf.py b/modeling/inference_models/hf.py index f795395e..eae4bb2d 100644 --- a/modeling/inference_models/hf.py +++ b/modeling/inference_models/hf.py @@ -25,15 +25,22 @@ class HFInferenceModel(InferenceModel): """ if self.model_name in ["NeoCustom", "GPT2Custom", "TPUMeshTransformerGPTJ", "TPUMeshTransformerGPTNeoX"]: - assert utils.koboldai_vars.custmodpth + model_path = utils.koboldai_vars.custmodpth + assert model_path + + # Path can be absolute or relative to models directory + if os.path.exists(model_path): + return model_path + + model_path = os.path.join("models", model_path) try: - assert os.path.exists(utils.koboldai_vars.custmodpth) + assert os.path.exists(model_path) except AssertionError: - logger.error(f"Custom model at '{utils.koboldai_vars.custmodpth}' doesn't seem to exist") + logger.error(f"Custom model does not exist at '{utils.koboldai_vars.custmodpth}' or '{model_path}'.") raise - return utils.koboldai_vars.custmodpth + return model_path basename = utils.koboldai_vars.model.replace("/", "_") if legacy: