Skip to content

Commit bce0c65

Browse files
committed
try to make init_marginal_gpu faster
1 parent 912ad7e commit bce0c65

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

src/queries/marginal_flow.jl

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -261,23 +261,40 @@ end
261261
Initialize values from the gpu data
262262
"""
263263
function init_marginal_gpu(data, reuse, num_nodes; Float=Float32)
264-
flowtype = isgpu(data) ? CuMatrix{Float} : Matrix{Float}
264+
flowtype = CuMatrix{Float}
265265
values = similar!(reuse, flowtype, num_examples(data), num_nodes)
266266
@views values[:, LogicCircuits.TRUE_BITS] .= log(one(Float))
267267
@views values[:, LogicCircuits.FALSE_BITS] .= log(zero(Float))
268-
# TODO;;; here we should use a custom CUDA kernel to extract Float marginals from bit vectors
269-
# for now the lazy solution is to move everything to the CPU and do the work there...
270-
data_cpu = to_cpu(data)
268+
271269
nfeatures = num_features(data)
270+
num_data = size(values, 1)
271+
272+
# Option 1; not possible rn cause cannot pass datafame to cuda kernel
273+
# kernel = @cuda name="init_marginal_cuda" launch=false init_marginal_cuda(values, data, nfeatures)
274+
# config = launch_configuration(kernel.fun)
275+
# threads, blocks = balance_threads_2d(num_data, nfeatures, config.threads)
276+
# kernel(values, data, nfeatures; threads, blocks)
277+
278+
279+
## option 2 - still slow
272280
for i=1:nfeatures
273-
marg_pos::Vector{Float} = log.(coalesce.(data_cpu[:,i], one(Float)))
274-
marg_neg::Vector{Float} = log.(coalesce.(1.0 .- data_cpu[:,i], one(Float)))
275-
values[:,2+i] .= same_device(marg_pos, values)
276-
values[:,2+nfeatures+i] .= same_device(marg_neg, values)
281+
@views values[:, 2 + i] .= log.(coalesce.(data[:, i], one(Float)))
282+
@views values[:, 2 + i + nfeatures] .= log.(coalesce.(1.0 .- data[:, i], one(Float)))
277283
end
284+
285+
# Option 3 - very slow
286+
# TODO;;; here we should use a custom CUDA kernel to extract Float marginals from bit vectors
287+
# for now the lazy solution is to move everything to the CPU and do the work there...
288+
# data_cpu = to_cpu(data)
289+
# nfeatures = num_features(data)
290+
# for i=1:nfeatures
291+
# marg_pos::Vector{Float} = log.(coalesce.(data_cpu[:,i], one(Float)))
292+
# marg_neg::Vector{Float} = log.(coalesce.(1.0 .- data_cpu[:,i], one(Float)))
293+
# values[:,2+i] .= same_device(marg_pos, values)
294+
# values[:,2+nfeatures+i] .= same_device(marg_neg, values)
295+
# end
278296
return values
279297
end
280-
281298
# upward pass helpers on CPU
282299

283300
"Compute marginals on the CPU (SIMD & multi-threaded)"

0 commit comments

Comments
 (0)