From ada247fedafce79d192b44477a4c843fd85958ac Mon Sep 17 00:00:00 2001 From: Seth Zegelstein Date: Fri, 10 Oct 2025 21:50:48 +0000 Subject: [PATCH 01/11] transport: Deprecate enforce_cst CUDA 11.3 released cuFlushGPUDirectRDMAWrites API which takes the place of the host transport enforce_cst api. NVSHMEM no longer supports CUDA 11, so these legacy API's can be removed. Signed-off-by: Seth Zegelstein --- src/host/proxy/proxy.cpp | 50 +++--------- .../internal/host_transport/transport.h | 1 - src/modules/transport/ibdevx/ibdevx.cpp | 41 ---------- src/modules/transport/ibgda/ibgda.cpp | 1 - src/modules/transport/ibrc/ibrc.cpp | 1 - src/modules/transport/libfabric/libfabric.cpp | 79 ------------------- src/modules/transport/libfabric/libfabric.h | 5 -- src/modules/transport/ucx/ucx.cpp | 62 --------------- test/unit/mem/transport/remote_unit_tests.cpp | 1 - 9 files changed, 11 insertions(+), 230 deletions(-) diff --git a/src/host/proxy/proxy.cpp b/src/host/proxy/proxy.cpp index 2fcaa7fc..a1504c7d 100644 --- a/src/host/proxy/proxy.cpp +++ b/src/host/proxy/proxy.cpp @@ -687,51 +687,23 @@ int process_channel_amo(proxy_state_t *state, proxy_channel_t *ch, int *is_proce } void enforce_cst(proxy_state_t *proxy_state) { -#if defined(NVSHMEM_X86_64) - nvshmemi_state_t *state = proxy_state->nvshmemi_state; -#endif - int status = 0; if (nvshmemi_options.BYPASS_FLUSH) return; - if (proxy_state->is_consistency_api_supported) { - if (CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER > proxy_state->gdr_device_native_ordering && - CUPFN(nvshmemi_cuda_syms, cuFlushGPUDirectRDMAWrites)) { - status = - CUPFN(nvshmemi_cuda_syms, - cuFlushGPUDirectRDMAWrites(CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX, - CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER)); - /** We would want to use cudaFlushGPUDirectRDMAWritesToAllDevices when we enable - consistent access of data on any GPU (and not just self GPU) with - wait_until, quiet, barrier, etc. **/ - if (status != CUDA_SUCCESS) { - NVSHMEMI_ERROR_EXIT("cuFlushGPUDirectRDMAWrites() failed in the proxy thread \n"); - } - } - return; - } -#if defined(NVSHMEM_PPC64LE) - status = cudaEventRecord(proxy_state->cuev, proxy_state->stream); - if (unlikely(status != CUDA_SUCCESS)) { - NVSHMEMI_ERROR_EXIT("cuEventRecord() failed in the proxy thread \n"); - } -#elif defined(NVSHMEM_X86_64) - for (int i = 0; i < state->num_initialized_transports; i++) { - if (!((state->transport_bitmap) & (1 << i))) continue; - struct nvshmem_transport *tcurr = state->transports[i]; - if (!tcurr->host_ops.enforce_cst) continue; - - // assuming the transport is connected - IB RC - if (tcurr->attr & NVSHMEM_TRANSPORT_ATTR_CONNECTED) { - status = tcurr->host_ops.enforce_cst(tcurr); - if (status) { - NVSHMEMI_ERROR_PRINT("aborting due to error in progress_cst \n"); - exit(-1); - } + if (CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER > proxy_state->gdr_device_native_ordering && + CUPFN(nvshmemi_cuda_syms, cuFlushGPUDirectRDMAWrites)) { + status = + CUPFN(nvshmemi_cuda_syms, + cuFlushGPUDirectRDMAWrites(CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX, + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER)); + /** We would want to use cudaFlushGPUDirectRDMAWritesToAllDevices when we enable + consistent access of data on any GPU (and not just self GPU) with + wait_until, quiet, barrier, etc. **/ + if (status != CUDA_SUCCESS) { + NVSHMEMI_ERROR_EXIT("cuFlushGPUDirectRDMAWrites() failed in the proxy thread \n"); } } -#endif } inline void quiet_ack_channels(proxy_state_t *proxy_state) { diff --git a/src/include/internal/host_transport/transport.h b/src/include/internal/host_transport/transport.h index f3fc7c14..f36b9959 100644 --- a/src/include/internal/host_transport/transport.h +++ b/src/include/internal/host_transport/transport.h @@ -148,7 +148,6 @@ struct nvshmem_transport_host_ops { fence_handle fence; quiet_handle quiet; put_signal_handle put_signal; - int (*enforce_cst)(struct nvshmem_transport *transport); int (*enforce_cst_at_target)(struct nvshmem_transport *transport); int (*add_device_remote_mem_handles)(struct nvshmem_transport *transport, int transport_stride, nvshmem_mem_handle_t *mem_handles, uint64_t heap_offset, diff --git a/src/modules/transport/ibdevx/ibdevx.cpp b/src/modules/transport/ibdevx/ibdevx.cpp index edc11086..b65d4740 100644 --- a/src/modules/transport/ibdevx/ibdevx.cpp +++ b/src/modules/transport/ibdevx/ibdevx.cpp @@ -1440,46 +1440,6 @@ int nvshmemt_ibdevx_amo(struct nvshmem_transport *tcurr, int pe, void *curetptr, return status; } -int nvshmemt_ibdevx_enforce_cst_at_target(struct nvshmem_transport *tcurr) { - nvshmemt_ib_common_state_t ibdevx_state = (nvshmemt_ib_common_state_t)tcurr->state; - struct ibdevx_ep *ep = (struct ibdevx_ep *)ibdevx_state->cst_ep; - struct ibdevx_rw_wqe *wqe; - - int status = 0; - - uintptr_t wqe_bb_idx_64 = ep->wqe_bb_idx; - uint32_t wqe_bb_idx_32 = ep->wqe_bb_idx; - size_t wqe_size; - - wqe = (struct ibdevx_rw_wqe *)((char *)ep->wq_buf + - ((wqe_bb_idx_64 % get_ibdevx_qp_depth(ibdevx_state)) - << NVSHMEMT_IBDEVX_WQE_BB_SHIFT)); - wqe_size = sizeof(struct ibdevx_rw_wqe); - memset(wqe, 0, sizeof(struct ibdevx_rw_wqe)); - - wqe->ctrl.fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; - wqe->ctrl.qpn_ds = - htobe32((uint32_t)(wqe_size / NVSHMEMT_IBDEVX_MLX5_SEND_WQE_DS) | ep->qpid << 8); - wqe->ctrl.opmod_idx_opcode = htobe32(MLX5_OPCODE_RDMA_READ | (wqe_bb_idx_32 << 8)); - - wqe->raddr.raddr = htobe64((uintptr_t)local_dummy_mr.mr->addr); - wqe->raddr.rkey = htobe32(local_dummy_mr.rkey); - - wqe->data.data_seg.byte_count = htobe32((uint32_t)4); - wqe->data.data_seg.lkey = htobe32(local_dummy_mr.lkey); - wqe->data.data_seg.addr = htobe64((uintptr_t)local_dummy_mr.mr->addr); - - assert(wqe_size <= MLX5_SEND_WQE_BB); - ep->wqe_bb_idx++; - nvshmemt_ibdevx_post_send(ep, (void *)wqe, 1); - - status = nvshmemt_ib_common_check_poll_avail(tcurr, ep, NVSHMEMT_IB_COMMON_WAIT_ALL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "check_poll failed \n"); - -out: - return status; -} - // Using common fence and quiet functions from transport_ib_common int nvshmemt_ibdevx_ep_create(struct ibdevx_ep **ep, int devid, nvshmem_transport_t t, @@ -1922,7 +1882,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.finalize = nvshmemt_ibdevx_finalize; transport->host_ops.show_info = nvshmemt_ibdevx_show_info; transport->host_ops.progress = nvshmemt_ibdevx_progress; - transport->host_ops.enforce_cst = nvshmemt_ibdevx_enforce_cst_at_target; transport->host_ops.put_signal = nvshmemt_put_signal; transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; diff --git a/src/modules/transport/ibgda/ibgda.cpp b/src/modules/transport/ibgda/ibgda.cpp index 115f6b66..d73232ed 100644 --- a/src/modules/transport/ibgda/ibgda.cpp +++ b/src/modules/transport/ibgda/ibgda.cpp @@ -4903,7 +4903,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.amo = NULL; transport->host_ops.fence = NULL; transport->host_ops.quiet = NULL; - transport->host_ops.enforce_cst = NULL; transport->host_ops.add_device_remote_mem_handles = nvshmemt_ibgda_add_device_remote_mem_handles; transport->host_ops.put_signal = NULL; diff --git a/src/modules/transport/ibrc/ibrc.cpp b/src/modules/transport/ibrc/ibrc.cpp index f7c9ce06..b0fdddf1 100644 --- a/src/modules/transport/ibrc/ibrc.cpp +++ b/src/modules/transport/ibrc/ibrc.cpp @@ -1810,7 +1810,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.progress = nvshmemt_ibrc_progress; transport->host_ops.put_signal = nvshmemt_put_signal; - transport->host_ops.enforce_cst = nvshmemt_ibrc_enforce_cst_at_target; #if !defined(NVSHMEM_PPC64LE) && !defined(NVSHMEM_AARCH64) if (!use_gdrcopy) #endif diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 70bce5a5..3b227b85 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -1155,71 +1155,6 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v return status; } -static int nvshmemt_libfabric_enforce_cst(struct nvshmem_transport *tcurr) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; - uint64_t num_retries = 0; - int status; - int target_ep; - int mype = tcurr->my_pe; - -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - if (libfabric_state->provider != NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { - int temp; - nvshmemt_libfabric_memhandle_info_t *mem_handle_info; - - mem_handle_info = - (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get_by_idx( - libfabric_state->cache, 0); - if (!mem_handle_info) { - goto skip; - } - gdrcopy_ftable.copy_from_mapping(mem_handle_info->mh, &temp, mem_handle_info->cpu_ptr, - sizeof(int)); - } - } - -skip: -#endif - - target_ep = mype * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - do { - struct fi_msg_rma msg; - struct iovec l_iov; - struct fi_rma_iov r_iov; - void *desc = libfabric_state->local_mr_desc[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - uint64_t flags = 0; - - memset(&msg, 0, sizeof(struct fi_msg_rma)); - memset(&l_iov, 0, sizeof(struct iovec)); - memset(&r_iov, 0, sizeof(struct fi_rma_iov)); - - l_iov.iov_base = libfabric_state->local_mem_ptr; - l_iov.iov_len = 8; - - r_iov.addr = 0; // Zero offset - r_iov.len = 8; - r_iov.key = libfabric_state->local_mr_key[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - - msg.msg_iov = &l_iov; - msg.desc = &desc; - msg.iov_count = 1; - msg.rma_iov = &r_iov; - msg.rma_iov_count = 1; - msg.context = NULL; - msg.data = 0; - - if (libfabric_state->prov_info->caps & FI_FENCE) flags |= FI_FENCE; - - status = - fi_readmsg(libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX].endpoint, &msg, flags); - } while (try_again(tcurr, &status, &num_retries, - NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_ENFORCE_CST)); - - libfabric_state->eps[target_ep].submitted_ops++; - return status; -} - static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handle, nvshmem_transport_t t) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)t->state; @@ -1261,9 +1196,6 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl max_reg = 1; for (int i = 0; i < max_reg; i++) { - if (libfabric_state->local_mr[i] == fabric_handle->hdls[i].mr) - libfabric_state->local_mr[i] = NULL; - int status = fi_close(&fabric_handle->hdls[i].mr->fid); if (status) { NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %d (%d): %s\n", i, status, @@ -1443,15 +1375,6 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v } while (curr_ptr < (char *)buf + length); } - if (libfabric_state->local_mr[0] == NULL && !local_only) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - libfabric_state->local_mr[i] = fabric_handle->hdls[i].mr; - libfabric_state->local_mr_key[i] = fabric_handle->hdls[i].key; - libfabric_state->local_mr_desc[i] = fabric_handle->hdls[i].local_desc; - } - libfabric_state->local_mem_ptr = buf; - } - out: if (status) { if (handle_info) { @@ -2185,8 +2108,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.finalize = nvshmemt_libfabric_finalize; transport->host_ops.show_info = nvshmemt_libfabric_show_info; transport->host_ops.progress = nvshmemt_libfabric_progress; - transport->host_ops.enforce_cst = nvshmemt_libfabric_enforce_cst; - transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; transport->is_successfully_initialized = true; diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index 8a889a69..bb238a11 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -396,11 +396,6 @@ typedef struct { struct fid_domain *domain; struct fid_av *addresses[NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS]; nvshmemt_libfabric_endpoint_t *eps; - /* local_mr is used only for consistency ops. */ - struct fid_mr *local_mr[2]; - uint64_t local_mr_key[2]; - void *local_mr_desc[2]; - void *local_mem_ptr; nvshmemt_libfabric_domain_name_t *domain_names; int num_domains; nvshmemt_libfabric_provider provider; diff --git a/src/modules/transport/ucx/ucx.cpp b/src/modules/transport/ucx/ucx.cpp index 271ed69d..4959d0b4 100644 --- a/src/modules/transport/ucx/ucx.cpp +++ b/src/modules/transport/ucx/ucx.cpp @@ -1180,67 +1180,6 @@ int nvshmemt_ucx_finalize(nvshmem_transport_t transport) { return 0; } -int nvshmemt_ucx_enforce_cst_at_target(struct nvshmem_transport *tcurr) { - transport_ucx_state_t *ucx_state = (transport_ucx_state_t *)tcurr->state; - nvshmemt_ucx_mem_handle_info_t *mem_handle_info; - - mem_handle_info = - (nvshmemt_ucx_mem_handle_info_t *)nvshmemt_mem_handle_cache_get_by_idx(ucx_state->cache, 0); - - if (!mem_handle_info) return 0; -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - int temp; - gdrcopy_ftable.copy_from_mapping(mem_handle_info->mh, &temp, mem_handle_info->cpu_ptr, - sizeof(int)); - return 0; - } -#endif - int mype = tcurr->my_pe; - int ep_index = (ucx_state->ep_count * mype + ucx_state->proxy_ep_idx); - ucp_ep_h ep = ucx_state->endpoints[ep_index]; - ucp_request_param_t param; - ucs_status_ptr_t ucs_ptr_rc = NULL; - ucs_status_t ucs_rc; - nvshmemt_ucx_mem_handle_t *mem_handle; - ucp_rkey_h rkey; - int local_int; - - mem_handle = mem_handle_info->mem_handle; - if (unlikely(mem_handle->ep_rkey_host == NULL)) { - ucs_rc = ucp_ep_rkey_unpack(ep, mem_handle->rkey_packed_buf, &mem_handle->ep_rkey_host); - if (ucs_rc != UCS_OK) { - NVSHMEMI_ERROR_EXIT("Unable to unpack rkey in UCS transport! Exiting.\n"); - } - } - rkey = mem_handle->ep_rkey_host; - - param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK; - param.cb.send = nvshmemt_ucx_send_request_cb; - - ucs_ptr_rc = - ucp_get_nbx(ep, &local_int, sizeof(int), (uint64_t)mem_handle_info->ptr, rkey, ¶m); - - /* Wait for completion of get. */ - if (ucs_ptr_rc != NULL) { - if (UCS_PTR_IS_ERR(ucs_ptr_rc)) { - NVSHMEMI_ERROR_PRINT("UCX CST request completed with error.\n"); - return NVSHMEMX_ERROR_INTERNAL; - } else { - do { - ucs_rc = ucp_request_check_status(ucs_ptr_rc); - ucp_worker_progress(ucx_state->worker_context); - } while (ucs_rc == UCS_INPROGRESS); - if (ucs_rc != UCS_OK) { - NVSHMEMI_ERROR_PRINT("UCX CST request completed with error.\n"); - return NVSHMEMX_ERROR_INTERNAL; - } - } - } - - return 0; -} - int nvshmemt_ucx_show_info(struct nvshmem_transport *transport, int style) { NVSHMEMI_ERROR_PRINT("UCX show info not implemented"); return 0; @@ -1446,7 +1385,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.finalize = nvshmemt_ucx_finalize; transport->host_ops.show_info = nvshmemt_ucx_show_info; transport->host_ops.progress = nvshmemt_ucx_progress; - transport->host_ops.enforce_cst = nvshmemt_ucx_enforce_cst_at_target; transport->host_ops.enforce_cst_at_target = NULL; transport->host_ops.put_signal = nvshmemt_put_signal; transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; diff --git a/test/unit/mem/transport/remote_unit_tests.cpp b/test/unit/mem/transport/remote_unit_tests.cpp index 2d057e05..eb518858 100644 --- a/test/unit/mem/transport/remote_unit_tests.cpp +++ b/test/unit/mem/transport/remote_unit_tests.cpp @@ -258,7 +258,6 @@ nvshmem_transport_host_ops initialize_nvshmem_transport_host_ops() { .fence = NULL, .quiet = NULL, .put_signal = NULL, - .enforce_cst = NULL, .enforce_cst_at_target = NULL, .add_device_remote_mem_handles = &add_device_remote_mem_handles}; From 11c49f2e259d87f23092cd2eeb109153c7e49753 Mon Sep 17 00:00:00 2001 From: Seth Zegelstein Date: Mon, 20 Oct 2025 18:40:31 +0000 Subject: [PATCH 02/11] transport/libfabric: Rename is_proxy to qp_index The previous is_proxy variable equals qp_index. Change the name everywhere for consistency. Signed-off-by: Seth Zegelstein --- src/modules/transport/libfabric/libfabric.cpp | 58 +++++-------------- 1 file changed, 15 insertions(+), 43 deletions(-) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 3b227b85..34763dca 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -629,10 +629,9 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; nvshmemt_libfabric_endpoint_t *ep; - int is_proxy = qp_index != NVSHMEMX_QP_HOST; int status = 0; - if (is_proxy) { + if (qp_index) { ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; } else { ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; @@ -687,7 +686,7 @@ static int nvshmemt_libfabric_show_info(struct nvshmem_transport *transport, int static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, - rma_bytesdesc_t bytesdesc, int is_proxy, + rma_bytesdesc_t bytesdesc, int qp_index, uint32_t *imm_data) { nvshmemt_libfabric_mem_handle_ep_t *remote_handle, *local_handle = NULL; void *local_mr_desc = NULL; @@ -707,12 +706,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, memset(&p_op_msg, 0, sizeof(struct fi_msg_rma)); memset(&p_op_r_iov, 0, sizeof(struct fi_rma_iov)); - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - + ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; ep = &libfabric_state->eps[ep_idx]; target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; @@ -824,13 +818,13 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, - rma_bytesdesc_t bytesdesc, int is_proxy) { - return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, is_proxy, NULL); + rma_bytesdesc_t bytesdesc, int qp_index) { + return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL); } static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, - amo_bytesdesc_t bytesdesc, int is_proxy) { + amo_bytesdesc_t bytesdesc, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_endpoint_t *ep; nvshmemt_libfabric_gdr_op_ctx_t *amo; @@ -838,12 +832,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p int target_ep, ep_idx; int status = 0; - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - + ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; ep = &libfabric_state->eps[ep_idx]; target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; @@ -884,7 +873,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, amo_bytesdesc_t bytesdesc, - int is_proxy) { + int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_mem_handle_ep_t *remote_handle = NULL, *local_handle = NULL; nvshmemt_libfabric_endpoint_t *ep; @@ -906,12 +895,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v memset(&fi_ret_iov, 0, sizeof(struct fi_ioc)); memset(&fi_remote_iov, 0, sizeof(struct fi_rma_ioc)); - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - + ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; ep = &libfabric_state->eps[ep_idx]; target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; @@ -1037,7 +1021,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, - amo_bytesdesc_t bytesdesc, int is_proxy, + amo_bytesdesc_t bytesdesc, int qp_index, uint32_t sequence_count, uint16_t num_writes) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_endpoint_t *ep; @@ -1047,12 +1031,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in int target_ep, ep_idx; int status = 0; - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - + ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; ep = &libfabric_state->eps[ep_idx]; target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; @@ -1097,18 +1076,11 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v std::vector &write_local, std::vector &write_bytes_desc, amo_verb_t sig_verb, amo_memdesc_t *sig_target, - amo_bytesdesc_t sig_bytes_desc, int is_proxy) { + amo_bytesdesc_t sig_bytes_desc, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; int status; uint32_t sequence_count = 0; - int ep_idx; - - if (is_proxy) { - ep_idx = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; - } else { - ep_idx = NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - } - + int ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; nvshmemt_libfabric_endpoint_t &ep = libfabric_state->eps[ep_idx]; /* Get sequence number for this put-signal, with retry */ @@ -1134,7 +1106,7 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v for (size_t i = 0; i < write_remote.size(); i++) { status = nvshmemt_libfabric_rma_impl(tcurr, pe, write_verb, &write_remote[i], &write_local[i], - write_bytes_desc[i], is_proxy, &sequence_count); + write_bytes_desc[i], qp_index, &sequence_count); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT( "Error in nvshmemt_put_signal_unordered, could not submit write #%lu\n", i); @@ -1144,7 +1116,7 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v assert(use_staged_atomics == true); status = nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, - is_proxy, sequence_count, (uint16_t)write_remote.size()); + qp_index, sequence_count, (uint16_t)write_remote.size()); out: if (status) { NVSHMEMI_ERROR_PRINT( From 0775ad6085d4618aa0772a50303ae6f7133f78ca Mon Sep 17 00:00:00 2001 From: Seth Zegelstein Date: Thu, 9 Oct 2025 05:29:48 +0000 Subject: [PATCH 03/11] transport/libfabric: Optimize Progress Attempt to request FI_PROGRESS_AUTO to see if the libfabric provider supports it, if it doesn't fall back to FI_PROGRESS_MANUAL. FI_PROGRESS_AUTO means that we do not need to call into the progress engine for submitted operations to complete. This means that we can remove the host endpoint from the progress call, and we only need to progress the host endpoint when user calls nvshmem_quiet() from the host. This allows us to set the threading model as FI_THREAD_COMPELTION because the host only progress the host EP, and the proxy only progresses the proxy EP, leading to compliance with FI_THREAD_COMPLETION. An edge case exists here where the user calls nvshmem_quiet() on the host QP_IDX from a GPU kernel, but this is illegial because the user shouldn't be calling QP API's on QP's not provided to them via the qp creation API's. This patch should offer a performance improvement because it reduces the number of EP's that are progressed in the critical path, and it allows the libfabric provider to reduce locking b/c of threading model FI_THREAD_COMPLETION. Signed-off-by: Seth Zegelstein --- src/modules/transport/libfabric/libfabric.cpp | 278 +++++++++++------- src/modules/transport/libfabric/libfabric.h | 1 + 2 files changed, 175 insertions(+), 104 deletions(-) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 34763dca..d2187212 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -66,6 +66,8 @@ static bool use_gdrcopy = false; sizeof(nvshmemt_libfabric_gdr_op_ctx_t) - sizeof(struct fi_context2) - sizeof(fi_addr_t) static bool use_staged_atomics = false; +static bool use_auto_progress = false; + threadSafeOpQueue nvshmemtLibfabricOpQueue; std::recursive_mutex gdrRecvMutex; @@ -88,8 +90,8 @@ typedef enum { NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_ENFORCE_CST } nvshmemt_libfabric_try_again_call_site_t; -int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport); -int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport); +int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_index); +int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport, int qp_index); int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry, fi_addr_t *addr); @@ -163,91 +165,135 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo return status; } -static int nvshmemt_libfabric_process_completions(nvshmem_transport_t transport) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; +static int nvshmemt_libfabric_single_ep_progress(nvshmem_transport_t transport, + nvshmemt_libfabric_endpoint_t *ep) { + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + char buf[MAX_COMPLETIONS_PER_CQ_POLL * sizeof(struct fi_cq_data_entry)]; + fi_addr_t src_addr[MAX_COMPLETIONS_PER_CQ_POLL]; + fi_addr_t *addr; + ssize_t qstatus; + struct fi_cq_data_entry *entry; + uint64_t cnt; int status = 0; - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - uint64_t cnt = fi_cntr_readerr(libfabric_state->eps[i].counter); - - if (cnt > 0) { - NVSHMEMI_WARN_PRINT("Nonzero error count progressing EP %d (%" PRIu64 ")\n", i, cnt); - - struct fi_cq_err_entry err; - memset(&err, 0, sizeof(struct fi_cq_err_entry)); - ssize_t nerr = fi_cq_readerr(libfabric_state->eps[i].cq, &err, 0); - if (nerr > 0) { - char str[100] = "\0"; - const char *err_str = fi_cq_strerror(libfabric_state->eps[i].cq, err.prov_errno, - err.err_data, str, 100); - NVSHMEMI_WARN_PRINT( - "CQ reported error (%d): %s\n\tProvider error: %s\n\tSupplemental error " - "info: %s\n", - err.err, fi_strerror(err.err), err_str ? err_str : "none", - strlen(str) ? str : "none"); - } else if (nerr == -FI_EAGAIN) { - NVSHMEMI_WARN_PRINT("fi_cq_readerr returned -FI_EAGAIN\n"); - } else { - NVSHMEMI_WARN_PRINT("fi_cq_readerr returned %zd: %s\n", nerr, - fi_strerror(-1 * nerr)); - } - return err.err; + cnt = fi_cntr_readerr(ep->counter); + if (cnt > 0) { + NVSHMEMI_WARN_PRINT("Nonzero error count progressing EP (%" PRIu64 ")\n", cnt); + struct fi_cq_err_entry err; + memset(&err, 0, sizeof(struct fi_cq_err_entry)); + ssize_t nerr = fi_cq_readerr(ep->cq, &err, 0); + + if (nerr > 0) { + char str[100] = "\0"; + const char *err_str = fi_cq_strerror(ep->cq, err.prov_errno, err.err_data, str, 100); + NVSHMEMI_WARN_PRINT( + "CQ reported error (%d): %s\n\tProvider error: %s\n\tSupplemental error " + "info: %s\n", + err.err, fi_strerror(err.err), err_str ? err_str : "none", + strlen(str) ? str : "none"); + } else if (nerr == -FI_EAGAIN) { + NVSHMEMI_WARN_PRINT("fi_cq_readerr returned -FI_EAGAIN\n"); + } else { + NVSHMEMI_WARN_PRINT("fi_cq_readerr returned %zd: %s\n", nerr, fi_strerror(-1 * nerr)); } + return NVSHMEMX_ERROR_INTERNAL; + } - { - char buf[MAX_COMPLETIONS_PER_CQ_POLL * sizeof(struct fi_cq_data_entry)]; - fi_addr_t src_addr[MAX_COMPLETIONS_PER_CQ_POLL]; - ssize_t qstatus; - nvshmemt_libfabric_endpoint_t *ep = &libfabric_state->eps[i]; - do { - qstatus = fi_cq_readfrom(ep->cq, buf, MAX_COMPLETIONS_PER_CQ_POLL, src_addr); - /* Note - EFA provider does not support selective completions */ - if (qstatus > 0) { - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - struct fi_cq_data_entry *entry = (struct fi_cq_data_entry *)buf; - fi_addr_t *addr = src_addr; - for (int i = 0; i < qstatus; i++, entry++, addr++) { - status = nvshmemt_libfabric_gdr_process_completion(transport, ep, entry, - addr); - if (status) return NVSHMEMX_ERROR_INTERNAL; - } - } else { - NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP %d\n", qstatus, i); - } + do { + qstatus = fi_cq_readfrom(ep->cq, buf, MAX_COMPLETIONS_PER_CQ_POLL, src_addr); + /* Note - EFA provider does not support selective completions */ + if (qstatus > 0) { + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + entry = (struct fi_cq_data_entry *)buf; + addr = src_addr; + for (int i = 0; i < qstatus; i++, entry++, addr++) { + status = nvshmemt_libfabric_gdr_process_completion(transport, ep, entry, addr); + if (status) return NVSHMEMX_ERROR_INTERNAL; } - } while (qstatus > 0); - if (qstatus < 0 && qstatus != -FI_EAGAIN) { - NVSHMEMI_WARN_PRINT("Error progressing CQ (%zd): %s\n", qstatus, - fi_strerror(qstatus * -1)); - return NVSHMEMX_ERROR_INTERNAL; + } else { + NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP\n", qstatus); } } + } while (qstatus > 0); + if (qstatus < 0 && qstatus != -FI_EAGAIN) { + NVSHMEMI_WARN_PRINT("Error progressing CQ (%zd): %s\n", qstatus, fi_strerror(qstatus * -1)); + return NVSHMEMX_ERROR_INTERNAL; } + + return 0; +} + +static int nvshmemt_libfabric_auto_progress(nvshmem_transport_t transport, int qp_index) { + int status; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + nvshmemt_libfabric_endpoint_t *ep; + + if (qp_index) + ep = &state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; + else + ep = &state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; + + status = nvshmemt_libfabric_single_ep_progress(transport, ep); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", status); + +out: + return status; +} + +static int nvshmemt_libfabric_progress(nvshmem_transport_t transport, int qp_index); +static int nvshmemt_libfabric_auto_proxy_progress(nvshmem_transport_t transport) { + return nvshmemt_libfabric_progress(transport, NVSHMEMT_LIBFABRIC_PROXY_EP_IDX); +} + +static int nvshmemt_libfabric_manual_progress(nvshmem_transport_t transport) { + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; + int status; + for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { + status = nvshmemt_libfabric_single_ep_progress(transport, &state->eps[i]); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", + status); + } + +out: + return status; +} + +static int nvshmemt_libfabric_process_completions(nvshmem_transport_t transport, int qp_index) { + int status = 0; + + if (use_auto_progress) + status = nvshmemt_libfabric_auto_progress(transport, qp_index); + else + status = nvshmemt_libfabric_manual_progress(transport); + + nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - status = nvshmemt_libfabric_gdr_process_amos_ack(transport); + status = nvshmemt_libfabric_gdr_process_amos_ack(transport, qp_index); if (status) { return NVSHMEMX_ERROR_INTERNAL; } } - return 0; + + return status; } -static int nvshmemt_libfabric_progress(nvshmem_transport_t transport) { + +static int nvshmemt_libfabric_progress(nvshmem_transport_t transport, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; int status; - status = nvshmemt_libfabric_process_completions(transport); + status = nvshmemt_libfabric_process_completions(transport, qp_index); if (status) { return NVSHMEMX_ERROR_INTERNAL; } if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { if (gdrRecvMutex.try_lock()) { - status = nvshmemt_libfabric_gdr_process_amos(transport); - gdrRecvMutex.unlock(); + status = nvshmemt_libfabric_gdr_process_amos(transport, qp_index); if (status) { return NVSHMEMX_ERROR_INTERNAL; } + gdrRecvMutex.unlock(); } } @@ -255,6 +301,7 @@ static int nvshmemt_libfabric_progress(nvshmem_transport_t transport) { } static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t *num_retries, + int qp_index, nvshmemt_libfabric_try_again_call_site_t call_site, bool completions_only = false) { if (likely(*status == 0)) { @@ -270,9 +317,9 @@ static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t } (*num_retries)++; if (completions_only) { - *status = nvshmemt_libfabric_process_completions(transport); + *status = nvshmemt_libfabric_process_completions(transport, qp_index); } else { - *status = nvshmemt_libfabric_progress(transport); + *status = nvshmemt_libfabric_progress(transport, qp_index); } } @@ -303,7 +350,7 @@ int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t status = fi_writedata(ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr), imm_data, dest_addr, (uint64_t)libfabric_state->remote_addr_staged_amo_ack[pe], libfabric_state->rkey_staged_amo_ack[pe], &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries, + } while (try_again(transport, &status, &num_retries, ep->qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDRCOPY_AMO_ACK, true)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to write atomic ack.\n"); @@ -414,7 +461,7 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op do { status = fi_send(ep->endpoint, (void *)resp_op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, fi_mr_desc(libfabric_state->mr), src_addr, &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries, + } while (try_again(transport, &status, &num_retries, ep->qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_PERFORM_GDRCOPY_AMO_SEND, true)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to respond to atomic request.\n"); @@ -477,7 +524,7 @@ int nvshmemt_libfabric_gdr_process_ack(nvshmem_transport_t transport, return 0; } -int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport) { +int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *op; nvshmemt_libfabric_gdr_op_ctx_t *send_elems[2]; @@ -487,7 +534,7 @@ int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport) { do { status = nvshmemtLibfabricOpQueue.getNextAmoOps(send_elems, &op, NVSHMEMT_LIBFABRIC_RECV_TYPE_ACK); - } while (try_again(transport, &status, &num_retries, + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_PROCESS_AMOS_GET_NEXT_ACK, true)); num_retries = 0; @@ -507,7 +554,7 @@ int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport) { return status; } -int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport) { +int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_index) { nvshmemt_libfabric_gdr_op_ctx_t *op; nvshmemt_libfabric_gdr_op_ctx_t *send_elems[2]; size_t num_retries = 0; @@ -517,7 +564,7 @@ int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport) { do { status = nvshmemtLibfabricOpQueue.getNextAmoOps(send_elems, &op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); - } while (try_again(transport, &status, &num_retries, + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_PROCESS_AMOS_GET_NEXT_NOT_ACK, true)); num_retries = 0; @@ -628,30 +675,23 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; + uint64_t completed; nvshmemt_libfabric_endpoint_t *ep; int status = 0; - if (qp_index) { + if (qp_index) ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - } else { + else ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; - } - if (likely(libfabric_state->prov_info->domain_attr->control_progress == FI_PROGRESS_MANUAL) || - (libfabric_state->prov_info->domain_attr->data_progress == FI_PROGRESS_MANUAL) || - (use_staged_atomics == true) -#ifdef NVSHMEM_USE_GDRCOPY - || (use_gdrcopy == true) -#endif - ) { - uint64_t submitted, completed; + if (use_staged_atomics) { for (;;) { completed = fi_cntr_read(ep->counter); - submitted = ep->submitted_ops; - if (completed + ep->completed_staged_atomics == submitted) + if (completed + ep->completed_staged_atomics == ep->submitted_ops) { break; - else { - if (nvshmemt_libfabric_progress(tcurr)) { + } else { + status = nvshmemt_libfabric_progress(tcurr, qp_index); + if (status) { status = NVSHMEMX_ERROR_INTERNAL; break; } @@ -714,7 +754,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, nvshmemt_libfabric_gdr_op_ctx_t *gdr_ctx; do { status = nvshmemtLibfabricOpQueue.getNextSends((void **)(&gdr_ctx), 1); - } while (try_again(tcurr, &status, &num_retries, + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_GET_NEXT_SENDS)); NVSHMEMI_NULL_ERROR_JMP(gdr_ctx, status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to get context buffer for put request.\n"); @@ -741,7 +781,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, status = fi_write(ep->endpoint, &p_buf->p_op.value, op_size, fi_mr_desc(libfabric_state->mr), target_ep, (uintptr_t)remote->ptr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries, + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_P_EFA)); } else { p_op_msg.msg_iov = &p_op_l_iov; @@ -766,7 +806,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, */ do { status = fi_writemsg(ep->endpoint, &p_op_msg, FI_INJECT); - } while (try_again(tcurr, &status, &num_retries, + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_P_NON_EFA)); } } else if (verb.desc == NVSHMEMI_OP_PUT) { @@ -784,7 +824,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, else status = fi_write(ep->endpoint, local->ptr, op_size, local_mr_desc, target_ep, remote_addr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries, + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_PUT)); } else if (verb.desc == NVSHMEMI_OP_G || verb.desc == NVSHMEMI_OP_GET) { assert( @@ -798,7 +838,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, do { status = fi_read(ep->endpoint, local->ptr, op_size, local_mr_desc, target_ep, remote_addr, remote_handle->key, context); - } while (try_again(tcurr, &status, &num_retries, + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_GET)); } else { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INVALID_VALUE, out, @@ -838,7 +878,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p do { status = nvshmemtLibfabricOpQueue.getNextSends((void **)(&amo), 1); - } while (try_again(transport, &status, &num_retries, + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_AMO_GET_NEXT_SENDS)); NVSHMEMI_NULL_ERROR_JMP(amo, status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to retrieve AMO operation."); @@ -857,7 +897,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p do { status = fi_send(ep->endpoint, (void *)amo, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, fi_mr_desc(libfabric_state->mr), target_ep, &amo->ofi_context); - } while (try_again(transport, &status, &num_retries, + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_AMO_SEND)); if (status) { @@ -1004,7 +1044,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v status = fi_fetch_atomicmsg(ep->endpoint, &amo_msg, &fi_ret_iov, &local_handle->local_desc, 1, FI_INJECT); } - } while (try_again(transport, &status, &num_retries, + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_AMO_ATOMICMSG)); if (status) goto out; // Status set by try_again @@ -1039,7 +1079,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in sizeof(nvshmemt_libfabric_gdr_signal_op_t)); do { status = nvshmemtLibfabricOpQueue.getNextSends((void **)(&context), 1); - } while (try_again(transport, &status, &num_retries, + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_SIGNAL_GET_NEXT_SENDS)); NVSHMEMI_NULL_ERROR_JMP(context, status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to retrieve signal operation buffer."); @@ -1057,7 +1097,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in do { status = fi_send(ep->endpoint, (void *)signal, sizeof(nvshmemt_libfabric_gdr_signal_op_t), fi_mr_desc(libfabric_state->mr), target_ep, &context->ofi_context); - } while (try_again(transport, &status, &num_retries, + } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_SIGNAL_SEND)); if (status) { @@ -1093,7 +1133,7 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v sequence_count = seq_num; status = 0; } - } while (try_again(tcurr, &status, &num_retries, + } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_PUT_SIGNAL_UNORDERED_SEQ)); if (unlikely(status)) { @@ -1568,6 +1608,7 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele "Unable to allocate array of endpoint names."); for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { + state->eps[i].qp_index = i; status = fi_endpoint(state->domain, state->prov_info, &state->eps[i].endpoint, NULL); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to allocate endpoint: %d: %s\n", status, @@ -1895,7 +1936,8 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { return 0; } -static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state) { +static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state, + struct nvshmemi_options_s *options) { struct fi_info info; struct fi_tx_attr tx_attr; struct fi_rx_attr rx_attr; @@ -1946,20 +1988,44 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr info.mode |= FI_CONTEXT2; } - /* Be thread safe at the level of the endpoint completion context. */ - domain_attr.threading = FI_THREAD_SAFE; - + ep_attr.type = FI_EP_RDM; /* Reliable datagrams */ /* Require completion RMA completion at target for correctness of quiet */ info.tx_attr->op_flags = FI_DELIVERY_COMPLETE; - ep_attr.type = FI_EP_RDM; // Reliable datagrams + /* nvshmemt_libfabric_auto_progress relaxes threading requirement */ + domain_attr.threading = FI_THREAD_COMPLETION; + info.domain_attr->data_progress = FI_PROGRESS_AUTO; status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), NULL, NULL, 0, &info, &returned_fabrics); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "No providers matched fi_getinfo query: %d: %s\n", status, - fi_strerror(status * -1)); + /* + * 1. Ensure that at least one fabric was returned + * 2. Make sure returned fabric matches the name of selected provider + * + * This has an assumption that the provided fabric option + * options.LIBFABRIC_PROVIDER will be a substr of the returned fabric + * name + */ + if (!status && strstr(returned_fabrics->fabric_attr->name, options->LIBFABRIC_PROVIDER)) { + use_auto_progress = true; + } else { + fi_freeinfo(returned_fabrics); + + /* + * Fallback to FI_PROGRESS_MANUAL path + * nvshmemt_libfabric_slow_progress requires FI_THREAD_SAFE + */ + domain_attr.threading = FI_THREAD_SAFE; + info.domain_attr->data_progress = FI_PROGRESS_MANUAL; + status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), + NULL, NULL, 0, &info, &returned_fabrics); + + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "No providers matched fi_getinfo query: %d: %s\n", status, + fi_strerror(status * -1)); + } + state->all_prov_info = returned_fabrics; for (current_fabric = returned_fabrics; current_fabric != NULL; current_fabric = current_fabric->next) { @@ -2079,7 +2145,6 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, transport->host_ops.quiet = nvshmemt_libfabric_quiet; transport->host_ops.finalize = nvshmemt_libfabric_finalize; transport->host_ops.show_info = nvshmemt_libfabric_show_info; - transport->host_ops.progress = nvshmemt_libfabric_progress; transport->attr = NVSHMEM_TRANSPORT_ATTR_CONNECTED; transport->is_successfully_initialized = true; @@ -2195,12 +2260,17 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, #undef NVSHMEMI_SET_ENV_VAR /* Prepare fabric state information. */ - status = nvshmemi_libfabric_init_state(transport, libfabric_state); + status = nvshmemi_libfabric_init_state(transport, libfabric_state, &options); if (status) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out_clean, "Failed to initialize the libfabric state.\n"); } + if (use_auto_progress) + transport->host_ops.progress = nvshmemt_libfabric_auto_proxy_progress; + else + transport->host_ops.progress = nvshmemt_libfabric_manual_progress; + *t = transport; out: if (status) { diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index bb238a11..eac754cf 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -196,6 +196,7 @@ typedef struct { nvshmemt_libfabric_endpoint_seq_counter_t put_signal_seq_counter; std::unordered_map> *proxy_put_signal_comp_map; + int qp_index; } nvshmemt_libfabric_endpoint_t; typedef struct nvshmemt_libfabric_gdr_send_p_op { From 92b0ac99a5203255701590892b842580d26b3e30 Mon Sep 17 00:00:00 2001 From: Seth Zegelstein Date: Thu, 9 Oct 2025 17:33:11 +0000 Subject: [PATCH 04/11] transport/libfabric: Implement multi-rail This change implements multi-rail support for the libfabric host proxy transport. The transport changes from having 1 domain with 2 EP's to having 1 host domain on NIC 1 and one proxy domain per NIC. Splitting the host EP and proxy EP into seperate domains was done for simplicity of the code. Every domain resource (including AV) was bound on a 1-1 basis per EP so this change should be a functional no-op. In the future when one implements the QP API on the libfabric host proxy transport, N EP's per domain can be easily extended on this. This code uses a round robin based load balancer to assign messages to NIC's. One NIC will be used for the entire operation call into the libfabric transport (including put-signal), but not including messages that are segmented due to size or MR boundaries. The number of NIC's (domains) per PE are limited by the size of the struct nvshmemt_libfabric_mem_handle_t. A new env variable NVSHMEM_LIBFABRIC_MAX_NIC_PER_PE controls the max number of NIC's per PE. Thank you Justin for contributing an initial implementation of multi-rail which I built on top of. Co-authored-by: Justin Chui Signed-off-by: Seth Zegelstein --- src/modules/transport/common/env_defs.h | 3 + src/modules/transport/libfabric/libfabric.cpp | 993 +++++++++--------- src/modules/transport/libfabric/libfabric.h | 45 +- 3 files changed, 544 insertions(+), 497 deletions(-) diff --git a/src/modules/transport/common/env_defs.h b/src/modules/transport/common/env_defs.h index 086dc016..f7d2cc38 100644 --- a/src/modules/transport/common/env_defs.h +++ b/src/modules/transport/common/env_defs.h @@ -98,6 +98,9 @@ NVSHMEMI_ENV_DEF(DISABLE_LOCAL_ONLY_PROXY, bool, false, NVSHMEMI_ENV_CAT_TRANSPO NVSHMEMI_ENV_DEF(LIBFABRIC_PROVIDER, string, "cxi", NVSHMEMI_ENV_CAT_TRANSPORT, "Set the feature set provider for the libfabric transport: cxi, efa, verbs") +NVSHMEMI_ENV_DEF(LIBFABRIC_MAX_NIC_PER_PE, int, 16, NVSHMEMI_ENV_CAT_TRANSPORT, + "Set the maximum number of NIC's per PE to use for libfabric provider") + #if defined(NVSHMEM_IBGDA_SUPPORT) || defined(NVSHMEM_ENV_ALL) /** GPU-initiated communication **/ NVSHMEMI_ENV_DEF(IBGDA_ENABLE_MULTI_PORT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index d2187212..90ee4c09 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -65,10 +65,11 @@ static bool use_gdrcopy = false; #define NVSHMEM_STAGED_AMO_WIREDATA_SIZE \ sizeof(nvshmemt_libfabric_gdr_op_ctx_t) - sizeof(struct fi_context2) - sizeof(fi_addr_t) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + static bool use_staged_atomics = false; static bool use_auto_progress = false; -threadSafeOpQueue nvshmemtLibfabricOpQueue; std::recursive_mutex gdrRecvMutex; typedef enum { @@ -101,6 +102,27 @@ static nvshmemt_libfabric_imm_cq_data_hdr_t nvshmemt_get_write_with_imm_hdr(uint NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT); } +static inline nvshmemt_libfabric_endpoint_t *nvshmemt_libfabric_get_next_ep( + nvshmemt_libfabric_state_t *state, int qp_index) { + int selected_ep; + + if (qp_index == NVSHMEMX_QP_HOST) { + selected_ep = 0; + } else { + /* + * Return the current EP, and increment the next EP in round robin fashion + * between 1 and state->num_selected_domains - 1. state->cur_proxy_ep_index + * is initialized to 1. This round-robin goes through the proxy EP's and + * ignores the host EP. + */ + selected_ep = state->cur_proxy_ep_index; + state->cur_proxy_ep_index = (state->cur_proxy_ep_index + 1) % state->num_selected_domains; + if (!state->cur_proxy_ep_index) state->cur_proxy_ep_index = 1; + } + + return state->eps[selected_ep]; +} + static void nvshmemt_libfabric_put_signal_ack_completion(nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry) { uint32_t seq_num = entry->data & NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK; @@ -118,6 +140,7 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo fi_addr_t *addr) { int status = 0; nvshmemt_libfabric_gdr_op_ctx_t *op; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; /* Write w/imm doesn't have op->op_context, must be checked first */ if (entry->flags & FI_REMOTE_CQ_DATA) { @@ -142,19 +165,19 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo op->src_addr = *addr; if (entry->flags & FI_SEND) { - nvshmemtLibfabricOpQueue.putToSend(op); + state->op_queue[ep->domain_index]->putToSend(op); } else if (entry->flags & FI_RMA) { /* inlined p ops or atomic responses */ - nvshmemtLibfabricOpQueue.putToSend(op); + state->op_queue[ep->domain_index]->putToSend(op); } else if (op->type == NVSHMEMT_LIBFABRIC_MATCH) { /* Must happen after entry->flags & FI_SEND to avoid send completions */ status = nvshmemt_libfabric_put_signal_completion(transport, ep, entry, addr); } else if (entry->flags & FI_RECV) { op->ep = ep; if (op->type == NVSHMEMT_LIBFABRIC_ACK) { - nvshmemtLibfabricOpQueue.putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_ACK); + state->op_queue[ep->domain_index]->putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_ACK); } else { - nvshmemtLibfabricOpQueue.putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); + state->op_queue[ep->domain_index]->putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); } } else { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INVALID_VALUE, out, @@ -226,15 +249,20 @@ static int nvshmemt_libfabric_single_ep_progress(nvshmem_transport_t transport, static int nvshmemt_libfabric_auto_progress(nvshmem_transport_t transport, int qp_index) { int status; nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; - nvshmemt_libfabric_endpoint_t *ep; + int end_iter; - if (qp_index) - ep = &state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - else - ep = &state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else { + end_iter = state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } - status = nvshmemt_libfabric_single_ep_progress(transport, ep); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", status); + for (int i = qp_index; i < end_iter; i++) { + status = nvshmemt_libfabric_single_ep_progress(transport, state->eps[i]); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", + status); + } out: return status; @@ -248,8 +276,8 @@ static int nvshmemt_libfabric_auto_proxy_progress(nvshmem_transport_t transport) static int nvshmemt_libfabric_manual_progress(nvshmem_transport_t transport) { nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; int status; - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = nvshmemt_libfabric_single_ep_progress(transport, &state->eps[i]); + for (size_t i = 0; i < state->eps.size(); i++) { + status = nvshmemt_libfabric_single_ep_progress(transport, state->eps[i]); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error in progress: %d.\n", status); } @@ -268,7 +296,8 @@ static int nvshmemt_libfabric_process_completions(nvshmem_transport_t transport, nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - status = nvshmemt_libfabric_gdr_process_amos_ack(transport, qp_index); + int progress_qp_index = (use_auto_progress ? qp_index : NVSHMEMX_QP_ALL); + status = nvshmemt_libfabric_gdr_process_amos_ack(transport, progress_qp_index); if (status) { return NVSHMEMX_ERROR_INTERNAL; } @@ -287,9 +316,11 @@ static int nvshmemt_libfabric_progress(nvshmem_transport_t transport, int qp_ind return NVSHMEMX_ERROR_INTERNAL; } + int progress_qp_index = (use_auto_progress ? qp_index : NVSHMEMX_QP_ALL); + if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { if (gdrRecvMutex.try_lock()) { - status = nvshmemt_libfabric_gdr_process_amos(transport, qp_index); + status = nvshmemt_libfabric_gdr_process_amos(transport, progress_qp_index); if (status) { return NVSHMEMX_ERROR_INTERNAL; } @@ -341,16 +372,18 @@ int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t uint64_t num_retries = 0; int status; uint64_t imm_data = 0; + uint64_t rkey_index = pe * libfabric_state->num_selected_domains + ep->domain_index; resp_op = send_elems[0]; imm_data = (NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) | sequence_count; do { - status = fi_writedata(ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr), imm_data, - dest_addr, (uint64_t)libfabric_state->remote_addr_staged_amo_ack[pe], - libfabric_state->rkey_staged_amo_ack[pe], &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries, ep->qp_index, + status = fi_writedata( + ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr[ep->domain_index]), imm_data, + dest_addr, (uint64_t)libfabric_state->remote_addr_staged_amo_ack[pe], + libfabric_state->rkey_staged_amo_ack[rkey_index], &resp_op->ofi_context); + } while (try_again(transport, &status, &num_retries, ep->domain_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDRCOPY_AMO_ACK, true)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to write atomic ack.\n"); @@ -447,7 +480,7 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op /* Post recv before posting TX operations to avoid deadlocks */ status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), FI_ADDR_UNSPEC, &op->ofi_context); + fi_mr_desc(libfabric_state->mr[op->ep->domain_index]), FI_ADDR_UNSPEC, &op->ofi_context); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to re-post recv.\n"); if (is_fetch_amo) { @@ -459,9 +492,10 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op resp_op->type = NVSHMEMT_LIBFABRIC_ACK; do { - status = fi_send(ep->endpoint, (void *)resp_op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), src_addr, &resp_op->ofi_context); - } while (try_again(transport, &status, &num_retries, ep->qp_index, + status = fi_send(op->ep->endpoint, (void *)resp_op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(libfabric_state->mr[op->ep->domain_index]), op->src_addr, + &resp_op->ofi_context); + } while (try_again(transport, &status, &num_retries, op->ep->domain_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_PERFORM_GDRCOPY_AMO_SEND, true)); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to respond to atomic request.\n"); @@ -528,67 +562,100 @@ int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport, int q nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *op; nvshmemt_libfabric_gdr_op_ctx_t *send_elems[2]; - size_t num_retries = 0; + int end_iter; int status = 0; - do { + + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else if (qp_index == NVSHMEMX_QP_ALL) { + qp_index = 0; + end_iter = libfabric_state->eps.size(); + } else { + end_iter = libfabric_state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } + + for (int i = qp_index; i < end_iter; i++) { + + size_t num_retries = 0; do { - status = nvshmemtLibfabricOpQueue.getNextAmoOps(send_elems, &op, - NVSHMEMT_LIBFABRIC_RECV_TYPE_ACK); - } while (try_again(transport, &status, &num_retries, qp_index, - NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_PROCESS_AMOS_GET_NEXT_ACK, - true)); - num_retries = 0; - - if (op) { - status = nvshmemt_libfabric_gdr_process_ack(transport, op); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to process atomic.\n"); - status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), FI_ADDR_UNSPEC, &op->ofi_context); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to re-post recv.\n"); - } - } while (op); + do { + status = libfabric_state->op_queue[i]->getNextAmoOps(send_elems, &op, + NVSHMEMT_LIBFABRIC_RECV_TYPE_ACK); + } while (try_again(transport, &status, &num_retries, i, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_PROCESS_AMOS_GET_NEXT_ACK, + true)); + num_retries = 0; + + if (op) { + status = nvshmemt_libfabric_gdr_process_ack(transport, op); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to process atomic.\n"); + status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(libfabric_state->mr[op->ep->domain_index]), FI_ADDR_UNSPEC, &op->ofi_context); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to re-post recv.\n"); + } + } while (op); + + } out: return status; } int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_index) { + nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *op; nvshmemt_libfabric_gdr_op_ctx_t *send_elems[2]; size_t num_retries = 0; int status = 0; + int end_iter; + + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else if (qp_index == NVSHMEMX_QP_ALL) { + qp_index = 0; + end_iter = libfabric_state->eps.size(); + } else { + end_iter = libfabric_state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } + + for (int i = qp_index; i < end_iter; i++) { - do { do { - status = nvshmemtLibfabricOpQueue.getNextAmoOps(send_elems, &op, - NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); - } while (try_again(transport, &status, &num_retries, qp_index, - NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_PROCESS_AMOS_GET_NEXT_NOT_ACK, - true)); - num_retries = 0; - - if (op) { - if (op->type == NVSHMEMT_LIBFABRIC_SEND) { - assert(send_elems[0] != NULL); - assert(send_elems[1] != NULL); - status = nvshmemt_libfabric_gdr_process_amo(transport, op, send_elems, - NVSHMEM_STAGED_AMO_SEQ_NUM); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to process atomic.\n"); - /* Reposts recv in perform_gdrcopy_amo() */ - } else if (op->type == NVSHMEMT_LIBFABRIC_MATCH) { - assert(send_elems[0] != NULL); - assert(send_elems[1] != NULL); - status = nvshmemt_libfabric_gdr_process_amo(transport, op, send_elems, - op->send_amo.sequence_count); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to process atomic.\n"); - /* Reposts recv in perform_gdrcopy_amo() */ + do { + status = libfabric_state->op_queue[i]->getNextAmoOps(send_elems, &op, + NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); + } while (try_again(transport, &status, &num_retries, i, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_PROCESS_AMOS_GET_NEXT_NOT_ACK, + true)); + num_retries = 0; + + if (op) { + if (op->type == NVSHMEMT_LIBFABRIC_SEND) { + assert(send_elems[0] != NULL); + assert(send_elems[1] != NULL); + status = nvshmemt_libfabric_gdr_process_amo(transport, op, send_elems, + NVSHMEM_STAGED_AMO_SEQ_NUM); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to process atomic.\n"); + /* Reposts recv in perform_gdrcopy_amo() */ + } else if (op->type == NVSHMEMT_LIBFABRIC_MATCH) { + assert(send_elems[0] != NULL); + assert(send_elems[1] != NULL); + status = nvshmemt_libfabric_gdr_process_amo(transport, op, send_elems, + op->send_amo.sequence_count); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to process atomic.\n"); + /* Reposts recv in perform_gdrcopy_amo() */ + } } - } - } while (op); + } while (op); + + } + out: return status; } @@ -618,6 +685,7 @@ nvshmemt_libfabric_gdr_op_ctx_t *nvshmemt_inplace_copy_sig_op_to_gdr_op( int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry, fi_addr_t *addr) { + nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_signal_op *sig_op = NULL; nvshmemt_libfabric_gdr_op_ctx_t *op = NULL; bool is_write_comp = entry->flags & FI_REMOTE_CQ_DATA; @@ -665,7 +733,7 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, op = iter->second.first; } - nvshmemtLibfabricOpQueue.putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); + libfabric_state->op_queue[ep->domain_index]->putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); ep->proxy_put_signal_comp_map->erase(iter); } @@ -674,38 +742,47 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, } static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int qp_index) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)tcurr->state; uint64_t completed; - nvshmemt_libfabric_endpoint_t *ep; + bool all_nics_quieted; int status = 0; + int end_iter; - if (qp_index) - ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_PROXY_EP_IDX]; - else - ep = &libfabric_state->eps[NVSHMEMT_LIBFABRIC_HOST_EP_IDX]; + if (qp_index == NVSHMEMX_QP_HOST) { + end_iter = NVSHMEMX_QP_HOST + 1; + } else { + end_iter = state->eps.size(); + qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; + } if (use_staged_atomics) { for (;;) { - completed = fi_cntr_read(ep->counter); - if (completed + ep->completed_staged_atomics == ep->submitted_ops) { - break; - } else { - status = nvshmemt_libfabric_progress(tcurr, qp_index); - if (status) { - status = NVSHMEMX_ERROR_INTERNAL; - break; + all_nics_quieted = true; + for (int i = qp_index; i < end_iter; i++) { + completed = fi_cntr_read(state->eps[i]->counter) + + state->eps[i]->completed_staged_atomics; + if (state->eps[i]->submitted_ops != completed) { + all_nics_quieted = false; + if (nvshmemt_libfabric_progress(tcurr, qp_index)) { + status = NVSHMEMX_ERROR_INTERNAL; + break; + } } } + if (status || all_nics_quieted) break; } } else { - status = fi_cntr_wait(ep->counter, ep->submitted_ops, NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS); - if (status) { - /* note - Status is negative for this function in error cases but - * fi_strerror only accepts positive values. - */ - NVSHMEMI_ERROR_PRINT("Error in quiet operation (%d): %s.\n", status, - fi_strerror(status * -1)); - status = NVSHMEMX_ERROR_INTERNAL; + for (int i = qp_index; i < end_iter; i++) { + status = fi_cntr_wait(state->eps[i]->counter, state->eps[i]->submitted_ops, + NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS); + if (status) { + /* note - Status is negative for this function in error cases but + * fi_strerror only accepts positive values. + */ + NVSHMEMI_ERROR_PRINT("Error in quiet operation (%d): %s.\n", status, + fi_strerror(status * -1)); + status = NVSHMEMX_ERROR_INTERNAL; + } } } @@ -726,34 +803,33 @@ static int nvshmemt_libfabric_show_info(struct nvshmem_transport *transport, int static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, - rma_bytesdesc_t bytesdesc, int qp_index, - uint32_t *imm_data) { + rma_bytesdesc_t bytesdesc, int qp_index, uint32_t *imm_data, + nvshmemt_libfabric_endpoint_t *ep) { nvshmemt_libfabric_mem_handle_ep_t *remote_handle, *local_handle = NULL; void *local_mr_desc = NULL; nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; struct iovec p_op_l_iov; struct fi_msg_rma p_op_msg; struct fi_rma_iov p_op_r_iov; - nvshmemt_libfabric_endpoint_t *ep; size_t op_size; uint64_t num_retries = 0; int status = 0; int target_ep; - int ep_idx = 0; void *context = NULL; memset(&p_op_l_iov, 0, sizeof(struct iovec)); memset(&p_op_msg, 0, sizeof(struct fi_msg_rma)); memset(&p_op_r_iov, 0, sizeof(struct fi_rma_iov)); - ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + /* put_signal passes in EP to ensure that both operations go through same EP */ + if (!ep) ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { nvshmemt_libfabric_gdr_op_ctx_t *gdr_ctx; do { - status = nvshmemtLibfabricOpQueue.getNextSends((void **)(&gdr_ctx), 1); + status = libfabric_state->op_queue[ep->domain_index]->getNextSends((void **)(&gdr_ctx), 1); } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_GET_NEXT_SENDS)); NVSHMEMI_NULL_ERROR_JMP(gdr_ctx, status, NVSHMEMX_ERROR_INTERNAL, out, @@ -762,12 +838,12 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, /* local->handle may be NULL for small operations (P ops) sent by value/inline */ if (likely(local->handle != NULL)) { - local_handle = &((nvshmemt_libfabric_mem_handle_t *)local->handle)->hdls[ep_idx]; + local_handle = &((nvshmemt_libfabric_mem_handle_t *)local->handle)->hdls[ep->domain_index]; local_mr_desc = local_handle->local_desc; } } - remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[ep_idx]; + remote_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->handle)->hdls[ep->domain_index]; op_size = bytesdesc.elembytes * bytesdesc.nelems; if (verb.desc == NVSHMEMI_OP_P) { @@ -779,7 +855,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, do { p_buf->p_op.value = *(uint64_t *)local->ptr; status = fi_write(ep->endpoint, &p_buf->p_op.value, op_size, - fi_mr_desc(libfabric_state->mr), target_ep, + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, (uintptr_t)remote->ptr, remote_handle->key, context); } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_P_EFA)); @@ -794,7 +870,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, p_op_l_iov.iov_base = local->ptr; p_op_l_iov.iov_len = op_size; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & + FI_MR_VIRT_ADDR) p_op_r_iov.addr = (uintptr_t)remote->ptr; else p_op_r_iov.addr = (uintptr_t)remote->offset; @@ -811,7 +888,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, } } else if (verb.desc == NVSHMEMI_OP_PUT) { uintptr_t remote_addr; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) remote_addr = (uintptr_t)remote->ptr; else remote_addr = (uintptr_t)remote->offset; @@ -830,7 +907,7 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, assert( !imm_data); // Write w/ imm not suppored with NVSHMEMI_OP_G/GET on Libfabric transport uintptr_t remote_addr; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) remote_addr = (uintptr_t)remote->ptr; else remote_addr = (uintptr_t)remote->offset; @@ -859,7 +936,8 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, rma_bytesdesc_t bytesdesc, int qp_index) { - return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL); + return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL, + NULL); } static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int pe, void *curetptr, @@ -869,15 +947,14 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p nvshmemt_libfabric_endpoint_t *ep; nvshmemt_libfabric_gdr_op_ctx_t *amo; uint64_t num_retries = 0; - int target_ep, ep_idx; + int target_ep; int status = 0; - ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; do { - status = nvshmemtLibfabricOpQueue.getNextSends((void **)(&amo), 1); + status = libfabric_state->op_queue[ep->domain_index]->getNextSends((void **)(&amo), 1); } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_AMO_GET_NEXT_SENDS)); NVSHMEMI_NULL_ERROR_JMP(amo, status, NVSHMEMX_ERROR_INTERNAL, out, @@ -896,7 +973,8 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p num_retries = 0; do { status = fi_send(ep->endpoint, (void *)amo, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(libfabric_state->mr), target_ep, &amo->ofi_context); + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, + &amo->ofi_context); } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_AMO_SEND)); @@ -927,7 +1005,6 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v uint64_t num_retries = 0; int target_ep; int status = 0; - int ep_idx; memset(&amo_msg, 0, sizeof(struct fi_msg_atomic)); memset(&fi_local_iov, 0, sizeof(struct fi_ioc)); @@ -935,14 +1012,14 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v memset(&fi_ret_iov, 0, sizeof(struct fi_ioc)); memset(&fi_remote_iov, 0, sizeof(struct fi_rma_ioc)); - ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; remote_handle = - &((nvshmemt_libfabric_mem_handle_t *)remote->remote_memdesc.handle)->hdls[ep_idx]; + &((nvshmemt_libfabric_mem_handle_t *)remote->remote_memdesc.handle)->hdls[ep->domain_index]; if (verb.desc > NVSHMEMI_AMO_END_OF_NONFETCH) { - local_handle = &((nvshmemt_libfabric_mem_handle_t *)remote->ret_handle)->hdls[ep_idx]; + local_handle = + &((nvshmemt_libfabric_mem_handle_t *)remote->ret_handle)->hdls[ep->domain_index]; } if (bytesdesc.elembytes == 8) { @@ -1009,7 +1086,7 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v amo_msg.addr = target_ep; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_VIRT_ADDR) + if (libfabric_state->prov_infos[ep->domain_index]->domain_attr->mr_mode & FI_MR_VIRT_ADDR) fi_remote_iov.addr = (uintptr_t)remote->remote_memdesc.ptr; else fi_remote_iov.addr = (uintptr_t)remote->remote_memdesc.offset; @@ -1062,23 +1139,22 @@ static int nvshmemt_libfabric_amo(struct nvshmem_transport *transport, int pe, v static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, amo_bytesdesc_t bytesdesc, int qp_index, - uint32_t sequence_count, uint16_t num_writes) { + uint32_t sequence_count, uint16_t num_writes, + nvshmemt_libfabric_endpoint_t *ep) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; - nvshmemt_libfabric_endpoint_t *ep; + nvshmemt_libfabric_gdr_op_ctx_t *context; nvshmemt_libfabric_gdr_signal_op_t *signal; uint64_t num_retries = 0; - int target_ep, ep_idx; + int target_ep; int status = 0; - ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - ep = &libfabric_state->eps[ep_idx]; - target_ep = pe * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS + ep_idx; + target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; static_assert(sizeof(nvshmemt_libfabric_gdr_op_ctx) >= sizeof(nvshmemt_libfabric_gdr_signal_op_t)); do { - status = nvshmemtLibfabricOpQueue.getNextSends((void **)(&context), 1); + status = libfabric_state->op_queue[ep->domain_index]->getNextSends((void **)(&context), 1); } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_SIGNAL_GET_NEXT_SENDS)); NVSHMEMI_NULL_ERROR_JMP(context, status, NVSHMEMX_ERROR_INTERNAL, out, @@ -1096,7 +1172,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in num_retries = 0; do { status = fi_send(ep->endpoint, (void *)signal, sizeof(nvshmemt_libfabric_gdr_signal_op_t), - fi_mr_desc(libfabric_state->mr), target_ep, &context->ofi_context); + fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, &context->ofi_context); } while (try_again(transport, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_SIGNAL_SEND)); @@ -1117,16 +1193,15 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v std::vector &write_bytes_desc, amo_verb_t sig_verb, amo_memdesc_t *sig_target, amo_bytesdesc_t sig_bytes_desc, int qp_index) { - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; + nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)tcurr->state; int status; uint32_t sequence_count = 0; - int ep_idx = qp_index ? NVSHMEMT_LIBFABRIC_PROXY_EP_IDX : NVSHMEMT_LIBFABRIC_HOST_EP_IDX; - nvshmemt_libfabric_endpoint_t &ep = libfabric_state->eps[ep_idx]; + nvshmemt_libfabric_endpoint_t *ep = nvshmemt_libfabric_get_next_ep(state, qp_index); /* Get sequence number for this put-signal, with retry */ uint64_t num_retries = 0; do { - int32_t seq_num = ep.put_signal_seq_counter.next_seq_num(); + int32_t seq_num = ep->put_signal_seq_counter.next_seq_num(); if (seq_num < 0) { status = -EAGAIN; } else { @@ -1146,7 +1221,7 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v for (size_t i = 0; i < write_remote.size(); i++) { status = nvshmemt_libfabric_rma_impl(tcurr, pe, write_verb, &write_remote[i], &write_local[i], - write_bytes_desc[i], qp_index, &sequence_count); + write_bytes_desc[i], qp_index, &sequence_count, ep); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT( "Error in nvshmemt_put_signal_unordered, could not submit write #%lu\n", i); @@ -1155,8 +1230,9 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v } assert(use_staged_atomics == true); - status = nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, - qp_index, sequence_count, (uint16_t)write_remote.size()); + status = + nvshmemt_libfabric_gdr_signal(tcurr, pe, NULL, sig_verb, sig_target, sig_bytes_desc, + qp_index, sequence_count, (uint16_t)write_remote.size(), ep); out: if (status) { NVSHMEMI_ERROR_PRINT( @@ -1172,7 +1248,7 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)t->state; nvshmemt_libfabric_mem_handle_t *fabric_handle; void *curr_ptr; - int max_reg, status = 0; + int status = 0; assert(mem_handle != NULL); fabric_handle = (nvshmemt_libfabric_mem_handle_t *)mem_handle; @@ -1202,15 +1278,10 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl } } - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_ENDPOINT) - max_reg = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - else - max_reg = 1; - - for (int i = 0; i < max_reg; i++) { + for (size_t i = 0; i < libfabric_state->domains.size(); i++) { int status = fi_close(&fabric_handle->hdls[i].mr->fid); if (status) { - NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %d (%d): %s\n", i, status, + NVSHMEMI_WARN_PRINT("Error releasing mem handle idx %zu (%d): %s\n", i, status, fi_strerror(status * -1)); } } @@ -1219,6 +1290,7 @@ static int nvshmemt_libfabric_release_mem_handle(nvshmem_mem_handle_t *mem_handl return status; } +static_assert(sizeof(nvshmemt_libfabric_mem_handle_t) < sizeof(nvshmem_mem_handle_t)); static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, void *buf, size_t length, nvshmem_transport_t t, bool local_only) { @@ -1254,6 +1326,7 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v assert(mem_handle != NULL); fabric_handle = (nvshmemt_libfabric_mem_handle_t *)mem_handle; + fabric_handle->buf = buf; status = cudaPointerGetAttributes(&attr, buf); if (status != cudaSuccess) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, @@ -1284,40 +1357,15 @@ static int nvshmemt_libfabric_get_mem_handle(nvshmem_mem_handle_t *mem_handle, v mr_attr.iface = FI_HMEM_SYSTEM; } - fabric_handle->buf = buf; - if (libfabric_state->prov_info->domain_attr->mr_mode & FI_MR_ENDPOINT) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = - fi_mr_regattr(libfabric_state->domain, &mr_attr, 0, &fabric_handle->hdls[i].mr); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Error registering memory region: %s\n", - fi_strerror(status * -1)); - - status = - fi_mr_bind(fabric_handle->hdls[i].mr, &libfabric_state->eps[i].endpoint->fid, 0); - - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Error binding MR to EP %d: %s\n", i, fi_strerror(status * -1)); - - status = fi_mr_enable(fabric_handle->hdls[i].mr); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error enabling MR: %s\n", - fi_strerror(status * -1)); - - fabric_handle->hdls[i].key = fi_mr_key(fabric_handle->hdls[i].mr); - fabric_handle->hdls[i].local_desc = fi_mr_desc(fabric_handle->hdls[i].mr); - } - } else { + for (size_t i = 0; i < libfabric_state->domains.size(); i++) { struct fid_mr *mr; - - status = fi_mr_regattr(libfabric_state->domain, &mr_attr, 0, &mr); + status = fi_mr_regattr(libfabric_state->domains[i], &mr_attr, 0, &mr); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Error registering memory region: %s\n", fi_strerror(status * -1)); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - fabric_handle->hdls[i].mr = mr; - fabric_handle->hdls[i].key = fi_mr_key(mr); - fabric_handle->hdls[i].local_desc = fi_mr_desc(mr); - } + fabric_handle->hdls[i].mr = mr; + fabric_handle->hdls[i].key = fi_mr_key(mr); + fabric_handle->hdls[i].local_desc = fi_mr_desc(mr); } if (!local_only && libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { @@ -1457,174 +1505,186 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)t->state; nvshmemt_libfabric_ep_name_t *all_ep_names = NULL; nvshmemt_libfabric_ep_name_t *local_ep_names = NULL; - struct fi_info *current_fabric; + struct fi_info *current_info; + struct fid_fabric *fabric; + struct fid_domain *domain; + struct fid_av *address; + struct fid_mr *mr; struct fi_av_attr av_attr; struct fi_cq_attr cq_attr; struct fi_cntr_attr cntr_attr; size_t ep_namelen = NVSHMEMT_LIBFABRIC_EP_LEN; int status = 0; int total_num_eps; - size_t num_recvs_per_pe = 0; + size_t num_recvs_per_ep = 0; int n_pes = t->n_pes; - - if (state->eps) { - NVSHMEMI_WARN_PRINT( - "Device already selected. libfabric only supports one NIC per PE and doesn't support " - "additional QPs.\n"); - goto out_already_connected; + size_t num_sends; + size_t num_recvs; + size_t elem_size; + uint64_t flags; + state->num_selected_devs = MIN(num_selected_devs, state->max_nic_per_pe); + + if (state->eps.size()) { + NVSHMEMI_ERROR_PRINT("PE has previously called connect_endpoints()\n"); + return NVSHMEMX_ERROR_INTERNAL; } - state->eps = (nvshmemt_libfabric_endpoint_t *)calloc(NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS, - sizeof(nvshmemt_libfabric_endpoint_t)); - NVSHMEMI_NULL_ERROR_JMP(state->eps, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate EPs."); - - current_fabric = state->all_prov_info; - do { - if (!strncmp(current_fabric->nic->device_attr->name, - state->domain_names[selected_dev_ids[0]].name, - NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { - break; - } - current_fabric = current_fabric->next; - } while (current_fabric != NULL); - NVSHMEMI_NULL_ERROR_JMP(current_fabric, status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to find the selected fabric.\n"); - - state->prov_info = fi_dupinfo(current_fabric); - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA && - strcmp(state->prov_info->fabric_attr->name, "efa-direct")) + if (state->num_selected_devs > NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE) { + state->num_selected_devs = NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE; NVSHMEMI_WARN_PRINT( - "Libfabric transport is using efa fabric instead of efa-direct, " - "use libfabric v2.1.0 or newer for improved performance\n"); - - status = fi_fabric(state->prov_info->fabric_attr, &state->fabric, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate fabric: %d: %s\n", status, fi_strerror(status * -1)); - - status = fi_domain(state->fabric, state->prov_info, &state->domain, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate domain: %d: %s\n", status, fi_strerror(status * -1)); - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - state->num_sends = current_fabric->tx_attr->size * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - state->num_recvs = current_fabric->rx_attr->size * NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - size_t elem_size = sizeof(nvshmemt_libfabric_gdr_op_ctx_t); - - num_recvs_per_pe = state->num_recvs / NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - - state->recv_buf = calloc(state->num_sends + state->num_recvs, elem_size); - NVSHMEMI_NULL_ERROR_JMP(state->recv_buf, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate EFA msg buffer.\n"); - state->send_buf = (char *)state->recv_buf + (elem_size * state->num_recvs); - - status = fi_mr_reg(state->domain, state->recv_buf, - (state->num_sends + state->num_recvs) * elem_size, FI_SEND | FI_RECV, 0, - 0, 0, &state->mr, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to register EFA msg buffer: %d: %s\n", status, - fi_strerror(status * -1)); - - nvshmemtLibfabricOpQueue.putToSendBulk((char *)state->send_buf, elem_size, - state->num_sends); - } - - t->max_op_len = state->prov_info->ep_attr->max_msg_size; - av_attr.type = FI_AV_TABLE; - av_attr.rx_ctx_bits = 0; - av_attr.count = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * n_pes; - av_attr.ep_per_node = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; - av_attr.name = NULL; - av_attr.map_addr = NULL; - av_attr.flags = 0; - - /* Note - This is needed because EFA will only bind AVs to EPs on a 1:1 basis. - * If EFA ever lifts this requirement, we can reduce the number of AVs required. - */ - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = fi_av_open(state->domain, &av_attr, &state->addresses[i], NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to allocate address vector: %d: %s\n", status, - fi_strerror(status * -1)); + "PE selected %d devices, but the libfabric transport only supports a max of %d " + "devices. Continuing using %d devices.\n", + state->num_selected_devs, NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE, + NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE); } + state->num_selected_domains = state->num_selected_devs + 1; - INFO(state->log_level, "Selected provider %s, fabric %s, nic %s, hmem %s", - state->prov_info->fabric_attr->prov_name, state->prov_info->fabric_attr->name, - state->prov_info->nic->device_attr->name, state->prov_info->caps & FI_HMEM ? "yes" : "no"); - - assert(state->eps); + /* Initialize configuration which only need to be set once */ + t->max_op_len = UINT64_MAX; /* Set as sential value */ + state->cur_proxy_ep_index = 1; memset(&cq_attr, 0, sizeof(struct fi_cq_attr)); - memset(&cntr_attr, 0, sizeof(struct fi_cntr_attr)); - - state->prov_info->ep_attr->tx_ctx_cnt = 0; - state->prov_info->caps = FI_RMA; - if ((state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) || - (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS)) { - state->prov_info->caps |= FI_ATOMIC; - } else { - state->prov_info->caps |= FI_MSG; - state->prov_info->caps |= FI_SOURCE; - } - state->prov_info->tx_attr->op_flags = 0; - state->prov_info->tx_attr->mode = 0; - state->prov_info->rx_attr->mode = 0; - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - state->prov_info->mode = FI_CONTEXT2; - } else { - state->prov_info->mode = 0; - } - - state->prov_info->tx_attr->op_flags = FI_DELIVERY_COMPLETE; - - cntr_attr.events = FI_CNTR_EVENTS_COMP; - cntr_attr.wait_obj = FI_WAIT_UNSPEC; - cntr_attr.wait_set = NULL; - cntr_attr.flags = 0; - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { cq_attr.size = 16; /* CQ is only used to capture error events */ cq_attr.format = FI_CQ_FORMAT_UNSPEC; cq_attr.wait_obj = FI_WAIT_NONE; - } - - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { cq_attr.format = FI_CQ_FORMAT_DATA; cq_attr.wait_obj = FI_WAIT_NONE; cq_attr.size = 32768; } - local_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS, + memset(&av_attr, 0, sizeof(struct fi_av_attr)); + av_attr.type = FI_AV_TABLE; + av_attr.count = state->num_selected_domains * n_pes; + + memset(&cntr_attr, 0, sizeof(struct fi_cntr_attr)); + cntr_attr.events = FI_CNTR_EVENTS_COMP; + cntr_attr.wait_obj = FI_WAIT_UNSPEC; + + /* Find fabric info for each selected device */ + for (int dev_idx = 0; dev_idx < state->num_selected_devs; dev_idx++) { + current_info = state->all_prov_info; + do { + if (!strncmp(current_info->nic->device_attr->name, + state->domain_names[selected_dev_ids[dev_idx]].name, + NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { + break; + } + current_info = current_info->next; + } while (current_info != NULL); + NVSHMEMI_NULL_ERROR_JMP(current_info, status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to find fabric for device %d.\n", dev_idx); + + /* + * Create two domains (host/proxy domain) for the first NIC. + */ + if (state->prov_infos.size() == 0) state->prov_infos.push_back(current_info); + + state->prov_infos.push_back(current_info); + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA && + strcmp(current_info->fabric_attr->name, "efa-direct")) + NVSHMEMI_WARN_PRINT( + "Libfabric transport is using efa fabric instead of efa-direct, " + "use libfabric v2.1.0 or newer for improved performance\n"); + } + + /* Allocate out of band AV name exchange buffers */ + local_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(state->num_selected_domains, sizeof(nvshmemt_libfabric_ep_name_t)); NVSHMEMI_NULL_ERROR_JMP(local_ep_names, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate array of endpoint names."); - total_num_eps = NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * n_pes; + total_num_eps = n_pes * state->num_selected_domains; all_ep_names = (nvshmemt_libfabric_ep_name_t *)calloc(total_num_eps, sizeof(nvshmemt_libfabric_ep_name_t)); NVSHMEMI_NULL_ERROR_JMP(all_ep_names, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate array of endpoint names."); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - state->eps[i].qp_index = i; - status = fi_endpoint(state->domain, state->prov_info, &state->eps[i].endpoint, NULL); + /* Create Resources For Each Selected Device */ + for (size_t i = 0; i < state->prov_infos.size(); i++) { + INFO(state->log_level, + "Selected provider %s, fabric %s, nic %s, hmem %s multi-rail %zu/%d\n", + state->prov_infos[i]->fabric_attr->prov_name, state->prov_infos[i]->fabric_attr->name, + state->prov_infos[i]->nic->device_attr->name, + state->prov_infos[i]->caps & FI_HMEM ? "yes" : "no", i + 1, num_selected_devs); + + if (state->prov_infos[i]->ep_attr->max_msg_size < t->max_op_len) + t->max_op_len = state->prov_infos[i]->ep_attr->max_msg_size; + + status = fi_fabric(state->prov_infos[i]->fabric_attr, &fabric, NULL); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to allocate endpoint: %d: %s\n", status, + "Failed to allocate fabric: %d: %s\n", status, + fi_strerror(status * -1)); + state->fabrics.push_back(fabric); + + status = fi_domain(fabric, state->prov_infos[i], &domain, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to allocate domain: %d: %s\n", status, fi_strerror(status * -1)); + state->domains.push_back(domain); + + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + num_sends = state->prov_infos[i]->tx_attr->size; + num_recvs = state->prov_infos[i]->rx_attr->size; + elem_size = sizeof(nvshmemt_libfabric_gdr_op_ctx_t); + num_recvs_per_ep = num_recvs; + + state->recv_buf.push_back(calloc(num_sends + num_recvs, elem_size)); + NVSHMEMI_NULL_ERROR_JMP(state->recv_buf[i], status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to allocate EFA msg buffer.\n"); + state->send_buf.push_back((char *)state->recv_buf[i] + (elem_size * num_recvs)); + + status = fi_mr_reg(domain, state->recv_buf[i], (num_sends + num_recvs) * elem_size, + FI_SEND | FI_RECV | FI_WRITE, 0, 0, 0, &mr, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to register EFA msg buffer: %d: %s\n", status, + fi_strerror(status * -1)); + state->mr.push_back(mr); + + state->op_queue.push_back(new threadSafeOpQueue); + state->op_queue[i]->putToSendBulk((char *)state->send_buf[i], elem_size, num_sends); + } + + status = fi_av_open(domain, &av_attr, &address, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to allocate address vector: %d: %s\n", status, + fi_strerror(status * -1)); + state->addresses.push_back(address); + + /* Create nvshmemt_libfabric_endpoint_t resources */ + state->eps.push_back( + (nvshmemt_libfabric_endpoint_t *)calloc(1, sizeof(nvshmemt_libfabric_endpoint_t))); + NVSHMEMI_NULL_ERROR_JMP(state->eps[i], status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to alloc libfabric_tx_progress_group struct.\n"); + state->eps[i]->domain_index = i; /* Initialize per-endpoint proxy_put_signal_comp_map */ - state->eps[i].proxy_put_signal_comp_map = + state->eps[i]->proxy_put_signal_comp_map = new std::unordered_map>(); - state->eps[i].put_signal_seq_counter.reset(); - state->eps[i].completed_staged_atomics = 0; + state->eps[i]->put_signal_seq_counter.reset(); + state->eps[i]->completed_staged_atomics = 0; + + status = fi_cq_open(domain, &cq_attr, &state->eps[i]->cq, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to open completion queue for endpoint: %d: %s\n", status, + fi_strerror(status * -1)); + + status = fi_cntr_open(domain, &cntr_attr, &state->eps[i]->counter, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to open counter for endpoint: %d: %s\n", status, + fi_strerror(status * -1)); + status = fi_endpoint(domain, state->prov_infos[i], &state->eps[i]->endpoint, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to allocate endpoint: %d: %s\n", status, + fi_strerror(status * -1)); /* FI_OPT_CUDA_API_PERMITTED was introduced in libfabric 1.18.0 */ if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { bool prohibit_cuda_api = false; - status = fi_setopt(&state->eps[i].endpoint->fid, FI_OPT_ENDPOINT, + status = fi_setopt(&state->eps[i]->endpoint->fid, FI_OPT_ENDPOINT, FI_OPT_CUDA_API_PERMITTED, &prohibit_cuda_api, sizeof(bool)); if (status == -FI_ENOPROTOOPT) { NVSHMEMI_WARN_PRINT( @@ -1638,112 +1698,90 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele } } - status = fi_cq_open(state->domain, &cq_attr, &state->eps[i].cq, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to open completion queue for endpoint: %d: %s\n", status, - fi_strerror(status * -1)); - - status = fi_cntr_open(state->domain, &cntr_attr, &state->eps[i].counter, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to open counter for endpoint: %d: %s\n", status, - fi_strerror(status * -1)); - - status = fi_ep_bind(state->eps[i].endpoint, &state->addresses[i]->fid, 0); + /* Bind Resources To EP */ + status = fi_ep_bind(state->eps[i]->endpoint, &state->addresses[i]->fid, 0); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to bind endpoint to address vector: %d: %s\n", status, fi_strerror(status * -1)); - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, - FI_SELECTIVE_COMPLETION | FI_TRANSMIT | FI_RECV); - } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, - FI_SELECTIVE_COMPLETION | FI_TRANSMIT); - } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) + flags = FI_SELECTIVE_COMPLETION | FI_TRANSMIT | FI_RECV; + else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) + flags = FI_SELECTIVE_COMPLETION | FI_TRANSMIT; + else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) /* EFA is documented as not supporting FI_SELECTIVE_COMPLETION */ - status = - fi_ep_bind(state->eps[i].endpoint, &state->eps[i].cq->fid, FI_TRANSMIT | FI_RECV); - } else { + flags = FI_TRANSMIT | FI_RECV; + else { NVSHMEMI_ERROR_PRINT( "Invalid provider identified. This should be impossible. " "Possible memory corruption in the state pointer?"); status = NVSHMEMX_ERROR_INTERNAL; goto out; } + + status = fi_ep_bind(state->eps[i]->endpoint, &state->eps[i]->cq->fid, flags); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to bind endpoint to completion queue: %d: %s\n", status, fi_strerror(status * -1)); -#ifdef NVSHMEM_USE_GDRCOPY - if (use_gdrcopy) { - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].counter->fid, - FI_READ | FI_WRITE | FI_SEND); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to bind endpoint to completion counter: %d: %s\n", status, - fi_strerror(status * -1)); - } else -#endif - { - int flags = FI_READ | FI_WRITE; - if (use_staged_atomics) { - flags |= FI_SEND; - } - status = fi_ep_bind(state->eps[i].endpoint, &state->eps[i].counter->fid, flags); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to bind endpoint to completion counter: %d: %s\n", status, - fi_strerror(status * -1)); - } + flags = FI_READ | FI_WRITE; + if (use_staged_atomics) flags |= FI_SEND; + status = fi_ep_bind(state->eps[i]->endpoint, &state->eps[i]->counter->fid, flags); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to bind endpoint to completion counter: %d: %s\n", status, + fi_strerror(status * -1)); - status = fi_enable(state->eps[i].endpoint); + status = fi_enable(state->eps[i]->endpoint); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to enable endpoint: %d: %s\n", status, fi_strerror(status * -1)); if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - for (size_t j = 0; j < num_recvs_per_pe; j++) { - nvshmemt_libfabric_gdr_op_ctx_t *op; - op = (nvshmemt_libfabric_gdr_op_ctx_t *)state->recv_buf; - op = op + ((num_recvs_per_pe * i) + j); + nvshmemt_libfabric_gdr_op_ctx_t *op; + op = (nvshmemt_libfabric_gdr_op_ctx_t *)state->recv_buf[i]; + for (size_t j = 0; j < num_recvs_per_ep; j++, op++) { assert(op != NULL); - status = fi_recv(state->eps[i].endpoint, op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, - fi_mr_desc(state->mr), FI_ADDR_UNSPEC, &op->ofi_context); + status = fi_recv(state->eps[i]->endpoint, op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(state->mr[i]), FI_ADDR_UNSPEC, &op->ofi_context); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to post recv to ep. Error: %d: %s\n", status, + fi_strerror(status * -1)); } - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to post recv to ep. Error: %d: %s\n", status, - fi_strerror(status * -1)); } - status = fi_getname(&state->eps[i].endpoint->fid, local_ep_names[i].name, &ep_namelen); + status = fi_getname(&state->eps[i]->endpoint->fid, local_ep_names[i].name, &ep_namelen); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to get name for endpoint: %d: %s\n", status, fi_strerror(status * -1)); - if (ep_namelen > NVSHMEMT_LIBFABRIC_EP_LEN) { + if (ep_namelen > NVSHMEMT_LIBFABRIC_EP_LEN) NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Name of EP is too long."); - } } + /* Perform out of band address exchange */ status = t->boot_handle->allgather( local_ep_names, all_ep_names, - NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS * sizeof(nvshmemt_libfabric_ep_name_t), t->boot_handle); + state->num_selected_domains * sizeof(nvshmemt_libfabric_ep_name_t), t->boot_handle); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather endpoint names.\n"); /* We need to insert one at a time since each buffer is larger than the address. */ - for (int j = 0; j < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; j++) { - for (int i = 0; i < total_num_eps; i++) { - status = fi_av_insert(state->addresses[j], &all_ep_names[i], 1, NULL, 0, NULL); + for (int i = 0; i < state->num_selected_domains; i++) { + for (int j = 0; j < total_num_eps; j++) { + status = fi_av_insert(state->addresses[i], &all_ep_names[j], 1, NULL, 0, NULL); if (status < 1) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to insert ep names in address vector: %d: %s\n", status, fi_strerror(status * -1)); } - status = NVSHMEMX_SUCCESS; } } + /* Out of bounds exchange a pre-registered write w/imm target for staged_amo acks */ if (use_staged_atomics) { state->remote_addr_staged_amo_ack = (void **)calloc(sizeof(void *), t->n_pes); + state->rkey_staged_amo_ack = + (uint64_t *)calloc(sizeof(uint64_t), t->n_pes * state->num_selected_domains); NVSHMEMI_NULL_ERROR_JMP(state->remote_addr_staged_amo_ack, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate remote address array for staged atomic ack.\n"); @@ -1752,13 +1790,15 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to allocate CUDA memory for staged atomic ack.\n"); - status = fi_mr_reg(state->domain, state->remote_addr_staged_amo_ack[t->my_pe], sizeof(int), - FI_REMOTE_WRITE, 0, 0, 0, &state->mr_staged_amo_ack, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Failed to register EFA msg buffer: %d: %s\n", status, - fi_strerror(status * -1)); - state->rkey_staged_amo_ack = (uint64_t *)calloc(sizeof(uint64_t), t->n_pes); - state->rkey_staged_amo_ack[t->my_pe] = fi_mr_key(state->mr_staged_amo_ack); + for (size_t i = 0; i < state->domains.size(); i++) { + status = fi_mr_reg(state->domains[i], state->remote_addr_staged_amo_ack[t->my_pe], + sizeof(int), FI_REMOTE_WRITE, 0, 0, 0, &mr, NULL); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Failed to register EFA msg buffer: %d: %s\n", status, + fi_strerror(status * -1)); + state->rkey_staged_amo_ack[t->my_pe * state->num_selected_domains + i] = fi_mr_key(mr); + state->mr_staged_amo_ack.push_back(mr); + } status = t->boot_handle->allgather(&state->remote_addr_staged_amo_ack[t->my_pe], state->remote_addr_staged_amo_ack, sizeof(void *), @@ -1766,9 +1806,10 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather remote addresses.\n"); - status = - t->boot_handle->allgather(&state->rkey_staged_amo_ack[t->my_pe], - state->rkey_staged_amo_ack, sizeof(uint64_t), t->boot_handle); + status = t->boot_handle->allgather( + &state->rkey_staged_amo_ack[t->my_pe * state->num_selected_domains], + state->rkey_staged_amo_ack, sizeof(uint64_t) * state->num_selected_domains, + t->boot_handle); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Failed to gather remote keys.\n"); } @@ -1781,30 +1822,28 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele free(state->remote_addr_staged_amo_ack); } if (state->rkey_staged_amo_ack) free(state->rkey_staged_amo_ack); - if (state->mr_staged_amo_ack) fi_close(&state->mr_staged_amo_ack->fid); - if (state->eps) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - if (state->eps[i].proxy_put_signal_comp_map) - delete state->eps[i].proxy_put_signal_comp_map; - if (state->eps[i].endpoint) { - fi_close(&state->eps[i].endpoint->fid); - state->eps[i].endpoint = NULL; - } - if (state->eps[i].cq) { - fi_close(&state->eps[i].cq->fid); - state->eps[i].cq = NULL; - } - if (state->eps[i].counter) { - fi_close(&state->eps[i].counter->fid); - state->eps[i].counter = NULL; - } + for (size_t i = 0; i < state->mr_staged_amo_ack.size(); i++) + fi_close(&state->mr_staged_amo_ack[i]->fid); + for (size_t i = 0; i < state->eps.size(); i++) { + if (state->eps[i]->proxy_put_signal_comp_map) + delete state->eps[i]->proxy_put_signal_comp_map; + if (state->eps[i]->endpoint) { + fi_close(&state->eps[i]->endpoint->fid); + state->eps[i]->endpoint = NULL; + } + if (state->eps[i]->cq) { + fi_close(&state->eps[i]->cq->fid); + state->eps[i]->cq = NULL; } - free(state->eps); - state->eps = NULL; + if (state->eps[i]->counter) { + fi_close(&state->eps[i]->counter->fid); + state->eps[i]->counter = NULL; + } + free(state->eps[i]); + state->eps[i] = NULL; } } -out_already_connected: free(local_ep_names); free(all_ep_names); @@ -1812,12 +1851,12 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele } static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { - nvshmemt_libfabric_state_t *libfabric_state; + nvshmemt_libfabric_state_t *state; int status; assert(transport); - libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; + state = (nvshmemt_libfabric_state_t *)transport->state; if (transport->device_pci_paths) { for (int i = 0; i < transport->n_devices; i++) { @@ -1829,19 +1868,19 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { size_t mem_handle_cache_size; nvshmemt_libfabric_memhandle_info_t *handle_info = NULL, *previous_handle_info = NULL; - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - mem_handle_cache_size = nvshmemt_mem_handle_cache_get_size(libfabric_state->cache); + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + mem_handle_cache_size = nvshmemt_mem_handle_cache_get_size(state->cache); for (size_t i = 0; i < mem_handle_cache_size; i++) { handle_info = (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get_by_idx( - libfabric_state->cache, i); + state->cache, i); if (handle_info && handle_info != previous_handle_info) { free(handle_info); } previous_handle_info = handle_info; } - nvshmemt_mem_handle_cache_fini(libfabric_state->cache); + nvshmemt_mem_handle_cache_fini(state->cache); #ifdef NVSHMEM_USE_GDRCOPY if (use_gdrcopy) { nvshmemt_gdrcopy_ftable_fini(&gdrcopy_ftable, &gdr_desc, &gdrcopy_handle); @@ -1849,88 +1888,88 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { #endif } - if (libfabric_state->prov_info) { - fi_freeinfo(libfabric_state->prov_info); - } + /* + * Since fi_dupinfo() is not called, we don't need to clean + * we do not need to clean prov_infos + */ + if (state->all_prov_info) fi_freeinfo(state->all_prov_info); - if (libfabric_state->eps) { - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - if (libfabric_state->eps[i].proxy_put_signal_comp_map) - delete libfabric_state->eps[i].proxy_put_signal_comp_map; - if (libfabric_state->eps[i].endpoint) { - status = fi_close(&libfabric_state->eps[i].endpoint->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric endpoint.: %d: %s\n", status, - fi_strerror(status * -1)); - } + for (size_t i = 0; i < state->eps.size(); i++) { + if (state->eps[i]->proxy_put_signal_comp_map) + delete state->eps[i]->proxy_put_signal_comp_map; + if (state->eps[i]->endpoint) { + status = fi_close(&state->eps[i]->endpoint->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric endpoint.: %d: %s\n", status, + fi_strerror(status * -1)); } - if (libfabric_state->eps[i].cq) { - status = fi_close(&libfabric_state->eps[i].cq->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric cq: %d: %s\n", status, - fi_strerror(status * -1)); - } + } + if (state->eps[i]->cq) { + status = fi_close(&state->eps[i]->cq->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric cq: %d: %s\n", status, + fi_strerror(status * -1)); } - if (libfabric_state->eps[i].counter) { - status = fi_close(&libfabric_state->eps[i].counter->fid); - if (status) { - NVSHMEMI_WARN_PRINT("Unable to close fabric counter: %d: %s\n", status, - fi_strerror(status * -1)); - } + } + if (state->eps[i]->counter) { + status = fi_close(&state->eps[i]->counter->fid); + if (status) { + NVSHMEMI_WARN_PRINT("Unable to close fabric counter: %d: %s\n", status, + fi_strerror(status * -1)); } } - free(libfabric_state->eps); + free(state->eps[i]); } - if (libfabric_state->remote_addr_staged_amo_ack) { - if (libfabric_state->remote_addr_staged_amo_ack[transport->my_pe]) - cudaFree(libfabric_state->remote_addr_staged_amo_ack[transport->my_pe]); - free(libfabric_state->remote_addr_staged_amo_ack); + if (state->remote_addr_staged_amo_ack) { + if (state->remote_addr_staged_amo_ack[transport->my_pe]) + cudaFree(state->remote_addr_staged_amo_ack[transport->my_pe]); + free(state->remote_addr_staged_amo_ack); } - if (libfabric_state->rkey_staged_amo_ack) free(libfabric_state->rkey_staged_amo_ack); - if (libfabric_state->mr_staged_amo_ack) { - status = fi_close(&libfabric_state->mr_staged_amo_ack->fid); + if (state->rkey_staged_amo_ack) free(state->rkey_staged_amo_ack); + for (size_t i = 0; i < state->mr_staged_amo_ack.size(); i++) { + status = fi_close(&state->mr_staged_amo_ack[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close staged atomic ack MR: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->mr) { - status = fi_close(&libfabric_state->mr->fid); + for (size_t i = 0; i < state->mr.size(); i++) { + status = fi_close(&state->mr[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric MR: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->recv_buf) free(libfabric_state->recv_buf); - for (int i = 0; i < NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS; i++) { - status = fi_close(&libfabric_state->addresses[i]->fid); + for (size_t i = 0; i < state->recv_buf.size(); i++) free(state->recv_buf[i]); + + for (size_t i = 0; i < state->addresses.size(); i++) { + status = fi_close(&state->addresses[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric address vector: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->domain) { - status = fi_close(&libfabric_state->domain->fid); + for (size_t i = 0; i < state->domains.size(); i++) { + status = fi_close(&state->domains[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric domain: %d: %s\n", status, fi_strerror(status * -1)); } } - if (libfabric_state->fabric) { - status = fi_close(&libfabric_state->fabric->fid); + for (size_t i = 0; i < state->fabrics.size(); i++) { + status = fi_close(&state->fabrics[i]->fid); if (status) { NVSHMEMI_WARN_PRINT("Unable to close fabric: %d: %s\n", status, fi_strerror(status * -1)); } } - free(libfabric_state); - + free(state); free(transport); return 0; @@ -1938,7 +1977,7 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabric_state_t *state, struct nvshmemi_options_s *options) { - struct fi_info info; + struct fi_info hints; struct fi_tx_attr tx_attr; struct fi_rx_attr rx_attr; struct fi_ep_attr ep_attr; @@ -1946,58 +1985,58 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr struct fi_fabric_attr fabric_attr; struct fid_nic nic; struct fi_av_attr av_attr; - struct fi_info *returned_fabrics, *current_fabric; + struct fi_info *all_infos, *current_info; int num_fabrics_returned = 0; int status = 0; memset(&ep_attr, 0, sizeof(struct fi_ep_attr)); memset(&av_attr, 0, sizeof(struct fi_av_attr)); - memset(&info, 0, sizeof(struct fi_info)); + memset(&hints, 0, sizeof(struct fi_info)); memset(&tx_attr, 0, sizeof(struct fi_tx_attr)); memset(&rx_attr, 0, sizeof(struct fi_rx_attr)); memset(&domain_attr, 0, sizeof(struct fi_domain_attr)); memset(&fabric_attr, 0, sizeof(struct fi_fabric_attr)); memset(&nic, 0, sizeof(struct fid_nic)); - info.tx_attr = &tx_attr; - info.rx_attr = &rx_attr; - info.ep_attr = &ep_attr; - info.domain_attr = &domain_attr; - info.fabric_attr = &fabric_attr; - info.nic = &nic; + hints.tx_attr = &tx_attr; + hints.rx_attr = &rx_attr; + hints.ep_attr = &ep_attr; + hints.domain_attr = &domain_attr; + hints.fabric_attr = &fabric_attr; + hints.nic = &nic; - info.addr_format = FI_FORMAT_UNSPEC; - info.caps = FI_RMA | FI_HMEM; + hints.addr_format = FI_FORMAT_UNSPEC; + hints.caps = FI_RMA | FI_HMEM; if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_VERBS) { - info.caps |= FI_ATOMIC; + hints.caps |= FI_ATOMIC; domain_attr.mr_mode = FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY; } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_SLINGSHOT) { /* TODO: Use FI_FENCE to optimize put_with_signal */ - info.caps |= FI_FENCE | FI_ATOMIC; + hints.caps |= FI_FENCE | FI_ATOMIC; domain_attr.mr_mode = FI_MR_ENDPOINT | FI_MR_ALLOCATED | FI_MR_PROV_KEY; } else if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { domain_attr.mr_mode = FI_MR_LOCAL | FI_MR_VIRT_ADDR | FI_MR_ALLOCATED | FI_MR_PROV_KEY | FI_MR_HMEM; - info.caps |= FI_MSG; - info.caps |= FI_SOURCE; + hints.caps |= FI_MSG; + hints.caps |= FI_SOURCE; } if (use_staged_atomics) { - info.mode |= FI_CONTEXT2; + hints.mode |= FI_CONTEXT2; } ep_attr.type = FI_EP_RDM; /* Reliable datagrams */ /* Require completion RMA completion at target for correctness of quiet */ - info.tx_attr->op_flags = FI_DELIVERY_COMPLETE; + hints.tx_attr->op_flags = FI_DELIVERY_COMPLETE; /* nvshmemt_libfabric_auto_progress relaxes threading requirement */ domain_attr.threading = FI_THREAD_COMPLETION; - info.domain_attr->data_progress = FI_PROGRESS_AUTO; + hints.domain_attr->data_progress = FI_PROGRESS_AUTO; status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), NULL, - NULL, 0, &info, &returned_fabrics); + NULL, 0, &hints, &all_infos); /* * 1. Ensure that at least one fabric was returned @@ -2007,28 +2046,27 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr * options.LIBFABRIC_PROVIDER will be a substr of the returned fabric * name */ - if (!status && strstr(returned_fabrics->fabric_attr->name, options->LIBFABRIC_PROVIDER)) { + if (!status && strstr(all_infos->fabric_attr->name, options->LIBFABRIC_PROVIDER)) { use_auto_progress = true; } else { - fi_freeinfo(returned_fabrics); + fi_freeinfo(all_infos); /* * Fallback to FI_PROGRESS_MANUAL path * nvshmemt_libfabric_slow_progress requires FI_THREAD_SAFE */ domain_attr.threading = FI_THREAD_SAFE; - info.domain_attr->data_progress = FI_PROGRESS_MANUAL; + hints.domain_attr->data_progress = FI_PROGRESS_MANUAL; status = fi_getinfo(FI_VERSION(NVSHMEMT_LIBFABRIC_MAJ_VER, NVSHMEMT_LIBFABRIC_MIN_VER), - NULL, NULL, 0, &info, &returned_fabrics); + NULL, NULL, 0, &hints, &all_infos); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "No providers matched fi_getinfo query: %d: %s\n", status, fi_strerror(status * -1)); } - state->all_prov_info = returned_fabrics; - for (current_fabric = returned_fabrics; current_fabric != NULL; - current_fabric = current_fabric->next) { + state->all_prov_info = all_infos; + for (current_info = all_infos; current_info != NULL; current_info = current_info->next) { num_fabrics_returned++; } @@ -2039,53 +2077,51 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr /* Only select unique devices. */ state->num_domains = 0; - for (current_fabric = returned_fabrics; current_fabric != NULL; - current_fabric = current_fabric->next) { - if (!current_fabric->nic) { + for (current_info = all_infos; current_info != NULL; current_info = current_info->next) { + if (!current_info->nic) { INFO(state->log_level, "Interface did not return NIC structure to fi_getinfo. Skipping.\n"); continue; } - if (!current_fabric->tx_attr) { + if (!current_info->tx_attr) { INFO(state->log_level, "Interface did not return TX_ATTR structure to fi_getinfo. Skipping.\n"); continue; } TRACE(state->log_level, "fi_getinfo returned provider %s, fabric %s, nic %s", - current_fabric->fabric_attr->prov_name, current_fabric->fabric_attr->name, - current_fabric->nic->device_attr->name); + current_info->fabric_attr->prov_name, current_info->fabric_attr->name, + current_info->nic->device_attr->name); if (state->provider != NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - if (current_fabric->tx_attr->inject_size < NVSHMEMT_LIBFABRIC_INJECT_BYTES) { + if (current_info->tx_attr->inject_size < NVSHMEMT_LIBFABRIC_INJECT_BYTES) { INFO(state->log_level, "Disabling interface due to insufficient inject data size. reported %lu, " "expected " "%u", - current_fabric->tx_attr->inject_size, NVSHMEMT_LIBFABRIC_INJECT_BYTES); + current_info->tx_attr->inject_size, NVSHMEMT_LIBFABRIC_INJECT_BYTES); continue; } } - if ((current_fabric->domain_attr->mr_mode & FI_MR_PROV_KEY) == 0) { + if ((current_info->domain_attr->mr_mode & FI_MR_PROV_KEY) == 0) { INFO(state->log_level, "Disabling interface due to FI_MR_PROV_KEY support"); continue; } for (int i = 0; i <= state->num_domains; i++) { - if (!strncmp(current_fabric->nic->device_attr->name, state->domain_names[i].name, + if (!strncmp(current_info->nic->device_attr->name, state->domain_names[i].name, NVSHMEMT_LIBFABRIC_DOMAIN_LEN)) { break; } else if (i == state->num_domains) { - size_t name_len = strlen(current_fabric->nic->device_attr->name); + size_t name_len = strlen(current_info->nic->device_attr->name); if (name_len >= NVSHMEMT_LIBFABRIC_DOMAIN_LEN) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to copy domain name for libfabric transport."); } (void)strncpy(state->domain_names[state->num_domains].name, - current_fabric->nic->device_attr->name, - NVSHMEMT_LIBFABRIC_DOMAIN_LEN); + current_info->nic->device_attr->name, NVSHMEMT_LIBFABRIC_DOMAIN_LEN); state->num_domains++; break; } @@ -2107,8 +2143,6 @@ static int nvshmemi_libfabric_init_state(nvshmem_transport_t t, nvshmemt_libfabr nvshmemt_libfabric_finalize(t); } - free(info.fabric_attr->name); - return status; } @@ -2157,6 +2191,7 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, "Unable to initialize env options."); libfabric_state->log_level = nvshmemt_common_get_log_level(&options); + libfabric_state->max_nic_per_pe = options.LIBFABRIC_MAX_NIC_PER_PE; if (strcmp(options.LIBFABRIC_PROVIDER, "verbs") == 0) { libfabric_state->provider = NVSHMEMT_LIBFABRIC_PROVIDER_VERBS; diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index eac754cf..3df05486 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -31,12 +31,9 @@ #define NVSHMEMT_LIBFABRIC_DOMAIN_LEN 32 #define NVSHMEMT_LIBFABRIC_PROVIDER_LEN 32 #define NVSHMEMT_LIBFABRIC_EP_LEN 128 - -/* one EP for all proxy ops, one for host ops */ -#define NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS 2 +/* Constrainted by memhandle size */ +#define NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE 16 #define NVSHMEMT_LIBFABRIC_PROXY_EP_IDX 1 -#define NVSHMEMT_LIBFABRIC_HOST_EP_IDX 0 - #define NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS 20 /* Maximum size of inject data. Currently @@ -196,7 +193,7 @@ typedef struct { nvshmemt_libfabric_endpoint_seq_counter_t put_signal_seq_counter; std::unordered_map> *proxy_put_signal_comp_map; - int qp_index; + int domain_index; } nvshmemt_libfabric_endpoint_t; typedef struct nvshmemt_libfabric_gdr_send_p_op { @@ -261,6 +258,10 @@ class threadSafeOpQueue { std::deque other_recv; public: + threadSafeOpQueue() = default; + threadSafeOpQueue(const threadSafeOpQueue &) = delete; + threadSafeOpQueue &operator=(const threadSafeOpQueue &) = delete; + int getNextSends(void **elems, size_t num_elems = 1) { send_mutex.lock(); if (send.size() < num_elems) { @@ -391,26 +392,34 @@ class threadSafeOpQueue { }; typedef struct { - struct fi_info *prov_info; struct fi_info *all_prov_info; - struct fid_fabric *fabric; - struct fid_domain *domain; - struct fid_av *addresses[NVSHMEMT_LIBFABRIC_DEFAULT_NUM_EPS]; - nvshmemt_libfabric_endpoint_t *eps; + std::vector prov_infos; + std::vector fabrics; + std::vector domains; + std::vector addresses; + std::vector eps; + nvshmemt_libfabric_domain_name_t *domain_names; int num_domains; nvshmemt_libfabric_provider provider; int log_level; struct nvshmemi_cuda_fn_table *table; - size_t num_sends; - void *send_buf; - size_t num_recvs; - void *recv_buf; - struct fid_mr *mr; struct transport_mem_handle_info_cache *cache; + + /* Required for multi-rail */ + int max_nic_per_pe; + int num_selected_devs; + int num_selected_domains; + int cur_proxy_ep_index; + + /* Required for staged_amo */ + std::vector op_queue; + std::vector mr; + std::vector send_buf; + std::vector recv_buf; + std::vector mr_staged_amo_ack; void **remote_addr_staged_amo_ack; uint64_t *rkey_staged_amo_ack; - struct fid_mr *mr_staged_amo_ack; } nvshmemt_libfabric_state_t; typedef struct { @@ -431,7 +440,7 @@ typedef struct { typedef struct { void *buf; - nvshmemt_libfabric_mem_handle_ep_t hdls[2]; + nvshmemt_libfabric_mem_handle_ep_t hdls[1 + NVSHMEMT_LIBFABRIC_MAX_NIC_PER_PE]; } nvshmemt_libfabric_mem_handle_t; /* Wire data for put-signal gdr staged atomics From 49c21f84020a091c10d27c783688829203e2f5eb Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Tue, 24 Feb 2026 19:10:02 -0800 Subject: [PATCH 05/11] topo: Do load balancing across all GPUs The current topology code has the problem that NIC-GPU load balancing is applied only to PEs. If an application initializes NVSHMEM such that there is only one PE per node per rank, load balancing will not be applied between GPUs. Then, PEs/GPUs from separate ranks will share NICs, leading to contention and reduced performance. In this patch, fix the issue by iterating over all GPUs in the system, listed in `/sys/bus/pci/drivers/nvidia`, and do load balancing over all of them. Signed-off-by: Eric Raut --- src/host/topo/topo.cpp | 280 ++++++++++++++++++------------- src/host/topo/topo.h | 1 + src/host/transport/transport.cpp | 7 +- 3 files changed, 169 insertions(+), 119 deletions(-) diff --git a/src/host/topo/topo.cpp b/src/host/topo/topo.cpp index ac86a9eb..753bbe9d 100644 --- a/src/host/topo/topo.cpp +++ b/src/host/topo/topo.cpp @@ -9,6 +9,7 @@ #include // for CUDA_SUCCESS #include // for cudaDevice... #include // for cudaDevice... +#include // for opendir, readdir #include // for PATH_MAX #include // for NULL, fclose #include // for free, calloc @@ -49,6 +50,8 @@ enum pci_distance { static const int pci_distance_perf[PATH_COUNT] = {4, 4, 3, 2, 1}; static const char *pci_distance_string[PATH_COUNT] = {"PIX", "PXB", "PHB", "NODE", "SYS"}; +#define NVIDIA_DRIVER_PATH "/sys/bus/pci/drivers/nvidia" + static int get_cuda_bus_id(int cuda_dev, char *bus_id) { int status = NVSHMEMX_SUCCESS; cudaError_t err; @@ -106,6 +109,71 @@ static int get_device_path(char *bus_id, char **path) { return status; } +static int is_pci_addr(const char *name) { + // Match XXXX:XX:XX.X pattern + return strlen(name) == 12 && name[4] == ':' && name[7] == ':' && name[10] == '.'; +} + +int get_nvidia_gpu_count(void) { + DIR *dir = opendir(NVIDIA_DRIVER_PATH); + if (!dir) return 0; + int count = 0; + struct dirent *ent; + while ((ent = readdir(dir)) != NULL) { + if (is_pci_addr(ent->d_name)) count++; + } + closedir(dir); + return count; +} + +static int get_gpu_paths_and_index(int cuda_device_id, char **cuda_device_paths, + int *out_mygpu_index) { + int status = NVSHMEMX_SUCCESS; + char my_bus_id[MAX_BUSID_SIZE]; + DIR *nvidia_dir = NULL; + + status = get_cuda_bus_id(cuda_device_id, my_bus_id); + if (status != NVSHMEMX_SUCCESS) return status; + for (int k = 0; k < MAX_BUSID_SIZE; k++) + my_bus_id[k] = tolower(my_bus_id[k]); + + nvidia_dir = opendir(NVIDIA_DRIVER_PATH); + if (!nvidia_dir) { + NVSHMEMI_ERROR_PRINT("Failed to open " NVIDIA_DRIVER_PATH "\n"); + return NVSHMEMX_ERROR_INTERNAL; + } + + int gpu_id = 0; + *out_mygpu_index = -1; + struct dirent *ent; + while ((ent = readdir(nvidia_dir)) != NULL) { + if (!is_pci_addr(ent->d_name)) continue; + char bus_id[MAX_BUSID_SIZE]; + strncpy(bus_id, ent->d_name, MAX_BUSID_SIZE - 1); + bus_id[MAX_BUSID_SIZE - 1] = '\0'; + + status = get_device_path(bus_id, &cuda_device_paths[gpu_id]); + if (status != NVSHMEMX_SUCCESS) { + NVSHMEMI_ERROR_PRINT("get cuda path failed\n"); + closedir(nvidia_dir); + return status; + } + + if (strncmp(my_bus_id, bus_id, MAX_BUSID_SIZE) == 0) + *out_mygpu_index = gpu_id; + + gpu_id++; + } + closedir(nvidia_dir); + + if (*out_mygpu_index < 0) { + NVSHMEMI_ERROR_PRINT("Could not find current GPU in sysfs\n"); + return NVSHMEMX_ERROR_INTERNAL; + } + + return NVSHMEMX_SUCCESS; +} + static enum pci_distance get_pci_distance(char *cuda_path, char *mlx_path) { int score = 0; int depth = 0; @@ -130,7 +198,7 @@ static enum pci_distance get_pci_distance(char *cuda_path, char *mlx_path) { } typedef struct nvshmemi_path_pair_info { - int pe_idx; + int gpu_idx; int dev_idx; enum pci_distance pcie_distance; } nvshmemi_path_pair_info_t; @@ -146,37 +214,32 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, char gpu_bus_id[MAX_BUSID_SIZE]; } gpu_info, *gpu_info_all = NULL; - std::list pe_dev_pairs; + std::list gpu_dev_pairs; std::list::iterator pairs_iter; int ndev = tcurr->n_devices; int mype = nvshmemi_state->mype; int n_pes = nvshmemi_state->npes; int n_pes_node = nvshmemi_state->npes_node; - CUdevice gpu_device_id; char **cuda_device_paths = NULL; - int *pe_selected_devices = NULL; - enum pci_distance *pe_device_distance = NULL; + int *gpu_selected_devices = NULL; + enum pci_distance *gpu_device_distance = NULL; int *used_devs = NULL; + int n_gpus_node = 0; - int mype_array_index = -1, mydev_index = -1; - int i, dev_id, pe_id, pe_pair_index; + int mygpu_index = -1, mydev_index = -1; + int i, dev_id, gpu_id, gpu_pair_index; int devices_assigned = 0; - int mype_device_count = 0; + int mygpu_device_count = 0; int status = NVSHMEMX_ERROR_INTERNAL; + int mygpu_array_index; if (ndev <= 0) { NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "transport devices (setup_connections) failed \n"); } - status = CUPFN(nvshmemi_cuda_syms, cuCtxGetDevice(&gpu_device_id)); - if (status != CUDA_SUCCESS) { - status = NVSHMEMX_ERROR_INTERNAL; - goto out; - } - /* Allocate data structures start */ /* Array of dev_info structures of size # of local NICs */ dev_info_all = (struct dev_info *)calloc(ndev, sizeof(struct dev_info)); @@ -188,66 +251,45 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, NVSHMEMI_NULL_ERROR_JMP(gpu_info_all, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "gpu_info_all allocation failed \n"); - /* array linking each GPU on our node with it's pcie path */ - cuda_device_paths = (char **)calloc(n_pes_node, sizeof(char *)); - NVSHMEMI_NULL_ERROR_JMP(cuda_device_paths, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate memory for PE/NIC Mapping.\n"); - - /* Array of size n_pes_node * max_dev_per_pe storing the accepted mappings of PE to Dev(s) */ - pe_selected_devices = (int *)calloc(n_pes_node * max_dev_per_pe, sizeof(int)); - NVSHMEMI_NULL_ERROR_JMP(pe_selected_devices, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate memory for PE/NIC Mapping.\n"); - for (pe_id = 0; pe_id < n_pes_node; pe_id++) { - for (dev_id = 0; dev_id < max_dev_per_pe; dev_id++) { - pe_selected_devices[pe_id * max_dev_per_pe + dev_id] = -1; - } - } - - pe_device_distance = - (enum pci_distance *)calloc(n_pes_node * max_dev_per_pe, sizeof(enum pci_distance)); - NVSHMEMI_NULL_ERROR_JMP(pe_device_distance, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, - "Unable to allocate memory for PE/NIC Mapping.\n"); - for (pe_id = 0; pe_id < n_pes_node; pe_id++) { - for (dev_id = 0; dev_id < max_dev_per_pe; dev_id++) { - pe_device_distance[pe_id * max_dev_per_pe + dev_id] = PATH_SYS; - } - } - used_devs = (int *)calloc(ndev, sizeof(int)); NVSHMEMI_NULL_ERROR_JMP(used_devs, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate memory for PE/NIC Mapping.\n"); /* Allocate data structures end */ /* Gather GPU and NIC paths start */ - status = get_cuda_bus_id(gpu_device_id, gpu_info.gpu_bus_id); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "get cuda busid failed \n"); - - status = nvshmemi_boot_handle.allgather((void *)&gpu_info, (void *)gpu_info_all, - sizeof(struct gpu_info), &nvshmemi_boot_handle); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "allgather of gpu_info failed \n"); - - pe_id = 0; - for (i = 0; i < n_pes; i++) { - if (nvshmemi_state->pe_info[i].hostHash != nvshmemi_state->pe_info[mype].hostHash) { - continue; - } + n_gpus_node = get_nvidia_gpu_count(); + if (n_gpus_node <= 0) { + NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "No NVIDIA GPUs found in " NVIDIA_DRIVER_PATH "\n"); + } - status = get_device_path(gpu_info_all[i].gpu_bus_id, &cuda_device_paths[pe_id]); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "get cuda path failed \n"); - /* to get back to our PE after the algorithm finishes. */ - if (i == mype) { - mype_array_index = pe_id * max_dev_per_pe; - } + cuda_device_paths = (char **)calloc(n_gpus_node, sizeof(char *)); + NVSHMEMI_NULL_ERROR_JMP(cuda_device_paths, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to allocate memory for GPU/NIC Mapping.\n"); - pe_id++; - if (pe_id == n_pes_node) { - break; + status = get_gpu_paths_and_index(nvshmemi_state->device_id, cuda_device_paths, &mygpu_index); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "get_gpu_paths_and_index failed\n"); + mygpu_array_index = mygpu_index * max_dev_per_pe; + + /* Allocate GPU-based arrays */ + gpu_selected_devices = (int *)calloc(n_gpus_node * max_dev_per_pe, sizeof(int)); + NVSHMEMI_NULL_ERROR_JMP(gpu_selected_devices, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to allocate memory for GPU/NIC Mapping.\n"); + for (gpu_id = 0; gpu_id < n_gpus_node; gpu_id++) { + for (dev_id = 0; dev_id < max_dev_per_pe; dev_id++) { + gpu_selected_devices[gpu_id * max_dev_per_pe + dev_id] = -1; } } - if (pe_id != n_pes_node || mype_array_index == -1) { - NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Number of PEs found doesn't match the PE node count.\n"); + gpu_device_distance = + (enum pci_distance *)calloc(n_gpus_node * max_dev_per_pe, sizeof(enum pci_distance)); + NVSHMEMI_NULL_ERROR_JMP(gpu_device_distance, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, + "Unable to allocate memory for GPU/NIC Mapping.\n"); + for (gpu_id = 0; gpu_id < n_gpus_node; gpu_id++) { + for (dev_id = 0; dev_id < max_dev_per_pe; dev_id++) { + gpu_device_distance[gpu_id * max_dev_per_pe + dev_id] = PATH_SYS; + } } for (i = 0; i < ndev; i++) { @@ -257,37 +299,37 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, /* Gather GPU and NIC paths end */ /* Get path distances start */ - /* construct a n_pes_node * ndev array of distance measurements */ - for (pe_id = 0; pe_id < n_pes_node; pe_id++) { + /* construct a n_gpus_node * ndev array of distance measurements */ + for (gpu_id = 0; gpu_id < n_gpus_node; gpu_id++) { for (dev_id = 0; dev_id < ndev; dev_id++) { enum pci_distance distance_compare; distance_compare = - get_pci_distance(cuda_device_paths[pe_id], dev_info_all[dev_id].dev_path); - if (unlikely(pe_dev_pairs.empty())) { - pe_dev_pairs.push_front({pe_id, dev_id, distance_compare}); + get_pci_distance(cuda_device_paths[gpu_id], dev_info_all[dev_id].dev_path); + if (unlikely(gpu_dev_pairs.empty())) { + gpu_dev_pairs.push_front({gpu_id, dev_id, distance_compare}); } else { - for (pairs_iter = pe_dev_pairs.begin(); pairs_iter != pe_dev_pairs.end(); + for (pairs_iter = gpu_dev_pairs.begin(); pairs_iter != gpu_dev_pairs.end(); pairs_iter++) { if (distance_compare < (*pairs_iter).pcie_distance) { break; } } - INFO(NVSHMEM_TOPO, "PE %d: %s dev %d: %s distance: %d\n", pe_id, - cuda_device_paths[pe_id], dev_id, dev_info_all[dev_id].dev_path, + INFO(NVSHMEM_TOPO, "GPU %d: %s dev %d: %s distance: %d\n", gpu_id, + cuda_device_paths[gpu_id], dev_id, dev_info_all[dev_id].dev_path, distance_compare); - pe_dev_pairs.insert(pairs_iter, {pe_id, dev_id, distance_compare}); + gpu_dev_pairs.insert(pairs_iter, {gpu_id, dev_id, distance_compare}); } } } /* Get path distances end */ /* loop one, do initial assignments of NIC(s) to each GPU */ - for (pairs_iter = pe_dev_pairs.begin(); pairs_iter != pe_dev_pairs.end(); pairs_iter++) { + for (pairs_iter = gpu_dev_pairs.begin(); pairs_iter != gpu_dev_pairs.end(); pairs_iter++) { bool need_more_assignments = 0; - int pe_base_index = (*pairs_iter).pe_idx * max_dev_per_pe; + int gpu_base_index = (*pairs_iter).gpu_idx * max_dev_per_pe; /* skip pairs where the GPU already has a partner in the first loop */ - for (pe_pair_index = 0; pe_pair_index < max_dev_per_pe; pe_pair_index++) - if (pe_selected_devices[pe_base_index + pe_pair_index] == PE_DEVICE_NOT_ASSIGNED) { + for (gpu_pair_index = 0; gpu_pair_index < max_dev_per_pe; gpu_pair_index++) + if (gpu_selected_devices[gpu_base_index + gpu_pair_index] == PE_DEVICE_NOT_ASSIGNED) { need_more_assignments = 1; break; } @@ -297,13 +339,13 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, } if (pci_distance_perf[(*pairs_iter).pcie_distance] < - pci_distance_perf[pe_device_distance[pe_base_index]]) { + pci_distance_perf[gpu_device_distance[gpu_base_index]]) { /* This NIC and all subsequent ones are less optimal than the already selected NICs * They can be safely ignored and we assign -2 to indicate that there are no more * optimal NICs for this GPU. */ - for (; pe_pair_index < max_dev_per_pe; pe_pair_index++) { - pe_selected_devices[pe_base_index + pe_pair_index] = + for (; gpu_pair_index < max_dev_per_pe; gpu_pair_index++) { + gpu_selected_devices[gpu_base_index + gpu_pair_index] = PE_DEVICE_NO_OPTIMAL_ASSIGNMENT; /* While not technically assigned, we need to account for these NICs to make * forward progress. @@ -312,61 +354,61 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, } } else { /* This NIC is optimal for this GPU. */ - INFO(NVSHMEM_TOPO, "Pairing PE %d with device %d at distance %d\n", - (*pairs_iter).pe_idx, (*pairs_iter).dev_idx, (*pairs_iter).pcie_distance); - pe_selected_devices[pe_base_index + pe_pair_index] = (*pairs_iter).dev_idx; - pe_device_distance[pe_base_index + pe_pair_index] = (*pairs_iter).pcie_distance; + INFO(NVSHMEM_TOPO, "Pairing GPU %d with device %d at distance %d\n", + (*pairs_iter).gpu_idx, (*pairs_iter).dev_idx, (*pairs_iter).pcie_distance); + gpu_selected_devices[gpu_base_index + gpu_pair_index] = (*pairs_iter).dev_idx; + gpu_device_distance[gpu_base_index + gpu_pair_index] = (*pairs_iter).pcie_distance; used_devs[(*pairs_iter).dev_idx]++; devices_assigned++; } - if (devices_assigned == n_pes_node * max_dev_per_pe) { + if (devices_assigned == n_gpus_node * max_dev_per_pe) { break; } } /* loop two, load balance the NICs. */ - for (pe_id = 0; pe_id < n_pes_node; pe_id++) { + for (gpu_id = 0; gpu_id < n_gpus_node; gpu_id++) { for (dev_id = 0; dev_id < max_dev_per_pe; dev_id++) { - int pe_pair_idx = pe_id * max_dev_per_pe + dev_id; + int gpu_pair_idx = gpu_id * max_dev_per_pe + dev_id; int nic_density; - if (pe_selected_devices[pe_pair_idx] < 0) { + if (gpu_selected_devices[gpu_pair_idx] < 0) { continue; } - nic_density = used_devs[pe_selected_devices[pe_pair_idx]]; + nic_density = used_devs[gpu_selected_devices[gpu_pair_idx]]; /* Can't find a less populated NIC if ours is only assigned to 1 gpu. */ if (nic_density < 2) { continue; } - /* Calculate PE Index from nic_id. Each PE gets max_dev_per_pe assigned to them. - * If there are 8 NIC's and 4 PE's, the nic -> PE mapping looks like + /* Calculate GPU Index from nic_id. Each GPU gets max_dev_per_pe assigned to them. + * If there are 8 NIC's and 4 GPU's, the nic -> GPU mapping looks like * nic_id: 0 1 2 3 4 5 6 7 - * pe_idx: 0 0 1 1 2 2 3 3 + * gpu_idx: 0 0 1 1 2 2 3 3 */ - int pe_idx = (pe_pair_idx - (pe_pair_idx % max_dev_per_pe)) / max_dev_per_pe; - for (pairs_iter = pe_dev_pairs.begin(); pairs_iter != pe_dev_pairs.end(); + int gpu_idx = (gpu_pair_idx - (gpu_pair_idx % max_dev_per_pe)) / max_dev_per_pe; + for (pairs_iter = gpu_dev_pairs.begin(); pairs_iter != gpu_dev_pairs.end(); pairs_iter++) { /* Never change for a less optimal NIC. */ - if ((*pairs_iter).pe_idx != pe_idx) { + if ((*pairs_iter).gpu_idx != gpu_idx) { continue; } if (pci_distance_perf[(*pairs_iter).pcie_distance] < - pci_distance_perf[pe_device_distance[pe_pair_idx]]) { + pci_distance_perf[gpu_device_distance[gpu_pair_idx]]) { break; } if ((nic_density - used_devs[(*pairs_iter).dev_idx]) >= 2) { - INFO(NVSHMEM_TOPO, "Re-Pairing PE %d with device %d at distance %d\n", - (*pairs_iter).pe_idx, (*pairs_iter).dev_idx, (*pairs_iter).pcie_distance); - used_devs[pe_selected_devices[pe_pair_idx]]--; + INFO(NVSHMEM_TOPO, "Re-Pairing GPU %d with device %d at distance %d\n", + (*pairs_iter).gpu_idx, (*pairs_iter).dev_idx, (*pairs_iter).pcie_distance); + used_devs[gpu_selected_devices[gpu_pair_idx]]--; used_devs[(*pairs_iter).dev_idx]++; nic_density = used_devs[(*pairs_iter).dev_idx]; - pe_selected_devices[pe_pair_idx] = (*pairs_iter).dev_idx; - pe_device_distance[pe_pair_idx] = (*pairs_iter).pcie_distance; + gpu_selected_devices[gpu_pair_idx] = (*pairs_iter).dev_idx; + gpu_device_distance[gpu_pair_idx] = (*pairs_iter).pcie_distance; if (nic_density < 2) { break; } @@ -374,32 +416,34 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, } } - for (pe_pair_index = 0; pe_pair_index < max_dev_per_pe; pe_pair_index++) { - if (pe_selected_devices[mype_array_index + pe_pair_index] >= 0) { - mydev_index = pe_selected_devices[mype_array_index + pe_pair_index]; - device_arr[pe_pair_index] = mydev_index; - mype_device_count++; - INFO(NVSHMEM_TOPO, "Our PE is sharing its NIC at index %d with %d other PEs.\n", - used_devs[mydev_index], mype_device_count); + for (gpu_pair_index = 0; gpu_pair_index < max_dev_per_pe; gpu_pair_index++) { + if (gpu_selected_devices[mygpu_array_index + gpu_pair_index] >= 0) { + mydev_index = gpu_selected_devices[mygpu_array_index + gpu_pair_index]; + device_arr[gpu_pair_index] = mydev_index; + mygpu_device_count++; + INFO(NVSHMEM_TOPO, "Our GPU is sharing its NIC at index %d with %d other GPUs.\n", + used_devs[mydev_index], mygpu_device_count); } } - if (mype_device_count == 0) { + if (mygpu_device_count == 0) { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "No NICs were assigned to our PE.\n"); + "No NICs were assigned to our GPU.\n"); } /* No need to report this in a loop - All Devices will have the same perf characteristics. */ - if (pci_distance_perf[pe_device_distance[mype_array_index]] < pci_distance_perf[PATH_PIX]) { + if (pci_distance_perf[gpu_device_distance[mygpu_array_index]] < pci_distance_perf[PATH_PIX]) { nvshmemi_state->are_nics_ll128_compliant = false; INFO(NVSHMEM_TOPO, - "Our PE is connected to a NIC with pci distance %s." + "Our GPU is connected to a NIC with pci distance %s." "this will provide less than optimal performance.\n", - pci_distance_string[pe_device_distance[mype_array_index]]); + pci_distance_string[gpu_device_distance[mygpu_array_index]]); } } + status = NVSHMEMX_SUCCESS; + out: if (dev_info_all) { free(dev_info_all); @@ -410,7 +454,7 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, } if (cuda_device_paths) { - for (i = 0; i < n_pes_node; i++) { + for (i = 0; i < n_gpus_node; i++) { if (cuda_device_paths[i]) { free(cuda_device_paths[i]); } @@ -418,18 +462,18 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, free(cuda_device_paths); } - pe_dev_pairs.clear(); + gpu_dev_pairs.clear(); - if (pe_selected_devices) { - free(pe_selected_devices); + if (gpu_selected_devices) { + free(gpu_selected_devices); } if (used_devs) { free(used_devs); } - if (pe_device_distance) { - free(pe_device_distance); + if (gpu_device_distance) { + free(gpu_device_distance); } return status; diff --git a/src/host/topo/topo.h b/src/host/topo/topo.h index cdf07cd4..475c6250 100644 --- a/src/host/topo/topo.h +++ b/src/host/topo/topo.h @@ -10,6 +10,7 @@ int nvshmemi_get_devices_by_distance(int *device_arr, int max_dev_per_pe, struct nvshmem_transport *tcurr); +int get_nvidia_gpu_count(void); int nvshmemi_detect_same_device(nvshmemi_state_t *state); int nvshmemi_build_transport_map(nvshmemi_state_t *state); diff --git a/src/host/transport/transport.cpp b/src/host/transport/transport.cpp index 67077e7e..26e9b499 100644 --- a/src/host/transport/transport.cpp +++ b/src/host/transport/transport.cpp @@ -384,7 +384,12 @@ int nvshmemi_setup_connections(nvshmemi_state_t *state) { continue; } - int devices_temp = tcurr->n_devices / state->npes_node; + int n_gpus_node = get_nvidia_gpu_count(); + if (n_gpus_node <= 0) { + n_gpus_node = state->npes_node; + } + + int devices_temp = tcurr->n_devices / n_gpus_node; if (devices_temp == 0) devices_temp = 1; const int max_devices_per_pe = devices_temp; int selected_devices[max_devices_per_pe]; From c6e158a702f5649a2e4ba565a41a91ba9d8866ab Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Mon, 9 Mar 2026 14:59:03 -0700 Subject: [PATCH 06/11] libfabric: Add quick-out from quiet if no new ops submitted Cache counter value to avoid potentially expensive counter read operation if no new operations have been submitted since the last call to fence/quiet. Signed-off-by: Eric Raut --- src/modules/transport/libfabric/libfabric.cpp | 8 ++++++++ src/modules/transport/libfabric/libfabric.h | 1 + 2 files changed, 9 insertions(+) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 90ee4c09..6f10c0e9 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -759,8 +759,16 @@ static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int for (;;) { all_nics_quieted = true; for (int i = qp_index; i < end_iter; i++) { + + /* Quick out if the endpoint is still quiet since last time */ + if (state->eps[i]->submitted_ops == state->eps[i]->completed_ctr) { + continue; + } + completed = fi_cntr_read(state->eps[i]->counter) + state->eps[i]->completed_staged_atomics; + state->eps[i]->completed_ctr = completed; + if (state->eps[i]->submitted_ops != completed) { all_nics_quieted = false; if (nvshmemt_libfabric_progress(tcurr, qp_index)) { diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index 3df05486..e6bf410d 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -189,6 +189,7 @@ typedef struct { struct fid_cq *cq; struct fid_cntr *counter; uint64_t submitted_ops; + uint64_t completed_ctr; uint64_t completed_staged_atomics; nvshmemt_libfabric_endpoint_seq_counter_t put_signal_seq_counter; std::unordered_map> From 7f7e1b40dc1d8930df4c31dd31f2e04aff9842c7 Mon Sep 17 00:00:00 2001 From: Eric Raut Date: Tue, 10 Feb 2026 15:05:52 -0800 Subject: [PATCH 07/11] libfabric: add ordered delivery of signals for staged atomics All signals initiated by the same peer PE will be applied in-order based on sequence number. This change also orders signals with respect to puts, so that signal delivery guarantees previous puts have arrived. For this, puts are sent with a sequence number that orders them with the signals. A periodic ack is sent for puts to prevent sequence overflow errors. Ordering is applied across all rails, per peer PE. Signed-off-by: Eric Raut --- src/modules/transport/libfabric/libfabric.cpp | 349 +++++++++++++++--- src/modules/transport/libfabric/libfabric.h | 113 +++++- 2 files changed, 400 insertions(+), 62 deletions(-) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 6f10c0e9..4edc0921 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -123,17 +123,51 @@ static inline nvshmemt_libfabric_endpoint_t *nvshmemt_libfabric_get_next_ep( return state->eps[selected_ep]; } -static void nvshmemt_libfabric_put_signal_ack_completion(nvshmemt_libfabric_endpoint_t *ep, - struct fi_cq_data_entry *entry) { +static inline int convert_addr_to_pe(nvshmemt_libfabric_state_t *state, + nvshmemt_libfabric_endpoint_t *ep, + fi_addr_t addr) +{ + // addr = pe * libfabric_state->num_selected_domains + ep->domain_index + // so + // pe = (addr - ep->domain_index) / (libfabric_state->num_selected_domains) + int base_ep_index = addr - ep->domain_index; + assert((base_ep_index % state->num_selected_domains) == 0); + + return base_ep_index / state->num_selected_domains; +} + +static void nvshmemt_libfabric_put_signal_ack_completion(nvshmemt_libfabric_state_t *state, + nvshmemt_libfabric_endpoint_t *ep, + struct fi_cq_data_entry *entry, + fi_addr_t addr) { uint32_t seq_num = entry->data & NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK; if (seq_num != NVSHMEM_STAGED_AMO_SEQ_NUM) { - ep->put_signal_seq_counter.return_acked_seq_num(seq_num); + /* Use host_signal_state for eps[0], proxy_signal_state for eps[1+] */ + nvshmemt_libfabric_signal_state_t *signal_state = + (ep->domain_index == 0) ? &state->host_signal_state : &state->proxy_signal_state; + + int pe = convert_addr_to_pe(state, ep, addr); + nvshmemt_libfabric_imm_cq_data_hdr_t imm_header = + nvshmemt_get_write_with_imm_hdr(entry->data); + + if (imm_header == NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_ACK) { + (*signal_state->put_signal_seq_counter_per_pe)[pe] + .return_acked_seq_num_range_for_put(seq_num); + } else { + (*signal_state->put_signal_seq_counter_per_pe)[pe] + .return_acked_seq_num(seq_num); + } } ep->completed_staged_atomics++; } +static inline bool is_signal_only_op(nvshmemi_amo_t op) { + return (op == NVSHMEMI_AMO_SIGNAL || op == NVSHMEMI_AMO_SIGNAL_SET || + op == NVSHMEMI_AMO_SIGNAL_ADD); +} + static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry, @@ -146,11 +180,14 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo if (entry->flags & FI_REMOTE_CQ_DATA) { nvshmemt_libfabric_imm_cq_data_hdr_t imm_header = nvshmemt_get_write_with_imm_hdr(entry->data); - if (NVSHMEMT_LIBFABRIC_IMM_PUT_SIGNAL_SEQ == imm_header) { + if (NVSHMEMT_LIBFABRIC_IMM_PUT_SIGNAL_SEQ == imm_header || + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT == imm_header || + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ == imm_header) { status = nvshmemt_libfabric_put_signal_completion(transport, ep, entry, addr); goto out; - } else if (NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK == imm_header) { - nvshmemt_libfabric_put_signal_ack_completion(ep, entry); + } else if (NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK == imm_header || + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_ACK == imm_header) { + nvshmemt_libfabric_put_signal_ack_completion(state, ep, entry, *addr); goto out; } else { NVSHMEMI_ERROR_JMP(status, NVSHMEMX_ERROR_INVALID_VALUE, out, @@ -169,7 +206,7 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo } else if (entry->flags & FI_RMA) { /* inlined p ops or atomic responses */ state->op_queue[ep->domain_index]->putToSend(op); - } else if (op->type == NVSHMEMT_LIBFABRIC_MATCH) { + } else if ((op->type == NVSHMEMT_LIBFABRIC_MATCH) && (entry->flags & FI_RECV)) { /* Must happen after entry->flags & FI_SEND to avoid send completions */ status = nvshmemt_libfabric_put_signal_completion(transport, ep, entry, addr); } else if (entry->flags & FI_RECV) { @@ -364,9 +401,30 @@ static inline int try_again(nvshmem_transport_t transport, int *status, uint64_t return 1; } +static inline int get_next_seq_num_with_retry(nvshmem_transport_t transport, + nvshmemt_libfabric_endpoint_seq_counter_t &seq_counter, + uint32_t *sequence_count, + int qp_index, + nvshmemt_libfabric_try_again_call_site_t call_site) { + uint64_t num_retries = 0; + int status; + do { + int32_t seq_num = seq_counter.next_seq_num(); + if (seq_num < 0) { + status = -EAGAIN; + } else { + *sequence_count = seq_num; + status = 0; + } + } while (try_again(transport, &status, &num_retries, qp_index, call_site)); + + return status; +} + int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, fi_addr_t dest_addr, uint32_t sequence_count, int pe, - nvshmemt_libfabric_gdr_op_ctx_t **send_elems) { + nvshmemt_libfabric_gdr_op_ctx_t **send_elems, + nvshmemt_libfabric_imm_cq_data_hdr_t ack_header) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *resp_op = NULL; uint64_t num_retries = 0; @@ -375,9 +433,7 @@ int gdrcopy_amo_ack(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t uint64_t rkey_index = pe * libfabric_state->num_selected_domains + ep->domain_index; resp_op = send_elems[0]; - imm_data = (NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK - << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) | - sequence_count; + imm_data = (ack_header << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) | sequence_count; do { status = fi_writedata( ep->endpoint, resp_op, 0, fi_mr_desc(libfabric_state->mr[ep->domain_index]), imm_data, @@ -504,7 +560,8 @@ int perform_gdrcopy_amo(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_op } status = gdrcopy_amo_ack(transport, ep, src_addr, sequence_count, src_pe, - &send_elems[send_elems_index]); + &send_elems[send_elems_index], + NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK); out: return status; } @@ -689,9 +746,15 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, nvshmemt_libfabric_gdr_signal_op *sig_op = NULL; nvshmemt_libfabric_gdr_op_ctx_t *op = NULL; bool is_write_comp = entry->flags & FI_REMOTE_CQ_DATA; - int status = 0, progress_count; + int status = 0, progress_count, pe; uint64_t map_key; - std::unordered_map>::iterator iter; + bool is_standalone_put = false; + std::unordered_map::iterator iter; + + /* Use host_signal_state for eps[0], proxy_signal_state for eps[1+] */ + nvshmemt_libfabric_signal_state_t *signal_state = + (ep->domain_index == 0) ? &libfabric_state->host_signal_state + : &libfabric_state->proxy_signal_state; if (unlikely(*addr == FI_ADDR_NOTAVAIL)) { status = -1; @@ -699,13 +762,19 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, "Write w/imm returned with invalid src address.\n"); } + pe = convert_addr_to_pe(libfabric_state, ep, *addr); + if (is_write_comp) { - map_key = *addr << 32 | (uint32_t)entry->data; + nvshmemt_libfabric_imm_cq_data_hdr_t imm_header = + nvshmemt_get_write_with_imm_hdr(entry->data); + is_standalone_put = (imm_header == NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT || + imm_header == NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ); + map_key = (((uint64_t)pe) << 32) | ((uint32_t)entry->data & NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK); progress_count = -1; } else { sig_op = (nvshmemt_libfabric_gdr_signal_op *)container_of( entry->op_context, nvshmemt_libfabric_gdr_op_ctx_t, ofi_context); - map_key = *addr << 32 | sig_op->sequence_count; + map_key = (((uint64_t)pe) << 32) | sig_op->sequence_count; progress_count = (int)sig_op->num_writes; /* The EFA provider has an inline send size of 32 bytes. @@ -718,23 +787,81 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, op = nvshmemt_inplace_copy_sig_op_to_gdr_op(sig_op, ep); } - iter = ep->proxy_put_signal_comp_map->find(map_key); - if (iter != ep->proxy_put_signal_comp_map->end()) { - if (!is_write_comp) iter->second.first = op; - iter->second.second += progress_count; + if (is_write_comp && nvshmemt_get_write_with_imm_hdr(entry->data) == NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ) { + nvshmemt_libfabric_comp_entry_t ack_comp_entry; + ack_comp_entry.type = NVSHMEMT_LIBFABRIC_COMP_ENTRY_PUT_ACK; + ack_comp_entry.ack_entry.src_addr = *addr; + ack_comp_entry.ack_entry.ep = ep; + signal_state->proxy_put_signal_comp_map->insert(std::make_pair(map_key, ack_comp_entry)); } else { - iter = ep->proxy_put_signal_comp_map - ->insert(std::make_pair(map_key, std::make_pair(op, progress_count))) - .first; - } + iter = signal_state->proxy_put_signal_comp_map->find(map_key); + if (iter != signal_state->proxy_put_signal_comp_map->end()) { + if (!is_write_comp) iter->second.signal_entry.op = op; + iter->second.signal_entry.progress_count += progress_count; + } else { + nvshmemt_libfabric_comp_entry_t sig_comp_entry; + sig_comp_entry.type = NVSHMEMT_LIBFABRIC_COMP_ENTRY_SIGNAL; + if (is_standalone_put) { + sig_comp_entry.signal_entry.op = nullptr; + sig_comp_entry.signal_entry.progress_count = 0; + } else { + sig_comp_entry.signal_entry.op = op; + sig_comp_entry.signal_entry.progress_count = progress_count; + } + signal_state->proxy_put_signal_comp_map->insert(std::make_pair(map_key, sig_comp_entry)); + iter = signal_state->proxy_put_signal_comp_map->find(map_key); + } - if (!iter->second.second) { - if (is_write_comp) { - op = iter->second.first; + if (iter->second.signal_entry.progress_count != 0) { + goto out; } + } + + { + fi_addr_t src_addr = *addr; + // operator[] will default-construct (initialize to 0) if src_addr doesn't exist + uint32_t &next_seq = (*signal_state->next_expected_seq)[pe]; + + while (true) { + // Skip reserved sequence number + if (next_seq == NVSHMEM_STAGED_AMO_SEQ_NUM) { + next_seq = (next_seq + 1) & nvshmemt_libfabric_endpoint_seq_counter_t::sequence_mask; + continue; + } + + uint64_t key = (((uint64_t)pe) << 32) | next_seq; + auto it = signal_state->proxy_put_signal_comp_map->find(key); + + if (it == signal_state->proxy_put_signal_comp_map->end()) break; + + if (it->second.type == NVSHMEMT_LIBFABRIC_COMP_ENTRY_SIGNAL) { + if (it->second.signal_entry.progress_count != 0) break; + + if (it->second.signal_entry.op != NULL) { + auto op_ep = it->second.signal_entry.op->ep; + libfabric_state->op_queue[op_ep->domain_index]->putToRecv( + it->second.signal_entry.op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); + } + } else { + nvshmemt_libfabric_endpoint_t *ack_ep = it->second.ack_entry.ep; + nvshmemt_libfabric_gdr_op_ctx_t *send_elem; + uint64_t num_retries = 0; + do { + status = libfabric_state->op_queue[ack_ep->domain_index]->getNextSends( + (void **)(&send_elem), 1); + } while (try_again(transport, &status, &num_retries, ack_ep->domain_index, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDRCOPY_AMO_ACK, true)); + + if (status == 0) { + status = gdrcopy_amo_ack(transport, ack_ep, it->second.ack_entry.src_addr, next_seq, pe, + &send_elem, NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_ACK); + } + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "gdrcopy_amo_ack failed\n"); + } - libfabric_state->op_queue[ep->domain_index]->putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); - ep->proxy_put_signal_comp_map->erase(iter); + signal_state->proxy_put_signal_comp_map->erase(it); + next_seq = (next_seq + 1) & nvshmemt_libfabric_endpoint_seq_counter_t::sequence_mask; + } } out: @@ -855,16 +982,16 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, op_size = bytesdesc.elembytes * bytesdesc.nelems; if (verb.desc == NVSHMEMI_OP_P) { - assert(!imm_data); // Write w/ imm not suppored with NVSHMEMI_OP_P on Libfabric transport if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { nvshmemt_libfabric_gdr_op_ctx_t *p_buf = container_of(context, nvshmemt_libfabric_gdr_op_ctx_t, ofi_context); num_retries = 0; + p_buf->p_op.value = *(uint64_t *)local->ptr; + assert(imm_data); // EFA provider requires immediate data for p/put do { - p_buf->p_op.value = *(uint64_t *)local->ptr; - status = fi_write(ep->endpoint, &p_buf->p_op.value, op_size, - fi_mr_desc(libfabric_state->mr[ep->domain_index]), target_ep, - (uintptr_t)remote->ptr, remote_handle->key, context); + status = fi_writedata(ep->endpoint, &p_buf->p_op.value, op_size, + fi_mr_desc(libfabric_state->mr[ep->domain_index]), *imm_data, target_ep, + (uintptr_t)remote->ptr, remote_handle->key, context); } while (try_again(tcurr, &status, &num_retries, qp_index, NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_P_EFA)); } else { @@ -944,16 +1071,58 @@ static int nvshmemt_libfabric_rma_impl(struct nvshmem_transport *tcurr, int pe, static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_verb_t verb, rma_memdesc_t *remote, rma_memdesc_t *local, rma_bytesdesc_t bytesdesc, int qp_index) { - return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, NULL, - NULL); + nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)tcurr->state; + uint32_t imm_data_val = 0; + uint32_t *imm_data = NULL; + int status; + nvshmemt_libfabric_endpoint_t *ep = nullptr; + + // Generate sequence number for P and PUT operations when ordering is needed + if (use_staged_atomics && + (verb.desc == NVSHMEMI_OP_P || verb.desc == NVSHMEMI_OP_PUT)) { + ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); + + /* Use host_signal_state for qp_index 0, proxy_signal_state otherwise */ + nvshmemt_libfabric_signal_state_t *signal_state = + (qp_index == NVSHMEMX_QP_HOST) ? &libfabric_state->host_signal_state + : &libfabric_state->proxy_signal_state; + auto &seq_counter = (*signal_state->put_signal_seq_counter_per_pe)[pe]; + uint32_t sequence_count; + + status = get_next_seq_num_with_retry(tcurr, seq_counter, &sequence_count, qp_index, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_RMA_IMPL_OP_PUT); + if (status) return status; + + seq_counter.put_count++; + + nvshmemt_libfabric_imm_cq_data_hdr_t header; + if (seq_counter.put_count >= NVSHMEM_STAGED_AMO_PUT_ACK_FREQ) { + header = NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ; + seq_counter.put_count = 0; + ep->submitted_ops++; // Account for incoming ack + } else { + header = NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT; + } + + imm_data_val = (header << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) | sequence_count; + imm_data = &imm_data_val; + } + + return nvshmemt_libfabric_rma_impl(tcurr, pe, verb, remote, local, bytesdesc, qp_index, imm_data, + ep); } +static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, int pe, + void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, + amo_bytesdesc_t bytesdesc, int qp_index, + uint32_t sequence_count, uint16_t num_writes, + nvshmemt_libfabric_endpoint_t *ep); + static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int pe, void *curetptr, amo_verb_t verb, amo_memdesc_t *remote, amo_bytesdesc_t bytesdesc, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_endpoint_t *ep; - nvshmemt_libfabric_gdr_op_ctx_t *amo; uint64_t num_retries = 0; int target_ep; int status = 0; @@ -961,6 +1130,27 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p ep = nvshmemt_libfabric_get_next_ep(libfabric_state, qp_index); target_ep = pe * libfabric_state->num_selected_domains + ep->domain_index; + /* Signal-only operations use gdr_signal path with num_writes=0 */ + if (is_signal_only_op(verb.desc)) { + /* Use host_signal_state for qp_index 0, proxy_signal_state otherwise */ + nvshmemt_libfabric_signal_state_t *signal_state = + (qp_index == NVSHMEMX_QP_HOST) ? &libfabric_state->host_signal_state + : &libfabric_state->proxy_signal_state; + auto &seq_counter = (*signal_state->put_signal_seq_counter_per_pe)[pe]; + uint32_t sequence_count; + status = get_next_seq_num_with_retry(transport, seq_counter, &sequence_count, qp_index, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_GDR_AMO_GET_NEXT_SENDS); + if (status) goto out; + + seq_counter.put_count = 0; + + status = nvshmemt_libfabric_gdr_signal(transport, pe, curetptr, verb, remote, + bytesdesc, qp_index, sequence_count, 0, ep); + goto out; + } + + /* Fetch operations use full gdr_op_ctx_t */ + nvshmemt_libfabric_gdr_op_ctx_t *amo; do { status = libfabric_state->op_queue[ep->domain_index]->getNextSends((void **)(&amo), 1); } while (try_again(transport, &status, &num_retries, qp_index, @@ -975,8 +1165,8 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p amo->send_amo.swap_add = remote->val; amo->send_amo.size = bytesdesc.elembytes; amo->send_amo.src_pe = transport->my_pe; - amo->type = NVSHMEMT_LIBFABRIC_SEND; amo->send_amo.comp = remote->cmp; + amo->type = NVSHMEMT_LIBFABRIC_SEND; num_retries = 0; do { @@ -1206,24 +1396,23 @@ int nvshmemt_put_signal_unordered(struct nvshmem_transport *tcurr, int pe, rma_v uint32_t sequence_count = 0; nvshmemt_libfabric_endpoint_t *ep = nvshmemt_libfabric_get_next_ep(state, qp_index); - /* Get sequence number for this put-signal, with retry */ - uint64_t num_retries = 0; - do { - int32_t seq_num = ep->put_signal_seq_counter.next_seq_num(); - if (seq_num < 0) { - status = -EAGAIN; - } else { - sequence_count = seq_num; - status = 0; - } - } while (try_again(tcurr, &status, &num_retries, qp_index, - NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_PUT_SIGNAL_UNORDERED_SEQ)); + /* Get or create sequence counter for this destination fi_addr_t */ + /* Use host_signal_state for qp_index 0, proxy_signal_state otherwise */ + nvshmemt_libfabric_signal_state_t *signal_state = + (qp_index == NVSHMEMX_QP_HOST) ? &state->host_signal_state + : &state->proxy_signal_state; + auto &seq_counter = (*signal_state->put_signal_seq_counter_per_pe)[pe]; + /* Get sequence number for this put-signal, with retry */ + status = get_next_seq_num_with_retry(tcurr, seq_counter, &sequence_count, qp_index, + NVSHMEMT_LIBFABRIC_TRY_AGAIN_CALL_SITE_PUT_SIGNAL_UNORDERED_SEQ); if (unlikely(status)) { NVSHMEMI_ERROR_PRINT("Error in nvshmemt_put_signal_unordered while waiting for category\n"); goto out; } + seq_counter.put_count = 0; + assert(write_remote.size() == write_local.size() && write_local.size() == write_bytes_desc.size()); for (size_t i = 0; i < write_remote.size(); i++) { @@ -1507,6 +1696,34 @@ static int get_pci_path(int dev, char **pci_path, nvshmem_transport_t t) { return status; } +static void nvshmemt_libfabric_cleanup_signal_ordering_state(nvshmemt_libfabric_state_t *state) +{ + if (state->host_signal_state.put_signal_seq_counter_per_pe) { + delete state->host_signal_state.put_signal_seq_counter_per_pe; + state->host_signal_state.put_signal_seq_counter_per_pe = nullptr; + } + if (state->host_signal_state.proxy_put_signal_comp_map) { + delete state->host_signal_state.proxy_put_signal_comp_map; + state->host_signal_state.proxy_put_signal_comp_map = nullptr; + } + if (state->host_signal_state.next_expected_seq) { + delete state->host_signal_state.next_expected_seq; + state->host_signal_state.next_expected_seq = nullptr; + } + if (state->proxy_signal_state.put_signal_seq_counter_per_pe) { + delete state->proxy_signal_state.put_signal_seq_counter_per_pe; + state->proxy_signal_state.put_signal_seq_counter_per_pe = nullptr; + } + if (state->proxy_signal_state.proxy_put_signal_comp_map) { + delete state->proxy_signal_state.proxy_put_signal_comp_map; + state->proxy_signal_state.proxy_put_signal_comp_map = nullptr; + } + if (state->proxy_signal_state.next_expected_seq) { + delete state->proxy_signal_state.next_expected_seq; + state->proxy_signal_state.next_expected_seq = nullptr; + } +} + static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *selected_dev_ids, int num_selected_devs, int *out_qp_indices, int num_qps) { @@ -1610,6 +1827,21 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele NVSHMEMI_NULL_ERROR_JMP(all_ep_names, status, NVSHMEMX_ERROR_OUT_OF_MEMORY, out, "Unable to allocate array of endpoint names."); + /* Initialize state-level signal ordering state */ + state->host_signal_state.put_signal_seq_counter_per_pe = + new std::unordered_map(); + state->host_signal_state.proxy_put_signal_comp_map = + new std::unordered_map(); + state->host_signal_state.next_expected_seq = + new std::unordered_map(); + + state->proxy_signal_state.put_signal_seq_counter_per_pe = + new std::unordered_map(); + state->proxy_signal_state.proxy_put_signal_comp_map = + new std::unordered_map(); + state->proxy_signal_state.next_expected_seq = + new std::unordered_map(); + /* Create Resources For Each Selected Device */ for (size_t i = 0; i < state->prov_infos.size(); i++) { INFO(state->log_level, @@ -1668,11 +1900,6 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele "Unable to alloc libfabric_tx_progress_group struct.\n"); state->eps[i]->domain_index = i; - /* Initialize per-endpoint proxy_put_signal_comp_map */ - state->eps[i]->proxy_put_signal_comp_map = - new std::unordered_map>(); - - state->eps[i]->put_signal_seq_counter.reset(); state->eps[i]->completed_staged_atomics = 0; status = fi_cq_open(domain, &cq_attr, &state->eps[i]->cq, NULL); @@ -1689,6 +1916,7 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to allocate endpoint: %d: %s\n", status, fi_strerror(status * -1)); + /* FI_OPT_CUDA_API_PERMITTED was introduced in libfabric 1.18.0 */ if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { bool prohibit_cuda_api = false; @@ -1832,9 +2060,11 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele if (state->rkey_staged_amo_ack) free(state->rkey_staged_amo_ack); for (size_t i = 0; i < state->mr_staged_amo_ack.size(); i++) fi_close(&state->mr_staged_amo_ack[i]->fid); + + /* Cleanup state-level signal ordering state */ + nvshmemt_libfabric_cleanup_signal_ordering_state(state); + for (size_t i = 0; i < state->eps.size(); i++) { - if (state->eps[i]->proxy_put_signal_comp_map) - delete state->eps[i]->proxy_put_signal_comp_map; if (state->eps[i]->endpoint) { fi_close(&state->eps[i]->endpoint->fid); state->eps[i]->endpoint = NULL; @@ -1902,9 +2132,10 @@ static int nvshmemt_libfabric_finalize(nvshmem_transport_t transport) { */ if (state->all_prov_info) fi_freeinfo(state->all_prov_info); + /* Cleanup state-level signal ordering state */ + nvshmemt_libfabric_cleanup_signal_ordering_state(state); + for (size_t i = 0; i < state->eps.size(); i++) { - if (state->eps[i]->proxy_put_signal_comp_map) - delete state->eps[i]->proxy_put_signal_comp_map; if (state->eps[i]->endpoint) { status = fi_close(&state->eps[i]->endpoint->fid); if (status) { diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index e6bf410d..9df18dc1 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -67,6 +67,13 @@ typedef struct nvshmemt_libfabric_gdr_op_ctx nvshmemt_libfabric_gdr_op_ctx_t; #define NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK \ ((1U << NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_SHIFT) - 1) +/** + * Frequency at which we send an ack for puts (without signal). For puts-only, + * we don't need the ack for every message for semantic reasons, we only need + * an occasional ack to handle sequence number overflow correctly. + */ +#define NVSHMEM_STAGED_AMO_PUT_ACK_FREQ 64 + /** * The last sequence number is reserved for atomic-only operations. * This will not be returned by the sequence counter. @@ -95,6 +102,11 @@ struct nvshmemt_libfabric_endpoint_seq_counter_t { constexpr static uint32_t num_index_bits = (num_sequence_bits - num_category_bits); constexpr static uint32_t index_mask = ((1U << num_index_bits) - 1); + + /* Assert that index_mask is large enough to simplify some ranged ack return + logic. */ + static_assert((index_mask + 1) >= (2 * NVSHMEM_STAGED_AMO_PUT_ACK_FREQ), + "Number of indexes should be >= 2 * put_ack_freq"); constexpr static uint32_t category_mask = (1U << num_index_bits); constexpr static uint32_t sequence_mask = NVSHMEM_STAGED_AMO_PUT_SIGNAL_SEQ_CNTR_BIT_MASK; @@ -117,6 +129,14 @@ struct nvshmemt_libfabric_endpoint_seq_counter_t { uint32_t sequence_counter; uint32_t pending_acks[num_categories]; + uint32_t put_count; + + /** + * Default constructor - initializes counter to zero + */ + nvshmemt_libfabric_endpoint_seq_counter_t() { + reset(); + } /** * Reset counter and pending acks to zero @@ -124,6 +144,7 @@ struct nvshmemt_libfabric_endpoint_seq_counter_t { void reset() { sequence_counter = 0; memset(pending_acks, 0, sizeof(pending_acks)); + put_count = 0; } /** @@ -171,6 +192,55 @@ struct nvshmemt_libfabric_endpoint_seq_counter_t { assert(pending_acks[category] > 0); --pending_acks[category]; } + + /** + * Mark a range of sequence numbers as complete, resulting from reciving a + * put ack. The sequence range ends with end_seq. + * + * We send an ack for every NVSHMEM_STAGED_AMO_PUT_ACK_FREQ puts. Therefore, + * a put ack for is an acknowledgement sequence numbers (end_seq - + * NVSHMEM_STAGED_AMO_PUT_ACK_FREQ + 1) to end_seq, inclusive. The + * wraparound case is also handled. + * + * This code assumes the sequence range spans at most two categories. This + * will be true as long as the index space is sufficiently larger than the + * put ack frequency, as static asserted above. + */ + void return_acked_seq_num_range_for_put(uint32_t end_seq) { + assert(end_seq != NVSHMEM_STAGED_AMO_SEQ_NUM); + + uint32_t start_seq = (end_seq - NVSHMEM_STAGED_AMO_PUT_ACK_FREQ + 1) & sequence_mask; + + /* Note: in the wraparound case, the (start_seq, end_seq) range will + include NVSHMEM_STAGED_AMO_SEQ_NUM, which is not used. The logic + below handles this correctly, as long as `start_category` is correct + (which is true as long as the index space is sufficiently large that + we can only span two categories, as static-asserted above.) */ + + uint32_t start_category = get_category(start_seq); + uint32_t end_category = get_category(end_seq); + + uint32_t num_indexes; + if (end_seq >= start_seq) { + num_indexes = end_seq - start_seq + 1; + } else { + num_indexes = (NVSHMEM_STAGED_AMO_SEQ_NUM - start_seq + 1) + (end_seq + 1); + } + + if (start_category == end_category) { + assert(pending_acks[start_category] >= num_indexes); + pending_acks[start_category] -= num_indexes; + } else { + uint32_t count_in_start_cat = (index_mask + 1) - get_index(start_seq); + uint32_t count_in_end_cat = get_index(end_seq) + 1; + + assert(pending_acks[start_category] >= count_in_start_cat); + assert(pending_acks[end_category] >= count_in_end_cat); + + pending_acks[start_category] -= count_in_start_cat; + pending_acks[end_category] -= count_in_end_cat; + } + } }; typedef enum { @@ -191,12 +261,36 @@ typedef struct { uint64_t submitted_ops; uint64_t completed_ctr; uint64_t completed_staged_atomics; - nvshmemt_libfabric_endpoint_seq_counter_t put_signal_seq_counter; - std::unordered_map> - *proxy_put_signal_comp_map; int domain_index; } nvshmemt_libfabric_endpoint_t; +// Entry types for completion map +enum nvshmemt_libfabric_comp_entry_type { + NVSHMEMT_LIBFABRIC_COMP_ENTRY_SIGNAL, + NVSHMEMT_LIBFABRIC_COMP_ENTRY_PUT_ACK +}; + +// Entry for signal operations (put-signal, atomic) +struct nvshmemt_libfabric_signal_comp_entry { + nvshmemt_libfabric_gdr_op_ctx_t *op; + int progress_count; +}; + +// Entry for puts that need acknowledgment +struct nvshmemt_libfabric_put_ack_entry { + fi_addr_t src_addr; + nvshmemt_libfabric_endpoint_t *ep; +}; + +// Tagged union for completion entries +struct nvshmemt_libfabric_comp_entry_t { + nvshmemt_libfabric_comp_entry_type type; + union { + nvshmemt_libfabric_signal_comp_entry signal_entry; + nvshmemt_libfabric_put_ack_entry ack_entry; + }; +}; + typedef struct nvshmemt_libfabric_gdr_send_p_op { uint64_t value; } nvshmemt_libfabric_gdr_send_p_op_t; @@ -247,6 +341,9 @@ typedef enum { typedef enum { NVSHMEMT_LIBFABRIC_IMM_PUT_SIGNAL_SEQ = 0, NVSHMEMT_LIBFABRIC_IMM_STAGED_ATOMIC_ACK, + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT, + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ, + NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_ACK, } nvshmemt_libfabric_imm_cq_data_hdr_t; class threadSafeOpQueue { @@ -392,6 +489,12 @@ class threadSafeOpQueue { } }; +typedef struct { + std::unordered_map *put_signal_seq_counter_per_pe; + std::unordered_map *proxy_put_signal_comp_map; + std::unordered_map *next_expected_seq; +} nvshmemt_libfabric_signal_state_t; + typedef struct { struct fi_info *all_prov_info; std::vector prov_infos; @@ -421,6 +524,10 @@ typedef struct { std::vector mr_staged_amo_ack; void **remote_addr_staged_amo_ack; uint64_t *rkey_staged_amo_ack; + + /* Signal ordering state */ + nvshmemt_libfabric_signal_state_t host_signal_state; + nvshmemt_libfabric_signal_state_t proxy_signal_state; } nvshmemt_libfabric_state_t; typedef struct { From 27a9b02222642b4c6a4edbf76dd31a60752e7647 Mon Sep 17 00:00:00 2001 From: Xuan-1998 Date: Thu, 5 Mar 2026 18:43:18 -0800 Subject: [PATCH 08/11] transport/libfabric: Make op_queue locking conditional on threading model Introduce conditional_mutex that skips locking when FI_THREAD_COMPLETION is active (FI_PROGRESS_AUTO path). With auto progress, the provider guarantees thread-safe completion processing per endpoint, and the host thread and proxy thread operate on separate endpoints (eps[0] vs eps[1+]), so their op_queues are disjoint and no synchronization is needed. With FI_PROGRESS_MANUAL, we fall back to FI_THREAD_SAFE where manual_progress() iterates all EPs from both threads, requiring full mutex synchronization. Signed-off-by: Xuan Jiang --- src/modules/transport/libfabric/libfabric.cpp | 1 + src/modules/transport/libfabric/libfabric.h | 33 +++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 4edc0921..27f74487 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -1885,6 +1885,7 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele state->op_queue.push_back(new threadSafeOpQueue); state->op_queue[i]->putToSendBulk((char *)state->send_buf[i], elem_size, num_sends); + state->op_queue[i]->set_auto_progress(use_auto_progress); } status = fi_av_open(domain, &av_attr, &address, NULL); diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index 9df18dc1..4c84f7bc 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -346,11 +346,31 @@ typedef enum { NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_ACK, } nvshmemt_libfabric_imm_cq_data_hdr_t; +/* + * Conditional lock: skips locking when FI_THREAD_COMPLETION is active. + * + * With FI_PROGRESS_AUTO, we request FI_THREAD_COMPLETION from the provider, + * meaning the host thread and proxy thread each operate on separate endpoints + * (eps[0] vs eps[1+]), so their op_queues are disjoint and no synchronization + * is needed. With FI_PROGRESS_MANUAL, we use FI_THREAD_SAFE because + * manual_progress() iterates all EPs from both threads, requiring locking. + */ +class conditional_mutex { + std::mutex mtx; + bool needs_lock; + + public: + conditional_mutex() : needs_lock(true) {} + void set_needs_lock(bool v) { needs_lock = v; } + void lock() { if (needs_lock) mtx.lock(); } + void unlock() { if (needs_lock) mtx.unlock(); } +}; + class threadSafeOpQueue { private: - std::mutex send_mutex; - std::mutex ack_recv_mutex; - std::mutex other_recv_mutex; + conditional_mutex send_mutex; + conditional_mutex ack_recv_mutex; + conditional_mutex other_recv_mutex; std::vector send; std::deque ack_recv; std::deque other_recv; @@ -360,6 +380,13 @@ class threadSafeOpQueue { threadSafeOpQueue(const threadSafeOpQueue &) = delete; threadSafeOpQueue &operator=(const threadSafeOpQueue &) = delete; + /* Disable locking when FI_THREAD_COMPLETION keeps host/proxy EPs disjoint. */ + void set_auto_progress(bool auto_progress) { + send_mutex.set_needs_lock(!auto_progress); + ack_recv_mutex.set_needs_lock(!auto_progress); + other_recv_mutex.set_needs_lock(!auto_progress); + } + int getNextSends(void **elems, size_t num_elems = 1) { send_mutex.lock(); if (send.size() < num_elems) { From c9845880d32e263672629922af3962bb831636f8 Mon Sep 17 00:00:00 2001 From: Anshuman Goswami Date: Fri, 6 Mar 2026 21:20:15 +0000 Subject: [PATCH 09/11] transport/libfabric: Limit queue drain to configurable batch size Cap the three queue-draining loops in the proxy progress path (CQ poll, AMO ack processing, AMO processing) to a maximum number of items per iteration, preventing any single queue from starving the others. The limit is controlled by the NVSHMEM_LIBFABRIC_PROXY_REQUEST_BATCH_MAX environment variable (default: 32). Signed-off-by: Anshuman Goswami --- src/modules/transport/common/env_defs.h | 4 ++ src/modules/transport/libfabric/libfabric.cpp | 45 +++++++++++-------- src/modules/transport/libfabric/libfabric.h | 3 ++ 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/modules/transport/common/env_defs.h b/src/modules/transport/common/env_defs.h index f7d2cc38..654d8dfc 100644 --- a/src/modules/transport/common/env_defs.h +++ b/src/modules/transport/common/env_defs.h @@ -101,6 +101,10 @@ NVSHMEMI_ENV_DEF(LIBFABRIC_PROVIDER, string, "cxi", NVSHMEMI_ENV_CAT_TRANSPORT, NVSHMEMI_ENV_DEF(LIBFABRIC_MAX_NIC_PER_PE, int, 16, NVSHMEMI_ENV_CAT_TRANSPORT, "Set the maximum number of NIC's per PE to use for libfabric provider") +NVSHMEMI_ENV_DEF(LIBFABRIC_PROXY_REQUEST_BATCH_MAX, int, 32, NVSHMEMI_ENV_CAT_TRANSPORT, + "Maximum number of requests that the libfabric transport processes per queue " + "in a single iteration of the progress loop.") + #if defined(NVSHMEM_IBGDA_SUPPORT) || defined(NVSHMEM_ENV_ALL) /** GPU-initiated communication **/ NVSHMEMI_ENV_DEF(IBGDA_ENABLE_MULTI_PORT, bool, false, NVSHMEMI_ENV_CAT_TRANSPORT, diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 27f74487..c06c6ebc 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -62,6 +62,7 @@ static bool use_gdrcopy = false; #endif #define MAX_COMPLETIONS_PER_CQ_POLL 300 +#define MAX_COMPLETIONS_PER_CQ_POLL_EFA 32 #define NVSHMEM_STAGED_AMO_WIREDATA_SIZE \ sizeof(nvshmemt_libfabric_gdr_op_ctx_t) - sizeof(struct fi_context2) - sizeof(fi_addr_t) @@ -228,8 +229,11 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo static int nvshmemt_libfabric_single_ep_progress(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep) { nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)transport->state; - char buf[MAX_COMPLETIONS_PER_CQ_POLL * sizeof(struct fi_cq_data_entry)]; - fi_addr_t src_addr[MAX_COMPLETIONS_PER_CQ_POLL]; + int max_per_poll = (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) + ? MAX_COMPLETIONS_PER_CQ_POLL_EFA + : MAX_COMPLETIONS_PER_CQ_POLL; + char buf[max_per_poll * sizeof(struct fi_cq_data_entry)]; + fi_addr_t src_addr[max_per_poll]; fi_addr_t *addr; ssize_t qstatus; struct fi_cq_data_entry *entry; @@ -259,23 +263,20 @@ static int nvshmemt_libfabric_single_ep_progress(nvshmem_transport_t transport, return NVSHMEMX_ERROR_INTERNAL; } - do { - qstatus = fi_cq_readfrom(ep->cq, buf, MAX_COMPLETIONS_PER_CQ_POLL, src_addr); - /* Note - EFA provider does not support selective completions */ - if (qstatus > 0) { - if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - entry = (struct fi_cq_data_entry *)buf; - addr = src_addr; - for (int i = 0; i < qstatus; i++, entry++, addr++) { - status = nvshmemt_libfabric_gdr_process_completion(transport, ep, entry, addr); - if (status) return NVSHMEMX_ERROR_INTERNAL; - } - } else { - NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP\n", qstatus); + qstatus = fi_cq_readfrom(ep->cq, buf, max_per_poll, src_addr); + /* Note - EFA provider does not support selective completions */ + if (qstatus > 0) { + if (state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { + entry = (struct fi_cq_data_entry *)buf; + addr = src_addr; + for (int i = 0; i < qstatus; i++, entry++, addr++) { + status = nvshmemt_libfabric_gdr_process_completion(transport, ep, entry, addr); + if (status) return NVSHMEMX_ERROR_INTERNAL; } + } else { + NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP\n", qstatus); } - } while (qstatus > 0); - if (qstatus < 0 && qstatus != -FI_EAGAIN) { + } else if (qstatus < 0 && qstatus != -FI_EAGAIN) { NVSHMEMI_WARN_PRINT("Error progressing CQ (%zd): %s\n", qstatus, fi_strerror(qstatus * -1)); return NVSHMEMX_ERROR_INTERNAL; } @@ -633,6 +634,7 @@ int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport, int q } for (int i = qp_index; i < end_iter; i++) { + int ops_processed = 0; size_t num_retries = 0; do { @@ -645,6 +647,7 @@ int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport, int q num_retries = 0; if (op) { + ops_processed++; status = nvshmemt_libfabric_gdr_process_ack(transport, op); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to process atomic.\n"); @@ -653,7 +656,7 @@ int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport, int q NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to re-post recv.\n"); } - } while (op); + } while (op && ops_processed < libfabric_state->proxy_request_batch_max); } @@ -680,6 +683,7 @@ int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_in } for (int i = qp_index; i < end_iter; i++) { + int ops_processed = 0; do { do { @@ -691,6 +695,7 @@ int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_in num_retries = 0; if (op) { + ops_processed++; if (op->type == NVSHMEMT_LIBFABRIC_SEND) { assert(send_elems[0] != NULL); assert(send_elems[1] != NULL); @@ -709,7 +714,7 @@ int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_in /* Reposts recv in perform_gdrcopy_amo() */ } } - } while (op); + } while (op && ops_processed < libfabric_state->proxy_request_batch_max); } @@ -2541,6 +2546,8 @@ int nvshmemt_init(nvshmem_transport_t *t, struct nvshmemi_cuda_fn_table *table, "Failed to initialize the libfabric state.\n"); } + libfabric_state->proxy_request_batch_max = options.LIBFABRIC_PROXY_REQUEST_BATCH_MAX; + if (use_auto_progress) transport->host_ops.progress = nvshmemt_libfabric_auto_proxy_progress; else diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index 4c84f7bc..502c45d5 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -555,6 +555,9 @@ typedef struct { /* Signal ordering state */ nvshmemt_libfabric_signal_state_t host_signal_state; nvshmemt_libfabric_signal_state_t proxy_signal_state; + + /* Max ops per progress iteration */ + int proxy_request_batch_max; } nvshmemt_libfabric_state_t; typedef struct { From e834424b0ffbc6c44e290e8b8fc428b039176bf5 Mon Sep 17 00:00:00 2001 From: Amit Radzi Date: Sat, 14 Mar 2026 00:57:45 +0000 Subject: [PATCH 10/11] transport/libfabric: Replace counter-based completion tracking with CQ-based tracking Remove fi_cntr usage from the libfabric transport and track operation completions entirely through completion queue (CQ) entries. The endpoint now maintains a completed_ops counter that is incremented in the CQ completion handler, and quiet polls the CQ instead of reading or waiting on a counter. This eliminates the overhead of counter reads in the progress. The error handling path is also moved from a counter-error pre-check into the existing CQ error path (fi_cq_readerr), consolidating error reporting. Signed-off-by: Amit Radzi --- src/modules/transport/libfabric/libfabric.cpp | 113 +++++------------- src/modules/transport/libfabric/libfabric.h | 2 +- 2 files changed, 34 insertions(+), 81 deletions(-) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index c06c6ebc..003e1301 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -204,9 +204,11 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo if (entry->flags & FI_SEND) { state->op_queue[ep->domain_index]->putToSend(op); + ep->completed_ops++; } else if (entry->flags & FI_RMA) { /* inlined p ops or atomic responses */ state->op_queue[ep->domain_index]->putToSend(op); + ep->completed_ops++; } else if ((op->type == NVSHMEMT_LIBFABRIC_MATCH) && (entry->flags & FI_RECV)) { /* Must happen after entry->flags & FI_SEND to avoid send completions */ status = nvshmemt_libfabric_put_signal_completion(transport, ep, entry, addr); @@ -237,31 +239,8 @@ static int nvshmemt_libfabric_single_ep_progress(nvshmem_transport_t transport, fi_addr_t *addr; ssize_t qstatus; struct fi_cq_data_entry *entry; - uint64_t cnt; int status = 0; - - cnt = fi_cntr_readerr(ep->counter); - if (cnt > 0) { - NVSHMEMI_WARN_PRINT("Nonzero error count progressing EP (%" PRIu64 ")\n", cnt); - struct fi_cq_err_entry err; - memset(&err, 0, sizeof(struct fi_cq_err_entry)); - ssize_t nerr = fi_cq_readerr(ep->cq, &err, 0); - - if (nerr > 0) { - char str[100] = "\0"; - const char *err_str = fi_cq_strerror(ep->cq, err.prov_errno, err.err_data, str, 100); - NVSHMEMI_WARN_PRINT( - "CQ reported error (%d): %s\n\tProvider error: %s\n\tSupplemental error " - "info: %s\n", - err.err, fi_strerror(err.err), err_str ? err_str : "none", - strlen(str) ? str : "none"); - } else if (nerr == -FI_EAGAIN) { - NVSHMEMI_WARN_PRINT("fi_cq_readerr returned -FI_EAGAIN\n"); - } else { - NVSHMEMI_WARN_PRINT("fi_cq_readerr returned %zd: %s\n", nerr, fi_strerror(-1 * nerr)); - } - return NVSHMEMX_ERROR_INTERNAL; - } + int ret = 0; qstatus = fi_cq_readfrom(ep->cq, buf, max_per_poll, src_addr); /* Note - EFA provider does not support selective completions */ @@ -277,7 +256,22 @@ static int nvshmemt_libfabric_single_ep_progress(nvshmem_transport_t transport, NVSHMEMI_WARN_PRINT("Got %zd unexpected events on EP\n", qstatus); } } else if (qstatus < 0 && qstatus != -FI_EAGAIN) { - NVSHMEMI_WARN_PRINT("Error progressing CQ (%zd): %s\n", qstatus, fi_strerror(qstatus * -1)); + /* On call to fi_cq_readerr, Libfabric requires some members of + * err_entry to be zero-initialized or point to valid data. For + * simplicity, just zero out the whole struct. + */ + struct fi_cq_err_entry err_entry = {}; + + ret = fi_cq_readerr(ep->cq, &err_entry, 0); + if (ret == -FI_EAGAIN) { + return 0; + } else if (ret < 0) { + NVSHMEMI_WARN_PRINT("Unable to read from fi_cq_readerr. RC: %d. Error: %s\n", ret, fi_strerror(-ret)); + return NVSHMEMX_ERROR_INTERNAL; + } + + NVSHMEMI_WARN_PRINT("Received a CQE with error. RC: %d. Error: %d (%s)", err_entry.err, err_entry.prov_errno, + fi_cq_strerror(ep->cq, err_entry.prov_errno, err_entry.err_data, NULL, 0)); return NVSHMEMX_ERROR_INTERNAL; } @@ -875,7 +869,6 @@ int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int qp_index) { nvshmemt_libfabric_state_t *state = (nvshmemt_libfabric_state_t *)tcurr->state; - uint64_t completed; bool all_nics_quieted; int status = 0; int end_iter; @@ -887,45 +880,21 @@ static int nvshmemt_libfabric_quiet(struct nvshmem_transport *tcurr, int pe, int qp_index = NVSHMEMT_LIBFABRIC_PROXY_EP_IDX; } - if (use_staged_atomics) { - for (;;) { - all_nics_quieted = true; - for (int i = qp_index; i < end_iter; i++) { - - /* Quick out if the endpoint is still quiet since last time */ - if (state->eps[i]->submitted_ops == state->eps[i]->completed_ctr) { - continue; - } - - completed = fi_cntr_read(state->eps[i]->counter) + - state->eps[i]->completed_staged_atomics; - state->eps[i]->completed_ctr = completed; - - if (state->eps[i]->submitted_ops != completed) { - all_nics_quieted = false; - if (nvshmemt_libfabric_progress(tcurr, qp_index)) { - status = NVSHMEMX_ERROR_INTERNAL; - break; - } - } - } - if (status || all_nics_quieted) break; - } - } else { + for (;;) { + all_nics_quieted = true; for (int i = qp_index; i < end_iter; i++) { - status = fi_cntr_wait(state->eps[i]->counter, state->eps[i]->submitted_ops, - NVSHMEMT_LIBFABRIC_QUIET_TIMEOUT_MS); - if (status) { - /* note - Status is negative for this function in error cases but - * fi_strerror only accepts positive values. - */ - NVSHMEMI_ERROR_PRINT("Error in quiet operation (%d): %s.\n", status, - fi_strerror(status * -1)); - status = NVSHMEMX_ERROR_INTERNAL; + if (state->eps[i]->submitted_ops != state->eps[i]->completed_ops) { + all_nics_quieted = false; + if (nvshmemt_libfabric_progress(tcurr, qp_index)) { + status = NVSHMEMX_ERROR_INTERNAL; + break; + } } } + if (status || all_nics_quieted) break; } + return status; } @@ -1104,7 +1073,6 @@ static int nvshmemt_libfabric_rma(struct nvshmem_transport *tcurr, int pe, rma_v if (seq_counter.put_count >= NVSHMEM_STAGED_AMO_PUT_ACK_FREQ) { header = NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT_WITH_ACK_REQ; seq_counter.put_count = 0; - ep->submitted_ops++; // Account for incoming ack } else { header = NVSHMEMT_LIBFABRIC_IMM_STANDALONE_PUT; } @@ -1185,7 +1153,7 @@ static int nvshmemt_libfabric_gdr_amo(struct nvshmem_transport *transport, int p NVSHMEMI_ERROR_PRINT("Received an error when trying to post an AMO operation.\n"); status = NVSHMEMX_ERROR_INTERNAL; } else { - ep->submitted_ops += 2; + ep->submitted_ops++; } out: @@ -1383,7 +1351,7 @@ static int nvshmemt_libfabric_gdr_signal(struct nvshmem_transport *transport, in NVSHMEMI_ERROR_PRINT("Received an error when trying to post a signal operation.\n"); status = NVSHMEMX_ERROR_INTERNAL; } else { - ep->submitted_ops += 2; + ep->submitted_ops++; } out: @@ -1742,7 +1710,6 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele struct fid_mr *mr; struct fi_av_attr av_attr; struct fi_cq_attr cq_attr; - struct fi_cntr_attr cntr_attr; size_t ep_namelen = NVSHMEMT_LIBFABRIC_EP_LEN; int status = 0; int total_num_eps; @@ -1788,10 +1755,6 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele av_attr.type = FI_AV_TABLE; av_attr.count = state->num_selected_domains * n_pes; - memset(&cntr_attr, 0, sizeof(struct fi_cntr_attr)); - cntr_attr.events = FI_CNTR_EVENTS_COMP; - cntr_attr.wait_obj = FI_WAIT_UNSPEC; - /* Find fabric info for each selected device */ for (int dev_idx = 0; dev_idx < state->num_selected_devs; dev_idx++) { current_info = state->all_prov_info; @@ -1907,17 +1870,14 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele state->eps[i]->domain_index = i; state->eps[i]->completed_staged_atomics = 0; + state->eps[i]->submitted_ops = 0; + state->eps[i]->completed_ops = 0; status = fi_cq_open(domain, &cq_attr, &state->eps[i]->cq, NULL); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to open completion queue for endpoint: %d: %s\n", status, fi_strerror(status * -1)); - status = fi_cntr_open(domain, &cntr_attr, &state->eps[i]->counter, NULL); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to open counter for endpoint: %d: %s\n", status, - fi_strerror(status * -1)); - status = fi_endpoint(domain, state->prov_infos[i], &state->eps[i]->endpoint, NULL); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to allocate endpoint: %d: %s\n", status, @@ -1966,13 +1926,6 @@ static int nvshmemt_libfabric_connect_endpoints(nvshmem_transport_t t, int *sele "Unable to bind endpoint to completion queue: %d: %s\n", status, fi_strerror(status * -1)); - flags = FI_READ | FI_WRITE; - if (use_staged_atomics) flags |= FI_SEND; - status = fi_ep_bind(state->eps[i]->endpoint, &state->eps[i]->counter->fid, flags); - NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, - "Unable to bind endpoint to completion counter: %d: %s\n", status, - fi_strerror(status * -1)); - status = fi_enable(state->eps[i]->endpoint); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "Unable to enable endpoint: %d: %s\n", status, diff --git a/src/modules/transport/libfabric/libfabric.h b/src/modules/transport/libfabric/libfabric.h index 502c45d5..392516a5 100644 --- a/src/modules/transport/libfabric/libfabric.h +++ b/src/modules/transport/libfabric/libfabric.h @@ -259,7 +259,7 @@ typedef struct { struct fid_cq *cq; struct fid_cntr *counter; uint64_t submitted_ops; - uint64_t completed_ctr; + uint64_t completed_ops; uint64_t completed_staged_atomics; int domain_index; } nvshmemt_libfabric_endpoint_t; From 9a630ee2198d41cfebfdd6896d1ce2c28ade25b2 Mon Sep 17 00:00:00 2001 From: Amit Radzi Date: Sat, 14 Mar 2026 01:00:12 +0000 Subject: [PATCH 11/11] transport/libfabric: Process AMO acks inline in CQ completion handler Move AMO ack processing out of the batched op_queue drain (nvshmemt_libfabric_gdr_process_amos_ack) and into the CQ completion handler directly. When an ACK recv completion arrives, the ack is processed and the recv buffer is re-posted immediately, eliminating the intermediate queue and the separate per-progress-call drain loop. This removes a level of indirection and the associated queue put/get overhead from the AMO ack path, and removes the nvshmemt_libfabric_gdr_process_amos_ack call from the main progress function. Signed-off-by: Amit Radzi --- src/modules/transport/libfabric/libfabric.cpp | 68 +++++++++---------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/src/modules/transport/libfabric/libfabric.cpp b/src/modules/transport/libfabric/libfabric.cpp index 003e1301..80b362b0 100644 --- a/src/modules/transport/libfabric/libfabric.cpp +++ b/src/modules/transport/libfabric/libfabric.cpp @@ -93,7 +93,6 @@ typedef enum { } nvshmemt_libfabric_try_again_call_site_t; int nvshmemt_libfabric_gdr_process_amos(nvshmem_transport_t transport, int qp_index); -int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport, int qp_index); int nvshmemt_libfabric_put_signal_completion(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry, fi_addr_t *addr); @@ -169,6 +168,31 @@ static inline bool is_signal_only_op(nvshmemi_amo_t op) { op == NVSHMEMI_AMO_SIGNAL_ADD); } +inline int nvshmemt_libfabric_gdr_process_ack(nvshmem_transport_t transport, + nvshmemt_libfabric_gdr_op_ctx_t *op) { + nvshmemt_libfabric_gdr_ret_amo_op_t *ret = &op->ret_amo; + nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; + nvshmemt_libfabric_memhandle_info_t *handle_info; + g_elem_t *elem; + void *valid_cpu_ptr; + + handle_info = (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get( + transport, libfabric_state->cache, ret->ret_addr); + if (!handle_info) { + NVSHMEMI_ERROR_PRINT("Unable to get handle info for atomic response.\n"); + return NVSHMEMX_ERROR_INTERNAL; + } + + valid_cpu_ptr = + (void *)((char *)handle_info->cpu_ptr + ((char *)ret->ret_addr - (char *)handle_info->ptr)); + assert(valid_cpu_ptr); + elem = (g_elem_t *)valid_cpu_ptr; + elem->data = ret->elem.data; + elem->flag = ret->elem.flag; + + return 0; +} + static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transport, nvshmemt_libfabric_endpoint_t *ep, struct fi_cq_data_entry *entry, @@ -215,7 +239,14 @@ static int nvshmemt_libfabric_gdr_process_completion(nvshmem_transport_t transpo } else if (entry->flags & FI_RECV) { op->ep = ep; if (op->type == NVSHMEMT_LIBFABRIC_ACK) { - state->op_queue[ep->domain_index]->putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_ACK); + status = nvshmemt_libfabric_gdr_process_ack(transport, op); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to process atomic.\n"); + + status = fi_recv(op->ep->endpoint, (void *)op, NVSHMEM_STAGED_AMO_WIREDATA_SIZE, + fi_mr_desc(state->mr[op->ep->domain_index]), FI_ADDR_UNSPEC, &op->ofi_context); + NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, + "Unable to re-post recv.\n"); } else { state->op_queue[ep->domain_index]->putToRecv(op, NVSHMEMT_LIBFABRIC_RECV_TYPE_NOT_ACK); } @@ -326,15 +357,6 @@ static int nvshmemt_libfabric_process_completions(nvshmem_transport_t transport, else status = nvshmemt_libfabric_manual_progress(transport); - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; - if (libfabric_state->provider == NVSHMEMT_LIBFABRIC_PROVIDER_EFA) { - int progress_qp_index = (use_auto_progress ? qp_index : NVSHMEMX_QP_ALL); - status = nvshmemt_libfabric_gdr_process_amos_ack(transport, progress_qp_index); - if (status) { - return NVSHMEMX_ERROR_INTERNAL; - } - } - return status; } @@ -586,30 +608,6 @@ int nvshmemt_libfabric_gdr_process_amo(nvshmem_transport_t transport, return status; } -int nvshmemt_libfabric_gdr_process_ack(nvshmem_transport_t transport, - nvshmemt_libfabric_gdr_op_ctx_t *op) { - nvshmemt_libfabric_gdr_ret_amo_op_t *ret = &op->ret_amo; - nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; - nvshmemt_libfabric_memhandle_info_t *handle_info; - g_elem_t *elem; - void *valid_cpu_ptr; - - handle_info = (nvshmemt_libfabric_memhandle_info_t *)nvshmemt_mem_handle_cache_get( - transport, libfabric_state->cache, ret->ret_addr); - if (!handle_info) { - NVSHMEMI_ERROR_PRINT("Unable to get handle info for atomic response.\n"); - return NVSHMEMX_ERROR_INTERNAL; - } - - valid_cpu_ptr = - (void *)((char *)handle_info->cpu_ptr + ((char *)ret->ret_addr - (char *)handle_info->ptr)); - assert(valid_cpu_ptr); - elem = (g_elem_t *)valid_cpu_ptr; - elem->data = ret->elem.data; - elem->flag = ret->elem.flag; - return 0; -} - int nvshmemt_libfabric_gdr_process_amos_ack(nvshmem_transport_t transport, int qp_index) { nvshmemt_libfabric_state_t *libfabric_state = (nvshmemt_libfabric_state_t *)transport->state; nvshmemt_libfabric_gdr_op_ctx_t *op;