Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions src/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,16 @@ GPUCompiler.method_table(@nospecialize(::HIPCompilerJob)) = AMDGPU.method_table

GPUCompiler.kernel_state_type(@nospecialize(::HIPCompilerJob)) = AMDGPU.KernelState

function GPUCompiler.link_libraries!(
@nospecialize(job::HIPCompilerJob), mod::LLVM.Module,
undefined_fns::Vector{String},
)
function GPUCompiler.link_libraries!(@nospecialize(job::HIPCompilerJob), mod::LLVM.Module)
invoke(GPUCompiler.link_libraries!,
Tuple{CompilerJob{GCNCompilerTarget}, typeof(mod), typeof(undefined_fns)},
job, mod, undefined_fns)
Tuple{CompilerJob{GCNCompilerTarget},typeof(mod)}, job, mod)

# Detect global hostcalls here, before optimizations & cleanup occur.
_global_hostcalls[hash(job)] = find_global_hostcalls(mod)

# Link only if there are undefined functions.
# Everything else was loaded in `finish_module!` stage.
link_device_libs!(
job.config.target, mod, undefined_fns;
wavefrontsize64=job.config.params.wavefrontsize64,
only_undefined=true)
job.config.target, mod;
wavefrontsize64=job.config.params.wavefrontsize64)
end

function GPUCompiler.finish_module!(
Expand All @@ -55,16 +48,12 @@ function GPUCompiler.finish_module!(
Tuple{CompilerJob{GCNCompilerTarget}, typeof(mod), typeof(entry)},
job, mod, entry)

# Link libraries early to include options libraries in the runtime.
# Otherwise we get wave64 specific instructions on wave32 hardware
# which results in ICE.
undefined_fns = GPUCompiler.decls(mod)
if !isempty(undefined_fns)
link_device_libs!(
job.config.target, mod, LLVM.name.(undefined_fns);
wavefrontsize64=job.config.params.wavefrontsize64,
only_undefined=false)
end
# Re-link device libs to resolve references introduced by the GPUCompiler
# runtime (e.g. boxing → malloc → hostcall → __ockl_hsa_signal*) which are
# added after link_libraries! has already run.
link_device_libs!(
job.config.target, mod;
wavefrontsize64=job.config.params.wavefrontsize64)

# Set kernel target cpu and features.
if LLVM.callconv(entry) == LLVM.API.LLVMAMDGPUKERNELCallConv
Expand Down
35 changes: 6 additions & 29 deletions src/compiler/device_libs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,27 @@ mutable struct DevLib
name::String
path::String
data::Vector{UInt8}
fn_names::Set{String}

DevLib(name::String, path::String) = new(name, path, read(path), Set{String}())
DevLib(name::String, ::Nothing) = new(name, "", UInt8[], Set{String}())
DevLib(name::String, path::String) = new(name, path, read(path))
DevLib(name::String, ::Nothing) = new(name, "", UInt8[])
end

const DEVICE_LIBS::Dict{String, DevLib} = Dict{String, DevLib}()

function link_device_libs!(
target::GCNCompilerTarget, mod::LLVM.Module, undefined_fns::Vector{String};
wavefrontsize64::Bool, only_undefined::Bool,
target::GCNCompilerTarget, mod::LLVM.Module;
wavefrontsize64::Bool,
)
isnothing(libdevice_libs) && return
isempty(undefined_fns) && return

# 1. Load other libraries.
lib_names = ("hc", "hip", "irif", "ockl", "opencl", "ocml")
for lib_name in lib_names
devlib = get!(DEVICE_LIBS, lib_name) do
DevLib(lib_name, locate_lib(lib_name))
end
load_and_link!(devlib, mod, undefined_fns)
load_and_link!(devlib, mod)
end
only_undefined && return

# 2. Load OCLC library.
devlib = get!(DEVICE_LIBS, "oclc") do
Expand Down Expand Up @@ -72,28 +69,15 @@ function link_device_libs!(
end
end

function load_and_link!(
devlib::DevLib, mod::LLVM.Module, undefined_fns::Vector{String} = String[],
)
function load_and_link!(devlib::DevLib, mod::LLVM.Module)
isempty(devlib.path) && return

fill_fn_names = isempty(devlib.fn_names)
do_linking = false

if !fill_fn_names && !isempty(undefined_fns)
for undef_fn in undefined_fns
undef_fn ∈ devlib.fn_names && (do_linking = true; break)
end
do_linking || return
end

lib = parse(LLVM.Module, devlib.data)
inline_attr = EnumAttribute("alwaysinline")
noinline_attr = EnumAttribute("noinline")

for f in LLVM.functions(lib)
fn_name = LLVM.name(f)
fill_fn_names && push!(devlib.fn_names, fn_name)

# FIXME: We should be able to inline this, that we can't means
# we are inserting calls to it late.
Expand All @@ -110,13 +94,6 @@ function load_and_link!(
inline && push!(attrs, inline_attr)
end

if !do_linking && !isempty(undefined_fns)
for undef_fn in undefined_fns
undef_fn ∈ devlib.fn_names && (do_linking = true; break)
end
do_linking || return
end

# override triple and datalayout to avoid warnings
triple!(lib, triple(mod))
datalayout!(lib, datalayout(mod))
Expand Down