Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ steps:
matrix:
setup:
version:
- "1.10"
- "1.12"
group:
- core
- neural_networks
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ jobs:
fail-fast: false
matrix:
version:
- "1.10"
- "1.11"
- "1.12"
# - "1.11"
# - 'nightly'
os:
- ubuntu-24.04
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Documenter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
- uses: actions/checkout@v5
- uses: julia-actions/setup-julia@v2
with:
version: "1.11"
version: "1.12"
- uses: julia-actions/cache@v2
- name: Instantiate docs environment
shell: julia --color=yes --project=docs {0}
Expand Down
3 changes: 2 additions & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,9 +702,10 @@ function optimization_passes(
dus_to_concat::Bool=false,
recognize_comms::Bool=true,
lower_comms::Bool=true,
max_constant_threshold::Int=1024,
backend::String="gpu",
)
(; max_constant_threshold) = compile_options

transform_passes_list = [
"patterns=compare_op_canon<16>",
"transpose_transpose<16>",
Expand Down
161 changes: 147 additions & 14 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TestUtils

using ..Reactant: Reactant, TracedRArray
using ..Reactant: Reactant, TracedRArray, TracedRNumber, TracedUtils
using Reactant.Ops: @opcall
using ReactantCore: ReactantCore
using LinearAlgebra: LinearAlgebra

Expand All @@ -20,22 +21,154 @@ function construct_test_array(::Type{T}, dims::Int...) where {T}
return reshape(collect(T, 1:prod(dims)), dims...)
end

function finite_difference_gradient(
f, x::AbstractArray{T}; epsilon=eps(T)^(3 / 4)
) where {T}
# https://github.com/JuliaDiff/FiniteDiff.jl/blob/3a8c3d8d87e59de78e2831787a3f54b12b7c2075/src/epsilons.jl#L133
function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T}
if fdtype == :forward
return sqrt(eps(real(T)))
elseif fdtype == :central
return cbrt(eps(real(T)))
elseif fdtype == :hcentral
return eps(T)^(T(1 / 4))
else
return one(real(T))
end
end

function get_perturbation(x::AbstractArray{T}, epsilon) where {T}
onehot_matrix = Reactant.promote_to(
TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x))
TracedRArray{Reactant.unwrapped_eltype(T),2},
LinearAlgebra.Diagonal(fill(epsilon, length(x)));
)
return permutedims(
reshape(onehot_matrix, size(x)..., length(x)), (ndims(x) + 1, 1:(ndims(x))...)
)
end

function generate_perturbed_array(::Val{:central}, x::AbstractArray{T}, epsilon) where {T}
perturbation = get_perturbation(x, epsilon)
x_ = reshape(x, 1, size(x)...)
return cat(x_ .+ perturbation, x_ .- perturbation; dims=1)
end

function generate_perturbed_array(::Val{:forward}, x::AbstractArray{T}, epsilon) where {T}
perturbation = get_perturbation(x, epsilon)
x_ = reshape(x, 1, size(x)...)
return cat(x_ .+ perturbation, x_; dims=1)
end

function finite_difference_gradient(
f::F, args...; method::Union{Val{:central},Val{:forward}}=Val(:central)
) where {F}
argprefix = gensym("finitediffarg")
resprefix = gensym("finitediffresult")
resargprefix = gensym("finitediffresarg")

# TODO: can we detect and prevent using functions that mutate their arguments?
mlir_fn_res = TracedUtils.make_mlir_fn(
f,
args,
(),
"finite_difference_gradient_fn",
false;
args_in_result=:none,
argprefix,
resprefix,
resargprefix,
)
perturbation = reshape(onehot_matrix .* epsilon, size(x)..., length(x))
f_input = cat(x .+ perturbation, x .- perturbation; dims=ndims(x) + 1)

