Skip to content

Commit 0aa57c6

Browse files
authored
feat(sdk): allow users to serve Tesseracts using multiple worker processes (#135)
#### Relevant issue or PR n/a #### Description of changes Adds `num_workers` argument to `$ tesseract-runtime serve`, `Tesseract.from_image`, and `engine.serve`. This allows users to serve using multiple processes, handling requests in parallel (at the expense of higher resource usage). The default is unchanged (1). #### Testing done CI, new test #### License - [x] By submitting this pull request, I confirm that my contribution is made under the terms of the [Apache 2.0 license](https://pasteurlabs.github.io/tesseract/LICENSE). - [x] I sign the Developer Certificate of Origin below by adding my name and email address to the `Signed-off-by` line. <details> <summary><b>Developer Certificate of Origin</b></summary> ```text Developer Certificate of Origin Version 1.1 Copyright (C) 2004, 2006 The Linux Foundation and its contributors. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Developer's Certificate of Origin 1.1 By making a contribution to this project, I certify that: (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it. (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved. ``` </details> Signed-off-by: Dion Häfner <dion.haefner@simulation.science>
1 parent 7513d1d commit 0aa57c6

File tree

5 files changed

+131
-54
lines changed

5 files changed

+131
-54
lines changed

tesseract_core/sdk/engine.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def serve(
496496
volumes: list[str] | None = None,
497497
gpus: list[str] | None = None,
498498
debug: bool = False,
499+
num_workers: int = 1,
499500
) -> str:
500501
"""Serve one or more Tesseract images.
501502
@@ -507,6 +508,7 @@ def serve(
507508
volumes: list of paths to mount in the Tesseract container.
508509
gpus: IDs of host Nvidia GPUs to make available to the Tesseracts.
509510
debug: whether to enable debug mode.
511+
num_workers: number of workers to use for serving the Tesseracts.
510512
511513
Returns:
512514
A string representing the Tesseract Project ID.
@@ -527,7 +529,9 @@ def serve(
527529
f"Number of ports ({len(ports)}) must match number of images ({len(image_ids)})"
528530
)
529531

530-
template = _create_docker_compose_template(image_ids, ports, volumes, gpus, debug)
532+
template = _create_docker_compose_template(
533+
image_ids, ports, volumes, gpus, num_workers, debug
534+
)
531535
compose_fname = _create_compose_fname()
532536

533537
with tempfile.NamedTemporaryFile(
@@ -548,6 +552,7 @@ def _create_docker_compose_template(
548552
ports: list[str] | None = None,
549553
volumes: list[str] | None = None,
550554
gpus: list[str] | None = None,
555+
num_workers: int = 1,
551556
debug: bool = False,
552557
) -> str:
553558
"""Create Docker Compose template."""
@@ -576,7 +581,7 @@ def _create_docker_compose_template(
576581

577582
services.append(service)
578583
template = ENV.get_template("docker-compose.yml")
579-
return template.render(services=services)
584+
return template.render(services=services, num_workers=num_workers)
580585

581586

582587
def _create_compose_service_id(image_id: str) -> str:

tesseract_core/sdk/templates/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ services:
33
{{ service.name }}:
44
image: {{ service.image }}
55
restart: unless-stopped
6-
entrypoint: tesseract-runtime serve
6+
command: ["serve", "--num-workers", "{{ num_workers }}"]
77
ports:
88
- {{ service.port }}
99
{%- if service.volumes %}

tesseract_core/sdk/tesseract.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,12 @@
2424

2525
@dataclass
2626
class SpawnConfig:
27-
"""Configuration for spawning a Tesseract.
28-
29-
Attributes:
30-
image: The image to use.
31-
volumes: List of volumes to mount.
32-
gpus: List of GPUs to use.
33-
debug: Whether to run in debug mode.
34-
"""
27+
"""Configuration for spawning a Tesseract."""
3528

3629
image: str
3730
volumes: list[str] | None
3831
gpus: list[str] | None
32+
num_workers: int
3933
debug: bool
4034

4135

@@ -92,6 +86,7 @@ def from_image(
9286
*,
9387
volumes: list[str] | None = None,
9488
gpus: list[str] | None = None,
89+
num_workers: int = 1,
9590
) -> Tesseract:
9691
"""Create a Tesseract instance from a Docker image.
9792
@@ -108,8 +103,11 @@ def from_image(
108103
109104
Args:
110105
image: The Docker image to use.
111-
volumes: List of volumes to mount.
112-
gpus: List of GPUs to use.
106+
volumes: List of volumes to mount, e.g. ["/path/on/host:/path/in/container"].
107+
gpus: List of GPUs to use, e.g. ["0", "1"]. (default: no GPUs)
108+
num_workers: Number of worker processes to use. This determines how
109+
many requests can be handled in parallel. Higher values
110+
will increase throughput, but also increase resource usage.
113111
114112
Returns:
115113
A Tesseract instance.
@@ -119,6 +117,7 @@ def from_image(
119117
image=image,
120118
volumes=volumes,
121119
gpus=gpus,
120+
num_workers=num_workers,
122121
debug=True,
123122
)
124123
obj._serve_context = None
@@ -222,6 +221,7 @@ def serve(self, port: str | None = None) -> None:
222221
port=port,
223222
volumes=self._spawn_config.volumes,
224223
gpus=self._spawn_config.gpus,
224+
num_workers=self._spawn_config.num_workers,
225225
debug=self._spawn_config.debug,
226226
)
227227
self._serve_context = dict(
@@ -261,14 +261,20 @@ def _serve(
261261
volumes: list[str] | None = None,
262262
gpus: list[str] | None = None,
263263
debug: bool = False,
264+
num_workers: int = 1,
264265
) -> tuple[str, str, int]:
265266
if port is not None:
266267
ports = [port]
267268
else:
268269
ports = None
269270

270271
project_id = engine.serve(
271-
[image], ports=ports, volumes=volumes, gpus=gpus, debug=debug
272+
[image],
273+
ports=ports,
274+
volumes=volumes,
275+
gpus=gpus,
276+
debug=debug,
277+
num_workers=num_workers,
272278
)
273279

274280
command = ["docker", "compose", "-p", project_id, "ps", "--format", "json"]

tests/runtime_tests/test_serve.py

Lines changed: 100 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
import base64
55
import json
66
import os
7+
import platform
78
import subprocess
89
import sys
910
import time
11+
from concurrent.futures import ThreadPoolExecutor
12+
from contextlib import contextmanager
1013
from textwrap import dedent
1114

1215
import numpy as np
@@ -23,6 +26,12 @@
2326
}
2427

2528

29+
def is_wsl():
30+
"""Check if the current environment is WSL."""
31+
kernel = platform.uname().release
32+
return "Microsoft" in kernel or "WSL" in kernel
33+
34+
2635
def array_from_json(json_data):
2736
encoding = json_data["data"]["encoding"]
2837
if encoding == "base64":
@@ -42,6 +51,47 @@ def model_to_json(model):
4251
return json.loads(model.model_dump_json())
4352

4453

54+
@contextmanager
55+
def serve_in_subprocess(api_file, port, num_workers=1, timeout=30.0):
56+
try:
57+
proc = subprocess.Popen(
58+
[
59+
sys.executable,
60+
"-c",
61+
"from tesseract_core.runtime.serve import serve; "
62+
f"serve(host='localhost', port={port}, num_workers={num_workers})",
63+
],
64+
env=dict(os.environ, TESSERACT_API_PATH=api_file),
65+
stdout=subprocess.PIPE,
66+
stderr=subprocess.PIPE,
67+
)
68+
69+
# wait for server to start
70+
while True:
71+
try:
72+
response = requests.get(f"http://localhost:{port}/health")
73+
except requests.exceptions.ConnectionError:
74+
pass
75+
else:
76+
if response.status_code == 200:
77+
break
78+
79+
time.sleep(0.1)
80+
timeout -= 0.1
81+
82+
if timeout < 0:
83+
raise TimeoutError("Server did not start in time")
84+
85+
yield f"http://localhost:{port}"
86+
87+
finally:
88+
proc.terminate()
89+
stdout, stderr = proc.communicate()
90+
print(stdout.decode())
91+
print(stderr.decode())
92+
proc.wait(timeout=5)
93+
94+
4595
@pytest.fixture
4696
def http_client(dummy_tesseract_module):
4797
"""A test HTTP client."""
@@ -157,6 +207,10 @@ def test_get_openapi_schema(http_client):
157207
assert response.json()["paths"]
158208

159209

210+
@pytest.mark.skipif(
211+
is_wsl(),
212+
reason="flaky on Windows",
213+
)
160214
def test_threading_sanity(tmpdir, free_port):
161215
"""Test with a Tesseract that requires to be run in the main thread.
162216
@@ -178,9 +232,6 @@ class OutputSchema(BaseModel):
178232
def apply(input: InputSchema) -> OutputSchema:
179233
assert threading.current_thread() == threading.main_thread()
180234
return OutputSchema()
181-
182-
def abstract_eval(abstract_inputs: dict) -> dict:
183-
pass
184235
"""
185236
)
186237

