diff --git a/modeling/ipex/hijacks.py b/modeling/ipex/hijacks.py index 9645b03b..c50547bb 100644 --- a/modeling/ipex/hijacks.py +++ b/modeling/ipex/hijacks.py @@ -33,7 +33,7 @@ def check_device(device): return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) def return_xpu(device): - return str(f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu") + return f"xpu:{device[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" def ipex_no_cuda(orig_func, *args, **kwargs): torch.cuda.is_available = lambda: False