Skip to content

Commit f731cb3

Browse files
authored
Merge pull request #40 from lorenzoh/lorenzoh/fix-rotate-type
Fix `Rotate` type instability
2 parents aa72774 + e104f23 commit f731cb3

File tree

4 files changed

+19
-7
lines changed

4 files changed

+19
-7
lines changed

src/projective/affine.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,14 @@ Rotate(γ) = Rotate(Uniform(-abs(γ), abs(γ)))
107107

108108
getrandstate(tfm::Rotate) = rand(tfm.dist)
109109

110-
function getprojection(tfm::Rotate, bounds; randstate = getrandstate(tfm))
110+
function getprojection(
111+
tfm::Rotate,
112+
bounds::AbstractArray{<:SVector{N, T}};
113+
randstate = getrandstate(tfm)) where {N, T}
111114
γ = randstate
112115
middlepoint = sum(bounds) ./ length(bounds)
113116
r = γ / 360 * 2pi
114-
return recenter(RotMatrix(r), middlepoint)
117+
return recenter(RotMatrix(convert(T, r)), middlepoint)
115118
end
116119

117120

src/projective/base.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Default implementation falls back to `project`.
4949
function project!(bufitem, P, item, indices)
5050
titem = project(P, item, indices)
5151
copyitemdata!(bufitem, titem)
52+
5253
return bufitem
5354
end
5455

@@ -82,8 +83,8 @@ function apply!(
8283
bounds = getbounds(item)
8384
P = getprojection(tfm, bounds; randstate = randstate)
8485
indices = cropindices(tfm, P, bounds; randstate = randstate)
85-
project!(bufitem, P, item, indices)
86-
return bufitem
86+
res = project!(bufitem, P, item, indices)
87+
return res
8788
end
8889

8990

src/projective/warp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@ function threepointwarpaffine(
4242
c = (X \ Y)'
4343
A = SMatrix{2, 2, V}(c[:, 1:2])
4444
b = SVector{2, V}(c[:, 3])
45-
AffineMap(A, b)
45+
return AffineMap(A, b)
4646
end

test/projective/affine.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ include("../imports.jl")
9999
tfm = Rotate(10)
100100
image = Image(rand(RGB, 50, 50))
101101
@test_nowarn apply(tfm, image)
102-
102+
P = DataAugmentation.getprojection(tfm, getbounds(image))
103+
@test P isa AffineMap
104+
@test P.linear.mat[1] isa Float32
103105
end
104106

105107
@testset ExtendedTestSet "Rotate" begin
@@ -144,16 +146,20 @@ end
144146
Image(rand(RGB, sz)),
145147
Keypoints(rand(SVector{2, Float32}, 50), sz),
146148
MaskBinary(rand(Bool, sz)),
147-
MaskMulti(rand(1:8, sz)),
149+
MaskMulti(UInt8.(rand(1:8, sz)), 1:8),
148150
)
149151

150152
tfms = compose(
151153
Rotate(10),
152154
FlipX(), FlipY(),
153155
ScaleRatio((.8, .8)),
154156
RandomResizeCrop((50, 50)),
157+
WarpAffine(0.1),
158+
Zoom((1., 1.2))
155159
)
156160
@test_nowarn apply(tfms, items)
161+
titems = apply(tfms, items)
162+
@test all(typeof.(titems) == typeof.(items))
157163
end
158164

159165
@testset ExtendedTestSet "3D" begin
@@ -170,5 +176,7 @@ end
170176
RandomResizeCrop((25, 25, 25)),
171177
)
172178
@test_nowarn apply(tfms, items)
179+
titems = apply(tfms, items)
180+
@test all(typeof.(titems) == typeof.(items))
173181
end
174182
end

0 commit comments

Comments
 (0)