@@ -261,23 +261,40 @@ end
261261Initialize values from the gpu data
262262"""
263263function init_marginal_gpu (data, reuse, num_nodes; Float= Float32)
264- flowtype = isgpu (data) ? CuMatrix{Float} : Matrix {Float}
264+ flowtype = CuMatrix{Float}
265265 values = similar! (reuse, flowtype, num_examples (data), num_nodes)
266266 @views values[:, LogicCircuits. TRUE_BITS] .= log (one (Float))
267267 @views values[:, LogicCircuits. FALSE_BITS] .= log (zero (Float))
268- # TODO ;;; here we should use a custom CUDA kernel to extract Float marginals from bit vectors
269- # for now the lazy solution is to move everything to the CPU and do the work there...
270- data_cpu = to_cpu (data)
268+
271269 nfeatures = num_features (data)
270+ num_data = size (values, 1 )
271+
272+ # Option 1; not possible rn cause cannot pass datafame to cuda kernel
273+ # kernel = @cuda name="init_marginal_cuda" launch=false init_marginal_cuda(values, data, nfeatures)
274+ # config = launch_configuration(kernel.fun)
275+ # threads, blocks = balance_threads_2d(num_data, nfeatures, config.threads)
276+ # kernel(values, data, nfeatures; threads, blocks)
277+
278+
279+ # # option 2 - still slow
272280 for i= 1 : nfeatures
273- marg_pos:: Vector{Float} = log .(coalesce .(data_cpu[:,i], one (Float)))
274- marg_neg:: Vector{Float} = log .(coalesce .(1.0 .- data_cpu[:,i], one (Float)))
275- values[:,2 + i] .= same_device (marg_pos, values)
276- values[:,2 + nfeatures+ i] .= same_device (marg_neg, values)
281+ @views values[:, 2 + i] .= log .(coalesce .(data[:, i], one (Float)))
282+ @views values[:, 2 + i + nfeatures] .= log .(coalesce .(1.0 .- data[:, i], one (Float)))
277283 end
284+
285+ # Option 3 - very slow
286+ # TODO ;;; here we should use a custom CUDA kernel to extract Float marginals from bit vectors
287+ # for now the lazy solution is to move everything to the CPU and do the work there...
288+ # data_cpu = to_cpu(data)
289+ # nfeatures = num_features(data)
290+ # for i=1:nfeatures
291+ # marg_pos::Vector{Float} = log.(coalesce.(data_cpu[:,i], one(Float)))
292+ # marg_neg::Vector{Float} = log.(coalesce.(1.0 .- data_cpu[:,i], one(Float)))
293+ # values[:,2+i] .= same_device(marg_pos, values)
294+ # values[:,2+nfeatures+i] .= same_device(marg_neg, values)
295+ # end
278296 return values
279297end
280-
281298# upward pass helpers on CPU
282299
283300" Compute marginals on the CPU (SIMD & multi-threaded)"
0 commit comments