f_evaluated = mapslices(f, f_input; dims=ntuple(identity, ndims(x)))
return ReactantCore.materialize_traced_array(
reshape(
(f_evaluated[1:length(x)] - f_evaluated[(length(x) + 1):end]) ./ (2 * epsilon),
size(x),
),

seenargs = Reactant.OrderedIdDict()
Reactant.make_tracer(seenargs, f, (argprefix,), Reactant.TracedSetPath)
for (i, arg) in enumerate(args)
Reactant.make_tracer(seenargs, arg, (argprefix, i), Reactant.TracedSetPath)
end

linear_args = Reactant.TracedType[]
for (k, v) in seenargs
v isa Reactant.TracedType || continue
push!(linear_args, v)
end

if (
length(mlir_fn_res.linear_results) != 1 ||
!(mlir_fn_res.linear_results[1] isa TracedRNumber)
)
error("`finite_difference_gradient` only supports functions with a single scalar \
output. Received : $(mlir_fn_res.linear_results)")
end

gradient_results = TracedRArray[]
gradient_result_map_path = []
for i in 1:length(linear_args)
arg = linear_args[i]
if arg isa TracedRArray && TracedUtils.has_idx(arg, argprefix)
path = TracedUtils.get_idx(arg, argprefix)
if mlir_fn_res.fnwrapped && length(path) > 1 && path[2] == 1
continue
end

# We need the gradient wrt this argument
# we will naively insert the args here, cse will take care of the rest
new_arguments = TracedRArray[]

epsilon = default_epslion(method, Reactant.unwrapped_eltype(arg))
pertubed_arg = generate_perturbed_array(method, arg, epsilon)

bsize = size(pertubed_arg, 1)
for j in 1:length(linear_args)
if i == j
new_arg = pertubed_arg
elseif linear_args[j] isa TracedRNumber
new_arg = @opcall broadcast_in_dim(
linear_args[j], Int64[], Int64[bsize]
)
else
new_arg = @opcall broadcast_in_dim(
linear_args[j],
collect(Int64, 2:(ndims(linear_args[j]) + 1)),
Int64[bsize, size(linear_args[j])...],
)
end
new_arg = @opcall transpose(new_arg, Int64[1, ((ndims(new_arg)):-1:2)...];)
push!(new_arguments, new_arg)
end

batched_res = @opcall batch(
new_arguments,
[
Reactant.MLIR.IR.TensorType(
Int64[bsize],
Reactant.MLIR.IR.Type(
Reactant.unwrapped_eltype(mlir_fn_res.linear_results[1])
),
),
],
Int64[bsize];
fn=mlir_fn_res.f,
)
batched_res = only(batched_res)

if method isa Val{:central}
diff = batched_res[1:(bsize ÷ 2)] - batched_res[((bsize ÷ 2) + 1):end]
grad_res = diff ./ (2 * epsilon)
elseif method isa Val{:forward}
diff = batched_res[1:(end - 1)] .- batched_res[end:end]
grad_res = diff ./ epsilon
end

push!(gradient_result_map_path, TracedUtils.get_idx(arg, argprefix))
push!(
gradient_results,
ReactantCore.materialize_traced_array(reshape(grad_res, size(arg))),
)
end
end

results = deepcopy(args)
for (path, grad_res) in zip(gradient_result_map_path, gradient_results)
TracedUtils.set!(results, path[2:end], grad_res.mlir_data)
end
length(args) == 1 && return results[1]
return results
end

end
11 changes: 10 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ Base.convert(T::Type{<:TracedRArray}, x::AbstractArray) = Reactant.promote_to(T,
Base.complex(x::TracedRArray{<:Real}) = complex.(x)
Base.complex(x::TracedRArray{<:Complex}) = x

function Base.deepcopy_internal(x::TracedRArray, stackdict::IdDict)
if haskey(stackdict, x)
return stackdict[x]::typeof(x)
end
y = copy(x)
stackdict[x] = y
return y
end

TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x)

function maybe_assert_scalar_setindexing(
Expand Down Expand Up @@ -1109,7 +1118,7 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector)
return accumulate!(op, A, B; dims=1)
end

if isdefined(Base, :_accumulate_promote_op)
@static if isdefined(Base, :_accumulate_promote_op)
function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T}
if init !== nothing
init isa TracedRNumber && (init = zero(unwrapped_eltype(init)))
Expand Down
1 change: 1 addition & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ for (jlop, hloop) in (
(:(Base.log), :log),
(:(Base.log1p), :log_plus_one),
(:(Base.sqrt), :sqrt),
(:(Base.cbrt), :cbrt),
(:(Base.acos), :acos),
(:(Base.acosh), :acosh),
(:(Base.asin), :asin),
Expand Down
4 changes: 2 additions & 2 deletions src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ function overloaded_mul!(
return C
end

if isdefined(LinearAlgebra, :_triu)
@static if isdefined(LinearAlgebra, :_triu)
function LinearAlgebra._triu(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
return overloaded_triu(materialize_traced_array(A), k)
end
Expand All @@ -284,7 +284,7 @@ if isdefined(LinearAlgebra, :_triu)
end
end

if isdefined(LinearAlgebra, :_tril)
@static if isdefined(LinearAlgebra, :_tril)
function LinearAlgebra._tril(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
return overloaded_tril(materialize_traced_array(A), k)
end
Expand Down
41 changes: 41 additions & 0 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,44 @@ end

@test @jit(jvp_vjp_cubic(v_r, x_r, lambdas_r)) ≈ fill(6, (3, 2))
end

@testset "Finite Difference Gradient" begin
x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float16, 2, 2))
res = @jit Reactant.TestUtils.finite_difference_gradient(sum, x)
@test res isa Reactant.ConcreteRArray{Float16,2}
end

function fdiff_multiple_args(f, nt, x)
return sum(abs2, f(nt.y .+ x .- nt.x))
end

struct WrapperFunc{T}
x::T
end

(f::WrapperFunc)(x) = x .^ 3 .+ f.x

@testset "Finite Difference Gradient (non vector inputs)" begin
nt = (;
x=Reactant.TestUtils.construct_test_array(Float64, 3, 4),
y=Reactant.TestUtils.construct_test_array(Float64, 3, 4),
)
fn = WrapperFunc(Reactant.TestUtils.construct_test_array(Float64, 3, 4))
x = Reactant.TestUtils.construct_test_array(Float64, 3, 4)

nt_ra = Reactant.to_rarray(nt)
fn_ra = Reactant.to_rarray(fn)
x_ra = Reactant.to_rarray(x)

results_fd = @jit Reactant.TestUtils.finite_difference_gradient(
fdiff_multiple_args, fn_ra, nt_ra, x_ra
)
@test results_fd isa typeof((fn_ra, nt_ra, x_ra))

results_enz = @jit Enzyme.gradient(Reverse, fdiff_multiple_args, fn_ra, nt_ra, x_ra)

@test results_fd[1].x ≈ results_enz[1].x
@test results_fd[2].x ≈ results_enz[2].x
@test results_fd[2].y ≈ results_enz[2].y
@test results_fd[3] ≈ results_enz[3]
end
Loading
Loading