Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/TagBot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,4 @@ jobs:
- uses: JuliaRegistries/TagBot@v1
with:
token: ${{ secrets.TAGBOT_TOKEN }}
# Edit the following line to reflect the actual name of the GitHub Secret containing your private key
registry: HolyLab/HolyLabRegistry
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "CachedInterpolations"
uuid = "b9709bfb-d23d-5560-8276-8c35c4b76823"
version = "1.0.0"
version = "1.0.1"
authors = ["Tim Holy <tim.holy@gmail.com>"]

[deps]
Expand All @@ -11,7 +11,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Aqua = "0.8"
Documenter = "1"
ExplicitImports = "1"
Interpolations = "0.14 - 1"
Interpolations = "0.15 - 1"
StaticArrays = "1"
Test = "1"
julia = "1.10"
Expand Down
27 changes: 20 additions & 7 deletions src/CachedInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ julia> itps_c[1](0.0) # peak via centered coordinates
0.75
```
"""
function cachedinterpolators(parent::Array{T, M}, N, origin = ntuple(d -> 0, N)) where {T, M}
function cachedinterpolators(parent::Array{T, M}, N::Integer, origin = ntuple(d -> 0, N)) where {T, M}
0 <= N <= M || error("N must be between 0 and $M")
length(origin) == N || throw(DimensionMismatch("length(origin) = $(length(origin)) is inconsistent with $N interpolating dimensions"))
sz3 = ntuple(d -> d <= N ? 3 : size(parent, d), Val(M))
Expand Down Expand Up @@ -172,13 +172,26 @@ end
return icoefs[wis...]
end

# FIXME: this function cheats dangerously, because it does _not_
# update the cache. This is equivalent to the assumption you've called
# getindex for the current `(x_1, x_2, ...)` location before calling
# gradient!. If this is not true, you'll get wrong answers.
"""
Interpolations.gradient(itp::CachedInterpolation, ys...)

Return the gradient of `itp` evaluated at coordinates `ys` as an `SVector`.
The cache is updated automatically, so this can be called without a preceding
`itp(ys...)`.
"""
@inline function Interpolations.gradient(itp::CachedInterpolation{T, N, M, O, K}, ys::Vararg{Number, N}) where {T, N, M, O, K}
coefs, tileindex = itp.coefs, itp.tileindex
xs = ys .- round.(Int, ys) .+ 2
coefs, parent, center, tileindex = itp.coefs, itp.parent, itp.center, itp.tileindex
iys = round.(Int, ys)
xs = ys .- iys .+ 2
newcenter = iys .+ O
sz3 = ntuple(d -> 3, Val(N))
if newcenter != center
offset = CartesianIndex(newcenter .- 2)
for i in CartesianIndices(sz3)
coefs[i, tileindex] = parent[i + offset, tileindex]
end
itp.center = newcenter
end
itpinfo = (ntuple(d -> BSpline(Quadratic(InPlace(OnCell()))), Val(N))..., ntuple(d -> NoInterp(), Val(K))...)
wis = weightedindexes((value_weights, gradient_weights), itpinfo, axes(coefs), (xs..., Tuple(tileindex)...))
icoefs = InterpGetindex(coefs)
Expand Down
Loading