From 792ef0186216a7bb654bc6a034c904a1a389db7c Mon Sep 17 00:00:00 2001 From: Googler Date: Thu, 25 Jun 2026 15:24:53 -0700 Subject: [PATCH] Add Dockerfile for BAP ML-Actions and configure benchmark registry to test it PiperOrigin-RevId: 938224023 --- Dockerfile | 10 + benchmarks/benchmark_registry.pbtxt | 105 ++++++ tpu_raiden/benchmarks/BUILD | 46 +++ .../benchmarks/multi_host_perf_test_oss.py | 305 ++++++++++++++++++ .../core/kv_cache_manager_with_transfer.cc | 119 +++++-- 5 files changed, 551 insertions(+), 34 deletions(-) create mode 100644 Dockerfile create mode 100644 benchmarks/benchmark_registry.pbtxt create mode 100644 tpu_raiden/benchmarks/BUILD create mode 100644 tpu_raiden/benchmarks/multi_host_perf_test_oss.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..0624fb9 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,10 @@ +# Use the official ML Build container as the base +FROM us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest + +# Switch to root to install system packages +USER root + +# Install clang and llvm which are required by XLA/tpu-raiden's bazel configuration +RUN apt-get update && apt-get install -y clang llvm && rm -rf /var/lib/apt/lists/* + +# The container will run as the default user inherited from ml-build diff --git a/benchmarks/benchmark_registry.pbtxt b/benchmarks/benchmark_registry.pbtxt new file mode 100644 index 0000000..f62f3ca --- /dev/null +++ b/benchmarks/benchmark_registry.pbtxt @@ -0,0 +1,105 @@ +# proto-file: https://github.com/google-ml-infra/actions/blob/main/benchmarking/proto/benchmark_registry.proto +# proto-message: BenchmarkSuite + +benchmarks { + name: "tpu_raiden_h2d_d2h_perf_test_fp32" + description: "FP32 Performance Test for TPU Raiden local H2D and D2H offloading/reloading" + owner: "raiden-dev" + + workload { + action: "./ml_actions/actions/workload_executors/bazel" + action_inputs { key: "target" value: "//tpu_raiden/benchmarks:multi_host_perf_test_oss" } + # Flags passed to the test runner (avoid hardcoding role/peer for multi-host dynamic execution) + action_inputs { key: "runtime_flags" value: "--num_blocks=512 --num_layers=8 --parallelism=1 --dtype=float32 --warmup=5 --iters=100" } + } + + environment_configs { + id: "tpu-v5e-single-node" + runner_label: "linux-x86-ct5lp-224-8tpu" + container_image: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" + workload_action_inputs { key: "bazel_run_flags" value: "-c opt" } + } + + metrics { + name: "d2h_time_sec" + unit: "s" + stats { + stat: MEAN + } + } + + metrics { + name: "h2d_time_sec" + unit: "s" + stats { + stat: MEAN + } + } + + metrics { + name: "d2h_throughput_gbps" + unit: "Gbps" + stats { + stat: MEAN + } + } + + metrics { + name: "h2d_throughput_gbps" + unit: "Gbps" + stats { + stat: MEAN + } + } +} + +benchmarks { + name: "tpu_raiden_h2d_d2h_perf_test_bf16" + description: "BF16 Performance Test for TPU Raiden local H2D and D2H offloading/reloading" + owner: "raiden-dev" + + workload { + action: "./ml_actions/actions/workload_executors/bazel" + action_inputs { key: "target" value: "//tpu_raiden/benchmarks:multi_host_perf_test_oss" } + # Flags passed to the test runner (avoid hardcoding role/peer for multi-host dynamic execution) + action_inputs { key: "runtime_flags" value: "--num_blocks=512 --num_layers=8 --parallelism=1 --dtype=bfloat16 --warmup=5 --iters=100" } + } + + environment_configs { + id: "tpu-v5e-single-node" + runner_label: "linux-x86-ct5lp-224-8tpu" + container_image: "us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:latest" + workload_action_inputs { key: "bazel_run_flags" value: "-c opt" } + } + + metrics { + name: "d2h_time_sec" + unit: "s" + stats { + stat: MEAN + } + } + + metrics { + name: "h2d_time_sec" + unit: "s" + stats { + stat: MEAN + } + } + + metrics { + name: "d2h_throughput_gbps" + unit: "Gbps" + stats { + stat: MEAN + } + } + + metrics { + name: "h2d_throughput_gbps" + unit: "Gbps" + stats { + stat: MEAN + } + } diff --git a/tpu_raiden/benchmarks/BUILD b/tpu_raiden/benchmarks/BUILD new file mode 100644 index 0000000..a5d0b5f --- /dev/null +++ b/tpu_raiden/benchmarks/BUILD @@ -0,0 +1,46 @@ +# Copyright 2026 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2026 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_binary") + +package(default_visibility = ["//visibility:public"]) + +exports_files(["multi_host_perf_test_oss.py"]) + +py_binary( + name = "multi_host_perf_test_oss", + testonly = True, + srcs = ["multi_host_perf_test_oss.py"], + deps = [ + "//tpu_raiden/frameworks/jax:_tpu_raiden_jax", + "@com_google_absl_py//absl:app", + "@com_google_absl_py//absl/flags", + "@jax//jax", + "@pypi//numpy", + ], +) diff --git a/tpu_raiden/benchmarks/multi_host_perf_test_oss.py b/tpu_raiden/benchmarks/multi_host_perf_test_oss.py new file mode 100644 index 0000000..b0bfa2a --- /dev/null +++ b/tpu_raiden/benchmarks/multi_host_perf_test_oss.py @@ -0,0 +1,305 @@ +# Copyright 2026 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import json +from absl import app +from absl import flags +import jax +import jax.numpy as jnp +import numpy as np +from tpu_raiden.frameworks.jax import _tpu_raiden_jax as kv_cache_manager + +_NUM_BLOCKS = flags.DEFINE_integer( + 'num_blocks', 512, 'Number of global cache blocks to allocate.' +) +_BLOCK_SIZE = flags.DEFINE_integer('block_size', 2, 'Size of cache blocks.') +_NUM_LAYERS = flags.DEFINE_integer( + 'num_layers', 1, 'Number of transformer layers.' +) +_DTYPE = flags.DEFINE_string( + 'dtype', + 'float32', + 'Dataset type for the KV cache array: float32, bfloat16, float16.', +) +_WARMUP = flags.DEFINE_integer('warmup', 5, 'Number of warmup iterations.') +_ITERS = flags.DEFINE_integer('iters', 20, 'Number of benchmark iterations.') + + +def write_tensorboard_metrics( + d2h_time_sec: float, + h2d_time_sec: float, + d2h_gbps: float, + h2d_gbps: float, +): + """Logs local copy CPU-TPU transfer metrics to Tensorboard event logs for BAP.""" + tblog_dir = os.environ.get('TENSORBOARD_OUTPUT_DIR') + if not tblog_dir: + print('TENSORBOARD_OUTPUT_DIR is not set. Skipping TensorBoard logging.') + return + + print(f'Writing metrics to TensorBoard directory: {tblog_dir}') + try: + try: + # pylint: disable=g-import-not-at-top + import tensorboardX # pytype: disable=import-error + + writer = tensorboardX.SummaryWriter(log_dir=tblog_dir) + except ImportError: + # pylint: disable=g-import-not-at-top + import torch.utils.tensorboard # pytype: disable=import-error + + writer = torch.utils.tensorboard.SummaryWriter(log_dir=tblog_dir) + + # Log averages + writer.add_scalar('d2h_time_sec', d2h_time_sec, global_step=0) + writer.add_scalar('h2d_time_sec', h2d_time_sec, global_step=0) + writer.add_scalar('d2h_throughput_gbps', d2h_gbps, global_step=0) + writer.add_scalar('h2d_throughput_gbps', h2d_gbps, global_step=0) + writer.close() + print('Successfully wrote performance metrics to TensorBoard logs.') + except Exception as e: # pylint: disable=broad-exception-caught + print(f'WARNING: Failed to write TensorBoard logs: {e}', file=sys.stderr) + + +def setup_distributed_mesh(devices): + """Sets up a global JAX Mesh sharding the block count across processes.""" + process_id = jax.process_index() + num_processes = jax.process_count() + num_local_devices = len(jax.local_devices()) + + print( + 'Initializing JAX Distributed Mesh on Process' + f' {process_id}/{num_processes}' + ) + print(f'Local addressable devices seen: {jax.local_devices()}') + print(f'Global cluster devices seen: {devices}') + + # Reshape mesh to (num_processes, num_local_devices) + devices_array = np.array(devices).reshape((num_processes, num_local_devices)) + mesh = jax.sharding.Mesh(devices_array, ('host', 'device')) + + # Shard the first dimension (num_blocks) across the host axis! + spec = jax.sharding.PartitionSpec('host', None, None, None, None) + tpu_sharding = jax.sharding.NamedSharding(mesh, spec) + host_sharding = jax.sharding.NamedSharding( + mesh, spec, memory_kind='pinned_host' + ) + + return tpu_sharding, host_sharding + + +def verify_device_cache(tpu_cache) -> bool: + """Verifies local sharded tpu cache destination blocks match source blocks.""" + process_id = jax.process_index() + print( + f'[Process {process_id}] Verifying local sharded cache device' + ' consistency...' + ) + try: + for s in tpu_cache.addressable_shards: + # s.data is a process-local JAX array representing the blocks held by this + # shard + local_tpu_data = np.asarray(s.data) + local_blocks = local_tpu_data.shape[ + 0 + ] # Total blocks held locally by this shard + half = local_blocks // 2 # Half the blocks + + # Verify destination half matches source half + np.testing.assert_array_equal( + local_tpu_data[half:local_blocks], + local_tpu_data[0:half], + ) + except AssertionError as exc: + print(f'[Process {process_id}] Verification FAILED!') + print(exc) + return False + print( + f'[Process {process_id}] Device consistency verified successfully! 0%' + ' corruption.' + ) + return True + + +def main(_): + process_id = jax.process_index() + + devices = jax.devices('tpu') + tpu_sharding, _ = setup_distributed_mesh(devices) + + # Physical sharding shape: (num_blocks, head_count, 1, head_dim) + cache_shape = (_NUM_BLOCKS.value, 32, 1, 8, 128) + print(f'[Process {process_id}] Configured Cache Global Shape: {cache_shape}') + + # 1. Create a single large device array and initialize it with unique values + print( + f'[Process {process_id}] Initializing device cache with reference' + ' sequence...' + ) + dtype_map = { + 'float32': jnp.float32, + 'bfloat16': jnp.bfloat16, + 'float16': jnp.float16, + } + target_dtype = dtype_map.get(_DTYPE.value, jnp.float32) + print(f'[Process {process_id}] Selected benchmark data type: {_DTYPE.value}') + + base = jnp.arange(np.prod(cache_shape), dtype=target_dtype).reshape( + cache_shape + ) + tpu_cache = jax.device_put(base, tpu_sharding) + jax.block_until_ready(tpu_cache) + + num_processes = jax.process_count() + local_blocks = ( + _NUM_BLOCKS.value // num_processes + ) # 512 // 4 = 128 blocks locally + half_blocks = local_blocks // 2 # 128 // 2 = 64 blocks + + # 2. Create a kv cache manager and let it allocate internal host buffers for + # local blocks / 2 + print( + f'[Process {process_id}] Instantiating KVCacheManager with internal host' + f' buffers for {half_blocks} blocks...' + ) + manager = kv_cache_manager.KVCacheManager( + device_arrays=[tpu_cache], + host_blocks_to_allocate=half_blocks, + unsafe_skip_buffer_lock=True, + ) + + # Calculate data sizes for throughput + # cache_shape: (num_blocks, head_count, 1, head_dim) -> (num_blocks, 32, 1, 8, 128) + # Elements per block: 32 * 1 * 8 * 128 + elements_per_block = np.prod(cache_shape[1:]) + dtype_itemsize = jnp.dtype(target_dtype).itemsize + bytes_per_block = elements_per_block * dtype_itemsize + transferred_bytes_total = half_blocks * bytes_per_block + + # 3. Step A: Pull blocks 0:64 from device to internal host blocks 0:64 (D2H) + print( + f'[Process {process_id}] Executing D2H offloading (Local Blocks' + f' 0..{half_blocks} -> Host)...' + ) + src_offsets = list(range(0, half_blocks)) + dst_offsets = list(range(0, half_blocks)) + sizes = [1] * len(src_offsets) + + # Warmup + for i in range(_WARMUP.value): + manager.d2h( + src_offsets_major_dim=src_offsets, + dst_offsets_major_dim=dst_offsets, + copy_sizes_major_dim=sizes, + ).Await() + + # Benchmark Loop + d2h_total_time = 0.0 + d2h_times = [] + for _ in range(_ITERS.value): + start_time = time.perf_counter() + manager.d2h( + src_offsets_major_dim=src_offsets, + dst_offsets_major_dim=dst_offsets, + copy_sizes_major_dim=sizes, + ).Await() + elapsed = time.perf_counter() - start_time + d2h_times.append(elapsed) + d2h_total_time += elapsed + + d2h_time_mean = d2h_total_time / _ITERS.value + d2h_gbps = (transferred_bytes_total * 8) / (d2h_time_mean * 1e9) + print( + f'[Process {process_id}] D2H complete. Avg Time:' + f' {d2h_time_mean:.6f}s. Throughput: {d2h_gbps:.3f} Gbps' + ) + print(f'[Process {process_id}] D2H Individual times: {d2h_times}') + + # 4. Step B: Push blocks from internal host blocks 0:64 to local TPU blocks + # 64:128 (H2D) + print( + f'[Process {process_id}] Executing H2D reloading (Host -> Local TPU' + f' Blocks {half_blocks}..{local_blocks})...' + ) + src_offsets = list(range(0, half_blocks)) + dst_offsets = list(range(half_blocks, local_blocks)) + sizes = [1] * len(src_offsets) + + # Warmup + for i in range(_WARMUP.value): + manager.h2d( + src_offsets_major_dim=src_offsets, + dst_offsets_major_dim=dst_offsets, + copy_sizes_major_dim=sizes, + ).Await() + + # Benchmark Loop + h2d_total_time = 0.0 + h2d_times = [] + for _ in range(_ITERS.value): + start_time = time.perf_counter() + manager.h2d( + src_offsets_major_dim=src_offsets, + dst_offsets_major_dim=dst_offsets, + copy_sizes_major_dim=sizes, + ).Await() + elapsed = time.perf_counter() - start_time + h2d_times.append(elapsed) + h2d_total_time += elapsed + + h2d_time_mean = h2d_total_time / _ITERS.value + h2d_gbps = (transferred_bytes_total * 8) / (h2d_time_mean * 1e9) + print( + f'[Process {process_id}] H2D complete. Avg Time:' + f' {h2d_time_mean:.6f}s. Throughput: {h2d_gbps:.3f} Gbps' + ) + print(f'[Process {process_id}] H2D Individual times: {h2d_times}') + + # 5. Step C: Verify on device that blocks 256:512 match blocks 0:256 + success = verify_device_cache(tpu_cache) + if not success: + sys.exit(1) + + print( + f'[Process {process_id}] E2E Device-to-Device multihost verification' + ' completed successfully!' + ) + + if jax.process_index() == 0: + write_tensorboard_metrics(d2h_time_mean, h2d_time_mean, d2h_gbps, h2d_gbps) + + # Save raw times to artifacts directory for detailed analysis + artifact_dir = os.environ.get('WORKLOAD_ARTIFACTS_DIR') + if artifact_dir: + raw_results = { + 'd2h_times_sec': d2h_times, + 'h2d_times_sec': h2d_times, + 'd2h_time_mean': d2h_time_mean, + 'h2d_time_mean': h2d_time_mean, + 'transferred_bytes_total': transferred_bytes_total, + } + result_path = os.path.join(artifact_dir, 'raw_perf_results.json') + try: + with open(result_path, 'w') as f: + json.dump(raw_results, f, indent=2) + print(f'Saved raw performance results to {result_path}') + except Exception as e: # pylint: disable=broad-exception-caught + print(f'WARNING: Failed to write raw results artifact: {e}', file=sys.stderr) + + +if __name__ == '__main__': + app.run(main, flags_parser=lambda args: flags.FLAGS(args, known_only=True)) diff --git a/tpu_raiden/core/kv_cache_manager_with_transfer.cc b/tpu_raiden/core/kv_cache_manager_with_transfer.cc index 3cc2d75..517b8da 100644 --- a/tpu_raiden/core/kv_cache_manager_with_transfer.cc +++ b/tpu_raiden/core/kv_cache_manager_with_transfer.cc @@ -29,6 +29,7 @@ #include "tpu_raiden/core/kv_cache_manager_with_transfer.h" #include +#include #include #include #include @@ -152,11 +153,17 @@ int ConnectTcp(const std::string& endpoint) { auto [host, port] = SplitEndpoint(endpoint); int fd = socket(AF_INET, SOCK_STREAM, 0); if (fd < 0) { - throw std::runtime_error("socket() failed: " + - std::string(std::strerror(errno))); + throw std::runtime_error( + absl::StrCat("socket() failed: ", std::strerror(errno))); } - int opt = 1; - setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); + + int fd = -1; + for (p = res; p != nullptr; p = p->ai_next) { + fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol); + if (fd == -1) continue; + + int opt = 1; + setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); sockaddr_in addr; std::memset(&addr, 0, sizeof(addr)); @@ -164,29 +171,34 @@ int ConnectTcp(const std::string& endpoint) { addr.sin_port = htons(port); if (inet_pton(AF_INET, host.c_str(), &addr.sin_addr) != 1) { close(fd); - throw std::runtime_error("invalid IPv4 endpoint host: " + host); + throw std::runtime_error( + absl::StrCat("invalid IPv4 endpoint host: ", host)); } if (connect(fd, reinterpret_cast(&addr), sizeof(addr)) < 0) { std::string err = std::strerror(errno); close(fd); - throw std::runtime_error("connect(" + endpoint + ") failed: " + err); + throw std::runtime_error( + absl::StrCat("connect(", endpoint, ") failed: ", err)); } + return fd; } static std::string GetPeerIp(int fd) { - sockaddr_in addr; + struct sockaddr_storage addr; socklen_t len = sizeof(addr); - if (getpeername(fd, reinterpret_cast(&addr), &len) < 0) { + if (getpeername(fd, reinterpret_cast(&addr), &len) < 0) { throw std::runtime_error("getpeername() failed: " + std::string(std::strerror(errno))); } - char ip_buf[INET_ADDRSTRLEN]; - if (inet_ntop(AF_INET, &addr.sin_addr, ip_buf, sizeof(ip_buf)) == nullptr) { - throw std::runtime_error("inet_ntop() failed: " + - std::string(std::strerror(errno))); + + char host[NI_MAXHOST]; + int err = getnameinfo(reinterpret_cast(&addr), len, + host, sizeof(host), nullptr, 0, NI_NUMERICHOST); + if (err != 0) { + throw std::runtime_error(absl::StrCat("getnameinfo() failed: ", gai_strerror(err))); } - return std::string(ip_buf); + return std::string(host); } static void WriteBlockIds(int fd, const std::vector& block_ids) { if (block_ids.empty()) return; @@ -1058,37 +1070,72 @@ absl::Status KVCacheManagerWithTransfer::WaitForStagingBlockRead( } void KVCacheManagerWithTransfer::StartControlServer() { - control_fd_ = socket(AF_INET, SOCK_STREAM, 0); + // Try creating IPv6 socket first for dual-stack support + control_fd_ = socket(AF_INET6, SOCK_STREAM, 0); + bool is_v6 = true; + if (control_fd_ < 0) { - throw std::runtime_error("control socket() failed: " + - std::string(std::strerror(errno))); + // Fallback to IPv4 if IPv6 is not supported + control_fd_ = socket(AF_INET, SOCK_STREAM, 0); + is_v6 = false; + if (control_fd_ < 0) { + throw std::runtime_error("control socket() failed: " + + std::string(std::strerror(errno))); + } } + int opt = 1; setsockopt(control_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); setsockopt(control_fd_, IPPROTO_TCP, TCP_NODELAY, &opt, sizeof(opt)); - sockaddr_in addr; - std::memset(&addr, 0, sizeof(addr)); - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = INADDR_ANY; - addr.sin_port = htons(local_control_port_); - - if (bind(control_fd_, reinterpret_cast(&addr), sizeof(addr)) < 0) { - std::string err = std::strerror(errno); - close(control_fd_); - control_fd_ = -1; - throw std::runtime_error("control bind(" + - std::to_string(local_control_port_) + - ") failed: " + err); + if (is_v6) { + // Try to disable IPv6-only to support IPv4 connections mapped to IPv6 + int v6only = 0; + setsockopt(control_fd_, IPPROTO_IPV6, IPV6_V6ONLY, &v6only, sizeof(v6only)); + + struct sockaddr_in6 addr; + std::memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_addr = in6addr_any; + addr.sin6_port = htons(local_control_port_); + + if (bind(control_fd_, reinterpret_cast(&addr), sizeof(addr)) < 0) { + std::string err = std::strerror(errno); + close(control_fd_); + control_fd_ = -1; + throw std::runtime_error("control bind (IPv6) failed: " + err); + } + } else { + struct sockaddr_in addr; + std::memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(local_control_port_); + + if (bind(control_fd_, reinterpret_cast(&addr), sizeof(addr)) < 0) { + std::string err = std::strerror(errno); + close(control_fd_); + control_fd_ = -1; + throw std::runtime_error("control bind (IPv4) failed: " + err); + } } - socklen_t len = sizeof(addr); - if (getsockname(control_fd_, reinterpret_cast(&addr), &len) < 0) { + + // Get the bound port (especially useful if local_control_port_ was 0) + struct sockaddr_storage bound_addr; + socklen_t len = sizeof(bound_addr); + if (getsockname(control_fd_, reinterpret_cast(&bound_addr), &len) < 0) { close(control_fd_); control_fd_ = -1; throw std::runtime_error("getsockname() failed: " + std::string(std::strerror(errno))); } - local_control_port_ = ntohs(addr.sin_port); + + if (bound_addr.ss_family == AF_INET6) { + local_control_port_ = ntohs(reinterpret_cast(&bound_addr)->sin6_port); + } else { + local_control_port_ = ntohs(reinterpret_cast(&bound_addr)->sin_port); + } + if (listen(control_fd_, 128) < 0) { close(control_fd_); control_fd_ = -1; @@ -1205,8 +1252,12 @@ void KVCacheManagerWithTransfer::ProcessPullStream( WriteExact(fd, &response, sizeof(response))); std::string peer_ip = GetPeerIp(fd); - std::string remote_data_endpoint = - peer_ip + ":" + std::to_string(req.consumer_data_port); + std::string remote_data_endpoint; + if (peer_ip.find(':') != std::string::npos) { + remote_data_endpoint = absl::StrCat("[", peer_ip, "]:", req.consumer_data_port); + } else { + remote_data_endpoint = absl::StrCat(peer_ip, ":", req.consumer_data_port); + } VLOG(1) << "ProcessPullStream (Hybrid Bridge) successfully acknowledged " "consumer. Intercepting and launching StartPushInternal to "