Skip to content

Commit 917ba3e

Browse files
committed
Add OneHot transform for MaskMulti items
1 parent b2f1c11 commit 917ba3e

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

src/DataAugmentation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ export Item,
6161
CenterResizeCrop,
6262
Buffered,
6363
BufferedThreadsafe,
64+
OneHot,
6465
apply,
6566
Reflect,
6667
FlipX,

src/preprocessing.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ denormalize(a, means, stds) = denormalize!(copy(a), means, stds)
114114
"""
115115
NormalizeIntensity()
116116
117-
Normalizes the pixels of an array based on calculated mean and std.
117+
Normalizes the pixels of an array based on calculated mean and std.
118118
"""
119119

120120
struct NormalizeIntensity <: Transform end
@@ -183,6 +183,38 @@ end
183183
tensortoimage(a::AbstractArray{T, 3}) where T = colorview(RGB, permuteddimsview(a, (3, 1, 2)))
184184
tensortoimage(a::AbstractArray{T, 2}) where T = colorview(Gray, a)
185185

186+
187+
# OneHot encoding
188+
189+
struct OneHot{T} <: Transform end
190+
OneHot() = OneHot{Float32}()
191+
192+
function apply(tfm::OneHot{T}, item::MaskMulti; randstate = nothing) where T
193+
mask = itemdata(item)
194+
a = zeros(T, size(mask)..., length(item.classes))
195+
for I in CartesianIndices(mask)
196+
a[I, mask[I]] = one(T)
197+
end
198+
199+
return ArrayItem(a)
200+
end
201+
202+
203+
function apply!(buf, tfm::OneHot{T}, item::MaskMulti; randstate = nothing) where T
204+
mask = itemdata(item)
205+
a = itemdata(buf)
206+
@show a[1:6]
207+
fill!(a, zero(T))
208+
@show a[1:6]
209+
210+
for I in CartesianIndices(mask)
211+
a[I, mask[I]] = one(T)
212+
end
213+
214+
return buf
215+
end
216+
217+
186218
function onehot(T, x::Int, n::Int)
187219
v = fill(zero(T), n)
188220
v[x] = one(T)

test/preprocessing.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,22 @@ end
5555
end
5656
end
5757

58+
@testset ExtendedTestSet "OneHot" begin
59+
tfm = OneHot()
60+
mask = rand(1:4, 10, 10)
61+
item = MaskMulti(mask, 1:4)
62+
@test_nowarn apply(tfm, item)
63+
aitem = apply(tfm, item)
64+
@test size(itemdata(aitem)) == (10, 10, 4)
65+
66+
item2 = MaskMulti(rand(1:3, 10, 10), 1:4)
67+
buf = itemdata(aitem)
68+
bufcopy = copy(buf)
69+
apply!(aitem, tfm, item2)
70+
@test itemdata(item) == itemdata(item2) || itemdata(aitem) != bufcopy
71+
72+
end
73+
5874
@testset ExtendedTestSet "Image pipeline" begin
5975
image = Image(rand(RGB, 150, 150))
6076

0 commit comments

Comments
 (0)