Skip to content

Commit 0350dae

Browse files
[SYCL] Move kernel id into device kernel info struct
1 parent 29dfd03 commit 0350dae

File tree

6 files changed

+82
-88
lines changed

6 files changed

+82
-88
lines changed

sycl/source/detail/device_kernel_info.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ namespace sycl {
1212
inline namespace _V1 {
1313
namespace detail {
1414

15-
DeviceKernelInfo::DeviceKernelInfo(const CompileTimeKernelInfoTy &Info)
16-
: CompileTimeKernelInfoTy(Info) {}
15+
DeviceKernelInfo::DeviceKernelInfo(const CompileTimeKernelInfoTy &Info,
16+
std::optional<sycl::kernel_id> KernelID)
17+
: CompileTimeKernelInfoTy{Info}, MKernelID{std::move(KernelID)} {}
1718

1819
template <typename OtherTy>
1920
inline constexpr bool operator==(const CompileTimeKernelInfoTy &LHS,

sycl/source/detail/device_kernel_info.hpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <sycl/detail/compile_time_kernel_info.hpp>
1414
#include <sycl/detail/spinlock.hpp>
1515
#include <sycl/detail/ur.hpp>
16+
#include <sycl/kernel_bundle.hpp>
1617

1718
#include <mutex>
1819
#include <optional>
@@ -84,12 +85,10 @@ struct FastKernelSubcacheT {
8485
// information that is uniform between different submissions of the same
8586
// kernel). Pointers to instances of this class are stored in header function
8687
// templates as a static variable to avoid repeated runtime lookup overhead.
87-
// TODO Currently this class duplicates information fetched from the program
88-
// manager. Instead, we should merge all of this information
89-
// into this structure and get rid of the other KernelName -> * maps.
9088
class DeviceKernelInfo : public CompileTimeKernelInfoTy {
9189
public:
92-
DeviceKernelInfo(const CompileTimeKernelInfoTy &Info);
90+
DeviceKernelInfo(const CompileTimeKernelInfoTy &Info,
91+
std::optional<sycl::kernel_id> KernelID = std::nullopt);
9392

9493
void init(std::string_view KernelName);
9594
void setCompileTimeInfoIfNeeded(const CompileTimeKernelInfoTy &Info);
@@ -100,6 +99,11 @@ class DeviceKernelInfo : public CompileTimeKernelInfoTy {
10099
return MImplicitLocalArgPos;
101100
}
102101

102+
const sycl::kernel_id &getKernelID() const {
103+
assert(MKernelID);
104+
return *MKernelID;
105+
}
106+
103107
// Implicit local argument position is used only for some backends, so this
104108
// function allows setting it as more images are added.
105109
void setImplicitLocalArgPos(int Pos);
@@ -109,6 +113,7 @@ class DeviceKernelInfo : public CompileTimeKernelInfoTy {
109113

110114
FastKernelSubcacheT MFastKernelSubcache;
111115
std::optional<int> MImplicitLocalArgPos;
116+
const std::optional<sycl::kernel_id> MKernelID;
112117
};
113118

114119
} // namespace detail

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,8 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
689689
"Cannot resolve external symbols, linking is unsupported "
690690
"for the backend");
691691

692-
// Access to m_ExportedSymbolImages must be guarded by m_KernelIDsMutex.
693-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
692+
// Access to m_ExportedSymbolImages must be guarded by m_ImgMapsMutex.
693+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
694694

695695
while (!WorkList.empty()) {
696696
std::string Symbol = WorkList.front();
@@ -770,8 +770,8 @@ ProgramManager::collectDependentDeviceImagesForVirtualFunctions(
770770
if (!WorkList.empty()) {
771771
// Guard read access to m_VFSet2BinImage:
772772
// TODO: a better solution should be sought in the future, i.e. a different
773-
// mutex than m_KernelIDsMutex, check lock check pattern, etc.
774-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
773+
// mutex than m_ImgMapsMutex, check lock check pattern, etc.
774+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
775775

776776
while (!WorkList.empty()) {
777777
std::string SetName = WorkList.front();
@@ -1333,11 +1333,12 @@ ProgramManager::getDeviceImage(std::string_view KernelName,
13331333

13341334
const RTDeviceBinaryImage *Img = nullptr;
13351335
{
1336-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1337-
if (auto KernelId = m_KernelName2KernelIDs.find(KernelName);
1338-
KernelId != m_KernelName2KernelIDs.end()) {
1339-
Img = getBinImageFromMultiMap(m_KernelIDs2BinImage, KernelId->second,
1340-
ContextImpl, DeviceImpl);
1336+
std::lock_guard<std::mutex> Guard(m_DeviceKernelInfoMapMutex);
1337+
if (auto It = m_DeviceKernelInfoMap.find(KernelName);
1338+
It != m_DeviceKernelInfoMap.end()) {
1339+
Img = getBinImageFromMultiMap(m_KernelIDs2BinImage,
1340+
It->second.getKernelID(), ContextImpl,
1341+
DeviceImpl);
13411342
}
13421343
}
13431344

@@ -1369,7 +1370,7 @@ const RTDeviceBinaryImage &ProgramManager::getDeviceImage(
13691370
debugPrintBinaryImages();
13701371
}
13711372

1372-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1373+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
13731374
std::vector<sycl_device_binary> RawImgs(ImageSet.size());
13741375
auto ImageIterator = ImageSet.begin();
13751376
for (size_t i = 0; i < ImageSet.size(); i++, ImageIterator++)
@@ -1642,7 +1643,7 @@ void ProgramManager::addImage(sycl_device_binary RawImg,
16421643
}
16431644

16441645
// Fill maps for kernel bundles
1645-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1646+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
16461647

16471648
// For bfloat16 device library image, it doesn't include any kernel, device
16481649
// global, virtual function, so just skip adding it to any related maps.
@@ -1716,31 +1717,31 @@ void ProgramManager::addImage(sycl_device_binary RawImg,
17161717
m_BinImg2KernelIDs[Img.get()];
17171718
KernelIDs.reset(new std::vector<kernel_id>);
17181719

1720+
std::lock_guard<std::mutex> DKIGuard(m_DeviceKernelInfoMapMutex);
1721+
17191722
for (sycl_offload_entry EntriesIt = EntriesB; EntriesIt != EntriesE;
17201723
EntriesIt = EntriesIt->Increment()) {
17211724

17221725
auto name = EntriesIt->GetName();
17231726

1724-
// Skip creating unique kernel ID if it is an exported device
1727+
// Skip creating device kernel information if it is an exported device
17251728
// function. Exported device functions appear in the offload entries
17261729
// among kernels, but are identifiable by being listed in properties.
17271730
if (m_ExportedSymbolImages.find(name) != m_ExportedSymbolImages.end())
17281731
continue;
17291732

1730-
// ... and create a unique kernel ID for the entry
1731-
auto It = m_KernelName2KernelIDs.find(name);
1732-
if (It == m_KernelName2KernelIDs.end()) {
1733+
auto It = m_DeviceKernelInfoMap.find(std::string_view(name));
1734+
if (It == m_DeviceKernelInfoMap.end()) {
17331735
sycl::kernel_id KernelID = detail::createSyclObjFromImpl<sycl::kernel_id>(
17341736
std::make_shared<detail::kernel_id_impl>(name));
1735-
1736-
It = m_KernelName2KernelIDs.emplace_hint(It, name, KernelID);
1737+
CompileTimeKernelInfoTy DefaultCompileTimeInfo{std::string_view(name)};
1738+
It = m_DeviceKernelInfoMap.emplace_hint(
1739+
It, std::piecewise_construct, std::forward_as_tuple(name),
1740+
std::forward_as_tuple(DefaultCompileTimeInfo, KernelID));
17371741
}
1738-
m_KernelIDs2BinImage.insert(std::make_pair(It->second, Img.get()));
1739-
KernelIDs->push_back(It->second);
1740-
1741-
CompileTimeKernelInfoTy DefaultCompileTimeInfo{std::string_view(name)};
1742-
m_DeviceKernelInfoMap.try_emplace(std::string_view(name),
1743-
DefaultCompileTimeInfo);
1742+
m_KernelIDs2BinImage.insert(
1743+
std::make_pair(It->second.getKernelID(), Img.get()));
1744+
KernelIDs->push_back(It->second.getKernelID());
17441745

17451746
// Keep track of image to kernel name reference count for cleanup.
17461747
m_KernelNameRefCount[name]++;
@@ -1831,7 +1832,7 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
18311832
if (DeviceBinary->NumDeviceBinaries == 0)
18321833
return;
18331834
// Acquire lock to read and modify maps for kernel bundles
1834-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
1835+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
18351836

18361837
// Acquire lock to erase DeviceKernelInfoMap
18371838
std::lock_guard<std::mutex> Guard(m_DeviceKernelInfoMapMutex);
@@ -1919,9 +1920,10 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
19191920
continue;
19201921
}
19211922

1922-
auto Name2IDIt = m_KernelName2KernelIDs.find(Name);
1923-
if (Name2IDIt != m_KernelName2KernelIDs.end())
1924-
removeFromMultimapByVal(m_KernelIDs2BinImage, Name2IDIt->second, Img);
1923+
auto DKIIt = m_DeviceKernelInfoMap.find(Name);
1924+
assert(DKIIt != m_DeviceKernelInfoMap.end());
1925+
removeFromMultimapByVal(m_KernelIDs2BinImage, DKIIt->second.getKernelID(),
1926+
Img);
19251927

19261928
auto RefCountIt = m_KernelNameRefCount.find(Name);
19271929
assert(RefCountIt != m_KernelNameRefCount.end());
@@ -1933,10 +1935,8 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
19331935
if (--RefCount == 0) {
19341936
// TODO aggregate all these maps into a single one since their entries
19351937
// share lifetime.
1936-
m_DeviceKernelInfoMap.erase(Name);
1938+
m_DeviceKernelInfoMap.erase(DKIIt);
19371939
m_KernelNameRefCount.erase(RefCountIt);
1938-
if (Name2IDIt != m_KernelName2KernelIDs.end())
1939-
m_KernelName2KernelIDs.erase(Name2IDIt);
19401940
}
19411941
}
19421942

@@ -2045,7 +2045,7 @@ ProgramManager::getBinImageState(const RTDeviceBinaryImage *BinImage) {
20452045
}
20462046

20472047
bool ProgramManager::hasCompatibleImage(const device_impl &DeviceImpl) {
2048-
std::lock_guard<std::mutex> Guard(m_KernelIDsMutex);
2048+
std::lock_guard<std::mutex> Guard(m_ImgMapsMutex);
20492049

20502050
return std::any_of(
20512051
m_BinImg2KernelIDs.cbegin(), m_BinImg2KernelIDs.cend(),
@@ -2055,19 +2055,19 @@ bool ProgramManager::hasCompatibleImage(const device_impl &DeviceImpl) {
20552055
}
20562056

20572057
std::vector<kernel_id> ProgramManager::getAllSYCLKernelIDs() {
2058-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2058+
std::lock_guard<std::mutex> DKIGuard(m_DeviceKernelInfoMapMutex);
20592059

20602060
std::vector<sycl::kernel_id> AllKernelIDs;
2061-
AllKernelIDs.reserve(m_KernelName2KernelIDs.size());
2062-
for (std::pair<std::string_view, kernel_id> KernelID :
2063-
m_KernelName2KernelIDs) {
2064-
AllKernelIDs.push_back(KernelID.second);
2061+
AllKernelIDs.reserve(m_DeviceKernelInfoMap.size());
2062+
for (const std::pair<const std::string_view, DeviceKernelInfo> &Pair :
2063+
m_DeviceKernelInfoMap) {
2064+
AllKernelIDs.push_back(Pair.second.getKernelID());
20652065
}
20662066
return AllKernelIDs;
20672067
}
20682068

20692069
kernel_id ProgramManager::getBuiltInKernelID(std::string_view KernelName) {
2070-
std::lock_guard<std::mutex> BuiltInKernelIDsGuard(m_BuiltInKernelIDsMutex);
2070+
std::lock_guard<std::mutex> BuiltInImgMapsGuard(m_BuiltInKernelIDsMutex);
20712071

20722072
auto KernelID = m_BuiltInKernelIDs.find(KernelName);
20732073
if (KernelID == m_BuiltInKernelIDs.end()) {
@@ -2118,7 +2118,7 @@ ProgramManager::getKernelGlobalInfoDesc(const char *UniqueId) {
21182118
std::set<const RTDeviceBinaryImage *>
21192119
ProgramManager::getRawDeviceImages(const std::vector<kernel_id> &KernelIDs) {
21202120
std::set<const RTDeviceBinaryImage *> BinImages;
2121-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2121+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
21222122
for (const kernel_id &KID : KernelIDs) {
21232123
auto Range = m_KernelIDs2BinImage.equal_range(KID);
21242124
for (auto It = Range.first, End = Range.second; It != End; ++It)
@@ -2204,7 +2204,7 @@ device_image_plain ProgramManager::getDeviceImageFromBinaryImage(
22042204
std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
22052205
// Collect kernel names for the image.
22062206
{
2207-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2207+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
22082208
KernelIDs = m_BinImg2KernelIDs[BinImage];
22092209
}
22102210

@@ -2234,7 +2234,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
22342234
}
22352235
BinImages = getRawDeviceImages(KernelIDs);
22362236
} else {
2237-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2237+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
22382238
for (auto &ImageUPtr : m_BinImg2KernelIDs) {
22392239
BinImages.insert(ImageUPtr.first);
22402240
}
@@ -2293,7 +2293,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
22932293
ImgInfo.State = getBinImageState(BinImage);
22942294
// Collect kernel names for the image
22952295
{
2296-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2296+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
22972297
ImgInfo.KernelIDs = m_BinImg2KernelIDs[BinImage];
22982298
}
22992299
ImgInfo.Deps = collectDeviceImageDeps(*BinImage, Dev);
@@ -2390,7 +2390,7 @@ ProgramManager::createDependencyImage(const context &Ctx, devices_range Devs,
23902390
bundle_state DepState) {
23912391
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
23922392
{
2393-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2393+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
23942394
// For device library images, they are not in m_BinImg2KernelIDs since
23952395
// no kernel is included.
23962396
auto DepIt = m_BinImg2KernelIDs.find(DepImage);
@@ -2513,7 +2513,7 @@ ProgramManager::getSYCLDeviceImages(const context &Ctx, devices_range Devs,
25132513
return {};
25142514

25152515
{
2516-
std::lock_guard<std::mutex> BuiltInKernelIDsGuard(m_BuiltInKernelIDsMutex);
2516+
std::lock_guard<std::mutex> BuiltInImgMapsGuard(m_BuiltInKernelIDsMutex);
25172517

25182518
for (auto &It : m_BuiltInKernelIDs) {
25192519
if (std::find(KernelIDs.begin(), KernelIDs.end(), It.second) !=
@@ -2943,7 +2943,7 @@ ur_kernel_handle_t ProgramManager::getCachedMaterializedKernel(
29432943
<< "KernelName: " << KernelName << "\n";
29442944

29452945
{
2946-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2946+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
29472947
if (auto KnownMaterializations = m_MaterializedKernels.find(KernelName);
29482948
KnownMaterializations != m_MaterializedKernels.end()) {
29492949
if constexpr (DbgProgMgr > 0)
@@ -3000,7 +3000,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
30003000
BuildProgram, KernelName.data(), &UrKernel);
30013001
ur_kernel_handle_t RawUrKernel = UrKernel;
30023002
{
3003-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
3003+
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
30043004
m_MaterializedKernels[KernelName][SpecializationConsts] =
30053005
std::move(UrKernel);
30063006
}

0 commit comments

Comments
 (0)