Skip to content

Commit 1aec78f

Browse files
authored
move require usage to extensions on 1.9+ (#1390)
* move require usage to extensions on 1.9+ * remove extra loads in tracker extension * fix an unexprted function
1 parent 108e5a1 commit 1aec78f

File tree

6 files changed

+68
-18
lines changed

6 files changed

+68
-18
lines changed

Project.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,24 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2626
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2727
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2828

29+
[weakdeps]
30+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
31+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
32+
Tracker= "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
33+
34+
[extensions]
35+
ZygoteColorsExt = "Colors"
36+
ZygoteDistancesExt = "Distances"
37+
ZygoteTrackerExt = "Tracker"
38+
2939
[compat]
3040
AbstractFFTs = "1.3.1"
3141
ChainRules = "1.44.1"
3242
ChainRulesCore = "1.9"
3343
ChainRulesTestUtils = "1"
44+
Colors = "0.12"
3445
DiffRules = "1.4"
46+
Distances = "0.10"
3547
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
3648
ForwardDiff = "0.10"
3749
GPUArrays = "8.4.2"
@@ -43,17 +55,20 @@ NaNMath = "0.3, 1"
4355
Requires = "1.1"
4456
SnoopPrecompile = "1.0.3"
4557
SpecialFunctions = "1.6, 2"
58+
Tracker = "0.2"
4659
ZygoteRules = "0.2.1"
4760
julia = "1.6"
4861

4962
[extras]
63+
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
5064
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
5165
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
5266
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
5367
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
5468
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
5569
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
5670
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
71+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
5772

5873
[targets]
5974
test = ["ChainRulesTestUtils", "CUDA", "Distances", "FFTW", "FiniteDifferences", "PyCall", "Test"]

ext/ZygoteColorsExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module ZygoteColorsExt
2+
3+
if isdefined(Base, :get_extension)
4+
using Zygote
5+
using Colors
6+
else
7+
using ..Zygote
8+
using ..Colors
9+
end
10+
11+
Zygote.@non_differentiable Colors.ColorTypes._parameter_upper_bound(::Any...)
12+
13+
end
Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
using .Distances
1+
module ZygoteDistancesExt
2+
3+
if isdefined(Base, :get_extension)
4+
using Zygote
5+
using Distances
6+
using LinearAlgebra
7+
else
8+
using ..Zygote
9+
using ..Distances
10+
using ..LinearAlgebra
11+
end
12+
13+
using Zygote: @adjoint, @adjoint, AContext, _pullback
214

315
@adjoint function (::SqEuclidean)(x::AbstractVector, y::AbstractVector)
416
δ = x .- y
@@ -66,7 +78,7 @@ end
6678

6779
_sqrt_if_positive(d, δ) = d > δ ? sqrt(d) : zero(d)
6880

69-
function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
81+
function Zygote._pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
7082
kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean,
7183
X::AbstractMatrix, Y::AbstractMatrix)
7284
# Modify the forwards-pass slightly to ensure stability on the reverse.
@@ -77,11 +89,11 @@ function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
7789
return _sqrt_if_positive.(D2, δ)
7890
end
7991
res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X, Y)
80-
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(unthunk_tangent(Δ))...)
92+
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(Zygote.unthunk_tangent(Δ))...)
8193
return res, pairwise_Euclidean_pullback
8294
end
8395

84-
function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
96+
function Zygote._pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
8597
kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean,
8698
X::AbstractMatrix)
8799
# Modify the forwards-pass slightly to ensure stability on the reverse.
@@ -92,6 +104,8 @@ function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
92104
return _sqrt_if_positive.(D2, δ)
93105
end
94106
res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X)
95-
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(unthunk_tangent(Δ))...)
107+
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(Zygote.unthunk_tangent(Δ))...)
96108
return res, pairwise_Euclidean_pullback
97109
end
110+
111+
end

ext/ZygoteTrackerExt.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module ZygoteTrackerExt
2+
3+
if isdefined(Base, :get_extension)
4+
using Zygote
5+
using Tracker: Tracker, TrackedArray, TrackedReal
6+
else
7+
using ..Zygote
8+
using ..Tracker: Tracker, TrackedArray, TrackedReal
9+
end
10+
11+
Zygote.unwrap(x::Union{TrackedArray,TrackedReal}) = Tracker.data(x)
12+
13+
Zygote.pullback(f, ps::Tracker.Params) = pullback(f, ZygtParams(ps))
14+
Tracker.forward(f, ps::Params) = Tracker.forward(f, Tracker.Params(ps))
15+
Tracker.gradient_(f, ps::Params) = Tracker.gradient_(f, Tracker.Params(ps))
16+
17+
end

src/Zygote.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ include("lib/forward.jl")
4343
include("lib/utils.jl")
4444
include("lib/range.jl")
4545
include("lib/logexpfunctions.jl")
46-
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("lib/distances.jl")
4746

4847
# we need to define this late, so that the genfuncs see lib.jl
4948
# Move using statements out of this file to help with sysimage building
@@ -53,12 +52,11 @@ include("compiler/interface2.jl")
5352

5453
include("profiler/Profile.jl")
5554

56-
@init @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
57-
include("flux.jl")
58-
end
5955

60-
@init @require Colors="5ae59095-9a9b-59fe-a467-6f913c188581" begin
61-
@non_differentiable Colors.ColorTypes._parameter_upper_bound(::Any...)
56+
if !isdefined(Base, :get_extension)
57+
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("../ext/ZygoteDistancesExt.jl")
58+
@init @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/ZygoteTrackerExt.jl")
59+
@init @require Colors="5ae59095-9a9b-59fe-a467-6f913c188581" include("../ext/ZygoteColorsExt.jl")
6260
end
6361

6462
using InteractiveUtils

src/flux.jl

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)