Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ use_repo(
)

bazel_dep(name = "abseil-cpp", version = "20260107.1", repo_name = "com_google_absl")
bazel_dep(name = "cpu_features", version = "0.9.0")

llvm = use_extension("@xla//third_party/extensions:llvm.bzl", "llvm_extension")
use_repo(llvm, "llvm-project")
Expand Down Expand Up @@ -137,3 +138,9 @@ raiden_version_repo(
name = "raiden_version",
pyproject_toml = "//:pyproject.toml",
)

git_override(
module_name = "cpu_features",
commit = "14b80e8fe005a72a27a0d55d0b99b8e3f487317d",
remote = "https://github.com/google/cpu_features.git",
)
17 changes: 17 additions & 0 deletions tpu_raiden/core/raw_transfer_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,19 @@ struct PjRtEventBundle {
}
};

inline void DisplaceCpuCache() {
const size_t size = 128 * 1024 * 1024; // 128 MB
thread_local std::unique_ptr<char[]> dummy = []() {
return std::unique_ptr<char[]>(new char[size]);
}();
volatile char* p = dummy.get();
char sum = 0;
for (size_t i = 0; i < size; ++i) {
sum += p[i];
}
(void)sum;
}

struct PjRtCopyFuture {
xla::Future<> future;
BufferHolders holds;
Expand All @@ -337,6 +350,7 @@ struct PjRtCopyFuture {
// Await use PJRT_Event_* instead of xla::Future/AsyncValue. A vector so that
// JoinPjRtCopyFutures can aggregate bundles without re-owning/freeing events.
std::vector<std::shared_ptr<PjRtEventBundle>> event_bundles;
bool is_d2h = false;

PjRtCopyFuture() = default;
PjRtCopyFuture(xla::Future<> f, BufferHolders h,
Expand Down Expand Up @@ -485,6 +499,9 @@ struct PjRtCopyFuture {
status = av->GetError();
}
}
if (status.ok() && is_d2h) {
DisplaceCpuCache();
}
return status;
}

Expand Down
5 changes: 4 additions & 1 deletion tpu_raiden/core/raw_transfer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ inline absl::StatusOr<PjRtCopyFuture> transfer_d2h_core(
holds.push_back(BufferHolder{hold.c_hold, hold.common_hold,
/*ext_hold=*/nullptr, /*user_hold=*/nullptr});
}
return PjRtCopyFuture(xla::JoinFutures(all_futures), std::move(holds));
PjRtCopyFuture out =
PjRtCopyFuture(xla::JoinFutures(all_futures), std::move(holds));
out.is_d2h = true;
return out;
}

// Pure H2D transfer core
Expand Down
4 changes: 3 additions & 1 deletion tpu_raiden/frameworks/torch/torch_raw_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,12 @@ PjRtCopyFuture IssueD2HCopy(xla::PjRtBuffer* src_buffer, uint8_t* dst_data,
futures.push_back(hold.CopyRawDeviceToHost(
dst_data + chunk.dst_offset, chunk.src_offset, chunk.size_bytes));
}
return PjRtCopyFuture(
PjRtCopyFuture out(
xla::JoinFutures(absl::MakeSpan(futures)),
{BufferHolder{hold.c_hold, hold.common_hold, /*ext_hold=*/nullptr,
std::move(user_hold)}});
out.is_d2h = true;
return out;
}

PjRtCopyFuture IssueH2DCopy(const uint8_t* src_data, size_t src_size,
Expand Down
2 changes: 1 addition & 1 deletion tpu_raiden/kv_cache/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@cpu_features//:cpuinfo",
"@xla//xla/pjrt:semaphore",
"@xla//xla/stream_executor:device_address",
],
Expand Down Expand Up @@ -236,7 +237,6 @@ cc_binary(
"//tpu_raiden/core:tpu_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
82 changes: 82 additions & 0 deletions tpu_raiden/kv_cache/kv_cache_manager_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
#include "tpu_raiden/kv_cache/kv_cache_manager_base.h"

#include <algorithm>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <functional>
#if defined(__x86_64__)
#include <immintrin.h>
#endif
#include <future> // NOLINT(build/c++11)
#include <map>
#include <memory>
Expand All @@ -31,6 +35,7 @@
#include <utility>
#include <vector>

#include "cpuinfo_x86.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -83,6 +88,79 @@ absl::Status ValidateOffsetsAndSizes(const std::vector<int64_t>& src_offsets,
return absl::OkStatus();
}

