Skip to content

Commit b6e5f62

Browse files
authored
[feat] Add pcstore for enhanced PrefixCache performance (#393)
# Purpose Add `pcstore` for enhanced PrefixCache performance # Modifications - Data is paged out to Host and asynchronously written to SSD, freeing HBM space earlier. - Reads and writes are aggregated at Block-level granularity to increase SSD I/O size and improve performance. - For MLA models, data is loaded once from SSD and shared across Devices in DRAM, reducing SSD bandwidth pressure. # Test - `ucm/store/test/e2e/pcstore_embed.py` - `ucm/store/test/e2e/pcstore_fetch.py`
1 parent c87d6ef commit b6e5f62

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2253
-102
lines changed

ucm/shared/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
21
add_subdirectory(vendor)
32
add_subdirectory(trans)
43
add_subdirectory(test)

ucm/shared/trans/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ endif()
77
if(RUNTIME_ENVIRONMENT STREQUAL "simu")
88
add_subdirectory(simu)
99
endif()
10+
target_include_directories(trans PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/..)
1011

1112
file(GLOB_RECURSE UCMTRANS_CPY_SOURCE_FILES "./cpy/*.cc")
1213
pybind11_add_module(ucmtrans ${UCMTRANS_CPY_SOURCE_FILES})

ucm/shared/trans/ascend/ascend_buffer.cc

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,35 +42,15 @@ std::shared_ptr<void> Trans::AscendBuffer::MakeHostBuffer(size_t size)
4242
return nullptr;
4343
}
4444

45-
Status Trans::AscendBuffer::RegisterHostBuffer(void* ptr, size_t size)
45+
Status Buffer::RegisterHostBuffer(void* host, size_t size, void** pDevice)
4646
{
47-
if (registerOnHost_) {
48-
aclrtHostUnregister(registerOnHost_);
49-
registerOnHost_ = nullptr;
50-
registerOnDevice_ = nullptr;
51-
}
52-
auto ret = aclrtHostRegister(ptr, size, ACL_HOST_REGISTER_MAPPED, &registerOnDevice_);
53-
if (ret == ACL_SUCCESS) {
54-
registerOnHost_ = ptr;
55-
return Status::OK();
56-
}
57-
registerOnDevice_ = nullptr;
58-
return Status{ret, std::to_string(ret)};
59-
}
60-
61-
void Trans::AscendBuffer::UnregisterHostBuffer(void* ptr)
62-
{
63-
aclrtHostUnregister(ptr);
64-
if (registerOnHost_ == ptr) {
65-
registerOnHost_ = nullptr;
66-
registerOnDevice_ = nullptr;
67-
}
47+
void* device = nullptr;
48+
auto ret = aclrtHostRegister(host, size, ACL_HOST_REGISTER_MAPPED, &device);
49+
if (ret != ACL_SUCCESS) [[unlikely]] { return Status{ret, std::to_string(ret)}; }
50+
if (pDevice) { *pDevice = device; }
51+
return Status::OK();
6852
}
6953

70-
void* Trans::AscendBuffer::GetHostPtrOnDevice(void* ptr)
71-
{
72-
if (registerOnHost_ == ptr) { return registerOnDevice_; }
73-
return nullptr;
74-
}
54+
void Buffer::UnregisterHostBuffer(void* host) { aclrtHostUnregister(host); }
7555

7656
} // namespace UC::Trans

ucm/shared/trans/ascend/ascend_buffer.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,9 @@
2929
namespace UC::Trans {
3030

3131
class AscendBuffer : public ReservedBuffer {
32-
void* registerOnHost_{nullptr};
33-
void* registerOnDevice_{nullptr};
34-
3532
public:
3633
std::shared_ptr<void> MakeDeviceBuffer(size_t size) override;
3734
std::shared_ptr<void> MakeHostBuffer(size_t size) override;
38-
39-
Status RegisterHostBuffer(void* ptr, size_t size) override;
40-
void UnregisterHostBuffer(void* ptr) override;
41-
void* GetHostPtrOnDevice(void* ptr) override;
4235
};
4336

4437
} // namespace UC::Trans