@@ -191,47 +242,57 @@ def abstract_eval(abstract_inputs: dict) -> dict:
191242

192243
# We can't run the server in the same process because it will use threading under the hood
193244
# so we need to spawn a new process instead
194-
try:
195-
proc = subprocess.Popen(
196-
[
197-
sys.executable,
198-
"-c",
199-
"from tesseract_core.runtime.serve import serve; "
200-
f"serve(host='localhost', port={free_port}, num_workers=1)",
201-
],
202-
env=dict(os.environ, TESSERACT_API_PATH=api_file),
203-
stdout=subprocess.PIPE,
204-
stderr=subprocess.PIPE,
205-
)
245+
with serve_in_subprocess(api_file, free_port) as url:
246+
response = requests.post(f"{url}/apply", json={"inputs": {}})
247+
assert response.status_code == 200, response.text
206248

207-
# wait for server to start
208-
timeout = 30.0
209-
while True:
210-
try:
211-
response = requests.get(f"http://localhost:{free_port}/health")
212-
except requests.exceptions.ConnectionError:
213-
pass
214-
else:
215-
if response.status_code == 200:
216-
break
217249

218-
time.sleep(0.1)
219-
timeout -= 0.1
250+
@pytest.mark.skipif(
251+
is_wsl(),
252+
reason="flaky on Windows",
253+
)
254+
def test_multiple_workers(tmpdir, free_port):
255+
"""Test that the server can be run with multiple worker processes."""
256+
TESSERACT_API = dedent(
257+
"""
258+
import time
259+
import multiprocessing
260+
from pydantic import BaseModel
220261
221-
if timeout < 0:
222-
raise TimeoutError("Server did not start in time")
262+
class InputSchema(BaseModel):
263+
pass
223264
224-
response = requests.post(
225-
f"http://localhost:{free_port}/apply", json={"inputs": {}}
226-
)
227-
assert response.status_code == 200, response.text
265+
class OutputSchema(BaseModel):
266+
pid: int
228267
229-
finally:
230-
proc.terminate()
231-
stdout, stderr = proc.communicate()
232-
print(stdout.decode())
233-
print(stderr.decode())
234-
proc.wait(timeout=5)
268+
def apply(input: InputSchema) -> OutputSchema:
269+
return OutputSchema(pid=multiprocessing.current_process().pid)
270+
"""
271+
)
272+
273+
api_file = tmpdir / "tesseract_api.py"
274+
275+
with open(api_file, "w") as f:
276+
f.write(TESSERACT_API)
277+
278+
with serve_in_subprocess(api_file, free_port, num_workers=2) as url:
279+
# Fire back-to-back requests to the server and check that they are handled
280+
# by different workers (i.e. different PIDs)
281+
post_request = lambda _: requests.post(f"{url}/apply", json={"inputs": {}})
282+
283+
with ThreadPoolExecutor(max_workers=4) as executor:
284+
# Fire a lot of requests in parallel
285+
futures = executor.map(post_request, range(100))
286+
responses = list(futures)
287+
288+
# Check that all responses are 200
289+
for response in responses:
290+
assert response.status_code == 200, response.text
291+
292+
# Check that not all pids are the same
293+
# (i.e. the requests were handled by different workers)
294+
pids = set(response.json()["pid"] for response in responses)
295+
assert len(pids) > 1, "All requests were handled by the same worker"
235296

236297

237298
def test_debug_mode(dummy_tesseract_module, monkeypatch):

tests/sdk_tests/test_tesseract.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ def test_serve_lifecycle(mock_serving, mock_clients):
107107
pass
108108

109109
mock_serving["serve_mock"].assert_called_with(
110-
["sometesseract:0.2.3"], ports=None, volumes=None, gpus=None, debug=True
110+
["sometesseract:0.2.3"],
111+
ports=None,
112+
volumes=None,
113+
gpus=None,
114+
debug=True,
115+
num_workers=1,
111116
)
112117

113118
mock_serving["teardown_mock"].assert_called_with("proj-id-123")

0 commit comments

Comments
 (0)