Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions comms/ctran/Ctran.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,24 @@ commResult_t allGatherPExec(
commResult_t allGatherPDestroy(CtranPersistentRequest* request);
bool allGatherPSupport(CtranComm* comm);

/* Window-based allgather using the same CE+IB algorithm as allGatherP.
* Buffer metadata is sourced from CtranWin (post-exchange) instead of
* PersistArgs. The window must have been allocated and exchanged before init.
*/
commResult_t allGatherWinInit(
CtranWin* win,
CtranComm* comm,
cudaStream_t stream,
CtranPersistentRequest*& request);

commResult_t allGatherWinExec(
const void* sendbuff,
const size_t count,
commDataType_t datatype,
CtranPersistentRequest* request);

commResult_t allGatherWinDestroy(CtranPersistentRequest* request);

// All array inout arguments are merely pointer without value at init time;
// value will be updated at execution
commResult_t allToAllvDedupInit(
Expand Down
3 changes: 3 additions & 0 deletions comms/ctran/algos/AllGatherP/AlgoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class AlgoImpl {
}
}

// Allocate pipeSync and other internal resources.
commResult_t initResources();

private:
// Wait till either the async initialization is done or hit async error.
// It is called before execution scheduling any CE copy to the stream.
Expand Down
17 changes: 11 additions & 6 deletions comms/ctran/algos/AllGatherP/AllGatherP.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ extern __global__ void ncclKernelAllGatherPInit(
int* flag,
CtranAlgoDeviceState* devState);

commResult_t AlgoImpl::initResources() {
void* base = nullptr;
FB_CUDACHECK(
cudaHostAlloc(&base, sizeof(GpeKernelSync), cudaHostAllocDefault));

resource_.pipeSync = reinterpret_cast<GpeKernelSync*>(base);
new (resource_.pipeSync) GpeKernelSync(1 /* numWorkers */);
return commSuccess;
}

commResult_t AlgoImpl::initialize() {
auto opCount = comm_->ctran_->getOpCount();
CTRAN_COLL_INFO(
Expand All @@ -80,12 +90,7 @@ commResult_t AlgoImpl::initialize() {
comm_,
stream_);

void* base = nullptr;
FB_CUDACHECK(
cudaHostAlloc(&base, sizeof(GpeKernelSync), cudaHostAllocDefault));

resource_.pipeSync = reinterpret_cast<GpeKernelSync*>(base);
new (resource_.pipeSync) GpeKernelSync(1 /* numWorkers */);
FB_COMMCHECK(initResources());

KernelConfig config = KernelConfig(
KernelConfig::KernelType::ALLGATHERP_INIT,
Expand Down
111 changes: 111 additions & 0 deletions comms/ctran/algos/AllGatherP/WinAllGather.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

#include "comms/ctran/CtranComm.h"
#include "comms/ctran/algos/AllGatherP/AlgoImpl.h"
#include "comms/ctran/algos/CtranAlgo.h"
#include "comms/ctran/utils/Checks.h"
#include "comms/ctran/window/CtranWin.h"
#include "comms/utils/cvars/nccl_cvars.h"

using ctran::allgatherp::AlgoImpl;

#define CHECK_VALID_PREQ(pReq) \
do { \
if (!(pReq)) { \
FB_ERRORRETURN( \
commInvalidArgument, \
"Null PersistentRequest passed to {}", \
__func__); \
} \
if (pReq->type != CtranPersistentRequest::Type::ALLGATHER_P_WIN) { \
FB_ERRORRETURN( \
commInvalidArgument, \
"Unexpected PersistentRequest type {} called into {}", \
pReq->type, \
__func__); \
} \
} while (0)

namespace ctran {

commResult_t allGatherWinInit(
CtranWin* win,
CtranComm* comm,
cudaStream_t stream,
CtranPersistentRequest*& request) {
const auto statex = comm->statex_.get();
const auto nRanks = statex->nRanks();

if (win->remWinInfo.empty() ||
static_cast<int>(win->remWinInfo.size()) != nRanks) {
FB_ERRORRETURN(
commInvalidArgument,
"Window remWinInfo not populated (size {}). "
"Was exchange() called?",
win->remWinInfo.size());
}

auto algo = std::make_unique<AlgoImpl>(comm, stream);

// Populate pArgs from window remote info
algo->pArgs.recvbuff = win->winDataPtr;
algo->pArgs.recvHdl = win->dataRegHdl;
algo->pArgs.remoteRecvBuffs.resize(nRanks);
algo->pArgs.remoteAccessKeys.resize(nRanks);
for (int r = 0; r < nRanks; r++) {
algo->pArgs.remoteRecvBuffs[r] = win->remWinInfo[r].dataAddr;
algo->pArgs.remoteAccessKeys[r] = win->remWinInfo[r].dataRkey;
}
// Window already exchanged remote info, mark as initialized
algo->pArgs.initialized.store(true);

FB_COMMCHECK(algo->initResources());

request = new CtranPersistentRequest(
CtranPersistentRequest::Type::ALLGATHER_P_WIN, comm, stream);
request->algo = algo.release();

return commSuccess;
}

commResult_t allGatherWinExec(
const void* sendbuff,
const size_t count,
commDataType_t datatype,
CtranPersistentRequest* request) {
CHECK_VALID_PREQ(request);

auto* algo = reinterpret_cast<AlgoImpl*>(request->algo);

switch (NCCL_ALLGATHER_P_ALGO) {
case NCCL_ALLGATHER_P_ALGO::ctdirect:
return algo->execDirect(sendbuff, count, datatype);
case NCCL_ALLGATHER_P_ALGO::ctpipeline:
return algo->execPipeline(sendbuff, count, datatype);
default:
return ErrorStackTraceUtil::log(commInternalError);
}
}

commResult_t allGatherWinDestroy(CtranPersistentRequest* request) {
CHECK_VALID_PREQ(request);

auto* algo = reinterpret_cast<AlgoImpl*>(request->algo);
if (!algo) {
return commSuccess;
}
FB_COMMCHECK(algo->destroy());
delete algo;
request->algo = nullptr;

CLOGF_SUBSYS(
INFO,
INIT,
"allGatherWinDestroy: rank {} destroyed request {}",
request->comm_->statex_->rank(),
(void*)request);

return commSuccess;
}

} // namespace ctran
1 change: 1 addition & 0 deletions comms/ctran/algos/CtranAlgo.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class CtranPersistentRequest {
ALLTOALL_DEDUP,
ALLTOALL_P,
ALLTOALLV_DEDUP,
ALLGATHER_P_WIN,
};

Type type;
Expand Down
165 changes: 165 additions & 0 deletions comms/ctran/tests/CtranWinAllGatherTest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <folly/init/Init.h>

#include "comms/ctran/Ctran.h"
#include "comms/ctran/tests/CtranDistTestUtils.h"
#include "comms/ctran/window/CtranWin.h"
#include "comms/utils/cvars/nccl_cvars.h"

using namespace ctran;

class CtranWinAllGatherTest : public ctran::CtranDistTestFixture {
public:
CtranWinAllGatherTest() = default;

protected:
void SetUp() override {
setenv("NCCL_CTRAN_ENABLE", "1", 0);
CtranDistTestFixture::SetUp();
}

void TearDown() override {
CtranDistTestFixture::TearDown();
}

void verifyAllGather(
void* recvBuf,
size_t sendCount,
size_t sendBytes,
int nRanks,
int myRank,
int iter) {
for (int r = 0; r < nRanks; r++) {
std::vector<float> observed(sendCount);
CUDACHECK_TEST(cudaMemcpy(
observed.data(),
static_cast<char*>(recvBuf) + r * sendBytes,
sendBytes,
cudaMemcpyDeviceToHost));

const float expected = static_cast<float>(r * 100 + iter);
for (size_t i = 0; i < sendCount; i++) {
EXPECT_EQ(observed[i], expected)
<< "rank " << myRank << " iter " << iter << " chunk from rank " << r
<< " element " << i;
}
}
}

void run(size_t sendCount, const std::string& algoStr) {
SysEnvRAII algoEnv("NCCL_ALLGATHER_P_ALGO", algoStr);

auto comm = makeCtranComm();
ASSERT_NE(comm, nullptr);

auto statex = comm->statex_.get();
ASSERT_NE(statex, nullptr);

const auto nRanks = statex->nRanks();
const auto myRank = statex->rank();
const commDataType_t dt = commFloat;
const size_t sendBytes = sendCount * commTypeSize(dt);
const size_t recvBytes = sendBytes * nRanks;

cudaStream_t stream;
CUDACHECK_TEST(cudaStreamCreate(&stream));

// Allocate recv buffer and register it with a window
void* winBase = nullptr;
CUDACHECK_TEST(cudaMalloc(&winBase, recvBytes));
CtranWin* win = nullptr;
auto res = ctranWinRegister(winBase, recvBytes, comm.get(), &win);
ASSERT_EQ(res, commSuccess);

if (!win->persistentColSupported()) {
CUDACHECK_TEST(cudaFree(winBase));
ASSERT_EQ(ctranWinFree(win), commSuccess);
CUDACHECK_TEST(cudaStreamDestroy(stream));
GTEST_SKIP()
<< "Window persistent collectives not supported on this topology";
}

// Allocate separate send buffer
void* sendbuf = nullptr;
CUDACHECK_TEST(cudaMalloc(&sendbuf, sendBytes));

// Initialize allgather state from the window
CtranPersistentRequest* request = nullptr;
ASSERT_EQ(
ctran::allGatherWinInit(win, comm.get(), stream, request), commSuccess);
ASSERT_NE(request, nullptr);

// Sync stream to ensure any init work is complete
CUDACHECK_TEST(cudaStreamSynchronize(stream));

constexpr int nIter = 3;
for (int iter = 0; iter < nIter; iter++) {
// Fill sendbuf with rank+iter specific values
const float sendVal = static_cast<float>(myRank * 100 + iter);
std::vector<float> sendVals(sendCount, sendVal);
CUDACHECK_TEST(cudaMemcpyAsync(
sendbuf, sendVals.data(), sendBytes, cudaMemcpyHostToDevice, stream));

// Clear recvbuf (window data buffer)
CUDACHECK_TEST(cudaMemsetAsync(winBase, 0, recvBytes, stream));

ASSERT_EQ(
ctran::allGatherWinExec(sendbuf, sendCount, dt, request),
commSuccess);
CUDACHECK_TEST(cudaStreamSynchronize(stream));

verifyAllGather(winBase, sendCount, sendBytes, nRanks, myRank, iter);
}

// Verify no GPE resource leak
ASSERT_EQ(comm->ctran_->gpe->numInUseKernelElems(), 0);
ASSERT_EQ(comm->ctran_->gpe->numInUseKernelFlags(), 0);

ASSERT_EQ(ctran::allGatherWinDestroy(request), commSuccess);
delete request;

CUDACHECK_TEST(cudaFree(sendbuf));
ASSERT_EQ(ctranWinFree(win), commSuccess);
CUDACHECK_TEST(cudaFree(winBase));
CUDACHECK_TEST(cudaStreamDestroy(stream));
}
};

class CtranWinAllGatherTestParam
: public CtranWinAllGatherTest,
public ::testing::WithParamInterface<std::tuple<size_t, std::string>> {};

TEST_P(CtranWinAllGatherTestParam, Basic) {
const auto& [sendCount, algoStr] = GetParam();
run(sendCount, algoStr);
}

INSTANTIATE_TEST_SUITE_P(
WinAllGather,
CtranWinAllGatherTestParam,
::testing::Combine(
::testing::Values(1024, 8192, 65536),
::testing::Values("ctdirect", "ctpipeline")),
[](const ::testing::TestParamInfo<CtranWinAllGatherTestParam::ParamType>&
info) {
return "count_" + std::to_string(std::get<0>(info.param)) + "_" +
std::get<1>(info.param);
});

class CtranWinAllGatherTestEnv : public ctran::CtranEnvironmentBase {
public:
void SetUp() override {
ctran::CtranEnvironmentBase::SetUp();
setenv("NCCL_DEBUG", "WARN", 0);
}
};

int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
::testing::AddGlobalTestEnvironment(new CtranWinAllGatherTestEnv());
return RUN_ALL_TESTS();
}
5 changes: 5 additions & 0 deletions comms/ctran/window/CtranWin.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ struct CtranWin {
bufType_ == DevMemType::kCumem;
}

// Check whether this window's communicator supports persistent collectives.
// Returns true if ctran is initialized and all peers have configured
// backends.
bool persistentColSupported() const;

private:
DevMemType bufType_{DevMemType::kCumem};
// whether allocate window data buffer or provided by users
Expand Down
16 changes: 16 additions & 0 deletions comms/ctran/window/window.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,22 @@ commResult_t CtranWin::exchange() {
return commSuccess;
}

bool CtranWin::persistentColSupported() const {
if (!ctranInitialized(comm)) {
return false;
}
auto statex = comm->statex_.get();
auto mapper = comm->ctran_->mapper.get();
const auto myRank = statex->rank();
for (int rank = 0; rank < statex->nRanks(); rank++) {
if (rank != myRank &&
mapper->getBackend(rank) == CtranMapperBackend::UNSET) {
return false;
}
}
return true;
}

commResult_t CtranWin::allocate(void* userBufPtr) {
auto statex = comm->statex_.get();
const auto myRank = statex->rank();
Expand Down
Loading
Loading