[Draft] [ATOM_MESH] PD disaggregation router with multi-node support#502
[Draft] [ATOM_MESH] PD disaggregation router with multi-node support#502
Conversation
…-node 1P1D demo - Move eval_gsm8k_standalone.py to evaluation/common/eval_gsm8k.py as shared evaluator - Update single-node and multi-node scripts to reference common/ path - Add vLLM 1P1D multi-node demo scripts (DeepSeek-R1, 2x8 GPUs, Mooncake RDMA)
…SMG build Replace Dockerfile_OOT_vLLM and build_OOT_vLLM.sh with simplified mesh build that assumes base image (rocm/atom-dev) already contains vLLM. Focuses on: - ATOM source install from Jasen2201/ATOM pd_distributed branch - RDMA rdma-core v39 library overlay for Broadcom bnxt_re ABI compat - Mooncake TransferEngine (HIP build + Python wheel) - SMG (sgl-model-gateway) Rust binary Also hold rccl/rccl-dev/rocm-hip before Mooncake dependencies.sh to prevent apt version conflicts with base image's custom rccl build.
Add step [5/5] to install sglang from source with ROCm sgl-kernel. Uses backup/restore strategy to preserve ROCm torch after pip install. Verified with GSM8K eval (94% accuracy) on Qwen3-235B-A22B-FP8.
Add evaluation scripts for SGLang prefill-decode disaggregation using ATOM OOT plugin on Qwen3-235B-A22B-FP8 with TP=4/EP=4 per role. - 1_start_prefill.sh: prefill server on GPU 0-3 - 2_start_decode.sh: decode server on GPU 4-7 - 3_start_proxy_smg.sh: SMG PD proxy with --backend sglang - 4_eval_gsm8k.sh: GSM8K evaluation via proxy - 5_start_standalone.sh: standalone baseline (no PD) Also update eval_gsm8k.py to support --dataset-path for offline use, and add gsm8k_test_50.jsonl for environments without network access. Tested: 85.4% accuracy (41/48) via PD proxy, ~246 tokens/s.
…river configs - build_mesh.sh: auto-detect host libibverbs/librdmacm version instead of hardcoding v39; copy all provider plugins with symlink dereferencing; add --ulimit nofile=65535:65535 for cargo build - Dockerfile_mesh: use wildcard copy for RDMA libs; auto-generate /etc/libibverbs.d/*.driver configs from provider .so files; switch default MOONCAKE_REPO to Jasen2201/Mooncake (1GB buffer patch for ionic) - Add RDMA test scripts: loopback/cross-node connectivity tests and multi-device parallel bandwidth benchmark using Mooncake TransferEngine
…imeout - Add sglang_atomoot_deepseek_r1_standalone/ scripts for DeepSeek-R1 671B MoE serving with ATOM OOT plugin (TP=8/EP=1 and TP=4/EP=1) - Add apache-tvm-ffi install to Dockerfile_mesh (required by SGLang JIT kernels for overlap scheduling) - Add --timeout CLI arg to eval_gsm8k.py for long-running inference - Disable fp8 prefill attention (SGLANG_AITER_FP8_PREFILL_ATTN=0) as aiter lacks fp8 prefill ASM kernel for gqa=16
…mark - Add sglang_atomoot_1p1d_multi_node/ with prefill/decode/proxy/eval/bench scripts for 2-node PD disaggregation (node09 prefill, node07 decode) - Include patch_sglang_ib.py for per-TP-rank ionic device binding via JSON map (required for multi-subnet Pensando RDMA topology) - Add ib_device_map.json for TP rank -> ionic NIC mapping - Add standalone 3_bench_serving.sh for DeepSeek-R1 performance benchmark - Update standalone 1_start_server.sh: bind 0.0.0.0, add chunked prefill and max running requests config, use cuda-graph-bs range
… compat - build_mesh.sh: default SGLANG_BRANCH from main to v0.5.9 - patch_sglang_ib.py: rewrite as anchor-based insertion patch, idempotent, compatible with both v0.5.9 and main/latest
…bench - Rewrite 0_setup_docker.sh as simple local docker run (no SSH orchestration) - Replace InferenceMAX benchmark with sglang.bench_serving in 5_bench_serving.sh - Remove patch_sglang_ib.py and ib_device_map.json (no longer needed) - Migrate node config from g32/g17 to g05/g07 (10.28.104.181/183) - Add docker cp sync step in run_pd_test.sh (no NFS dependency) - Rewrite README for manual workflow without patching steps
…CTOR, drop --attention-backend)
Update summary tables to reflect 9 revised decisions (40 keep, 21 delete, 10 discuss). Add per-feature opinions with rationale for each item.
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds multi-node PD (prefill/decode) disaggregation support and ships new Python + Go bindings/CLI and tooling to launch, test, and benchmark the router and FFI layer.
Changes:
- Introduce
sglang-routerPython package (CLI entrypoints, router wrapper, launchers) with unit-test scaffolding and packaging metadata. - Add Rust FFI surface for Go (pre/post-processing, streaming, tool parsing) plus Go bindings, examples, and an OpenAI-compatible sample server.
- Add Mesh benchmarks and developer tooling (Makefiles, pre-commit, cargo config, docs/scripts).
Reviewed changes
Copilot reviewed 81 out of 431 changed files in this pull request and generated 15 comments.
Show a summary per file
| File | Description |
|---|---|
| mesh/bindings/python/tests/conftest.py | Adds pytest marker configuration |
| mesh/bindings/python/src/sglang_router/version.py | Defines Python package version |
| mesh/bindings/python/src/sglang_router/router.py | Python Router wrapper + config conversion |
| mesh/bindings/python/src/sglang_router/launch_server.py | Launcher for router + DP servers |
| mesh/bindings/python/src/sglang_router/launch_router.py | Router launcher + CLI arg parsing |
| mesh/bindings/python/src/sglang_router/cli.py | Top-level CLI (smg) dispatcher |
| mesh/bindings/python/src/sglang_router/main.py | Enables python -m sglang_router |
| mesh/bindings/python/src/sglang_router/init.py | Exposes __version__ |
| mesh/bindings/python/setup.py | Optional setuptools-rust build hook |
| mesh/bindings/python/pyproject.toml | Python packaging metadata + deps |
| mesh/bindings/python/README.md | Python bindings documentation |
| mesh/bindings/python/MANIFEST.in | sdist file inclusion rules |
| mesh/bindings/python/Cargo.toml | PyO3 bindings crate metadata |
| mesh/bindings/python/.coveragerc | Python coverage config |
| mesh/bindings/golang/src/utils.rs | FFI utility + placeholder constraints API |
| mesh/bindings/golang/src/tool_parser.rs | Tool parser FFI implementation |
| mesh/bindings/golang/src/stream.rs | Streaming FFI read/free support |
| mesh/bindings/golang/src/preprocessor.rs | FFI preprocessing (template/tokenize/tools) |
| mesh/bindings/golang/src/memory.rs | FFI memory free helpers |
| mesh/bindings/golang/src/lib.rs | Re-exports for Go/C consumers |
| mesh/bindings/golang/src/error.rs | FFI error codes + message helpers |
| mesh/bindings/golang/src/client.rs | FFI gRPC client + stream creation |
| mesh/bindings/golang/internal/ffi/preprocessor.go | Go wrapper for preprocessing FFI |
| mesh/bindings/golang/internal/ffi/postprocessor.go | Go wrapper for postprocessing FFI |
| mesh/bindings/golang/internal/ffi/grpc_converter.go | Go wrapper for response converter FFI |
| mesh/bindings/golang/internal/ffi/client.go | Go wrapper for client/stream FFI |
| mesh/bindings/golang/internal/ffi/batch_postprocessor.go | Go-side batching to reduce FFI calls |
| mesh/bindings/golang/integration_test.go | Integration tests (tagged) |
| mesh/bindings/golang/examples/streaming/run.sh | Streaming example runner script |
| mesh/bindings/golang/examples/streaming/main.go | Streaming example program |
| mesh/bindings/golang/examples/simple/run.sh | Simple example runner script |
| mesh/bindings/golang/examples/simple/main.go | Simple example program |
| mesh/bindings/golang/examples/oai_server/utils/utils.go | OpenAI server utility helpers |
| mesh/bindings/golang/examples/oai_server/service/sglang.go | Service wrapper around Go client |
| mesh/bindings/golang/examples/oai_server/scripts/profile_tpot.sh | TPOT profiling script |
| mesh/bindings/golang/examples/oai_server/scripts/pprof_test.sh | Load generation script |
| mesh/bindings/golang/examples/oai_server/scripts/pprof_quick.sh | Quick pprof collection script |
| mesh/bindings/golang/examples/oai_server/scripts/pprof_analysis.sh | Full pprof analysis script |
| mesh/bindings/golang/examples/oai_server/run.sh | OpenAI server runner script |
| mesh/bindings/golang/examples/oai_server/models/chat.go | OpenAI request model |
| mesh/bindings/golang/examples/oai_server/main.go | OpenAI-compatible HTTP server |
| mesh/bindings/golang/examples/oai_server/logger/logger.go | Zap logger initialization |
| mesh/bindings/golang/examples/oai_server/handlers/models.go | /v1/models + model info handlers |
| mesh/bindings/golang/examples/oai_server/handlers/health.go | /health handler |
| mesh/bindings/golang/examples/oai_server/docs/benchmark_result.md | Benchmark results documentation |
| mesh/bindings/golang/examples/oai_server/config/config.go | Env-driven config loader |
| mesh/bindings/golang/examples/oai_server/README.md | OpenAI server documentation |
| mesh/bindings/golang/examples/oai_server/Makefile | Build/run/e2e targets |
| mesh/bindings/golang/client_test.go | Go unit tests / benchmarks |
| mesh/bindings/golang/Makefile | Build/export Rust FFI + Go targets |
| mesh/bindings/golang/Cargo.toml | Golang FFI crate metadata |
| mesh/bindings/golang/.gitignore | Ignores build artifacts |
| mesh/benches/wasm_middleware_latency.rs | Criterion bench for wasm middleware |
| mesh/benches/router_registry_bench.rs | Bench registry optimizations |
| mesh/benches/manual_policy_benchmark.rs | Bench manual policy performance |
| mesh/benches/consistent_hash_bench.rs | Bench consistent hash ring |
| mesh/Makefile | Dev/build/release helper targets |
| mesh/LICENSE | Adds Apache 2.0 license text |
| mesh/Cargo.toml | Crate deps + benches registration |
| mesh/.pre-commit-config.yaml | Pre-commit hooks configuration |
| mesh/.isort.cfg | isort config |
| mesh/.gitignore | Repo ignore rules |
| mesh/.codespellrc | codespell config |
| mesh/.claude/commands/pd-test.md | PD test runbook/command doc |
| mesh/.cargo/config.toml | Cargo build config (macOS flags etc.) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def policy_from_str(policy_str: Optional[str]) -> PolicyType: | ||
| """Convert policy string to PolicyType enum.""" | ||
| if policy_str is None: | ||
| return None | ||
| policy_map = { | ||
| "random": PolicyType.Random, | ||
| "round_robin": PolicyType.RoundRobin, | ||
| "cache_aware": PolicyType.CacheAware, | ||
| "power_of_two": PolicyType.PowerOfTwo, | ||
| "bucket": PolicyType.Bucket, | ||
| "manual": PolicyType.Manual, | ||
| "consistent_hashing": PolicyType.ConsistentHashing, | ||
| "prefix_hash": PolicyType.PrefixHash, | ||
| } | ||
| return policy_map[policy_str] |
There was a problem hiding this comment.
policy_from_str will raise a KeyError for unknown policies and is also case-sensitive, which is inconsistent with backend_from_str (which normalizes + raises ValueError). Consider lowercasing policy_str, validating membership, and raising ValueError with valid options (and update the return annotation if None is a valid return).
| # Build API key entries | ||
| py_api_keys = [] | ||
| for key_tuple in api_keys: | ||
| # Tuple format: (id, name, key, role) |
There was a problem hiding this comment.
This destructuring will raise a ValueError if any entry in control_plane_api_keys doesn't have exactly 4 elements, and the resulting error won’t be very actionable. Consider validating tuple length and raising a clear ValueError describing the expected format (id, name, key, role) and the offending entry.
| # Tuple format: (id, name, key, role) | |
| # Tuple format: (id, name, key, role) | |
| if len(key_tuple) != 4: | |
| raise ValueError( | |
| "Each entry in control_plane_api_keys must have format " | |
| "(id, name, key, role); offending entry: " | |
| f"{key_tuple!r}" | |
| ) |
| def launch_router(args: argparse.Namespace) -> Optional[Router]: | ||
| """ | ||
| Launch the SGLang router with the configuration from parsed arguments. | ||
|
|
||
| Args: | ||
| args: Namespace object containing router configuration | ||
| Can be either raw argparse.Namespace or converted RouterArgs | ||
|
|
||
| Returns: | ||
| Router instance if successful, None if failed | ||
| """ | ||
| setproctitle.setproctitle("sglang::router") | ||
| try: | ||
| # Convert to RouterArgs if needed | ||
| if not isinstance(args, RouterArgs): | ||
| router_args = RouterArgs.from_cli_args(args) | ||
| else: | ||
| router_args = args | ||
|
|
||
| if router_args.mini_lb: | ||
| mini_lb = MiniLoadBalancer(router_args) | ||
| mini_lb.start() | ||
| else: | ||
| if Router is None: | ||
| raise RuntimeError("Rust Router is not installed") | ||
| router_args._validate_router_args() | ||
| router = Router.from_args(router_args) | ||
| router.start() | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Error starting router: {e}") | ||
| raise e |
There was a problem hiding this comment.
launch_router is annotated/documented to return Optional[Router] but it never returns anything (always returns None implicitly). Also, raise e discards the original traceback; use bare raise after logging. Proposed fix: return router when starting the Rust router and return None when starting mini_lb, or adjust the signature/docstring to match the actual behavior.
|
|
||
| def main(): | ||
| # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes | ||
| mp.set_start_method("spawn") |
There was a problem hiding this comment.
mp.set_start_method("spawn") will raise RuntimeError: context has already been set if main() is invoked in a process where the start method was already configured (e.g., in some test runners or embedded scenarios). Consider wrapping in try/except RuntimeError or using mp.set_start_method("spawn", force=True) if that’s acceptable for this entrypoint.
| mp.set_start_method("spawn") | |
| try: | |
| mp.set_start_method("spawn") | |
| except RuntimeError: | |
| if mp.get_start_method(allow_none=True) != "spawn": | |
| raise |
| def find_available_ports(base_port: int, count: int) -> List[int]: | ||
| """Find consecutive available ports starting from base_port.""" | ||
| available_ports = [] | ||
| current_port = base_port | ||
|
|
||
| while len(available_ports) < count: | ||
| if is_port_available(current_port): | ||
| available_ports.append(current_port) | ||
| current_port += random.randint(100, 1000) |
There was a problem hiding this comment.
The docstring says "Find consecutive available ports" but the implementation intentionally jumps by a random amount (random.randint(100, 1000)), producing non-consecutive ports. Update the docstring (or the logic) so behavior and documentation match—this also affects predictability for firewall rules and port-forwarding setups.
| let handle_ref = &*handle; | ||
| let parser = Arc::clone(&handle_ref.parser); | ||
| let model = handle_ref.model.clone(); |
There was a problem hiding this comment.
history_tool_calls_count is read to influence tool-call ID generation, but it’s never updated after producing tool calls. That means multiple parses on the same handle can reuse IDs (or generate IDs with an incorrect offset). Update history_tool_calls_count after parse_complete (e.g., += tool_calls.len()) and after incremental completion when new calls are finalized.
| let handle_ref = &*handle; | |
| let parser = Arc::clone(&handle_ref.parser); | |
| let model = handle_ref.model.clone(); | |
| let handle_ref = &mut *handle; | |
| let parser = Arc::clone(&handle_ref.parser); | |
| let model = handle_ref.model.clone(); | |
| handle_ref.history_tool_calls_count = handle_ref | |
| .history_tool_calls_count | |
| .max(handle_ref.tool_index_to_id.len()); |
mesh/bindings/golang/src/client.rs
Outdated
| let converter = sgl_grpc_response_converter_create( | ||
| tokenizer_handle, | ||
| CString::new(chat_request.model.clone()).unwrap().as_ptr(), | ||
| CString::new(request_id.clone()).unwrap().as_ptr(), | ||
| tools_json.unwrap_or(ptr::null_mut()), |
There was a problem hiding this comment.
This passes pointers derived from temporary CStrings created inline. Those pointers are only guaranteed valid for the duration of the statement; if the FFI layer stores these pointers instead of copying them, it can become a use-after-free. Prefer binding these CStrings to local variables (so their lifetimes clearly cover the entire FFI call) and pass as_ptr() from those locals.
| // Convert stop_token_ids to JSON string | ||
| stopTokenIDsJSON := "" | ||
| if len(stopTokenIDs) > 0 { | ||
| stopTokenIDsJSON = fmt.Sprintf("[%d", stopTokenIDs[0]) | ||
| for i := 1; i < len(stopTokenIDs); i++ { | ||
| stopTokenIDsJSON += fmt.Sprintf(",%d", stopTokenIDs[i]) | ||
| } | ||
| stopTokenIDsJSON += "]" | ||
| } |
There was a problem hiding this comment.
Manually building JSON via string concatenation is easy to get wrong and is relatively inefficient due to repeated allocations. Use json.Marshal(stopTokenIDs) (or a strings.Builder) to produce a correct JSON array string with simpler code and fewer allocation surprises.
| def from_args(args: RouterArgs) -> "Router": | ||
| """Create a router from a RouterArgs instance.""" | ||
|
|
||
| args_dict = vars(args) | ||
| # Convert RouterArgs to _Router parameters | ||
| args_dict["worker_urls"] = ( | ||
| [] | ||
| if args_dict["service_discovery"] or args_dict["pd_disaggregation"] | ||
| else args_dict["worker_urls"] | ||
| ) | ||
| args_dict["policy"] = policy_from_str(args_dict["policy"]) | ||
| args_dict["prefill_urls"] = ( | ||
| args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None | ||
| ) | ||
| args_dict["decode_urls"] = ( | ||
| args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None | ||
| ) | ||
| args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"]) | ||
| args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"]) | ||
|
|
||
| # Convert backend | ||
| args_dict["backend"] = backend_from_str(args_dict.get("backend")) | ||
|
|
There was a problem hiding this comment.
from_args contains substantial argument massaging (PD mode URL handling, backend/policy parsing, history backend config objects, auth config conversion, and field stripping). Given the amount of logic and potential edge cases, it would be good to add focused unit tests covering: unknown policy/backend strings, PD-disaggregation vs service discovery worker_urls behavior, and each history backend config path (Memory/None/Oracle/Postgres/Redis) to prevent regressions.
| def from_args(args: RouterArgs) -> "Router": | |
| """Create a router from a RouterArgs instance.""" | |
| args_dict = vars(args) | |
| # Convert RouterArgs to _Router parameters | |
| args_dict["worker_urls"] = ( | |
| [] | |
| if args_dict["service_discovery"] or args_dict["pd_disaggregation"] | |
| else args_dict["worker_urls"] | |
| ) | |
| args_dict["policy"] = policy_from_str(args_dict["policy"]) | |
| args_dict["prefill_urls"] = ( | |
| args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None | |
| ) | |
| args_dict["decode_urls"] = ( | |
| args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None | |
| ) | |
| args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"]) | |
| args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"]) | |
| # Convert backend | |
| args_dict["backend"] = backend_from_str(args_dict.get("backend")) | |
| def _normalize_worker_urls(args_dict: dict) -> None: | |
| """Normalize worker URLs for service discovery and PD-disaggregation modes.""" | |
| args_dict["worker_urls"] = ( | |
| [] | |
| if args_dict["service_discovery"] or args_dict["pd_disaggregation"] | |
| else args_dict["worker_urls"] | |
| ) | |
| @staticmethod | |
| def _normalize_pd_urls(args_dict: dict) -> None: | |
| """Keep PD URLs only when PD-disaggregation is enabled.""" | |
| args_dict["prefill_urls"] = ( | |
| args_dict["prefill_urls"] if args_dict["pd_disaggregation"] else None | |
| ) | |
| args_dict["decode_urls"] = ( | |
| args_dict["decode_urls"] if args_dict["pd_disaggregation"] else None | |
| ) | |
| @staticmethod | |
| def _normalize_policies(args_dict: dict) -> None: | |
| """Convert policy strings into PolicyType values.""" | |
| args_dict["policy"] = policy_from_str(args_dict["policy"]) | |
| args_dict["prefill_policy"] = policy_from_str(args_dict["prefill_policy"]) | |
| args_dict["decode_policy"] = policy_from_str(args_dict["decode_policy"]) | |
| @staticmethod | |
| def _normalize_backend(args_dict: dict) -> None: | |
| """Convert backend string into BackendType value.""" | |
| args_dict["backend"] = backend_from_str(args_dict.get("backend")) | |
| @staticmethod | |
| def from_args(args: RouterArgs) -> "Router": | |
| """Create a router from a RouterArgs instance.""" | |
| args_dict = vars(args) | |
| # Convert RouterArgs to _Router parameters | |
| Router._normalize_worker_urls(args_dict) | |
| Router._normalize_policies(args_dict) | |
| Router._normalize_pd_urls(args_dict) | |
| # Convert backend | |
| Router._normalize_backend(args_dict) |
| for field in fields_to_remove: | ||
| args_dict.pop(field, None) | ||
|
|
||
| return Router(_Router(**args_dict)) |
There was a problem hiding this comment.
from_args contains substantial argument massaging (PD mode URL handling, backend/policy parsing, history backend config objects, auth config conversion, and field stripping). Given the amount of logic and potential edge cases, it would be good to add focused unit tests covering: unknown policy/backend strings, PD-disaggregation vs service discovery worker_urls behavior, and each history backend config path (Memory/None/Oracle/Postgres/Redis) to prevent regressions.
Standalone gRPC SDK for Go clients, not needed for PD disaggregation.
…3.7) Non-core routing strategies not needed for PD disaggregation. CacheAware + PrefixHash cover all cache-aware routing needs.
Remove manual worker targeting via x-smg-target-worker header and routing key tracking via x-smg-routing-key header. Simplify WorkerLoadGuard to remove routing key state management.
Remove the legacy text completion endpoint, route_completion trait method, and all implementations across HTTP router, PD router, and router manager. Remove associated test infrastructure and mock handlers.
Remove rerank endpoints, route_rerank trait method, all router implementations, mock handlers, and rerank spec tests.
Remove classify endpoint, route_classify trait method, all router implementations, and gRPC classify pipeline initialization.
Remove CORS layer, --cors-allowed-origins CLI arg, config field, builder methods, and tower-http cors feature.
Remove server-level auth middleware that validates Bearer tokens on incoming requests. Keep api_key config field for worker connection authentication (different use case).
Remove unused QueueMetrics struct and unused `header` import from middleware.rs. Add docs/middleware_flow.md documenting the full request lifecycle through all middleware layers.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 87 out of 292 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mesh/e2e_test/infra/gpu_monitor.py
Outdated
| """GPU utilization monitoring for benchmarks. | ||
|
|
||
| This module provides a low-impact GPU monitor that runs in a separate process | ||
| and collects utilization samples using NVML. | ||
| """ |
There was a problem hiding this comment.
The monitor is NVML-based (pynvml), which typically won’t work on AMD ROCm-only systems (the PR description and report focus on MI355X). As written, GPU utilization thresholds can silently degrade into 'no samples' behavior. Consider either (a) adding a ROCm path (e.g., rocm-smi/amdsmi bindings) or (b) failing fast when thresholds are requested but NVML is unavailable, so benchmark gating is reliable.
mesh/e2e_test/infra/gpu_monitor.py
Outdated
| try: | ||
| import pynvml | ||
|
|
||
| pynvml.nvmlInit() | ||
| except Exception as e: | ||
| logger.warning("Failed to initialize NVML: %s", e) | ||
| _write_empty_result(output_path) | ||
| return |
There was a problem hiding this comment.
The monitor is NVML-based (pynvml), which typically won’t work on AMD ROCm-only systems (the PR description and report focus on MI355X). As written, GPU utilization thresholds can silently degrade into 'no samples' behavior. Consider either (a) adding a ROCm path (e.g., rocm-smi/amdsmi bindings) or (b) failing fast when thresholds are requested but NVML is unavailable, so benchmark gating is reliable.
| "GIT_BRANCH", | ||
| git_branch().unwrap_or_else(|| "unknown".into()) | ||
| ); | ||
| set_env!( | ||
| "GIT_COMMIT", | ||
| git_commit().unwrap_or_else(|| "unknown".into()) | ||
| ); | ||
| set_env!( | ||
| "GIT_STATUS", | ||
| git_status().unwrap_or_else(|| "unknown".into()) | ||
| ); |
There was a problem hiding this comment.
The build script embeds git metadata but only invalidates on Cargo.toml changes. This can produce stale GIT_BRANCH/GIT_COMMIT/GIT_STATUS values unless something else triggers a rebuild. Consider adding additional rebuild triggers (e.g., cargo:rerun-if-changed=.git/HEAD, cargo:rerun-if-changed=.git/index when available) and/or an opt-in env flag to avoid surprising 'unknown' or stale values in release builds.
| echo "[launch] Starting Decode server (TP=${TP_SIZE}, attention=${ATTENTION_BACKEND})..." | ||
| TORCHINDUCTOR_COMPILE_THREADS=128 python3 -m sglang.launch_server \ | ||
| --model-path "${MODEL}" \ | ||
| --host 0.0.0.0 \ | ||
| --port "${DECODE_PORT}" \ | ||
| --trust-remote-code \ | ||
| --tp-size "${TP_SIZE}" \ | ||
| --kv-cache-dtype "${KV_CACHE_DTYPE}" \ | ||
| --mem-fraction-static "${MEM_FRACTION}" \ | ||
| --page-size "${PAGE_SIZE}" \ | ||
| --max-running-requests "${MAX_RUNNING_REQUESTS}" \ | ||
| --cuda-graph-bs $(seq ${CUDA_GRAPH_BS_START} ${CUDA_GRAPH_BS_END}) \ | ||
| --disable-radix-cache \ | ||
| --log-level "${LOG_LEVEL}" \ | ||
| --watchdog-timeout "${WATCHDOG_TIMEOUT}" \ | ||
| --disaggregation-mode decode \ | ||
| --disaggregation-transfer-backend "${TRANSFER_BACKEND}" \ | ||
| --disaggregation-bootstrap-port "${BOOTSTRAP_PORT}" \ | ||
| --disaggregation-ib-device "${IB_DEVICE}" \ | ||
| 2>&1 | tee "${LOG_DIR}/${LOG_FILE}" |
There was a problem hiding this comment.
Unlike the prefill script (which passes --attention-backend), the decode script prints ATTENTION_BACKEND but never passes it to sglang.launch_server. If decode-side attention backend selection matters for correctness/perf parity, add --attention-backend \"${ATTENTION_BACKEND}\" (or drop the variable/echo to avoid implying it is applied).
Replace ASCII art with Mermaid sequence/graph diagrams for GitHub rendering. Includes request lifecycle, token lifecycle, and layer order diagrams.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 90 out of 292 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| use std::process::Command; | ||
|
|
||
| const DEFAULT_VERSION: &str = "0.0.0"; | ||
| const DEFAULT_PROJECT_NAME: &str = "sgl-model-gateway"; |
There was a problem hiding this comment.
The build metadata hardcodes PROJECT_NAME to "sgl-model-gateway" even though the crate/package is now atom-mesh and docs refer to “ATOM Mesh”. If this env var is user-visible (e.g., --version-verbose output), it will be misleading. Consider updating DEFAULT_PROJECT_NAME to atom-mesh (or deriving it from CARGO_PKG_NAME) to keep branding/version output consistent.
| let target = std::env::var("TARGET").unwrap_or_else(|_| get_rustc_host().unwrap_or_default()); | ||
| let profile = std::env::var("PROFILE").unwrap_or_default(); | ||
|
|
||
| set_env!("PROJECT_NAME", DEFAULT_PROJECT_NAME); |
There was a problem hiding this comment.
The build metadata hardcodes PROJECT_NAME to "sgl-model-gateway" even though the crate/package is now atom-mesh and docs refer to “ATOM Mesh”. If this env var is user-visible (e.g., --version-verbose output), it will be misleading. Consider updating DEFAULT_PROJECT_NAME to atom-mesh (or deriving it from CARGO_PKG_NAME) to keep branding/version output consistent.
Two new documents for the entrypoints/mesh/worker refactor: - mesh_feature_inventory.md: complete inventory of all mesh/src/ features - three_layer_architecture.md: three-layer design with Mermaid diagrams
…riant Remove 13 compiler warnings: unused imports across 6 files, dead RequestType::Responses variant and its scaffolding (for_responses, responses_request_arc, match arms), unused streaming methods (finalize, emit_reasoning_item, send_event_best_effort, emit_error), Reasoning variant, validate_worker_availability, and ExtractedToolCall helpers.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 83 out of 294 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mesh/e2e_test/infra/process_utils.py
Outdated
| if line and not line.startswith("HCA") and not line.startswith("---"): | ||
| devices.append(line) | ||
|
|
There was a problem hiding this comment.
ibv_devinfo -l commonly outputs lines like hca_id: mlx5_0. Appending the full line to devices will later call ibv_devinfo -d 'hca_id: mlx5_0', which fails and prevents device detection. Parse the device name (text after :) before appending, and only add non-empty names.
| if line and not line.startswith("HCA") and not line.startswith("---"): | |
| devices.append(line) | |
| if not line or line.startswith("HCA") or line.startswith("---"): | |
| continue | |
| _, sep, device = line.partition(":") | |
| if sep: | |
| device = device.strip() | |
| else: | |
| device = line | |
| if device: | |
| devices.append(device) |
| # Convert backend | ||
| args_dict["backend"] = backend_from_str(args_dict.get("backend")) | ||
|
|
||
| # Remove fields that shouldn't be passed to Rust Router constructor | ||
| # (deleted features, internal-only fields, or fields handled separately) | ||
| fields_to_remove = [ |
There was a problem hiding this comment.
The code converts backend into a BackendType, but then unconditionally removes backend from args_dict, so the Rust _Router(**args_dict) never receives the backend selection. If Rust Router expects this parameter, the Python bindings will silently ignore user backend input. Fix by not removing backend (or only removing the original string field before conversion, not the converted value).
| # Handled via backend_from_str or not needed | ||
| "backend", |
There was a problem hiding this comment.
The code converts backend into a BackendType, but then unconditionally removes backend from args_dict, so the Rust _Router(**args_dict) never receives the backend selection. If Rust Router expects this parameter, the Python bindings will silently ignore user backend input. Fix by not removing backend (or only removing the original string field before conversion, not the converted value).
| # Handled via backend_from_str or not needed | |
| "backend", |
| print(f"[client] Looking up remote first buffer address for {remote_session}...") | ||
| remote_buf = te.get_first_buffer_address(remote_session) | ||
| print(f"[client] Remote buffer address: {hex(remote_buf)}") | ||
|
|
||
| if remote_buf == 0: | ||
| print("[client] WARNING: get_first_buffer_address returned 0, using 0 as dst") | ||
|
|
||
| # Write our data to remote | ||
| test_pattern = bytes(range(256)) * (buf_size // 256 + 1) | ||
| ctypes.memmove(buf, test_pattern[:buf_size], buf_size) | ||
|
|
||
| print(f"[client] Attempting transfer_sync_write to {remote_session}...") | ||
| start = time.time() | ||
| ret = te.transfer_sync_write(remote_session, buf, remote_buf, buf_size) | ||
| elapsed = time.time() - start |
There was a problem hiding this comment.
If get_first_buffer_address() returns 0, the script continues and attempts an RDMA write to address 0 on the remote, which can cause undefined behavior or hard-to-debug failures. In this case, the safer behavior is to treat it as a hard error (exit non-zero) or implement an explicit fallback handshake to obtain a valid remote address before issuing transfers.
mesh/evaluation/common/eval_gsm8k.py
Outdated
| parser.add_argument("--host", type=str, default="http://127.0.0.1") | ||
| parser.add_argument("--port", type=int, default=8000) | ||
| parser.add_argument("--model", type=str, default="qwen3-235b") | ||
| parser.add_argument("--num-questions", type=int, default=50) | ||
| parser.add_argument("--max-tokens", type=int, default=512) | ||
| parser.add_argument("--temperature", type=float, default=0.0) | ||
| parser.add_argument("--workers", type=int, default=4) | ||
| parser.add_argument("--timeout", type=int, default=120, help="API request timeout in seconds") | ||
| parser.add_argument("--save-results", type=str, default=None) | ||
| parser.add_argument( | ||
| "--dataset-path", | ||
| type=str, | ||
| default=None, | ||
| help="Path to local GSM8K JSONL file (bypasses download)", | ||
| ) | ||
| args = parser.parse_args() | ||
|
|
||
| base_url = f"{args.host}:{args.port}" |
There was a problem hiding this comment.
Building base_url via simple string concatenation can easily produce invalid URLs (e.g., if --host already contains a port, or if it ends with /). Consider parsing --host as a URL (or accepting a single --base-url) and only injecting the port when it’s missing, to avoid malformed requests like http://127.0.0.1:8000:8000 or http://127.0.0.1/:8000.
mesh/e2e_test/conftest.py
Outdated
|
|
||
| handler = logging.StreamHandler(sys.stdout) | ||
| handler.setFormatter(logging.Formatter(fmt, datefmt)) | ||
|
|
||
| for logger_name in ("e2e_test", "infra", "fixtures"): | ||
| log = logging.getLogger(logger_name) | ||
| log.setLevel(logging.INFO) | ||
| log.addHandler(handler) |
There was a problem hiding this comment.
This unconditionally adds a new handler to each named logger. If _setup_logging() runs more than once (e.g., re-imports in certain pytest/plugin scenarios), logs can become duplicated. A robust pattern is to check whether a compatible handler is already attached (or clear existing handlers for these loggers) before adding the new one.
| handler = logging.StreamHandler(sys.stdout) | |
| handler.setFormatter(logging.Formatter(fmt, datefmt)) | |
| for logger_name in ("e2e_test", "infra", "fixtures"): | |
| log = logging.getLogger(logger_name) | |
| log.setLevel(logging.INFO) | |
| log.addHandler(handler) | |
| formatter = logging.Formatter(fmt, datefmt) | |
| for logger_name in ("e2e_test", "infra", "fixtures"): | |
| log = logging.getLogger(logger_name) | |
| log.setLevel(logging.INFO) | |
| has_compatible_handler = any( | |
| isinstance(existing_handler, logging.StreamHandler) | |
| and getattr(existing_handler, "stream", None) is sys.stdout | |
| and existing_handler.formatter is not None | |
| and existing_handler.formatter._fmt == formatter._fmt | |
| and existing_handler.formatter.datefmt == formatter.datefmt | |
| for existing_handler in log.handlers | |
| ) | |
| if not has_compatible_handler: | |
| handler = logging.StreamHandler(sys.stdout) | |
| handler.setFormatter(formatter) | |
| log.addHandler(handler) |
mesh/e2e_test/benchmarks/conftest.py
Outdated
| try: | ||
| stdout, stderr = proc.communicate(timeout=timeout) | ||
| except subprocess.TimeoutExpired: | ||
| proc.kill() | ||
| stdout, stderr = proc.communicate() | ||
| logger.error("genai-bench timed out after %ds", timeout) | ||
|
|
||
| # Log output if process failed or for debugging | ||
| if proc.returncode != 0: |
There was a problem hiding this comment.
On TimeoutExpired, the code kills the process and logs an error, but it doesn’t fail the test immediately. Depending on whether partial results exist, the benchmark test could proceed and potentially pass or fail later with a less actionable error. Consider explicitly failing (e.g., raising an AssertionError/pytest.fail) after the timeout so the test outcome clearly reflects the timeout condition.
These directories are kept locally but removed from version control as they are not needed for compilation and can be maintained separately.
Integration tests kept locally but removed from version control.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 70 out of 152 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pub async fn get_engine_metrics( | ||
| worker_registry: &WorkerRegistry, | ||
| client: &reqwest::Client, | ||
| ) -> EngineMetricsResult { | ||
| let workers = worker_registry.get_all(); | ||
|
|
||
| if workers.is_empty() { | ||
| return EngineMetricsResult::Err("No available workers".to_string()); | ||
| } | ||
|
|
||
| let responses = fan_out(&workers, client, "metrics", reqwest::Method::GET).await; |
There was a problem hiding this comment.
get_engine_metrics currently fans out to all workers, including non-HTTP ones. If gRPC workers are registered, these requests will reliably fail and can add avoidable latency (up to REQUEST_TIMEOUT per attempt) and noise. Filter to ConnectionMode::Http (similar to flush_cache_all) before calling fan_out, and consider reporting how many workers were skipped due to non-HTTP mode.
| let futures: Vec<_> = workers | ||
| .iter() | ||
| .map(|worker| { | ||
| let url = worker.url().to_string(); | ||
| let api_key = worker.api_key().clone(); | ||
| let worker_type = match worker.worker_type() { | ||
| WorkerType::Regular => None, | ||
| WorkerType::Prefill { .. } => Some("prefill".to_string()), | ||
| WorkerType::Decode => Some("decode".to_string()), | ||
| }; | ||
| let is_http = matches!(worker.connection_mode(), ConnectionMode::Http); | ||
| let client = client.clone(); | ||
|
|
||
| async move { | ||
| let load = if is_http { | ||
| Self::parse_load_response(&client, &url, api_key.as_deref()).await | ||
| } else { | ||
| -1 | ||
| }; | ||
| WorkerLoadInfo { | ||
| worker: url, | ||
| worker_type, | ||
| load, | ||
| } | ||
| } | ||
| }) | ||
| .collect(); | ||
|
|
||
| let loads = future::join_all(futures).await; |
There was a problem hiding this comment.
join_all will spawn requests for every worker concurrently, which can become a thundering herd with large fleets (file descriptor pressure, bursty DNS/conn attempts, and increased tail latency). Consider using a bounded concurrency approach (e.g., stream::iter(...).buffer_unordered(MAX_CONCURRENT)) similar to fan_out, and/or reusing fan_out to keep request concurrency consistent across worker-management operations.
| interval: Duration, | ||
| tx: watch::Sender<HashMap<String, isize>>, | ||
| rx: watch::Receiver<HashMap<String, isize>>, | ||
| monitor_handle: Arc<Mutex<Option<JoinHandle<()>>>>, |
There was a problem hiding this comment.
LoadMonitor uses tokio::sync::Mutex for monitor_handle, but Drop can only try_lock(). If another task holds the lock during drop (e.g., concurrent start() / stop()), the handle may not be aborted and the background task can keep running after the struct is dropped. A concrete fix is to switch monitor_handle to a sync mutex (parking_lot::Mutex/std::sync::Mutex) so Drop can always lock, or store a CancellationToken/watch flag that the loop checks so the task terminates even if the handle can’t be acquired in Drop.
| impl Drop for LoadMonitor { | ||
| fn drop(&mut self) { | ||
| if let Ok(mut handle_guard) = self.monitor_handle.try_lock() { | ||
| if let Some(handle) = handle_guard.take() { | ||
| handle.abort(); | ||
| } | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
LoadMonitor uses tokio::sync::Mutex for monitor_handle, but Drop can only try_lock(). If another task holds the lock during drop (e.g., concurrent start() / stop()), the handle may not be aborted and the background task can keep running after the struct is dropped. A concrete fix is to switch monitor_handle to a sync mutex (parking_lot::Mutex/std::sync::Mutex) so Drop can always lock, or store a CancellationToken/watch flag that the loop checks so the task terminates even if the handle can’t be acquired in Drop.
| let wait_time = { | ||
| let inner = self.inner.lock(); | ||
| let tokens_needed = tokens - inner.tokens; | ||
| let wait_secs = (tokens_needed / self.refill_rate).max(0.0); | ||
| Duration::from_secs_f64(wait_secs) | ||
| }; | ||
|
|
||
| debug!( | ||
| "Token bucket: waiting {:?} for {} tokens", | ||
| wait_time, tokens | ||
| ); | ||
|
|
||
| tokio::time::timeout(wait_time, async { |
There was a problem hiding this comment.
When refill_rate > 0, wait_time can be computed as 0 (e.g., if tokens become available between the initial failed try_acquire and this calculation). tokio::time::timeout(Duration::ZERO, ...) can immediately elapse, returning an error even though the bucket could satisfy the request. Fix by short-circuiting when tokens_needed <= 0.0, or enforcing a small minimum timeout (or recomputing/refilling tokens under the same logic used by try_acquire_sync before deciding to time out).
| // Track unique model IDs we've updated policies for | ||
| let mut updated_models = Vec::new(); |
There was a problem hiding this comment.
updated_models.contains(&model_id) makes this loop O(n²) in the number of workers. Switching updated_models to a HashSet<String> avoids repeated scans and makes uniqueness tracking O(1) average-case, while preserving the same behavior.
| if !updated_models.contains(&model_id) { | ||
| updated_models.push(model_id); | ||
| } |
There was a problem hiding this comment.
updated_models.contains(&model_id) makes this loop O(n²) in the number of workers. Switching updated_models to a HashSet<String> avoids repeated scans and makes uniqueness tracking O(1) average-case, while preserving the same behavior.
| /// Load monitoring service that periodically fetches worker loads | ||
| pub struct LoadMonitor { | ||
| worker_registry: Arc<WorkerRegistry>, | ||
| policy_registry: Arc<PolicyRegistry>, | ||
| client: reqwest::Client, | ||
| interval: Duration, | ||
| tx: watch::Sender<HashMap<String, isize>>, | ||
| rx: watch::Receiver<HashMap<String, isize>>, | ||
| monitor_handle: Arc<Mutex<Option<JoinHandle<()>>>>, | ||
| } |
There was a problem hiding this comment.
This introduces a long-running background loop with start/stop/drop semantics and periodic updates to PowerOfTwo policies, but there are no unit tests covering: (1) start() idempotency, (2) stop() actually halting updates, and (3) drop behavior aborting the task. Adding focused tests (e.g., using tokio::time::pause() + a stub PolicyRegistry / small interval) would help prevent regressions in production task lifecycle behavior.
Remove Python bindings directory from version control (kept locally). Remove Python package install, Python launcher, and Python dev references from README.
Tests, e2e_test, and Python bindings have been removed from git tracking, so remove cargo test and pre-commit hooks sections.
Remove Kubernetes service discovery, OpenTelemetry tracing, and multi-backend (vLLM/TRT-LLM) references that no longer exist in src/.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 59 out of 130 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| let task = PeriodicTask::spawn(interval_secs, "InFlightRequestSampler", move || { | ||
| tracker.sample_and_record(); | ||
| }); | ||
| self.sampler.set(task).unwrap(); |
There was a problem hiding this comment.
start_sampler() will panic if called more than once because OnceLock::set() returns an error on subsequent calls and this code unwraps it. Consider making this idempotent (e.g., return early if already set) or returning a Result so callers can handle the “already started” case gracefully.
| let task = PeriodicTask::spawn(interval_secs, "InFlightRequestSampler", move || { | |
| tracker.sample_and_record(); | |
| }); | |
| self.sampler.set(task).unwrap(); | |
| self.sampler.get_or_init(|| { | |
| PeriodicTask::spawn(interval_secs, "InFlightRequestSampler", move || { | |
| tracker.sample_and_record(); | |
| }) | |
| }); |
| /// Thread-safe: uses entry API to avoid race conditions. | ||
| pub fn get_or_register( | ||
| &self, | ||
| label_value: &str, | ||
| ) -> dashmap::mapref::one::Ref<'_, Arc<str>, (GaugeHistogramHandle, Vec<usize>)> { | ||
| // Fast path: already cached | ||
| if let Some(entry) = self.cache.get(label_value) { | ||
| return entry; | ||
| } | ||
|
|
||
| // Slow path: use entry API to handle concurrent inserts atomically | ||
| self.cache.entry(Arc::from(label_value)).or_insert_with(|| { | ||
| let handle = self.histogram.register(&[(self.label_key, label_value)]); | ||
| let counts_buf = vec![0usize; self.histogram.bounds.bucket_count()]; | ||
| (handle, counts_buf) | ||
| }); | ||
|
|
||
| self.cache.get(label_value).unwrap() |
There was a problem hiding this comment.
get_or_register() does a second lookup followed by unwrap(). If another thread calls remove() between the entry().or_insert_with(...) and the final get(), this can panic. To make this race-free, return the entry reference directly from the entry API (adjusting the return type if needed), or otherwise avoid the second get(...).unwrap().
| /// Thread-safe: uses entry API to avoid race conditions. | |
| pub fn get_or_register( | |
| &self, | |
| label_value: &str, | |
| ) -> dashmap::mapref::one::Ref<'_, Arc<str>, (GaugeHistogramHandle, Vec<usize>)> { | |
| // Fast path: already cached | |
| if let Some(entry) = self.cache.get(label_value) { | |
| return entry; | |
| } | |
| // Slow path: use entry API to handle concurrent inserts atomically | |
| self.cache.entry(Arc::from(label_value)).or_insert_with(|| { | |
| let handle = self.histogram.register(&[(self.label_key, label_value)]); | |
| let counts_buf = vec![0usize; self.histogram.bounds.bucket_count()]; | |
| (handle, counts_buf) | |
| }); | |
| self.cache.get(label_value).unwrap() | |
| /// Thread-safe: retries if a concurrent remove happens after insertion. | |
| pub fn get_or_register( | |
| &self, | |
| label_value: &str, | |
| ) -> dashmap::mapref::one::Ref<'_, Arc<str>, (GaugeHistogramHandle, Vec<usize>)> { | |
| loop { | |
| // Fast path: already cached | |
| if let Some(entry) = self.cache.get(label_value) { | |
| return entry; | |
| } | |
| // Slow path: use entry API to handle concurrent inserts atomically | |
| self.cache.entry(Arc::from(label_value)).or_insert_with(|| { | |
| let handle = self.histogram.register(&[(self.label_key, label_value)]); | |
| let counts_buf = vec![0usize; self.histogram.bounds.bucket_count()]; | |
| (handle, counts_buf) | |
| }); | |
| // If another thread removed the entry before we could fetch it, | |
| // retry instead of panicking on unwrap. | |
| } |
| let job_queue = self.get_job_queue()?; | ||
|
|
||
| if let Some(worker) = self.worker_registry.get(&worker_id) { | ||
| let worker_url = worker.url().to_string(); | ||
| let mut worker_info = worker_to_info(&worker); | ||
| worker_info.id = worker_id.as_str().to_string(); | ||
| if let Some(status) = job_queue.get_status(&worker_url) { | ||
| worker_info.job_status = Some(status); | ||
| } | ||
| return Ok(GetWorkerResponse(worker_info)); | ||
| } | ||
|
|
||
| if let Some(worker_url) = self.worker_registry.get_url_by_id(&worker_id) { | ||
| if let Some(status) = job_queue.get_status(&worker_url) { | ||
| return Ok(GetWorkerResponse(WorkerInfo::pending( | ||
| worker_id.as_str(), | ||
| worker_url, | ||
| Some(status), | ||
| ))); | ||
| } |
There was a problem hiding this comment.
get_worker() currently fails with QueueNotInitialized even when the worker exists in the registry, because it requires job_queue before checking the registry. This breaks read-only inspection when the queue isn’t configured. A safer approach is: first check worker_registry.get(...), and only consult the job queue for optional job status if it’s available.
| let job_queue = self.get_job_queue()?; | |
| if let Some(worker) = self.worker_registry.get(&worker_id) { | |
| let worker_url = worker.url().to_string(); | |
| let mut worker_info = worker_to_info(&worker); | |
| worker_info.id = worker_id.as_str().to_string(); | |
| if let Some(status) = job_queue.get_status(&worker_url) { | |
| worker_info.job_status = Some(status); | |
| } | |
| return Ok(GetWorkerResponse(worker_info)); | |
| } | |
| if let Some(worker_url) = self.worker_registry.get_url_by_id(&worker_id) { | |
| if let Some(status) = job_queue.get_status(&worker_url) { | |
| return Ok(GetWorkerResponse(WorkerInfo::pending( | |
| worker_id.as_str(), | |
| worker_url, | |
| Some(status), | |
| ))); | |
| } | |
| let job_queue = self.get_job_queue().ok(); | |
| if let Some(worker) = self.worker_registry.get(&worker_id) { | |
| let worker_url = worker.url().to_string(); | |
| let mut worker_info = worker_to_info(&worker); | |
| worker_info.id = worker_id.as_str().to_string(); | |
| if let Some(status) = job_queue | |
| .as_ref() | |
| .and_then(|job_queue| job_queue.get_status(&worker_url)) | |
| { | |
| worker_info.job_status = Some(status); | |
| } | |
| return Ok(GetWorkerResponse(worker_info)); | |
| } | |
| if let Some(worker_url) = self.worker_registry.get_url_by_id(&worker_id) { | |
| let status = job_queue | |
| .as_ref() | |
| .and_then(|job_queue| job_queue.get_status(&worker_url)); | |
| return Ok(GetWorkerResponse(WorkerInfo::pending( | |
| worker_id.as_str(), | |
| worker_url, | |
| status, | |
| ))); |
| // Check if worker already exists | ||
| if app_context | ||
| .worker_registry | ||
| .get_by_url(&config.url) | ||
| .is_some() | ||
| { | ||
| return Err(WorkflowError::StepFailed { | ||
| step_id: StepId::new("create_worker"), | ||
| message: format!("Worker {} already exists", config.url), | ||
| }); | ||
| } |
There was a problem hiding this comment.
This existence check only looks up config.url exactly. For DP-aware workers, the registry typically stores per-rank URLs (e.g., base@0, base@1, …), so get_by_url(base) may return None and allow duplicate DP-aware registrations. Consider using a prefix check when config.dp_aware is true (e.g., detect any registered worker whose URL starts with normalized_url + \"@\").
| pub(crate) fn find_workers_by_url( | ||
| registry: &WorkerRegistry, | ||
| url: &str, | ||
| dp_aware: bool, | ||
| ) -> Vec<Arc<dyn Worker>> { | ||
| if dp_aware { | ||
| let worker_url_prefix = format!("{}@", url); | ||
| registry | ||
| .get_all() | ||
| .iter() | ||
| .filter(|worker| worker.url().starts_with(&worker_url_prefix)) | ||
| .cloned() | ||
| .collect() | ||
| } else { | ||
| match registry.get_by_url(url) { | ||
| Some(worker) => vec![worker], | ||
| None => Vec::new(), | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
DP-aware matching appends @ unconditionally. If url already contains a rank (e.g., base@0), this builds a prefix like base@0@, which won’t match any workers and will break update/remove flows that use ranked URLs. Fix by handling both forms: if url already contains @, either (a) treat it as an exact match for a single worker, or (b) strip the rank portion to derive the base URL before doing a prefix match—depending on the intended semantics for that call site.
| "Job queue full: {} jobs pending (capacity: {})", | ||
| queue_depth, | ||
| self.tx.max_capacity() |
There was a problem hiding this comment.
For tokio::sync::mpsc::Sender::send(...).await, the error path indicates the receiver is closed (dispatcher stopped), not that the channel is “full” (since send().await applies backpressure until there is capacity). The current error message is misleading. Consider returning an error like “job queue closed/shutting down” here, and if you need a “full” error, use try_send() (or another explicit overload) instead.
| "Job queue full: {} jobs pending (capacity: {})", | |
| queue_depth, | |
| self.tx.max_capacity() | |
| "Job queue closed/shutting down: dispatcher stopped with {} jobs pending", | |
| queue_depth |
| let _ = tracing_subscriber::registry() | ||
| .with(env_filter) | ||
| .with(layers) | ||
| .try_init(); |
There was a problem hiding this comment.
The result of try_init() is ignored, which can silently disable expected logging configuration (e.g., if logging was already initialized elsewhere or initialization fails). Consider handling the error explicitly (at least eprintln! it), or change the API to return a Result<LogGuard, _> so callers can decide how to handle initialization failures.
| const DEFAULT_PROJECT_NAME: &str = "sgl-model-gateway"; | ||
|
|
||
| /// Set a compile-time environment variable with the SGL_MODEL_GATEWAY_ prefix | ||
| macro_rules! set_env { | ||
| ($name:expr, $value:expr) => { | ||
| println!("cargo:rustc-env=SGL_MODEL_GATEWAY_{}={}", $name, $value); |
There was a problem hiding this comment.
The build metadata still hard-codes the legacy project name (sgl-model-gateway) even though the crate/README are now ATOM Mesh. This can leak into --version output and diagnostics via the exported env vars. Consider updating DEFAULT_PROJECT_NAME (and potentially the env var prefix) to match the new project identity to avoid confusion in ops/debugging.
| const DEFAULT_PROJECT_NAME: &str = "sgl-model-gateway"; | |
| /// Set a compile-time environment variable with the SGL_MODEL_GATEWAY_ prefix | |
| macro_rules! set_env { | |
| ($name:expr, $value:expr) => { | |
| println!("cargo:rustc-env=SGL_MODEL_GATEWAY_{}={}", $name, $value); | |
| const DEFAULT_PROJECT_NAME: &str = "atom-mesh"; | |
| /// Set a compile-time environment variable with the ATOM_MESH_ prefix | |
| macro_rules! set_env { | |
| ($name:expr, $value:expr) => { | |
| println!("cargo:rustc-env=ATOM_MESH_{}={}", $name, $value); |
Most targets depend on deleted directories (bindings, tests, e2e_test). Remaining useful targets are trivial cargo commands.
119 git-tracked files in mesh/src/ documented with Chinese annotations covering purpose, key types, and architectural role of each file.
…README" This reverts commit 3a455ff.
Remove non-existent /v1/embeddings endpoint and add the actual /v1/completions endpoint that is implemented in both HTTP routers.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 62 out of 128 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| let healthy_workers: Vec<(usize, &Arc<dyn Worker>)> = workers | ||
| .iter() | ||
| .enumerate() | ||
| .filter(|(_, w)| w.is_healthy()) |
There was a problem hiding this comment.
PrefixHashPolicy only filters by is_healthy() and ignores circuit breaker state, while other selection helpers (e.g., get_healthy_worker_indices) also require w.circuit_breaker().can_execute(). This can route traffic to a worker whose circuit breaker is open. Update the filter to include the circuit breaker check consistently.
| .filter(|(_, w)| w.is_healthy()) | |
| .filter(|(_, w)| w.is_healthy() && w.circuit_breaker().can_execute()) |
| // Find worker using ring with load balancing | ||
| self.find_worker_with_load_balance(workers, info, prefix_hash) |
There was a problem hiding this comment.
Unlike CacheAwarePolicy and PowerOfTwoPolicy, PrefixHashPolicy never calls workers[idx].increment_processed() on successful selection. If processed-count metrics (or downstream accounting) depend on this, prefix-hash will underreport/behave inconsistently. Consider incrementing the processed counter when returning a selected worker index (mandatory if other policies rely on it for observability/behavior).
| // we must degrade BOTH to request counts to ensure fairness. | ||
| let (load1, load2) = match (load1_tokens, load2_tokens) { | ||
| (Some(t1), Some(t2)) => { | ||
| // Both have token data. Compare Tokens. | ||
| (t1, t2) | ||
| } | ||
| _ => { | ||
| // If One or both are missing token data. | ||
| // Fallback to local request counts for BOTH. |
There was a problem hiding this comment.
LoadMonitor can populate the cached load map with sentinel values like -1 (e.g., non-HTTP workers or parse failures). The (Some(t1), Some(t2)) branch treats negative values as valid “token loads”, which can bias routing (a -1 will always win). Treat negative cached values as “missing” and force the fallback to request counts unless both cached loads are present and non-negative.
| // we must degrade BOTH to request counts to ensure fairness. | |
| let (load1, load2) = match (load1_tokens, load2_tokens) { | |
| (Some(t1), Some(t2)) => { | |
| // Both have token data. Compare Tokens. | |
| (t1, t2) | |
| } | |
| _ => { | |
| // If One or both are missing token data. | |
| // Fallback to local request counts for BOTH. | |
| // or has an invalid negative sentinel value, we must degrade BOTH | |
| // to request counts to ensure fairness. | |
| let (load1, load2) = match (load1_tokens, load2_tokens) { | |
| (Some(t1), Some(t2)) if t1 >= 0 && t2 >= 0 => { | |
| // Both have valid non-negative token data. Compare tokens. | |
| (t1, t2) | |
| } | |
| _ => { | |
| // If one or both are missing token data, or either cached | |
| // value is invalid, fallback to local request counts for BOTH. |
| match self.tx.send(job).await { | ||
| Ok(_) => { | ||
| let (queue_depth, available_permits) = self.get_load_info(); | ||
| debug!( | ||
| "Job submitted: type={}, worker={}, queue_depth={}, available_slots={}", | ||
| job_type, worker_url, queue_depth, available_permits | ||
| ); | ||
| Ok(()) | ||
| } | ||
| Err(_) => { | ||
| self.status_map.remove(&worker_url); | ||
| let (queue_depth, _) = self.get_load_info(); | ||
| Err(format!( | ||
| "Job queue full: {} jobs pending (capacity: {})", | ||
| queue_depth, | ||
| self.tx.max_capacity() | ||
| )) | ||
| } | ||
| } |
There was a problem hiding this comment.
mpsc::Sender::send(...).await doesn’t fail because the queue is full (it waits); it fails only if the receiver is closed. Returning "Job queue full" here is misleading and will complicate debugging/shutdown behavior. Use an error message indicating the dispatcher/receiver is closed, and consider keeping the failed status (instead of removing it) so callers can observe the failure.
RFC: ATOM Mesh — High-Performance Model Gateway for Prefill-Decode Disaggregation
1. Summary
ATOM Mesh is a high-performance model routing gateway written in Rust, purpose-built for Prefill-Decode (PD) disaggregated LLM inference on the AMD ROCm platform. It serves as both the control plane and data plane for orchestrating fleets of heterogeneous LLM workers, enabling independent scaling and optimized GPU utilization for the prefill and decode phases of autoregressive inference.
Forked from sgl-model-gateway v0.3.2 and extended with PD-specific routing, gRPC pipeline support, cache-aware load balancing, and out-of-band KV cache transfer coordination via the Mooncake Transfer Engine.
2. Motivation
LLM inference has two phases with opposite compute profiles: prefill is compute-bound (parallel matrix ops), while decode is memory-bandwidth-bound (sequential token generation). Coupling them on the same GPU wastes resources — prefill bursts starve decode, and decode underutilizes ALUs.
ATOM Mesh solves this by separating them into independent worker pools that scale and optimize independently, with KV cache transferred between pools via RDMA/TCP (Mooncake). This is the AMD ROCm counterpart to NVIDIA Dynamo's PD disaggregation.
3. Architecture Overview
graph TB Client["Client (OpenAI API)"] subgraph Gateway["ATOM Mesh Gateway"] Server["Axum Server<br/>(HTTP / HTTPS)"] Router["Router Layer<br/>(HTTP / gRPC)<br/>(Regular / PD)"] Policy["Policy Engine<br/>(CacheAware, PowerOfTwo,<br/>PrefixHash, ...)"] WR["Worker Registry<br/>(DashMap + HashRing)"] PR["Policy Registry<br/>(prefill_policy +<br/>decode_policy)"] end P0["Prefill Worker 0"] P1["Prefill Worker 1"] D0["Decode Worker 0"] D1["Decode Worker 1"] Client --> Server Server --> Router Router --> Policy Policy --> WR Policy --> PR WR --> P0 WR --> P1 WR --> D0 WR --> D1 P0 -- "KV Cache Transfer<br/>(Mooncake RDMA/TCP)" --> D0 P1 -- "KV Cache Transfer<br/>(Mooncake RDMA/TCP)" --> D1 style Gateway fill:#f0f4ff,stroke:#4a6fa5,stroke-width:2px style P0 fill:#e8f5e9,stroke:#388e3c style P1 fill:#e8f5e9,stroke:#388e3c style D0 fill:#fff3e0,stroke:#f57c00 style D1 fill:#fff3e0,stroke:#f57c00Component Summary
4. PD Disaggregation Design
4.1 Routing Mode Configuration
The gateway supports two routing modes:
CLI usage:
mesh--pd-disaggregation \ --prefill http://prefill-0:30000 8998 \ --prefill http://prefill-1:30000 8998 \ --decode http://decode-0:30000 \ --decode http://decode-1:30000 \ --prefill-policy cache_aware \ --decode-policy power_of_two4.2 Request Lifecycle in PD Mode
sequenceDiagram participant C as Client participant G as Mesh Gateway participant P as Prefill Worker participant D as Decode Worker C->>G: POST /v1/chat/completions Note over G: 1. Select PD Pair<br/>(prefill_policy + decode_policy) Note over G: 2. Inject Bootstrap Metadata<br/>bootstrap_host, bootstrap_port,<br/>bootstrap_room (random u64) par 3. Simultaneous Dual Dispatch G->>P: POST /generate (with bootstrap metadata) G->>D: POST /generate (with bootstrap metadata) end Note over P: 4. Compute KV Cache P-->>D: 5. Mooncake RDMA/TCP KV Cache Transfer<br/>(out-of-band, using bootstrap_room as session ID) Note over D: 6. Begin Autoregressive Decoding D-->>G: SSE Stream (generated tokens) G-->>C: SSE Stream (generated tokens)Step-by-step:
Worker Pair Selection — The prefill policy selects a prefill worker, and the decode policy independently selects a decode worker. Each policy runs against its respective worker pool filtered by health and circuit breaker state.
Bootstrap Metadata Injection — The gateway injects three fields into the request body:
bootstrap_host— The prefill worker's address (where Mooncake listens)bootstrap_port— The Mooncake transfer engine port (default 8998)bootstrap_room— A random u64 session ID in[0, 2^63)to isolate concurrent transfersSimultaneous Dual Dispatch — Both workers receive the annotated request at the same time via
tokio::join!(). The decode worker can prepare internal state while waiting for the KV cache, avoiding sequential latency. The KV cache transfer happens out-of-band — the gateway never touches KV cache bytes (which can be hundreds of MB).Prefill Computation — The prefill worker processes the input prompt and computes the KV cache.
KV Cache Transfer — The Mooncake Transfer Engine transfers the KV cache from prefill to decode via RDMA or TCP. The
bootstrap_roomensures concurrent requests to the same prefill worker do not collide.Decode Generation — The decode worker receives the KV cache and begins autoregressive token generation, streaming results back through the gateway to the client.
5. Load Balancing Policies
PD disaggregation allows independent policies for prefill and decode pools, reflecting their different optimization targets:
graph LR subgraph Prefill Pool PP["Prefill Policy<br/>(CacheAware / PrefixHash)"] PW0["Prefill Worker 0"] PW1["Prefill Worker 1"] end subgraph Decode Pool DP["Decode Policy<br/>(PowerOfTwo / RoundRobin)"] DW0["Decode Worker 0"] DW1["Decode Worker 1"] end Request["Incoming Request"] --> PP Request --> DP PP --> PW0 PP --> PW1 DP --> DW0 DP --> DW1 style PP fill:#e8f5e9,stroke:#388e3c style DP fill:#fff3e0,stroke:#f57c006. Deployment Topology
graph TB LB["Load Balancer"] subgraph Mesh["ATOM Mesh Gateway (Port 30050)"] MESH["mesh binary"] end subgraph Prefill["Prefill Pool"] P0["Prefill-0<br/>GPU 0-3<br/>Port 30000<br/>Mooncake: 8998"] P1["Prefill-1<br/>GPU 4-7<br/>Port 30000<br/>Mooncake: 8998"] end subgraph Decode["Decode Pool"] D0["Decode-0<br/>GPU 8-11<br/>Port 30000"] D1["Decode-1<br/>GPU 12-15<br/>Port 30000"] end LB --> MESH MESH--> P0 MESH--> P1 MESH--> D0 MESH--> D1 P0 -. "RDMA KV Transfer" .-> D0 P0 -. "RDMA KV Transfer" .-> D1 P1 -. "RDMA KV Transfer" .-> D0 P1 -. "RDMA KV Transfer" .-> D1 style Mesh fill:#f0f4ff,stroke:#4a6fa5,stroke-width:2px style Prefill fill:#e8f5e9,stroke:#388e3c,stroke-width:2px style Decode fill:#fff3e0,stroke:#f57c00,stroke-width:2px