Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sycl/source/detail/device_kernel_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ namespace sycl {
inline namespace _V1 {
namespace detail {

DeviceKernelInfo::DeviceKernelInfo(const CompileTimeKernelInfoTy &Info)
: CompileTimeKernelInfoTy(Info) {}
DeviceKernelInfo::DeviceKernelInfo(const CompileTimeKernelInfoTy &Info,
std::optional<sycl::kernel_id> KernelID)
: CompileTimeKernelInfoTy{Info}, MKernelID{std::move(KernelID)} {}

template <typename OtherTy>
inline constexpr bool operator==(const CompileTimeKernelInfoTy &LHS,
Expand Down
16 changes: 12 additions & 4 deletions sycl/source/detail/device_kernel_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <sycl/detail/compile_time_kernel_info.hpp>
#include <sycl/detail/spinlock.hpp>
#include <sycl/detail/ur.hpp>
#include <sycl/kernel_bundle.hpp>

#include <mutex>
#include <optional>
Expand Down Expand Up @@ -84,12 +85,10 @@ struct FastKernelSubcacheT {
// information that is uniform between different submissions of the same
// kernel). Pointers to instances of this class are stored in header function
// templates as a static variable to avoid repeated runtime lookup overhead.
// TODO Currently this class duplicates information fetched from the program
// manager. Instead, we should merge all of this information
// into this structure and get rid of the other KernelName -> * maps.
class DeviceKernelInfo : public CompileTimeKernelInfoTy {
public:
DeviceKernelInfo(const CompileTimeKernelInfoTy &Info);
DeviceKernelInfo(const CompileTimeKernelInfoTy &Info,
std::optional<sycl::kernel_id> KernelID = std::nullopt);

void init(std::string_view KernelName);
void setCompileTimeInfoIfNeeded(const CompileTimeKernelInfoTy &Info);
Expand All @@ -100,6 +99,14 @@ class DeviceKernelInfo : public CompileTimeKernelInfoTy {
return MImplicitLocalArgPos;
}

const sycl::kernel_id &getKernelID() const {
// Expected to be called only for DeviceKernelInfo instances created by
// program manager (as opposed to allocated by sycl::kernel with
// origins other than SYCL offline compilation).
assert(MKernelID);
return *MKernelID;
}

// Implicit local argument position is used only for some backends, so this
// function allows setting it as more images are added.
void setImplicitLocalArgPos(int Pos);
Expand All @@ -109,6 +116,7 @@ class DeviceKernelInfo : public CompileTimeKernelInfoTy {

FastKernelSubcacheT MFastKernelSubcache;
std::optional<int> MImplicitLocalArgPos;
const std::optional<sycl::kernel_id> MKernelID;
};

} // namespace detail
Expand Down
90 changes: 45 additions & 45 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,8 @@ ProgramManager::collectDeviceImageDepsForImportedSymbols(
"Cannot resolve external symbols, linking is unsupported "
"for the backend");

// Access to m_ExportedSymbolImages must be guarded by m_KernelIDsMutex.
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
// Access to m_ExportedSymbolImages must be guarded by m_ImgMapsMutex.
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);

while (!WorkList.empty()) {
std::string Symbol = WorkList.front();
Expand Down Expand Up @@ -748,8 +748,8 @@ ProgramManager::collectDependentDeviceImagesForVirtualFunctions(
if (!WorkList.empty()) {
// Guard read access to m_VFSet2BinImage:
// TODO: a better solution should be sought in the future, i.e. a different
// mutex than m_KernelIDsMutex, check lock check pattern, etc.
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
// mutex than m_ImgMapsMutex, check lock check pattern, etc.
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);

while (!WorkList.empty()) {
std::string SetName = WorkList.front();
Expand Down Expand Up @@ -1311,11 +1311,12 @@ ProgramManager::getDeviceImage(std::string_view KernelName,

const RTDeviceBinaryImage *Img = nullptr;
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
if (auto KernelId = m_KernelName2KernelIDs.find(KernelName);
KernelId != m_KernelName2KernelIDs.end()) {
Img = getBinImageFromMultiMap(m_KernelIDs2BinImage, KernelId->second,
ContextImpl, DeviceImpl);
std::lock_guard<std::mutex> Guard(m_DeviceKernelInfoMapMutex);
if (auto It = m_DeviceKernelInfoMap.find(KernelName);
It != m_DeviceKernelInfoMap.end()) {
Img = getBinImageFromMultiMap(m_KernelIDs2BinImage,
It->second.getKernelID(), ContextImpl,
DeviceImpl);
}
}

Expand Down Expand Up @@ -1347,7 +1348,7 @@ const RTDeviceBinaryImage &ProgramManager::getDeviceImage(
debugPrintBinaryImages();
}

std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
std::vector<sycl_device_binary> RawImgs(ImageSet.size());
auto ImageIterator = ImageSet.begin();
for (size_t i = 0; i < ImageSet.size(); i++, ImageIterator++)
Expand Down Expand Up @@ -1620,7 +1621,7 @@ void ProgramManager::addImage(sycl_device_binary RawImg,
}

// Fill maps for kernel bundles
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);

// For bfloat16 device library image, it doesn't include any kernel, device
// global, virtual function, so just skip adding it to any related maps.
Expand Down Expand Up @@ -1694,31 +1695,31 @@ void ProgramManager::addImage(sycl_device_binary RawImg,
m_BinImg2KernelIDs[Img.get()];
KernelIDs.reset(new std::vector<kernel_id>);

std::lock_guard<std::mutex> DKIGuard(m_DeviceKernelInfoMapMutex);

for (sycl_offload_entry EntriesIt = EntriesB; EntriesIt != EntriesE;
EntriesIt = EntriesIt->Increment()) {

auto name = EntriesIt->GetName();

// Skip creating unique kernel ID if it is an exported device
// Skip creating device kernel information if it is an exported device
// function. Exported device functions appear in the offload entries
// among kernels, but are identifiable by being listed in properties.
if (m_ExportedSymbolImages.find(name) != m_ExportedSymbolImages.end())
continue;

// ... and create a unique kernel ID for the entry
auto It = m_KernelName2KernelIDs.find(name);
if (It == m_KernelName2KernelIDs.end()) {
auto It = m_DeviceKernelInfoMap.find(std::string_view(name));
if (It == m_DeviceKernelInfoMap.end()) {
sycl::kernel_id KernelID = detail::createSyclObjFromImpl<sycl::kernel_id>(
std::make_shared<detail::kernel_id_impl>(name));

It = m_KernelName2KernelIDs.emplace_hint(It, name, KernelID);
CompileTimeKernelInfoTy DefaultCompileTimeInfo{std::string_view(name)};
It = m_DeviceKernelInfoMap.emplace_hint(
It, std::piecewise_construct, std::forward_as_tuple(name),
std::forward_as_tuple(DefaultCompileTimeInfo, KernelID));
}
m_KernelIDs2BinImage.insert(std::make_pair(It->second, Img.get()));
KernelIDs->push_back(It->second);

CompileTimeKernelInfoTy DefaultCompileTimeInfo{std::string_view(name)};
m_DeviceKernelInfoMap.try_emplace(std::string_view(name),
DefaultCompileTimeInfo);
m_KernelIDs2BinImage.insert(
std::make_pair(It->second.getKernelID(), Img.get()));
KernelIDs->push_back(It->second.getKernelID());

// Keep track of image to kernel name reference count for cleanup.
m_KernelNameRefCount[name]++;
Expand Down Expand Up @@ -1777,7 +1778,7 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
if (DeviceBinary->NumDeviceBinaries == 0)
return;
// Acquire lock to read and modify maps for kernel bundles
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);

// Acquire lock to erase DeviceKernelInfoMap
std::lock_guard<std::mutex> Guard(m_DeviceKernelInfoMapMutex);
Expand Down Expand Up @@ -1846,9 +1847,10 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
continue;
}

auto Name2IDIt = m_KernelName2KernelIDs.find(Name);
if (Name2IDIt != m_KernelName2KernelIDs.end())
removeFromMultimapByVal(m_KernelIDs2BinImage, Name2IDIt->second, Img);
auto DKIIt = m_DeviceKernelInfoMap.find(Name);
assert(DKIIt != m_DeviceKernelInfoMap.end());
removeFromMultimapByVal(m_KernelIDs2BinImage, DKIIt->second.getKernelID(),
Img);

auto RefCountIt = m_KernelNameRefCount.find(Name);
assert(RefCountIt != m_KernelNameRefCount.end());
Expand All @@ -1860,10 +1862,8 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
if (--RefCount == 0) {
// TODO aggregate all these maps into a single one since their entries
// share lifetime.
m_DeviceKernelInfoMap.erase(Name);
m_DeviceKernelInfoMap.erase(DKIIt);
m_KernelNameRefCount.erase(RefCountIt);
if (Name2IDIt != m_KernelName2KernelIDs.end())
m_KernelName2KernelIDs.erase(Name2IDIt);
}
}

Expand Down Expand Up @@ -1971,7 +1971,7 @@ ProgramManager::getBinImageState(const RTDeviceBinaryImage *BinImage) {
}

bool ProgramManager::hasCompatibleImage(const device_impl &DeviceImpl) {
std::lock_guard<std::mutex> Guard(m_KernelIDsMutex);
std::lock_guard<std::mutex> Guard(m_ImgMapsMutex);

return std::any_of(
m_BinImg2KernelIDs.cbegin(), m_BinImg2KernelIDs.cend(),
Expand All @@ -1981,19 +1981,19 @@ bool ProgramManager::hasCompatibleImage(const device_impl &DeviceImpl) {
}

std::vector<kernel_id> ProgramManager::getAllSYCLKernelIDs() {
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> DKIGuard(m_DeviceKernelInfoMapMutex);

std::vector<sycl::kernel_id> AllKernelIDs;
AllKernelIDs.reserve(m_KernelName2KernelIDs.size());
for (std::pair<std::string_view, kernel_id> KernelID :
m_KernelName2KernelIDs) {
AllKernelIDs.push_back(KernelID.second);
AllKernelIDs.reserve(m_DeviceKernelInfoMap.size());
for (const std::pair<const std::string_view, DeviceKernelInfo> &Pair :
m_DeviceKernelInfoMap) {
AllKernelIDs.push_back(Pair.second.getKernelID());
}
return AllKernelIDs;
}

kernel_id ProgramManager::getBuiltInKernelID(std::string_view KernelName) {
std::lock_guard<std::mutex> BuiltInKernelIDsGuard(m_BuiltInKernelIDsMutex);
std::lock_guard<std::mutex> BuiltInImgMapsGuard(m_BuiltInKernelIDsMutex);

auto KernelID = m_BuiltInKernelIDs.find(KernelName);
if (KernelID == m_BuiltInKernelIDs.end()) {
Expand Down Expand Up @@ -2044,7 +2044,7 @@ ProgramManager::getKernelGlobalInfoDesc(const char *UniqueId) {
std::set<const RTDeviceBinaryImage *>
ProgramManager::getRawDeviceImages(const std::vector<kernel_id> &KernelIDs) {
std::set<const RTDeviceBinaryImage *> BinImages;
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
for (const kernel_id &KID : KernelIDs) {
auto Range = m_KernelIDs2BinImage.equal_range(KID);
for (auto It = Range.first, End = Range.second; It != End; ++It)
Expand Down Expand Up @@ -2099,7 +2099,7 @@ device_image_plain ProgramManager::getDeviceImageFromBinaryImage(
std::shared_ptr<std::vector<sycl::kernel_id>> KernelIDs;
// Collect kernel names for the image.
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
KernelIDs = m_BinImg2KernelIDs[BinImage];
}

Expand Down Expand Up @@ -2129,7 +2129,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
}
BinImages = getRawDeviceImages(KernelIDs);
} else {
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
for (auto &ImageUPtr : m_BinImg2KernelIDs) {
BinImages.insert(ImageUPtr.first);
}
Expand Down Expand Up @@ -2188,7 +2188,7 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
ImgInfo.State = getBinImageState(BinImage);
// Collect kernel names for the image
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
ImgInfo.KernelIDs = m_BinImg2KernelIDs[BinImage];
}
ImgInfo.Deps = collectDeviceImageDeps(*BinImage, Dev);
Expand Down Expand Up @@ -2285,7 +2285,7 @@ ProgramManager::createDependencyImage(const context &Ctx, devices_range Devs,
bundle_state DepState) {
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
// For device library images, they are not in m_BinImg2KernelIDs since
// no kernel is included.
auto DepIt = m_BinImg2KernelIDs.find(DepImage);
Expand Down Expand Up @@ -2408,7 +2408,7 @@ ProgramManager::getSYCLDeviceImages(const context &Ctx, devices_range Devs,
return {};

{
std::lock_guard<std::mutex> BuiltInKernelIDsGuard(m_BuiltInKernelIDsMutex);
std::lock_guard<std::mutex> BuiltInImgMapsGuard(m_BuiltInKernelIDsMutex);

for (auto &It : m_BuiltInKernelIDs) {
if (std::find(KernelIDs.begin(), KernelIDs.end(), It.second) !=
Expand Down Expand Up @@ -2838,7 +2838,7 @@ ur_kernel_handle_t ProgramManager::getCachedMaterializedKernel(
<< "KernelName: " << KernelName << "\n";

{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
if (auto KnownMaterializations = m_MaterializedKernels.find(KernelName);
KnownMaterializations != m_MaterializedKernels.end()) {
if constexpr (DbgProgMgr > 0)
Expand Down Expand Up @@ -2895,7 +2895,7 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
BuildProgram, KernelName.data(), &UrKernel);
ur_kernel_handle_t RawUrKernel = UrKernel;
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
std::lock_guard<std::mutex> ImgMapsGuard(m_ImgMapsMutex);
m_MaterializedKernels[KernelName][SpecializationConsts] =
std::move(UrKernel);
}
Expand Down
Loading