diff --git a/comms/ctran/algos/AllGatherP/AlgoImpl.h b/comms/ctran/algos/AllGatherP/AlgoImpl.h index d0b04033d..2fb0832c1 100644 --- a/comms/ctran/algos/AllGatherP/AlgoImpl.h +++ b/comms/ctran/algos/AllGatherP/AlgoImpl.h @@ -52,6 +52,9 @@ class AlgoImpl { } } + // Allocate pipeSync and other internal resources. + commResult_t initResources(); + private: // Wait till either the async initialization is done or hit async error. // It is called before execution scheduling any CE copy to the stream. diff --git a/comms/ctran/algos/AllGatherP/AllGatherP.cc b/comms/ctran/algos/AllGatherP/AllGatherP.cc index dacb0c545..5d04720d4 100644 --- a/comms/ctran/algos/AllGatherP/AllGatherP.cc +++ b/comms/ctran/algos/AllGatherP/AllGatherP.cc @@ -68,6 +68,16 @@ extern __global__ void ncclKernelAllGatherPInit( int* flag, CtranAlgoDeviceState* devState); +commResult_t AlgoImpl::initResources() { + void* base = nullptr; + FB_CUDACHECK( + cudaHostAlloc(&base, sizeof(GpeKernelSync), cudaHostAllocDefault)); + + resource_.pipeSync = reinterpret_cast(base); + new (resource_.pipeSync) GpeKernelSync(1 /* numWorkers */); + return commSuccess; +} + commResult_t AlgoImpl::initialize() { auto opCount = comm_->ctran_->getOpCount(); CTRAN_COLL_INFO( @@ -80,12 +90,7 @@ commResult_t AlgoImpl::initialize() { comm_, stream_); - void* base = nullptr; - FB_CUDACHECK( - cudaHostAlloc(&base, sizeof(GpeKernelSync), cudaHostAllocDefault)); - - resource_.pipeSync = reinterpret_cast(base); - new (resource_.pipeSync) GpeKernelSync(1 /* numWorkers */); + FB_COMMCHECK(initResources()); KernelConfig config = KernelConfig( KernelConfig::KernelType::ALLGATHERP_INIT, diff --git a/comms/ncclx/v2_28/meta/rma/tests/RMATest.cc b/comms/ncclx/v2_28/meta/rma/tests/RMATest.cc index fa2e13c54..e6cc5d3d1 100644 --- a/comms/ncclx/v2_28/meta/rma/tests/RMATest.cc +++ b/comms/ncclx/v2_28/meta/rma/tests/RMATest.cc @@ -23,12 +23,7 @@ class RMATest : public NcclxBaseTest { NcclxBaseTest::SetUp(); this->comm = createNcclComm( - this->globalRank, - this->numRanks, - this->localRank, - false, - nullptr, - server.get()); + this->globalRank, this->numRanks, this->localRank, bootstrap_.get()); ASSERT_NE(this->comm, nullptr); } void TearDown() override { @@ -174,6 +169,47 @@ TEST_P(MultiWindowTestParam, multiWindow) { class RMATestParam : public RMATest, public ::testing::WithParamInterface< std::tuple> { + protected: + // Share a single communicator across all parameterized test cases. + // createNcclComm is expensive in multi-node configs (cross-node transport + // setup via socket bootstrap). Reusing it avoids per-test-case overhead. + static inline ncclComm_t shared_comm_ = nullptr; + + void SetUp() override { + setenv("NCCL_CTRAN_ENABLE", "1", 0); + setenv("NCCL_CTRAN_IB_EPOCH_LOCK_ENFORCE_CHECK", "true", 0); + NcclxBaseTest::SetUp(); + + if (shared_comm_ == nullptr) { + shared_comm_ = createNcclComm( + this->globalRank, + this->numRanks, + this->localRank, + bootstrap_.get()); + ASSERT_NE(shared_comm_, nullptr); + } + this->comm = shared_comm_; + } + + void TearDown() override { + // Barrier to sync all ranks before moving to the next test case. + // Ensures no rank starts TearDownTestSuite (ncclCommDestroy) while + // another is still using the shared comm. + if (shared_comm_ != nullptr) { + this->barrier(shared_comm_, nullptr); + } + // Don't destroy the shared comm — TearDownTestSuite handles it. + this->comm = nullptr; + EXPECT_TRUE(segments.empty()) << "Not all memory segments were freed"; + NcclxBaseTest::TearDown(); + } + + static void TearDownTestSuite() { + if (shared_comm_ != nullptr) { + ncclCommDestroy(shared_comm_); + shared_comm_ = nullptr; + } + } }; TEST_P(RMATestParam, winPutWait) { @@ -505,7 +541,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( // kNumElements, kNumIters, ctranAllReduce, bufType, userBuf ::testing::Values(8192, 8 * 1024 * 1024), - ::testing::Values(500), + ::testing::Values(50), ::testing::Values(true, false), ::testing::Values( MemAllocType::kMemNcclMemAlloc, @@ -537,7 +573,20 @@ INSTANTIATE_TEST_SUITE_P( class NvlEnabledTestParam : public RMATest, public ::testing::WithParamInterface< - std::tuple, bool>> {}; + std::tuple, bool>> { + protected: + // Skip fixture comm creation — this test creates its own comm with + // parameterized backends and never uses the fixture's this->comm. + void SetUp() override { + setenv("NCCL_CTRAN_ENABLE", "1", 0); + setenv("NCCL_CTRAN_IB_EPOCH_LOCK_ENFORCE_CHECK", "true", 0); + NcclxBaseTest::SetUp(); + } + void TearDown() override { + EXPECT_TRUE(segments.empty()) << "Not all memory segments were freed"; + NcclxBaseTest::TearDown(); + } +}; TEST_P(NvlEnabledTestParam, ncclWinGetAttributes) { const auto& [backends, expectNvlEnabled] = GetParam(); @@ -548,9 +597,7 @@ TEST_P(NvlEnabledTestParam, ncclWinGetAttributes) { this->globalRank, this->numRanks, this->localRank, - false, - nullptr, - server.get()); + bootstrap_.get()); ASSERT_NE(comm, nullptr); auto statex = comm->ctranComm_->statex_.get();