From 150463e629e1f9f73f5c010a0b24b0f0ff44b0da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=AE=AE=E0=AE=A9=E0=AF=8B=E0=AE=9C=E0=AF=8D=E0=AE=95?= =?UTF-8?q?=E0=AF=81=E0=AE=AE=E0=AE=BE=E0=AE=B0=E0=AF=8D=20=E0=AE=AA?= =?UTF-8?q?=E0=AE=B4=E0=AE=A9=E0=AE=BF=E0=AE=9A=E0=AF=8D=E0=AE=9A=E0=AE=BE?= =?UTF-8?q?=E0=AE=AE=E0=AE=BF?= Date: Sun, 5 Jan 2025 11:58:05 +0530 Subject: [PATCH] feat: Add GPU support (#6042) --- openhands/core/config/sandbox_config.py | 2 ++ openhands/runtime/impl/docker/docker_runtime.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/openhands/core/config/sandbox_config.py b/openhands/core/config/sandbox_config.py index 3c918444c4..a91c8b06a3 100644 --- a/openhands/core/config/sandbox_config.py +++ b/openhands/core/config/sandbox_config.py @@ -34,6 +34,7 @@ class SandboxConfig: platform: The platform on which the image should be built. Default is None. remote_runtime_resource_factor: Factor to scale the resource allocation for remote runtime. Must be one of [1, 2, 4, 8]. Will only be used if the runtime is remote. + enable_gpu: Whether to enable GPU. """ remote_runtime_api_url: str = 'http://localhost:8000' @@ -59,6 +60,7 @@ class SandboxConfig: platform: str | None = None close_delay: int = 15 remote_runtime_resource_factor: int = 1 + enable_gpu: bool = False def defaults_to_dict(self) -> dict: """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" diff --git a/openhands/runtime/impl/docker/docker_runtime.py b/openhands/runtime/impl/docker/docker_runtime.py index 62623d53a2..852e50f617 100644 --- a/openhands/runtime/impl/docker/docker_runtime.py +++ b/openhands/runtime/impl/docker/docker_runtime.py @@ -266,6 +266,14 @@ class DockerRuntime(ActionExecutionClient): detach=True, environment=environment, volumes=volumes, + device_requests=( + [docker.types.DeviceRequest( + capabilities=[['gpu']], + count=-1 + )] + if self.config.sandbox.enable_gpu + else None + ), ) self.log('debug', f'Container started. Server url: {self.api_url}') self.send_status_message('STATUS$CONTAINER_STARTED')