@@ -166,23 +166,58 @@ function apply!(buf, ::ImageToTensor, image::Image; randstate = nothing)
166166end
167167
168168function imagetotensor (image:: AbstractArray{C, N} , T = Float32) where {C<: Color , N}
169- T .(permuteddimsview (channelview (image), ((i for i in 2 : N+ 1 ). .. , 1 )))
169+ T .(permuteddimsview (_channelview (image), ((i for i in 2 : N+ 1 ). .. , 1 )))
170170end
171171
172+ #=
172173function imagetotensor(image::AbstractArray{C, N}, T = Float32) where {TC, C<:Color{TC, 1}, N}
173- return T .(channelview (image))
174+ return T.(_channelview (image))
174175end
176+ =#
175177
176178
177- function imagetotensor! (buf, image:: AbstractArray{<:AbstractRGB, N} ) where N
179+ # TODO : relax color type constraint, implement for other colors
180+ # single-channel colors need a `channelview` that also expands the array
181+ function imagetotensor! (buf, image:: AbstractArray{<:Color, N} ) where N
178182 permutedims! (
179183 buf,
180- channelview (image),
181- (2 , 3 , 1 ))
184+ _channelview (image),
185+ (2 : N + 1 ... , 1 ))
182186end
183- tensortoimage (a:: AbstractArray{T, 3} ) where T = colorview (RGB, permuteddimsview (a, (3 , 1 , 2 )))
184- tensortoimage (a:: AbstractArray{T, 2} ) where T = colorview (Gray, a)
185187
188+ function tensortoimage (a:: AbstractArray )
189+ nchannels = size (a)[end ]
190+ if nchannels == 3
191+ return tensortoimage (RGB, a)
192+ elseif nchannels == 1
193+ return tensortoimage (Gray, a)
194+ else
195+ error (" Found image tensor with $nchannels color channels. Pass in color type
196+ explicitly." )
197+ end
198+ end
199+
200+ function tensortoimage (C:: Type{<:Color} , a:: AbstractArray{T, N} ) where {T, N}
201+ perm = (N, 1 : N- 1 ... )
202+ return _colorview (C, permuteddimsview (a, perm))
203+ end
204+
205+
206+ function _channelview (img)
207+ chview = channelview (img)
208+ # for single-channel colors, expand the color dimension anyway
209+ if size (img) == size (chview)
210+ chview = reshape (chview, 1 , size (chview)... )
211+ end
212+ return chview
213+ end
214+
215+ function _colorview (C:: Type{<:Color} , img) where T
216+ if size (img, 1 ) == 1
217+ img = reshape (img, size (img)[2 : end ])
218+ end
219+ return colorview (C, img)
220+ end
186221
187222# OneHot encoding
188223
0 commit comments