ucm/shared/trans/buffer.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ class Buffer {
4141
virtual Status MakeHostBuffers(size_t size, size_t number) = 0;
4242
virtual std::shared_ptr<void> GetHostBuffer(size_t size) = 0;
4343

44-
virtual Status RegisterHostBuffer(void* ptr, size_t size) = 0;
45-
virtual void UnregisterHostBuffer(void* ptr) = 0;
46-
virtual void* GetHostPtrOnDevice(void* ptr) = 0;
44+
static Status RegisterHostBuffer(void* host, size_t size, void** pDevice = nullptr);
45+
static void UnregisterHostBuffer(void* host);
4746
};
4847

4948
} // namespace UC::Trans

ucm/shared/trans/cuda/cuda_buffer.cc

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,17 @@ std::shared_ptr<void> CudaBuffer::MakeHostBuffer(size_t size)
4242
return nullptr;
4343
}
4444

45-
Status CudaBuffer::RegisterHostBuffer(void* ptr, size_t size)
45+
Status Buffer::RegisterHostBuffer(void* host, size_t size, void** pDevice)
4646
{
47-
auto ret = cudaHostRegister(ptr, size, cudaHostRegisterDefault);
48-
if (ret == cudaSuccess) { return Status::OK(); }
49-
return Status{ret, cudaGetErrorString(ret)};
47+
auto ret = cudaHostRegister(host, size, cudaHostRegisterDefault);
48+
if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; }
49+
if (pDevice) {
50+
ret = cudaHostGetDevicePointer(pDevice, host, 0);
51+
if (ret != cudaSuccess) [[unlikely]] { return Status{ret, cudaGetErrorString(ret)}; }
52+
}
53+
return Status::OK();
5054
}
5155

52-
void CudaBuffer::UnregisterHostBuffer(void* ptr) { cudaHostUnregister(ptr); }
53-
54-
void* CudaBuffer::GetHostPtrOnDevice(void* ptr)
55-
{
56-
void* device = nullptr;
57-
auto ret = cudaHostGetDevicePointer(&device, ptr, 0);
58-
if (ret == cudaSuccess) { return nullptr; }
59-
return device;
60-
}
56+
void Buffer::UnregisterHostBuffer(void* host) { cudaHostUnregister(host); }
6157

6258
} // namespace UC::Trans

ucm/shared/trans/cuda/cuda_buffer.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ class CudaBuffer : public ReservedBuffer {
3232
public:
3333
std::shared_ptr<void> MakeDeviceBuffer(size_t size) override;
3434
std::shared_ptr<void> MakeHostBuffer(size_t size) override;
35-
36-
Status RegisterHostBuffer(void* ptr, size_t size) override;
37-
void UnregisterHostBuffer(void* ptr) override;
38-
void* GetHostPtrOnDevice(void* ptr) override;
3935
};
4036

4137
} // namespace UC::Trans

ucm/shared/trans/simu/simu_buffer.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,12 @@ std::shared_ptr<void> SimuBuffer::MakeHostBuffer(size_t size)
6767
return std::shared_ptr<void>(device, FreeMemory);
6868
}
6969

70-
Status SimuBuffer::RegisterHostBuffer(void* ptr, size_t size) { return Status::OK(); }
71-
72-
void SimuBuffer::UnregisterHostBuffer(void* ptr) {}
70+
Status Buffer::RegisterHostBuffer(void* host, size_t size, void** pDevice)
71+
{
72+
if (pDevice) { *pDevice = host; }
73+
return Status::OK();
74+
}
7375

74-
void* SimuBuffer::GetHostPtrOnDevice(void* ptr) { return ptr; }
76+
void Buffer::UnregisterHostBuffer(void* host) {}
7577

7678
} // namespace UC::Trans

ucm/shared/trans/simu/simu_buffer.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ class SimuBuffer : public ReservedBuffer {
3232
public:
3333
std::shared_ptr<void> MakeDeviceBuffer(size_t size) override;
3434
std::shared_ptr<void> MakeHostBuffer(size_t size) override;
35-
36-
Status RegisterHostBuffer(void* ptr, size_t size) override;
37-
void UnregisterHostBuffer(void* ptr) override;
38-
void* GetHostPtrOnDevice(void* ptr) override;
3935
};
4036

4137
} // namespace UC::Trans