inline void LocalDmaWriteBarrier() {
#if defined(__x86_64__)
__asm__ __volatile__("sfence" : : : "memory");
#elif defined(__aarch64__)
__asm__ __volatile__("dmb oshst" : : : "memory");
#else
std::atomic_thread_fence(std::memory_order_seq_cst);
#endif
}

#if defined(__x86_64__)
struct CpuCacheFlushCapabilities {
bool has_clwb;
bool has_clflushopt;
bool has_clflush;

CpuCacheFlushCapabilities() {
auto info = cpu_features::GetX86Info();
has_clwb = info.features.clwb;
has_clflushopt = info.features.clflushopt;
has_clflush = info.features.clfsh;
fprintf(stderr, "RAIDEN_CPU_CAPS: clwb=%d, clflushopt=%d, clflush=%d\n",
has_clwb, has_clflushopt, has_clflush);
}
};

const CpuCacheFlushCapabilities& GetCpuCapabilities() {
static const CpuCacheFlushCapabilities caps;
return caps;
}

void FlushCpuCacheRangeNoBarrier(const void* ptr, size_t size_bytes,
const CpuCacheFlushCapabilities& caps) {
if (size_bytes == 0) return;

uintptr_t start = reinterpret_cast<uintptr_t>(ptr);
uintptr_t end = start + size_bytes;

// Align start to 64-byte boundary
start &= ~(uintptr_t)63;

if (caps.has_clwb) {
for (uintptr_t addr = start; addr < end; addr += 64) {
asm volatile("clwb %0" : : "m"(*reinterpret_cast<const char*>(addr)));
}
} else if (caps.has_clflushopt) {
for (uintptr_t addr = start; addr < end; addr += 64) {
asm volatile("clflushopt %0"
:
: "m"(*reinterpret_cast<const char*>(addr)));
}
} else if (caps.has_clflush) {
for (uintptr_t addr = start; addr < end; addr += 64) {
asm volatile("clflush %0" : : "m"(*reinterpret_cast<const char*>(addr)));
}
}
}

void FlushCpuCacheRange(const void* ptr, size_t size_bytes) {
const CpuCacheFlushCapabilities& caps = GetCpuCapabilities();
FlushCpuCacheRangeNoBarrier(ptr, size_bytes, caps);
LocalDmaWriteBarrier();
}
#else
void FlushCpuCacheRangeNoBarrier(const void* ptr, size_t size_bytes,
const CpuCacheFlushCapabilities& caps) {
// No-op on non-x86 architectures.
}
void FlushCpuCacheRange(const void* ptr, size_t size_bytes) {
// No-op on non-x86 architectures.
}
#endif

} // namespace

KVCacheManagerBase::KVCacheManagerBase(
Expand Down Expand Up @@ -991,6 +1069,9 @@ KVCacheManagerBase::DispatchH2dWork(
{base_host_ptr + src_offset, dst_offset, size_to_copy});
}
}
for (const auto& copy : copies) {
FlushCpuCacheRange(copy.src, copy.size);
}
TF_ASSIGN_OR_RETURN(raiden::PjRtCopyFuture cf,
raiden::IssueH2dShard(shard_hold, copies));
local_futures.push_back(std::move(cf));
Expand Down Expand Up @@ -1074,6 +1155,7 @@ KVCacheManagerBase::DispatchD2hWork(const std::vector<CopyWork>& works,
copies.push_back({dst_host_ptr + dst_offset, src_offset, size_to_copy});
}
}

TF_ASSIGN_OR_RETURN(raiden::PjRtCopyFuture cf,
raiden::IssueD2hShard(shard_hold, copies));
local_futures.push_back(std::move(cf));
Expand Down