Skip to content

Commit d4d4687

Browse files
authored
Merge pull request #45 from manikyabard/manikyabard/tabulartfms
Adds `TabularItem` and transforms for it: `NormalizeRow`, `FillMissing`, `Categorify`.
2 parents 377e704 + a03bc72 commit d4d4687

File tree

6 files changed

+194
-2
lines changed

6 files changed

+194
-2
lines changed

src/DataAugmentation.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ include("./sequence.jl")
2929
include("./items/arrayitem.jl")
3030
include("./projective/base.jl")
3131
include("./items/image.jl")
32+
include("./items/table.jl")
3233
include("./items/keypoints.jl")
3334
include("./items/mask.jl")
3435
include("./projective/compose.jl")
@@ -37,6 +38,7 @@ include("./projective/affine.jl")
3738
include("./projective/warp.jl")
3839
include("./oneof.jl")
3940
include("./preprocessing.jl")
41+
include("./rowtransforms.jl")
4042
include("./colortransforms.jl")
4143
include("testing.jl")
4244
include("./visualization.jl")
@@ -50,6 +52,7 @@ export Item,
5052
Sequence,
5153
Project,
5254
Image,
55+
TabularItem,
5356
Keypoints,
5457
Polygon,
5558
ToEltype,
@@ -89,7 +92,8 @@ export Item,
8992
onehot,
9093
showitems,
9194
showgrid,
92-
Bounds
95+
Bounds,
96+
getcategorypools
9397

9498

9599
end # module

src/items/table.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
struct TabularItem{T} <: Item
2+
data::T
3+
columns
4+
end