ucm/sparse/gsa/prefetch/include/kvcache_pre.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ namespace ucmprefetch
3737
int bsIndex;
3838
} PrefetchReqInfo;
3939

40-
class ThreadPool
40+
class ThreadPool
4141
{
4242
public:
43-
static ThreadPool *GetInst()
43+
static ThreadPool *GetInst()
4444
{
4545
static ThreadPool pool(1);
4646
return &pool;
@@ -52,7 +52,7 @@ namespace ucmprefetch
5252
auto enqueue(F&& f, Args&&... args) -> std::future<typename std::result_of<F(Args...)>::type>;
5353

5454
size_t GetActiveThreads() const;
55-
55+
5656
private:
5757
explicit ThreadPool(size_t threadCount);
5858
std::vector<std::thread> workers;
@@ -66,7 +66,7 @@ namespace ucmprefetch
6666

6767
void MutliBSThreadFun(void *args);
6868

69-
class __attribute__((visibility("hidden"))) GSAPrefetchEngineC
69+
class __attribute__((visibility("hidden"))) GSAPrefetchEngineC
7070
{
7171
private:
7272
std::map<std::string, std::vector<std::map<int, int>>> mDocsTables;
@@ -95,7 +95,7 @@ namespace ucmprefetch
9595
std::map<std::string, std::vector<std::vector<int>>> allNeedLoadBlock;
9696
std::map<std::string, std::vector<std::vector<int>>> allMissIdxs;
9797
std::map<std::string, int> mPromptLen;
98-
UC::CCStore *mStore = nullptr;
98+
UC::CCStore<> *mStore = nullptr;
9999
std::vector<torch::Tensor> mKvCaches;
100100
uint32_t mBlockSize = 128;
101101
uint32_t mTensorElemSize = 2; // fp16
@@ -108,21 +108,21 @@ namespace ucmprefetch
108108
public:
109109
std::mutex mMutex;
110110
bool mStopPrefetch = false;
111-
111+
112112
private:
113113
void LoadKVToHBM(std::vector<int> loadNPUBlockIDs,
114114
std::vector<int> missIdxs, int layerID, std::string reqID);
115-
115+
116116
void GetHitAndMissBlock(PrefetchReqInfo oneBsInfo,
117117
std::unordered_set<int> &hitBlocks,
118118
std::map<int, int> &hitBlocksIdx,
119119
std::vector<int> &missIdxs);
120-
120+
121121
void RunPrefetchH2D(PrefetchReqInfo oneBsInfo,
122122
std::unordered_set<int> &hitBlocks,
123123
std::map<int, int> &hitBlocksIdx,
124124
std::vector<int> &missIdxs);
125-
125+
126126
void RunOneBsPrefetch(std::string reqID, int topkLen,
127127
int bsIndex, int topkIndex);
128128

@@ -144,7 +144,7 @@ namespace ucmprefetch
144144
void SetBlocksMap(std::string reqID, std::vector<int> &blockTableList,
145145
std::vector<int> &selectIndex, std::vector<std::string> &blocksHash,
146146
int maxIdx);
147-
147+
148148
void SetBlocksMapMultiLayer(std::string reqID,
149149
std::vector<std::map<int, int>> &remainMap,
150150
std::vector<std::map<int, int>> &prefetchMap,
@@ -168,7 +168,7 @@ namespace ucmprefetch
168168
std::vector<int> &bsIndexInput,
169169
std::vector<torch::Tensor> &kvCaches,
170170
void *storePtr);
171-
171+
172172
int CallPrefetchProcessFun();
173173

174174
void PrintMap(std::string reqID, int i);
@@ -178,7 +178,7 @@ namespace ucmprefetch
178178
void SetPrefetchStatus(bool flag);
179179

180180
void SetModelRunningStatus(bool flag);
181-
181+
182182
size_t GetOffset(uint32_t layerID, bool isV);
183183

184184
std::map<std::string, std::vector<std::vector<int>>> ObtainLoadBlocks();
@@ -189,7 +189,7 @@ namespace ucmprefetch
189189

190190
std::map<std::string, std::vector<std::map<int, int>>> ObtainDocsMap();
191191
};
192-
192+
193193
} // namespace uc
194194

195195
#endif

0 commit comments

Comments
 (0)