Skip to content

Commit 871f377

Browse files
angela-koxalelax
andauthored
feat: Add Volume class to docker client and --user flag to cli (#241)
#### Description of changes - Add volume class to docker_client and docker_cleanup - Add user flag to docker container run and serve pipelines - Add tests to explore volume permission and user id switching #### Testing done Unit Testing --------- Co-authored-by: Alessandro Angioi <alessandro.angioi@simulation.science>
1 parent 257b8c1 commit 871f377

File tree

7 files changed

+343
-3
lines changed

7 files changed

+343
-3
lines changed

tesseract_core/sdk/cli.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,15 @@ def serve(
510510
),
511511
),
512512
] = None,
513+
user: Annotated[
514+
str | None,
515+
typer.Option(
516+
"--user",
517+
help=(
518+
"User to run the Tesseracts as e.g. '1000' or '1000:1000' (uid:gid)."
519+
),
520+
),
521+
] = None,
513522
) -> None:
514523
"""Serve one or more Tesseract images.
515524
@@ -555,6 +564,7 @@ def serve(
555564
num_workers,
556565
no_compose,
557566
service_names_list,
567+
user,
558568
)
559569
except RuntimeError as ex:
560570
raise UserError(
@@ -830,6 +840,13 @@ def run_container(
830840
),
831841
),
832842
] = None,
843+
user: Annotated[
844+
str | None,
845+
typer.Option(
846+
"--user",
847+
help=("User to run the Tesseract as e.g. '1000' or '1000:1000' (uid:gid)."),
848+
),
849+
] = None,
833850
) -> None:
834851
"""Execute a command in a Tesseract.
835852
@@ -873,7 +890,7 @@ def run_container(
873890

874891
try:
875892
result_out, result_err = engine.run_tesseract(
876-
tesseract_image, cmd, args, volumes=volume, gpus=gpus
893+
tesseract_image, cmd, args, volumes=volume, gpus=gpus, user=user
877894
)
878895

879896
except ImageNotFound as e:

tesseract_core/sdk/docker_client.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,13 +512,15 @@ def run(
512512
ports: dict | None = None,
513513
stdout: bool = True,
514514
stderr: bool = False,
515+
user: str | None = None,
515516
) -> Container | tuple[bytes, bytes] | bytes:
516517
"""Run a command in a container from an image.
517518
518519
Params:
519520
image: The image name or id to run the command in.
520521
command: The command to run in the container.
521522
volumes: A dict of volumes to mount in the container.
523+
user: String of user information to run command as in the format "uid:(optional)gid".
522524
device_requests: A list of device requests for the container.
523525
detach: If True, run the container in detached mode. Detach must be set to
524526
True if we wish to retrieve the container id of the running container,
@@ -556,6 +558,9 @@ def run(
556558
)
557559
optional_args.extend(volume_args)
558560

561+
if user:
562+
optional_args.extend(["-u", user])
563+
559564
if device_requests:
560565
gpus_str = ",".join(device_requests)
561566
optional_args.extend(["--gpus", f'"device={gpus_str}"'])
@@ -774,6 +779,116 @@ def _update_projects(include_stopped: bool = False) -> dict[str, list_[str]]:
774779
return project_container_map
775780

776781

782+
@dataclass
783+
class Volume:
784+
"""Volume class to wrap Docker volumes."""
785+
786+
name: str
787+
attrs: dict
788+
789+
@classmethod
790+
def from_dict(cls, json_dict: dict) -> "Volume":
791+
"""Create an Image object from a json dictionary.
792+
793+
Params:
794+
json_dict: The json dictionary to create the object from.
795+
796+
Returns:
797+
The created volume object.
798+
"""
799+
return cls(
800+
name=json_dict.get("Name", None),
801+
attrs=json_dict,
802+
)
803+
804+
def remove(self, force: bool = False) -> None:
805+
"""Remove a Docker volume.
806+
807+
Params:
808+
force: If True, force the removal of the volume.
809+
"""
810+
docker = _get_executable("docker")
811+
try:
812+
_ = subprocess.run(
813+
[*docker, "volume", "rm", "--force" if force else "", self.name],
814+
check=True,
815+
capture_output=True,
816+
text=True,
817+
)
818+
except subprocess.CalledProcessError as ex:
819+
raise NotFound(f"Error removing volume {self.name}: {ex}") from ex
820+
821+
822+
class Volumes:
823+
"""Volume class to wrap Docker volumes."""
824+
825+
@staticmethod
826+
def create(name: str) -> Volume:
827+
"""Create a Docker volume.
828+
829+
Params:
830+
name: The name of the volume to create.
831+
832+
Returns:
833+
The created volume object.
834+
"""
835+
docker = _get_executable("docker")
836+
try:
837+
_ = subprocess.run(
838+
[*docker, "volume", "create", name],
839+
check=True,
840+
capture_output=True,
841+
text=True,
842+
)
843+
return Volumes.get(name)
844+
except subprocess.CalledProcessError as ex:
845+
raise NotFound(f"Error creating volume {name}: {ex}") from ex
846+
847+
@staticmethod
848+
def get(name: str) -> Volume:
849+
"""Get a Docker volume.
850+
851+
Params:
852+
name: The name of the volume to get.
853+
854+
Returns:
855+
The volume object.
856+
"""
857+
docker = _get_executable("docker")
858+
try:
859+
result = subprocess.run(
860+
[*docker, "volume", "inspect", name],
861+
check=True,
862+
capture_output=True,
863+
text=True,
864+
)
865+
json_dict = json.loads(result.stdout)
866+
except subprocess.CalledProcessError as ex:
867+
raise NotFound(f"Volume {name} not found: {ex}") from ex
868+
if not json_dict:
869+
raise NotFound(f"Volume {name} not found.")
870+
return Volume.from_dict(json_dict[0])
871+
872+
@staticmethod
873+
def list() -> list[str]:
874+
"""List all Docker volumes.
875+
876+
Returns:
877+
List of volume names.
878+
"""
879+
docker = _get_executable("docker")
880+
try:
881+
result = subprocess.run(
882+
[*docker, "volume", "ls", "-q"],
883+
check=True,
884+
capture_output=True,
885+
text=True,
886+
)
887+
except subprocess.CalledProcessError as ex:
888+
raise APIError(f"Error listing volumes: {ex}") from ex
889+
return result.stdout.strip().split("\n")
890+
891+
777892
class DockerException(Exception):
778893
"""Base class for Docker CLI exceptions."""
779894

@@ -847,6 +962,7 @@ def __init__(self) -> None:
847962
self.containers = Containers()
848963
self.images = Images()
849964
self.compose = Compose()
965+
self.volumes = Volumes()
850966

851967
@staticmethod
852968
def info() -> tuple:

tesseract_core/sdk/engine.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ def serve(
549549
num_workers: int = 1,
550550
no_compose: bool = False,
551551
service_names: list[str] | None = None,
552+
user: str | None = None,
552553
) -> str:
553554
"""Serve one or more Tesseract images.
554555
@@ -566,6 +567,7 @@ def serve(
566567
num_workers: number of workers to use for serving the Tesseracts.
567568
no_compose: if True, do not use Docker Compose to serve the Tesseracts.
568569
service_names: list of service names under which to expose each Tesseract container on the shared network.
570+
user: user to run the Tesseracts as, e.g. '1000' or '1000:1000' (uid:gid).
569571
570572
Returns:
571573
A string representing the Tesseract project ID.
@@ -642,6 +644,7 @@ def serve(
642644
ports=port_mappings,
643645
detach=True,
644646
volumes=volumes,
647+
user=user,
645648
)
646649
# wait for server to start
647650
timeout = 30
@@ -671,6 +674,7 @@ def serve(
671674
gpus,
672675
num_workers,
673676
debug=debug,
677+
user=user,
674678
)
675679
compose_fname = f"docker-compose-{_id_generator()}.yml"
676680

@@ -696,6 +700,7 @@ def _create_docker_compose_template(
696700
gpus: list[str] | None = None,
697701
num_workers: int = 1,
698702
debug: bool = False,
703+
user: str | None = None,
699704
) -> str:
700705
"""Create Docker Compose template."""
701706
services = []
@@ -744,6 +749,7 @@ def _create_docker_compose_template(
744749
for i, image_id in enumerate(image_ids):
745750
service = {
746751
"name": service_names[i],
752+
"user": user,
747753
"image": image_id,
748754
"port": f"{ports[i]}:8000",
749755
"volumes": volumes,
@@ -756,8 +762,28 @@ def _create_docker_compose_template(
756762
}
757763

758764
services.append(service)
765+
766+
docker_volumes = {} # Dictionary of volume names mapped to whether or not they already exist
767+
if volumes:
768+
for volume in volumes:
769+
source = volume.split(":")[0]
770+
# Check if source exists to determine if specified volume is a docker volume
771+
if not Path(source).exists():
772+
# Check if volume exists
773+
if not docker_client.volumes.get(source):
774+
if "/" not in source:
775+
docker_volumes[source] = False
776+
else:
777+
raise ValueError(
778+
f"Volume/Path {source} does not already exist, "
779+
"and new volume cannot be created due to '/' in name."
780+
)
781+
else:
782+
# Docker volume is external
783+
docker_volumes[source] = True
784+
759785
template = ENV.get_template("docker-compose.yml")
760-
return template.render(services=services)
786+
return template.render(services=services, docker_volumes=docker_volumes)
761787

762788

763789
def _id_generator(
@@ -817,6 +843,7 @@ def run_tesseract(
817843
volumes: list[str] | None = None,
818844
gpus: list[int | str] | None = None,
819845
ports: dict[str, str] | None = None,
846+
user: str | None = None,
820847
) -> tuple[str, str]:
821848
"""Start a Tesseract and execute a given command.
822849
@@ -828,6 +855,7 @@ def run_tesseract(
828855
gpus: list of GPUs, as indices or names, to passthrough the container.
829856
ports: dictionary of ports to bind to the host. Key is the host port,
830857
value is the container port.
858+
user: user to run the Tesseract as, e.g. '1000' or '1000:1000' (uid:gid).
831859
832860
Returns:
833861
Tuple with the stdout and stderr of the Tesseract.
@@ -896,6 +924,7 @@ def run_tesseract(
896924
detach=False,
897925
remove=True,
898926
stderr=True,
927+
user=user,
899928
)
900929
stdout = stdout.decode("utf-8")
901930
stderr = stderr.decode("utf-8")

tesseract_core/sdk/templates/docker-compose.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ services:
44
image: {{ service.image }}
55
restart: unless-stopped
66
command: ["serve", "--host", "0.0.0.0", "--num-workers", "{{ service.num_workers }}"]
7+
{%- if service.user %}
8+
user: "{{ service.user }}"
9+
{%- endif %}
710
ports:
811
- {{ service.port }}
912
{% if service.environment.TESSERACT_DEBUG == "1" %}
@@ -39,5 +42,15 @@ services:
3942
{% endif %}
4043
{% endfor %}
4144

45+
{%- if docker_volumes %}
46+
volumes:
47+
{%- for name, external in docker_volumes.items() %}
48+
{{ name }}:
49+
{%- if external %}
50+
external: true
51+
{% endif %}
52+
{% endfor %}
53+
{% endif %}
54+
4255
networks:
4356
multi-tesseract-network:

tests/conftest.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,20 @@ def docker_client():
207207
return docker_client_module.CLIDockerClient()
208208

209209

210+
@pytest.fixture
211+
def docker_volume(docker_client):
212+
# Create the Docker volume
213+
volume = docker_client.volumes.create(name="docker_client_test_volume")
214+
try:
215+
yield volume
216+
finally:
217+
try:
218+
volume.remove()
219+
except Exception:
220+
# already removed
221+
pass
222+
223+
210224
@pytest.fixture(scope="module")
211225
def docker_cleanup_module(docker_client, request):
212226
"""Clean up all tesseracts created by the tests after the module exits."""
@@ -222,7 +236,7 @@ def docker_cleanup(docker_client, request):
222236
def _docker_cleanup(docker_client, request):
223237
"""Clean up all tesseracts created by the tests."""
224238
# Shared object to track what objects need to be cleaned up in each test
225-
context = {"images": [], "project_ids": [], "containers": []}
239+
context = {"images": [], "project_ids": [], "containers": [], "volumes": []}
226240

227241
def pprint_exc(e: BaseException) -> str:
228242
"""Pretty print exception."""
@@ -268,6 +282,18 @@ def cleanup_func():
268282
except Exception as e:
269283
failures.append(f"Failed to remove image {image}: {pprint_exc(e)}")
270284

285+
# Remove volumes
286+
for volume in context["volumes"]:
287+
try:
288+
if isinstance(volume, str):
289+
volume_obj = docker_client.volumes.get(volume)
290+
else:
291+
volume_obj = volume
292+
293+
volume_obj.remove(force=True)
294+
except Exception as e:
295+
failures.append(f"Failed to remove volume {volume}: {pprint_exc(e)}")
296+
271297
if failures:
272298
raise RuntimeError(
273299
"Failed to clean up some Docker objects during test teardown:\n"

0 commit comments

Comments
 (0)