src/rowtransforms.jl

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
NormalizeRow(dict, cols)
3+
4+
Normalizes the values of a row present in `TabularItem` for the columns
5+
specified in `cols` using `dict`, which contains the column names as
6+
dictionary keys and the mean and standard deviation tuple present as
7+
dictionary values.
8+
9+
## Example
10+
11+
```julia
12+
using DataAugmentation
13+
14+
cols = [:col1, :col2, :col3]
15+
row = (; zip(cols, [1, 2, 3])...)
16+
item = TabularItem(row, cols)
17+
normdict = Dict(:col1 => (1, 1), :col2 => (2, 2))
18+
19+
tfm = NormalizeRow(normdict, [:col1, :col2])
20+
apply(tfm, item)
21+
```
22+
"""
23+
struct NormalizeRow{T, S} <: Transform
24+
dict::T
25+
cols::S
26+
end
27+
28+
function apply(tfm::NormalizeRow, item::TabularItem; randstate=nothing)
29+
x = NamedTuple(Iterators.map(item.columns, item.data) do col, val
30+
if col in tfm.cols
31+
colmean, colstd = tfm.dict[col]
32+
val = (val - colmean)/colstd
33+
end
34+
(col, val)
35+
end)
36+
TabularItem(x, item.columns)
37+
end
38+
39+
"""
40+
FillMissing(dict, cols)
41+
42+
Fills the missing values of a row present in `TabularItem` for the columns
43+
specified in `cols` using `dict`, which contains the column names as
44+
dictionary keys and the value to fill the column with present as
45+
dictionary values.
46+
47+
## Example
48+
49+
```julia
50+
using DataAugmentation
51+
52+
cols = [:col1, :col2, :col3]
53+
row = (; zip(cols, [1, 2, 3])...)
54+
item = TabularItem(row, cols)
55+
fmdict = Dict(:col1 => 100, :col2 => 100)
56+
57+
tfm = FillMissing(fmdict, [:col1, :col2])
58+
apply(tfm, item)
59+
```
60+
"""
61+
struct FillMissing{T, S} <: Transform
62+
dict::T
63+
cols::S
64+
end
65+
66+
function apply(tfm::FillMissing, item::TabularItem; randstate=nothing)
67+
x = NamedTuple(Iterators.map(item.columns, item.data) do col, val
68+
if col in tfm.cols && ismissing(val)
69+
val = tfm.dict[col]
70+
end
71+
(col, val)
72+
end)
73+
TabularItem(x, item.columns)
74+
end
75+
76+
"""
77+
Categorify(dict, cols)
78+
79+
Label encodes the values of a row present in `TabularItem` for the
80+
columns specified in `cols` using `dict`, which contains the column
81+
names as dictionary keys and the unique values of column present
82+
as dictionary values.
83+
84+
if there are any `missing` values in the values to be transformed,
85+
they are replaced by 1.
86+
87+
## Example
88+
89+
```julia
90+
using DataAugmentation
91+
92+
cols = [:col1, :col2, :col3]
93+
row = (; zip(cols, ["cat", 2, 3])...)
94+
item = TabularItem(row, cols)
95+
catdict = Dict(:col1 => ["dog", "cat"])
96+
97+
tfm = Categorify(catdict, [:col1])
98+
apply(tfm, item)
99+
```
100+
"""
101+
struct Categorify{T, S} <: Transform
102+
dict::T
103+
cols::S
104+
function Categorify{T, S}(dict::T, cols::S) where {T, S}
105+
for (col, vals) in dict
106+
if any(ismissing, vals)
107+
dict[col] = filter(!ismissing, vals)
108+
@warn "There is a missing value present for category '$col' which will be removed from Categorify dict"
109+
end
110+
end
111+
new{T, S}(dict, cols)
112+
end
113+
end
114+
115+
Categorify(dict::T, cols::S) where {T, S} = Categorify{T, S}(dict, cols)
116+
117+
function apply(tfm::Categorify, item::TabularItem; randstate=nothing)
118+
x = NamedTuple(Iterators.map(item.columns, item.data) do col, val
119+
if col in tfm.cols
120+
val = ismissing(val) ? 1 : findfirst(val .== tfm.dict[col]) + 1
121+
end
122+
(col, val)
123+
end)
124+
TabularItem(x, item.columns)
125+
end

test/imports.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using CoordinateTransformations
99
using DataAugmentation: Item, Transform, getrandstate, itemdata, setdata, ComposedProjectiveTransform,
1010
projectionbounds, getprojection, offsetcropbounds,
1111
CroppedProjectiveTransform, getbounds, project, project!, makebuffer, imagetotensor, imagetotensor!,
12-
normalize, normalize!, tensortoimage, denormalize, denormalize!
12+
normalize, normalize!, tensortoimage, denormalize, denormalize!,
13+
NormalizeRow, FillMissing, Categorify, TabularItem
1314
using DataAugmentation: testitem, testapply, testapply!, testprojective
1415
import DataAugmentation: apply, compose

test/rowtransforms.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
include("imports.jl")
2+
3+
@testset ExtendedTestSet "`NormalizeRow`" begin
4+
cols = [:col1, :col2, :col3]
5+
item = TabularItem(NamedTuple(zip(cols, [1, "a", 10])), cols)
6+
cols_to_normalize = [:col1, :col3]
7+
col1_mean, col1_std = 10, 100
8+
col3_mean, col3_std = 100, 10
9+
normdict = Dict(:col1 => (col1_mean, col1_std), :col3 => (col3_mean, col3_std))
10+
11+
tfm = NormalizeRow(normdict, cols_to_normalize)
12+
testapply(tfm, item)
13+
titem = apply(tfm, item)
14+
@test titem.data[:col1] == (item.data[:col1] - col1_mean)/col1_std
15+
@test titem.data[:col3] == (item.data[:col3] - col3_mean)/col3_std
16+
end
17+
18+
@testset ExtendedTestSet "`FillMissing`" begin
19+
cols = [:col1, :col2, :col3]
20+
item = TabularItem(NamedTuple(zip(cols, [1, missing, missing])), cols)
21+
cols_to_fill = [:col1, :col3]
22+
col1_fmval = 1000.
23+
col2_fmval = "d"
24+
col3_fmval = 1000.
25+
fmdict = Dict(:col1 => col1_fmval, :col2 => col2_fmval, :col3 => col3_fmval)
26+
27+
tfm1 = FillMissing(fmdict, cols_to_fill)
28+
@test_nowarn apply(tfm1, item)
29+
titem = apply(tfm1, item)
30+
@test titem.data[:col1] == coalesce(item.data[:col1], col1_fmval)
31+
@test titem.data[:col3] == coalesce(item.data[:col3], col3_fmval)
32+
@test ismissing(titem.data[:col2])
33+
34+
push!(cols_to_fill, :col2)
35+
tfm2 = FillMissing(fmdict, cols_to_fill)
36+
testapply(tfm2, item)
37+
titem2 = apply(tfm2, item)
38+
@test titem2.data[:col2] == coalesce(item.data[:col2], "d")
39+
end
40+
41+
@testset ExtendedTestSet "`Categorify`" begin
42+
cols = [:col1, :col2, :col3, :col4]
43+
item = TabularItem(NamedTuple(zip(cols, [1, "a", "A", missing])), cols)
44+
cols_to_categorify = [:col2, :col3, :col4]
45+
46+
categorydict = Dict(:col2 => ["a", "b", "c"], :col3 => ["C", "B", "A"], :col4 => [missing, 10, 20])
47+
tfm = Categorify(categorydict, cols_to_categorify)
48+
@test !any(ismissing, tfm.dict[:col4])
49+
@test_nowarn apply(tfm, item)
50+
testapply(tfm, item)
51+
titem = apply(tfm, item)
52+
@test titem.data[:col2] == 2
53+
@test titem.data[:col3] == 4
54+
@test titem.data[:col4] == 1
55+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,7 @@ include("./imports.jl")
4444
@testset ExtendedTestSet "visualization.jl" begin
4545
include("visualization.jl")
4646
end
47+
@testset ExtendedTestSet "rowtransforms.jl" begin
48+
include("rowtransforms.jl")
49+
end
4750
end

0 commit comments

Comments
 (0)