diff --git a/MODULE.bazel b/MODULE.bazel index 20e08ca..f0f341d 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -73,6 +73,7 @@ use_repo( ) bazel_dep(name = "abseil-cpp", version = "20260107.1", repo_name = "com_google_absl") +bazel_dep(name = "cpu_features", version = "0.9.0") llvm = use_extension("@xla//third_party/extensions:llvm.bzl", "llvm_extension") use_repo(llvm, "llvm-project") @@ -137,3 +138,9 @@ raiden_version_repo( name = "raiden_version", pyproject_toml = "//:pyproject.toml", ) + +git_override( + module_name = "cpu_features", + commit = "14b80e8fe005a72a27a0d55d0b99b8e3f487317d", + remote = "https://github.com/google/cpu_features.git", +) diff --git a/tpu_raiden/core/raw_transfer_core.h b/tpu_raiden/core/raw_transfer_core.h index e1e2741..3d91f91 100644 --- a/tpu_raiden/core/raw_transfer_core.h +++ b/tpu_raiden/core/raw_transfer_core.h @@ -329,6 +329,19 @@ struct PjRtEventBundle { } }; +inline void DisplaceCpuCache() { + const size_t size = 128 * 1024 * 1024; // 128 MB + thread_local std::unique_ptr dummy = []() { + return std::unique_ptr(new char[size]); + }(); + volatile char* p = dummy.get(); + char sum = 0; + for (size_t i = 0; i < size; ++i) { + sum += p[i]; + } + (void)sum; +} + struct PjRtCopyFuture { xla::Future<> future; BufferHolders holds; @@ -337,6 +350,7 @@ struct PjRtCopyFuture { // Await use PJRT_Event_* instead of xla::Future/AsyncValue. A vector so that // JoinPjRtCopyFutures can aggregate bundles without re-owning/freeing events. std::vector> event_bundles; + bool is_d2h = false; PjRtCopyFuture() = default; PjRtCopyFuture(xla::Future<> f, BufferHolders h, @@ -485,6 +499,9 @@ struct PjRtCopyFuture { status = av->GetError(); } } + if (status.ok() && is_d2h) { + DisplaceCpuCache(); + } return status; } diff --git a/tpu_raiden/core/raw_transfer_impl.h b/tpu_raiden/core/raw_transfer_impl.h index 372f192..7ac8e1a 100644 --- a/tpu_raiden/core/raw_transfer_impl.h +++ b/tpu_raiden/core/raw_transfer_impl.h @@ -165,7 +165,10 @@ inline absl::StatusOr transfer_d2h_core( holds.push_back(BufferHolder{hold.c_hold, hold.common_hold, /*ext_hold=*/nullptr, /*user_hold=*/nullptr}); } - return PjRtCopyFuture(xla::JoinFutures(all_futures), std::move(holds)); + PjRtCopyFuture out = + PjRtCopyFuture(xla::JoinFutures(all_futures), std::move(holds)); + out.is_d2h = true; + return out; } // Pure H2D transfer core diff --git a/tpu_raiden/frameworks/torch/torch_raw_transfer.cc b/tpu_raiden/frameworks/torch/torch_raw_transfer.cc index ddbee18..c4b592a 100644 --- a/tpu_raiden/frameworks/torch/torch_raw_transfer.cc +++ b/tpu_raiden/frameworks/torch/torch_raw_transfer.cc @@ -230,10 +230,12 @@ PjRtCopyFuture IssueD2HCopy(xla::PjRtBuffer* src_buffer, uint8_t* dst_data, futures.push_back(hold.CopyRawDeviceToHost( dst_data + chunk.dst_offset, chunk.src_offset, chunk.size_bytes)); } - return PjRtCopyFuture( + PjRtCopyFuture out( xla::JoinFutures(absl::MakeSpan(futures)), {BufferHolder{hold.c_hold, hold.common_hold, /*ext_hold=*/nullptr, std::move(user_hold)}}); + out.is_d2h = true; + return out; } PjRtCopyFuture IssueH2DCopy(const uint8_t* src_data, size_t src_size, diff --git a/tpu_raiden/kv_cache/BUILD b/tpu_raiden/kv_cache/BUILD index fa0f950..4a5f704 100644 --- a/tpu_raiden/kv_cache/BUILD +++ b/tpu_raiden/kv_cache/BUILD @@ -75,6 +75,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@cpu_features//:cpuinfo", "@xla//xla/pjrt:semaphore", "@xla//xla/stream_executor:device_address", ], @@ -236,7 +237,6 @@ cc_binary( "//tpu_raiden/core:tpu_utils", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", - "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/tpu_raiden/kv_cache/kv_cache_manager_base.cc b/tpu_raiden/kv_cache/kv_cache_manager_base.cc index d7e6d35..c106455 100644 --- a/tpu_raiden/kv_cache/kv_cache_manager_base.cc +++ b/tpu_raiden/kv_cache/kv_cache_manager_base.cc @@ -15,11 +15,15 @@ #include "tpu_raiden/kv_cache/kv_cache_manager_base.h" #include +#include #include #include #include #include #include +#if defined(__x86_64__) +#include +#endif #include // NOLINT(build/c++11) #include #include @@ -31,6 +35,7 @@ #include #include +#include "cpuinfo_x86.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -83,6 +88,79 @@ absl::Status ValidateOffsetsAndSizes(const std::vector& src_offsets, return absl::OkStatus(); } +inline void LocalDmaWriteBarrier() { +#if defined(__x86_64__) + __asm__ __volatile__("sfence" : : : "memory"); +#elif defined(__aarch64__) + __asm__ __volatile__("dmb oshst" : : : "memory"); +#else + std::atomic_thread_fence(std::memory_order_seq_cst); +#endif +} + +#if defined(__x86_64__) +struct CpuCacheFlushCapabilities { + bool has_clwb; + bool has_clflushopt; + bool has_clflush; + + CpuCacheFlushCapabilities() { + auto info = cpu_features::GetX86Info(); + has_clwb = info.features.clwb; + has_clflushopt = info.features.clflushopt; + has_clflush = info.features.clfsh; + fprintf(stderr, "RAIDEN_CPU_CAPS: clwb=%d, clflushopt=%d, clflush=%d\n", + has_clwb, has_clflushopt, has_clflush); + } +}; + +const CpuCacheFlushCapabilities& GetCpuCapabilities() { + static const CpuCacheFlushCapabilities caps; + return caps; +} + +void FlushCpuCacheRangeNoBarrier(const void* ptr, size_t size_bytes, + const CpuCacheFlushCapabilities& caps) { + if (size_bytes == 0) return; + + uintptr_t start = reinterpret_cast(ptr); + uintptr_t end = start + size_bytes; + + // Align start to 64-byte boundary + start &= ~(uintptr_t)63; + + if (caps.has_clwb) { + for (uintptr_t addr = start; addr < end; addr += 64) { + asm volatile("clwb %0" : : "m"(*reinterpret_cast(addr))); + } + } else if (caps.has_clflushopt) { + for (uintptr_t addr = start; addr < end; addr += 64) { + asm volatile("clflushopt %0" + : + : "m"(*reinterpret_cast(addr))); + } + } else if (caps.has_clflush) { + for (uintptr_t addr = start; addr < end; addr += 64) { + asm volatile("clflush %0" : : "m"(*reinterpret_cast(addr))); + } + } +} + +void FlushCpuCacheRange(const void* ptr, size_t size_bytes) { + const CpuCacheFlushCapabilities& caps = GetCpuCapabilities(); + FlushCpuCacheRangeNoBarrier(ptr, size_bytes, caps); + LocalDmaWriteBarrier(); +} +#else +void FlushCpuCacheRangeNoBarrier(const void* ptr, size_t size_bytes, + const CpuCacheFlushCapabilities& caps) { + // No-op on non-x86 architectures. +} +void FlushCpuCacheRange(const void* ptr, size_t size_bytes) { + // No-op on non-x86 architectures. +} +#endif + } // namespace KVCacheManagerBase::KVCacheManagerBase( @@ -991,6 +1069,9 @@ KVCacheManagerBase::DispatchH2dWork( {base_host_ptr + src_offset, dst_offset, size_to_copy}); } } + for (const auto& copy : copies) { + FlushCpuCacheRange(copy.src, copy.size); + } TF_ASSIGN_OR_RETURN(raiden::PjRtCopyFuture cf, raiden::IssueH2dShard(shard_hold, copies)); local_futures.push_back(std::move(cf)); @@ -1074,6 +1155,7 @@ KVCacheManagerBase::DispatchD2hWork(const std::vector& works, copies.push_back({dst_host_ptr + dst_offset, src_offset, size_to_copy}); } } + TF_ASSIGN_OR_RETURN(raiden::PjRtCopyFuture cf, raiden::IssueD2hShard(shard_hold, copies)); local_futures.push_back(std::move(cf));