Skip to content

Commit 0179425

Browse files
authored
Merge pull request #33 from lorenzoh/lo/one-of
add `OneOf` and `Maybe` transform wrappers
2 parents bf5c1c7 + cf4ef6c commit 0179425

File tree

12 files changed

+167
-18
lines changed

12 files changed

+167
-18
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
1212
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
1313
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
MosaicViews = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389"
1516
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1617
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
1718
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

docs/literate/colortransforms.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Color transformations
2+
3+
DataAugmentation.jl currently supports the following color transformations for augmentation:
4+
5+
- [`AdjustContrast`](#) randomly adjusts the contrast; and
6+
- [`AdjustBrightness`](#) randomly adjusts the brightness.
7+
8+
See the docstrings for examples.

docs/literate/stochastic.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Stochastic transformations
2+
3+
4+
When augmenting data, it is often useful to apply a transformation only with some probability or choose from a set of transformations. Unlike in other data augmentation libraries like *albumentations*, in DataAugmentation.jl you can use wrapper transformations for this functionality.
5+
6+
- [`Maybe`](#)`(tfm, p = 0.5)` applies a transformation with probability `p`; and
7+
- [`OneOf`](#)`([tfm1, tfm2])` randomly selects a transformation to apply.
8+
9+
## Example
10+
Let's say we have an image classification dataset. For most datasets, horizontally flipping the image does not change the label: a flipped image of a cat still shows a cat. So let's flip every image horizontally half of the time to improve the generalization of the model we might be training.
11+
12+
{cell=main}
13+
```julia
14+
using DataAugmentation, TestImages
15+
item = Image(testimage("lighthouse"))
16+
tfm = Maybe(FlipX())
17+
titems = [apply(tfm, item) for _ in 1:8]
18+
showgrid(titems; ncol = 4, npad = 16)
19+
```

docs/serve.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,5 @@ using Pkg
22
using Pollen
33
using DataAugmentation
44

5-
p = Pollen.documentationproject(DataAugmentation, executecode=true)
6-
push!(p.rewriters, Pollen.PackageWatcher([DataAugmentation]))
7-
server = Server(p)
8-
mode = ServeFilesLazy()
9-
runserver(server, mode)
5+
p = Pollen.documentationproject(DataAugmentation)
6+
Pollen.serve(p)

src/DataAugmentation.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@ module DataAugmentation
22

33
using ColorBlendModes
44
using CoordinateTransformations
5-
using Distributions: Sampleable, Uniform
5+
using Distributions: Sampleable, Uniform, Categorical
66
using ImageDraw
77
using Images
88
using Images: Colorant, permuteddimsview
99
using ImageTransformations
1010
using ImageTransformations: center, _center, box_extrapolation, warp!
1111
using Interpolations
12+
using MosaicViews: mosaicview
1213
using OffsetArrays: OffsetArray
1314
using LinearAlgebra: I
1415
using Parameters
@@ -34,6 +35,7 @@ include("./projective/compose.jl")
3435
include("./projective/crop.jl")
3536
include("./projective/affine.jl")
3637
include("./projective/warp.jl")
38+
include("./oneof.jl")
3739
include("./preprocessing.jl")
3840
include("./colortransforms.jl")
3941

@@ -69,6 +71,8 @@ export Item,
6971
BufferedThreadsafe,
7072
OneHot,
7173
Zoom,
74+
OneOf,
75+
Maybe,
7276
apply,
7377
Reflect,
7478
WarpAffine,
@@ -81,7 +85,8 @@ export Item,
8185
PadDivisible,
8286
ResizePadDivisible,
8387
onehot,
84-
showitems
88+
showitems,
89+
showgrid
8590

8691

8792
end # module

src/colortransforms.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ from `f ∈ [1-δ, 1+δ]` by multiplying each color channel by `f`.
99
You can also pass any `Distributions.Sampleable` from which the
1010
factor is selected.
1111
12+
## Example
13+
1214
{cell=AdjustBrightness}
1315
```julia
1416
using DataAugmentation, TestImages
1517
16-
img = testimage("lighthouse")
18+
item = Image(testimage("lighthouse"))
1719
tfm = AdjustBrightness(0.2)
18-
apply(tfm, Image(img)) |> showitem
20+
titems = [apply(tfm, item) for _ in 1:8]
21+
showgrid(titems; ncol = 4, npad = 16)
1922
```
2023
"""
2124
struct AdjustBrightness{S<:Sampleable} <: Transform
@@ -67,13 +70,16 @@ of the image.
6770
You can also pass any `Distributions.Sampleable` from which the
6871
factor is selected.
6972
73+
## Example
74+
7075
{cell=AdjustBrightness}
7176
```julia
7277
using DataAugmentation, TestImages
7378
74-
img = testimage("lighthouse")
79+
item = Image(testimage("lighthouse"))
7580
tfm = AdjustContrast(0.2)
76-
apply(tfm, Image(img)) |> showitem
81+
titems = [apply(tfm, item) for _ in 1:8]
82+
showgrid(titems; ncol = 4, npad = 16)
7783
```
7884
"""
7985
struct AdjustContrast{S<:Sampleable} <: Transform

src/oneof.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
OneOf(tfms)
3+
OneOf(tfms, ps)
4+
5+
Apply one of `tfms` selected randomly with probability `ps` each
6+
or uniformly chosen if no `ps` is given.
7+
"""
8+
struct OneOf{T<:Transform, D<:Sampleable} <: Transform
9+
tfms::Vector{T}
10+
dist::D
11+
end
12+
13+
function OneOf(tfms::Vector{<:Transform}, ps = ones(length(tfms)) ./ length(tfms))
14+
@assert length(tfms) == length(ps)
15+
return OneOf(tfms, Categorical(ps))
16+
end
17+
18+
function getrandstate(oneof::OneOf)
19+
i = rand(oneof.dist)
20+
return i, getrandstate(oneof.tfms[i])
21+
end
22+
23+
24+
function apply(oneof::OneOf, item::AbstractItem; randstate = getrandstate(oneof))
25+
i, tfmrandstate = randstate
26+
return apply(oneof.tfms[i], item; randstate = randstate)
27+
end
28+
29+
30+
function makebuffer(oneof::OneOf, items)
31+
return [makebuffer(tfm, items) for tfm in oneof.tfms]
32+
end
33+
34+
function apply!(bufs, oneof::OneOf, item::AbstractItem; randstate = getrandstate(tfm))
35+
i, tfm, tfmrandstate = randstate
36+
buf = bufs[i]
37+
return apply!(buf, tfm, item; randstate = randstate)
38+
end
39+
40+
"""
41+
Maybe(tfm, p = 0.5) <: Transform
42+
43+
With probability `p`, apply transformation `tfm`.
44+
"""
45+
Maybe(tfm, p = 0.5) = OneOf([tfm, Identity()], [p, 1-p])
46+
47+
48+
struct OneOfProjective{T<:Transform, D<:Sampleable} <: ProjectiveTransform
49+
tfms::Vector{T}
50+
dist::D
51+
end
52+
53+
54+
function OneOf(tfms::Vector{<:ProjectiveTransform}, ps = ones(length(tfms)) ./ length(tfms))
55+
@assert length(tfms) == length(ps)
56+
return OneOfProjective(tfms, Categorical(ps))
57+
end
58+
59+
60+
function getrandstate(oneof::OneOfProjective)
61+
i = rand(oneof.dist)
62+
return i, getrandstate(oneof.tfms[i])
63+
end
64+
65+
66+
function Maybe(tfm::ProjectiveTransform, p = 1/2)
67+
return OneOf([tfm, Project(IdentityTransformation())], [p, 1-p])
68+
end
69+
70+
71+
function getprojection(oneof::OneOfProjective, bounds; randstate = getrandstate(oneof))
72+
i, tfmrandstate = randstate
73+
return getprojection(oneof.tfms[i], bounds; randstate = tfmrandstate)
74+
end
75+
76+
77+
function makebuffer(oneof::OneOfProjective, items)
78+
return [makebuffer(tfm, items) for tfm in oneof.tfms]
79+
end

src/projective/affine.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,14 @@ angle is selected.
131131
tfm = Reflect(10)
132132
```
133133
"""
134-
struct Reflect{S<:Sampleable} <: ProjectiveTransform
135-
dist::S
134+
struct Reflect <: ProjectiveTransform
135+
γ
136136
end
137-
Reflect(γ) = Reflect(Uniform(-abs(γ), abs(γ)))
138137

139-
getrandstate(tfm::Reflect) = rand(tfm.dist)
140138

141139
function getprojection(tfm::Reflect, bounds; randstate = getrandstate(tfm))
142-
γ = randstate
143140
midpoint = sum(bounds) ./ length(bounds)
144-
r = γ / 360 * 2pi
141+
r = tfm.γ / 360 * 2pi
145142
return recenter(reflectionmatrix(r), midpoint)
146143
end
147144

src/sequence.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ getrandstate(seq::Sequence) = getrandstate.(seq.transforms)
2323

2424
compose(tfm1::Transform, tfm2::Transform) = Sequence(tfm1, tfm2)
2525
compose(seq::Sequence, tfm::Transform) = Sequence(seq.transforms..., tfm)
26-
compose(seq::Sequence, ::Identity) = seq
2726
compose(tfm::Transform, seq::Sequence) = compose(tfm, seq.transforms...)
27+
compose(::Identity, seq::Sequence) = seq
28+
compose(seq::Sequence, ::Identity) = seq
2829

2930

3031
function apply(seq::Sequence, items::Tuple; randstate = getrandstate(seq))

src/visualization.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ function showimage!(dst, img)
3434
end
3535

3636

37+
function showgrid(items; fillvalue = RGBA{N0f8}(0.,0.,0.,0.), kwargs...)
38+
imgs = [showitems(item) for item in items]
39+
mosaicview(imgs; fillvalue = fillvalue, kwargs...)
40+
end
41+
3742
showbounds(bounds) = showbounds!(zeros(RGBA{N0f8}, boundssize(bounds)), bounds)
3843

3944
function showbounds!(img, bounds::AbstractArray{<:SVector{2}})

0 commit comments

Comments
 (0)