@@ -63,7 +63,7 @@ julia> augment_batch_dim(A, 3)
6363```
6464"""
6565function augment_batch_dim (input:: AbstractArray{T,N} , n) where {T,N}
66- return repeat (input; inner= (ntuple (_ -> 1 , Val ( N - 1 ) )... , n))
66+ return repeat (input; inner= (ntuple (Returns ( 1 ), N - 1 )... , n))
6767end
6868
6969"""
7272Reduce augmented input batch by averaging the explanation for each augmented sample.
7373"""
7474function reduce_augmentation (input:: AbstractArray{T,N} , n) where {T<: AbstractFloat ,N}
75- return cat (
76- (
77- Iterators. map (1 : n: size (input, N)) do i
78- augmentation_range = ntuple (_ -> :, Val (N - 1 ))... , i: (i + n - 1 )
79- sum (view (input, augmentation_range... ); dims= N) / n
80- end
81- ). .. ; dims= N
82- ):: Array{T,N}
75+ # Allocate output array
76+ in_size = size (input)
77+ in_size[end ] % n != 0 &&
78+ throw (ArgumentError (" Can't reduce augmented batch size of $(in_size[end ]) by $n " ))
79+ out_size = (in_size[1 : (end - 1 )]. .. , div (in_size[end ], n))
80+ out = similar (input, eltype (input), out_size)
81+
82+ axs = axes (input, N)
83+ inds_before_N = ntuple (Returns (:), N - 1 )
84+ for (i, ax) in enumerate (first (axs): n: last (axs))
85+ view (out, inds_before_N... , i) .=
86+ sum (view (input, inds_before_N... , ax: (ax + n - 1 )); dims= N) / n
87+ end
88+ return out
8389end
8490"""
8591 augment_indices(indices, n)
0 commit comments