feat(ml): round-robin device assignment (#13237)

* round-robin device assignment

* docs and tests

clarify doc
This commit is contained in:
Mert
2024-10-07 17:37:45 -04:00
committed by GitHub
parent 063969ca05
commit bd826b0b9b
8 changed files with 62 additions and 7 deletions
+4
View File
@@ -39,6 +39,10 @@ class Settings(BaseSettings):
case_sensitive = False
env_nested_delimiter = "__"
@property
def device_id(self) -> str:
return os.environ.get("MACHINE_LEARNING_DEVICE_ID", "0")
class LogSettings(BaseSettings):
immich_log_level: str = "info"
+4 -2
View File
@@ -86,11 +86,13 @@ class OrtSession:
provider_options = []
for provider in self.providers:
match provider:
case "CPUExecutionProvider" | "CUDAExecutionProvider":
case "CPUExecutionProvider":
options = {"arena_extend_strategy": "kSameAsRequested"}
case "CUDAExecutionProvider":
options = {"arena_extend_strategy": "kSameAsRequested", "device_id": settings.device_id}
case "OpenVINOExecutionProvider":
options = {
"device_type": "GPU",
"device_type": f"GPU.{settings.device_id}",
"precision": "FP32",
"cache_dir": (self.model_path.parent / "openvino").as_posix(),
}
+15 -1
View File
@@ -210,10 +210,24 @@ class TestOrtSession:
session = OrtSession(model_path, providers=["OpenVINOExecutionProvider", "CPUExecutionProvider"])
assert session.provider_options == [
{"device_type": "GPU", "precision": "FP32", "cache_dir": "/cache/ViT-B-32__openai/openvino"},
{"device_type": "GPU.0", "precision": "FP32", "cache_dir": "/cache/ViT-B-32__openai/openvino"},
{"arena_extend_strategy": "kSameAsRequested"},
]
def test_sets_device_id_for_openvino(self) -> None:
os.environ["MACHINE_LEARNING_DEVICE_ID"] = "1"
session = OrtSession("ViT-B-32__openai", providers=["OpenVINOExecutionProvider"])
assert session.provider_options[0]["device_type"] == "GPU.1"
def test_sets_device_id_for_cuda(self) -> None:
os.environ["MACHINE_LEARNING_DEVICE_ID"] = "1"
session = OrtSession("ViT-B-32__openai", providers=["CUDAExecutionProvider"])
assert session.provider_options[0]["device_id"] == "1"
def test_sets_provider_options_kwarg(self) -> None:
session = OrtSession(
"ViT-B-32__openai",