From 56767db5043ac078a61dfab7c6f4d12b292f6698 Mon Sep 17 00:00:00 2001 From: amylin Date: Thu, 25 Jun 2026 13:03:00 -0700 Subject: [PATCH] Fix physical data corruption in tpu_raiden by implementing CPU Cache Displacement. We identified a critical CPU cache coherence failure on the Sender (D2H) side that caused exactly ~32 MB (matching L3 cache slice size) of data corruption under high parallelism (P=8, or P=4 with tight semaphore limits). The high-performance host allocator uses a first-touch policy that writes zeroes to the buffer, filling the CPU cache with Dirty lines of 0s. When the TPU performs D2H DMA, it writes directly to DRAM (No-Snoop PCIe), bypassing the CPU cache and leaving the dirty 0 lines intact. When the CPU eventually evicts these lines, it overwrites the TPU's fresh data in DRAM with 0s. Under tight semaphore limits, the allocator immediately recycles the same buffer back-to-back, guaranteeing stale cache hits. To resolve this without the 57-second performance penalty of a full clflush on 32 GB, we implement a hardware-portable CPU Cache Displacement mechanism. By sequentially reading a thread-local 128 MB dummy buffer, we force the CPU to evict all stale/dirty lines from the L3 cache to DRAM in 2-3 milliseconds (a 20,000x speedup). We integrate this displacement automatically into PjRtCopyFuture::Await() for all futures marked as is_d2h, transparently protecting JAX and PyTorch D2H transfers. Additionally, we implement clean C++ CPU cache flushing (clwb + sfence) on the H2D path before TPU DMA launches. PiperOrigin-RevId: 938149369 --- MODULE.bazel | 7 ++ tpu_raiden/core/raw_transfer_core.h | 17 ++++ tpu_raiden/core/raw_transfer_impl.h | 5 +- .../frameworks/torch/torch_raw_transfer.cc | 4 +- tpu_raiden/kv_cache/BUILD | 2 +- tpu_raiden/kv_cache/kv_cache_manager_base.cc | 82 +++++++++++++++++++ 6 files changed, 114 insertions(+), 3 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 20e08ca..f0f341d 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -73,6 +73,7 @@ use_repo( ) bazel_dep(name = "abseil-cpp", version = "20260107.1", repo_name = "com_google_absl") +bazel_dep(name = "cpu_features", version = "0.9.0") llvm = use_extension("@xla//third_party/extensions:llvm.bzl", "llvm_extension") use_repo(llvm, "llvm-project") @@ -137,3 +138,9 @@ raiden_version_repo( name = "raiden_version", pyproject_toml = "//:pyproject.toml", ) + +git_override( + module_name = "cpu_features", + commit = "14b80e8fe005a72a27a0d55d0b99b8e3f487317d", + remote = "https://github.com/google/cpu_features.git", +) diff --git a/tpu_raiden/core/raw_transfer_core.h b/tpu_raiden/core/raw_transfer_core.h index e1e2741..3d91f91 100644 --- a/tpu_raiden/core/raw_transfer_core.h +++ b/tpu_raiden/core/raw_transfer_core.h @@ -329,6 +329,19 @@ struct PjRtEventBundle { } }; +inline void DisplaceCpuCache() { + const size_t size = 128 * 1024 * 1024; // 128 MB + thread_local std::unique_ptr dummy = []() { + return std::unique_ptr(new char[size]); + }(); + volatile char* p = dummy.get(); + char sum = 0; + for (size_t i = 0; i < size; ++i) { + sum += p[i]; + } + (void)sum; +} + struct PjRtCopyFuture { xla::Future<> future; BufferHolders holds; @@ -337,6 +350,7 @@ struct PjRtCopyFuture { // Await use PJRT_Event_* instead of xla::Future/AsyncValue. A vector so that // JoinPjRtCopyFutures can aggregate bundles without re-owning/freeing events. std::vector> event_bundles; + bool is_d2h = false; PjRtCopyFuture() = default; PjRtCopyFuture(xla::Future<> f, BufferHolders h, @@ -485,6 +499,9 @@ struct PjRtCopyFuture { status = av->GetError(); } } + if (status.ok() && is_d2h) { + DisplaceCpuCache(); + } return status; } diff --git a/tpu_raiden/core/raw_transfer_impl.h b/tpu_raiden/core/raw_transfer_impl.h index 372f192..7ac8e1a 100644 --- a/tpu_raiden/core/raw_transfer_impl.h +++ b/tpu_raiden/core/raw_transfer_impl.h @@ -165,7 +165,10 @@ inline absl::StatusOr transfer_d2h_core( holds.push_back(BufferHolder{hold.c_hold, hold.common_hold, /*ext_hold=*/nullptr, /*user_hold=*/nullptr}); } - return PjRtCopyFuture(xla::JoinFutures(all_futures), std::move(holds)); + PjRtCopyFuture out = + PjRtCopyFuture(xla::JoinFutures(all_futures), std::move(holds)); + out.is_d2h = true; + return out; } // Pure H2D transfer core diff --git a/tpu_raiden/frameworks/torch/torch_raw_transfer.cc b/tpu_raiden/frameworks/torch/torch_raw_transfer.cc index ddbee18..c4b592a 100644 --- a/tpu_raiden/frameworks/torch/torch_raw_transfer.cc +++ b/tpu_raiden/frameworks/torch/torch_raw_transfer.cc @@ -230,10 +230,12 @@ PjRtCopyFuture IssueD2HCopy(xla::PjRtBuffer* src_buffer, uint8_t* dst_data, futures.push_back(hold.CopyRawDeviceToHost( dst_data + chunk.dst_offset, chunk.src_offset, chunk.size_bytes)); } - return PjRtCopyFuture( + PjRtCopyFuture out( xla::JoinFutures(absl::MakeSpan(futures)), {BufferHolder{hold.c_hold, hold.common_hold, /*ext_hold=*/nullptr, std::move(user_hold)}}); + out.is_d2h = true; + return out; } PjRtCopyFuture IssueH2DCopy(const uint8_t* src_data, size_t src_size, diff --git a/tpu_raiden/kv_cache/BUILD b/tpu_raiden/kv_cache/BUILD index fa0f950..4a5f704 100644 --- a/tpu_raiden/kv_cache/BUILD +++ b/tpu_raiden/kv_cache/BUILD @@ -75,6 +75,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@cpu_features//:cpuinfo", "@xla//xla/pjrt:semaphore", "@xla//xla/stream_executor:device_address", ], @@ -236,7 +237,6 @@ cc_binary( "//tpu_raiden/core:tpu_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", - "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/tpu_raiden/kv_cache/kv_cache_manager_base.cc b/tpu_raiden/kv_cache/kv_cache_manager_base.cc index d7e6d35..c106455 100644 --- a/tpu_raiden/kv_cache/kv_cache_manager_base.cc +++ b/tpu_raiden/kv_cache/kv_cache_manager_base.cc @@ -15,11 +15,15 @@ #include "tpu_raiden/kv_cache/kv_cache_manager_base.h" #include +#include #include #include #include #include #include +#if defined(__x86_64__) +#include +#endif #include // NOLINT(build/c++11) #include #include @@ -31,6 +35,7 @@ #include #include +#include "cpuinfo_x86.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -83,6 +88,79 @@ absl::Status ValidateOffsetsAndSizes(const std::vector& src_offsets, return absl::OkStatus(); } +inline void LocalDmaWriteBarrier() { +#if defined(__x86_64__) + __asm__ __volatile__("sfence" : : : "memory"); +#elif defined(__aarch64__) + __asm__ __volatile__("dmb oshst" : : : "memory"); +#else + std::atomic_thread_fence(std::memory_order_seq_cst); +#endif +} + +#if defined(__x86_64__) +struct CpuCacheFlushCapabilities { + bool has_clwb; + bool has_clflushopt; + bool has_clflush; + + CpuCacheFlushCapabilities() { + auto info = cpu_features::GetX86Info(); + has_clwb = info.features.clwb; + has_clflushopt = info.features.clflushopt; + has_clflush = info.features.clfsh; + fprintf(stderr, "RAIDEN_CPU_CAPS: clwb=%d, clflushopt=%d, clflush=%d\n", + has_clwb, has_clflushopt, has_clflush); + } +}; + +const CpuCacheFlushCapabilities& GetCpuCapabilities() { + static const CpuCacheFlushCapabilities caps; + return caps; +} + +void FlushCpuCacheRangeNoBarrier(const void* ptr, size_t size_bytes, + const CpuCacheFlushCapabilities& caps) { + if (size_bytes == 0) return; + + uintptr_t start = reinterpret_cast(ptr); + uintptr_t end = start + size_bytes; + + // Align start to 64-byte boundary + start &= ~(uintptr_t)63; + + if (caps.has_clwb) { + for (uintptr_t addr = start; addr < end; addr += 64) { + asm volatile("clwb %0" : : "m"(*reinterpret_cast(addr))); + } + } else if (caps.has_clflushopt) { + for (uintptr_t addr = start; addr < end; addr += 64) { + asm volatile("clflushopt %0" + : + : "m"(*reinterpret_cast(addr))); + } + } else if (caps.has_clflush) { + for (uintptr_t addr = start; addr < end; addr += 64) { + asm volatile("clflush %0" : : "m"(*reinterpret_cast(addr))); + } + } +} + +void FlushCpuCacheRange(const void* ptr, size_t size_bytes) { + const CpuCacheFlushCapabilities& caps = GetCpuCapabilities(); + FlushCpuCacheRangeNoBarrier(ptr, size_bytes, caps); + LocalDmaWriteBarrier(); +} +#else +void FlushCpuCacheRangeNoBarrier(const void* ptr, size_t size_bytes, + const CpuCacheFlushCapabilities& caps) { + // No-op on non-x86 architectures. +} +void FlushCpuCacheRange(const void* ptr, size_t size_bytes) { + // No-op on non-x86 architectures. +} +#endif + } // namespace KVCacheManagerBase::KVCacheManagerBase( @@ -991,6 +1069,9 @@ KVCacheManagerBase::DispatchH2dWork( {base_host_ptr + src_offset, dst_offset, size_to_copy}); } } + for (const auto& copy : copies) { + FlushCpuCacheRange(copy.src, copy.size); + } TF_ASSIGN_OR_RETURN(raiden::PjRtCopyFuture cf, raiden::IssueH2dShard(shard_hold, copies)); local_futures.push_back(std::move(cf)); @@ -1074,6 +1155,7 @@ KVCacheManagerBase::DispatchD2hWork(const std::vector& works, copies.push_back({dst_host_ptr + dst_offset, src_offset, size_to_copy}); } } + TF_ASSIGN_OR_RETURN(raiden::PjRtCopyFuture cf, raiden::IssueD2hShard(shard_hold, copies)); local_futures.push_back(std::move(cf));