Basic backend module prioritization

not secure; we're loading these modules so they can obviously execute
code that manipulates the prioritization
This commit is contained in:
somebody
2023-07-12 19:03:49 -05:00
parent f67cb7fa05
commit 8549c7c896

View File

@@ -628,18 +628,30 @@ import importlib
model_backend_code = {}
model_backends = {}
model_backend_type_crosswalk = {}
PRIORITIZED_BACKEND_MODULES = ["generic_hf_torch"]
for module in os.listdir("./modeling/inference_models"):
if not os.path.isfile(os.path.join("./modeling/inference_models",module)) and module != '__pycache__':
try:
model_backend_code[module] = importlib.import_module('modeling.inference_models.{}.class'.format(module))
model_backends[model_backend_code[module].model_backend_name] = model_backend_code[module].model_backend()
if 'disable' in vars(model_backends[model_backend_code[module].model_backend_name]) and model_backends[model_backend_code[module].model_backend_name].disable:
del model_backends[model_backend_code[module].model_backend_name]
else:
if model_backend_code[module].model_backend_type in model_backend_type_crosswalk:
model_backend_type_crosswalk[model_backend_code[module].model_backend_type].append(model_backend_code[module].model_backend_name)
backend_code = importlib.import_module('modeling.inference_models.{}.class'.format(module))
backend_name = backend_code.model_backend_name
backend_type = backend_code.model_backend_type
backend_object = backend_code.model_backend()
if "disable" in vars(backend_object) and backend_object.disable:
continue
model_backends[backend_name] = backend_object
model_backend_code[module] = backend_code
if backend_type in model_backend_type_crosswalk:
if module in PRIORITIZED_BACKEND_MODULES:
model_backend_type_crosswalk[backend_type].insert(0, backend_name)
else:
model_backend_type_crosswalk[model_backend_code[module].model_backend_type] = [model_backend_code[module].model_backend_name]
model_backend_type_crosswalk[backend_type].append(backend_name)
else:
model_backend_type_crosswalk[backend_type] = [backend_name]
except Exception:
logger.error("Model Backend {} failed to load".format(module))