diff --git a/aiserver.py b/aiserver.py index 8f04cd07..0583c5b8 100644 --- a/aiserver.py +++ b/aiserver.py @@ -630,7 +630,9 @@ model_backends = {} model_backend_module_names = {} model_backend_type_crosswalk = {} -PRIORITIZED_BACKEND_MODULES = ["generic_hf_torch"] +PRIORITIZED_BACKEND_MODULES = { + "generic_hf_torch": 1 +} for module in os.listdir("./modeling/inference_models"): if module == '__pycache__': @@ -666,10 +668,15 @@ for module in os.listdir("./modeling/inference_models"): model_backend_module_names[backend_name] = module if backend_type in model_backend_type_crosswalk: - if module in PRIORITIZED_BACKEND_MODULES: - model_backend_type_crosswalk[backend_type].insert(0, backend_name) - else: - model_backend_type_crosswalk[backend_type].append(backend_name) + model_backend_type_crosswalk[backend_type].append(backend_name) + model_backend_type_crosswalk[backend_type] = list(sorted( + model_backend_type_crosswalk[backend_type], + key=lambda name: PRIORITIZED_BACKEND_MODULES.get( + [mod for b_name, mod in model_backend_module_names.items() if b_name == name][0], + 0 + ), + reverse=True + )) else: model_backend_type_crosswalk[backend_type] = [backend_name]