Skip to content

Commit fea1477

Browse files
committed
Add unary implementation for ufunc with two output arrays
1 parent c086482 commit fea1477

File tree

2 files changed

+290
-35
lines changed

2 files changed

+290
-35
lines changed

dpnp/backend/extensions/elementwise_functions/elementwise_functions.hpp

Lines changed: 275 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@
2828

2929
#pragma once
3030

31+
#include <cstddef>
3132
#include <exception>
3233
#include <stdexcept>
34+
#include <utility>
35+
#include <vector>
3336

3437
#include <sycl/sycl.hpp>
3538

@@ -43,20 +46,18 @@
4346

4447
// dpctl tensor headers
4548
#include "kernels/alignment.hpp"
46-
// #include "kernels/dpctl_tensor_types.hpp"
4749
#include "utils/memory_overlap.hpp"
4850
#include "utils/offset_utils.hpp"
4951
#include "utils/output_validation.hpp"
5052
#include "utils/sycl_alloc_utils.hpp"
5153
#include "utils/type_dispatch.hpp"
5254

53-
namespace py = pybind11;
54-
namespace td_ns = dpctl::tensor::type_dispatch;
55-
5655
static_assert(std::is_same_v<py::ssize_t, dpctl::tensor::ssize_t>);
5756

5857
namespace dpnp::extensions::py_internal
5958
{
59+
namespace py = pybind11;
60+
namespace td_ns = dpctl::tensor::type_dispatch;
6061

6162
using dpctl::tensor::kernels::alignment_utils::is_aligned;
6263
using dpctl::tensor::kernels::alignment_utils::required_alignment;
@@ -108,10 +109,10 @@ std::pair<sycl::event, sycl::event>
108109
const py::ssize_t *src_shape = src.get_shape_raw();
109110
const py::ssize_t *dst_shape = dst.get_shape_raw();
110111
bool shapes_equal(true);
111-
size_t src_nelems(1);
112+
std::size_t src_nelems(1);
112113

113114
for (int i = 0; i < src_nd; ++i) {
114-
src_nelems *= static_cast<size_t>(src_shape[i]);
115+
src_nelems *= static_cast<std::size_t>(src_shape[i]);
115116
shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
116117
}
117118
if (!shapes_equal) {
@@ -277,6 +278,262 @@ py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
277278
}
278279
}
279280

281+
/**
282+
* @brief Template implementing Python API for a unary elementwise function
283+
* with two output arrays.
284+
*/
285+
template <typename output_typesT,
286+
typename contig_dispatchT,
287+
typename strided_dispatchT>
288+
std::pair<sycl::event, sycl::event>
289+
py_unary_two_outputs_ufunc(const dpctl::tensor::usm_ndarray &src,
290+
const dpctl::tensor::usm_ndarray &dst1,
291+
const dpctl::tensor::usm_ndarray &dst2,
292+
sycl::queue &q,
293+
const std::vector<sycl::event> &depends,
294+
//
295+
const output_typesT &output_type_vec,
296+
const contig_dispatchT &contig_dispatch_vector,
297+
const strided_dispatchT &strided_dispatch_vector)
298+
{
299+
int src_typenum = src.get_typenum();
300+
int dst1_typenum = dst1.get_typenum();
301+
int dst2_typenum = dst2.get_typenum();
302+
303+
const auto &array_types = td_ns::usm_ndarray_types();
304+
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
305+
int dst1_typeid = array_types.typenum_to_lookup_id(dst1_typenum);
306+
int dst2_typeid = array_types.typenum_to_lookup_id(dst2_typenum);
307+
308+
std::pair<int, int> func_output_typeids = output_type_vec[src_typeid];
309+
310+
// check that types are supported
311+
if (dst1_typeid != func_output_typeids.first ||
312+
dst2_typeid != func_output_typeids.second)
313+
{
314+
throw py::value_error(
315+
"One of destination arrays has unexpected elemental data type.");
316+
}
317+
318+
// check that queues are compatible
319+
if (!dpctl::utils::queues_are_compatible(q, {src, dst1, dst2})) {
320+
throw py::value_error(
321+
"Execution queue is not compatible with allocation queues");
322+
}
323+
324+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst1);
325+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst2);
326+
327+
// check that dimensions are the same
328+
int src_nd = src.get_ndim();
329+
if (src_nd != dst1.get_ndim() || src_nd != dst2.get_ndim()) {
330+
throw py::value_error("Array dimensions are not the same.");
331+
}
332+
333+
// check that shapes are the same
334+
const py::ssize_t *src_shape = src.get_shape_raw();
335+
const py::ssize_t *dst1_shape = dst1.get_shape_raw();
336+
const py::ssize_t *dst2_shape = dst2.get_shape_raw();
337+
bool shapes_equal(true);
338+
std::size_t src_nelems(1);
339+
340+
for (int i = 0; i < src_nd; ++i) {
341+
src_nelems *= static_cast<std::size_t>(src_shape[i]);
342+
shapes_equal = shapes_equal && (src_shape[i] == dst1_shape[i]) &&
343+
(src_shape[i] == dst2_shape[i]);
344+
}
345+
if (!shapes_equal) {
346+
throw py::value_error("Array shapes are not the same.");
347+
}
348+
349+
// if nelems is zero, return
350+
if (src_nelems == 0) {
351+
return std::make_pair(sycl::event(), sycl::event());
352+
}
353+
354+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst1,
355+
src_nelems);
356+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst2,
357+
src_nelems);
358+
359+
// check memory overlap
360+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
361+
auto const &same_logical_tensors =
362+
dpctl::tensor::overlap::SameLogicalTensors();
363+
if ((overlap(src, dst1) && !same_logical_tensors(src, dst1)) ||
364+
(overlap(src, dst2) && !same_logical_tensors(src, dst2)) ||
365+
(overlap(dst1, dst2) && !same_logical_tensors(dst1, dst2)))
366+
{
367+
throw py::value_error("Arrays index overlapping segments of memory");
368+
}
369+
370+
const char *src_data = src.get_data();
371+
char *dst1_data = dst1.get_data();
372+
char *dst2_data = dst2.get_data();
373+
374+
// handle contiguous inputs
375+
bool is_src_c_contig = src.is_c_contiguous();
376+
bool is_src_f_contig = src.is_f_contiguous();
377+
378+
bool is_dst1_c_contig = dst1.is_c_contiguous();
379+
bool is_dst1_f_contig = dst1.is_f_contiguous();
380+
381+
bool is_dst2_c_contig = dst2.is_c_contiguous();
382+
bool is_dst2_f_contig = dst2.is_f_contiguous();
383+
384+
bool all_c_contig =
385+
(is_src_c_contig && is_dst1_c_contig && is_dst2_c_contig);
386+
bool all_f_contig =
387+
(is_src_f_contig && is_dst1_f_contig && is_dst2_f_contig);
388+
389+
if (all_c_contig || all_f_contig) {
390+
auto contig_fn = contig_dispatch_vector[src_typeid];
391+
392+
if (contig_fn == nullptr) {
393+
throw std::runtime_error(
394+
"Contiguous implementation is missing for src_typeid=" +
395+
std::to_string(src_typeid));
396+
}
397+
398+
auto comp_ev =
399+
contig_fn(q, src_nelems, src_data, dst1_data, dst2_data, depends);
400+
sycl::event ht_ev =
401+
dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
402+
403+
return std::make_pair(ht_ev, comp_ev);
404+
}
405+
406+
// simplify iteration space
407+
// if 1d with strides 1 - input is contig
408+
// dispatch to strided
409+
410+
auto const &src_strides = src.get_strides_vector();
411+
auto const &dst1_strides = dst1.get_strides_vector();
412+
auto const &dst2_strides = dst2.get_strides_vector();
413+
414+
using shT = std::vector<py::ssize_t>;
415+
shT simplified_shape;
416+
shT simplified_src_strides;
417+
shT simplified_dst1_strides;
418+
shT simplified_dst2_strides;
419+
py::ssize_t src_offset(0);
420+
py::ssize_t dst1_offset(0);
421+
py::ssize_t dst2_offset(0);
422+
423+
int nd = src_nd;
424+
const py::ssize_t *shape = src_shape;
425+
426+
simplify_iteration_space_3(
427+
nd, shape, src_strides, dst1_strides, dst2_strides,
428+
// output
429+
simplified_shape, simplified_src_strides, simplified_dst1_strides,
430+
simplified_dst2_strides, src_offset, dst1_offset, dst2_offset);
431+
432+
if (nd == 1 && simplified_src_strides[0] == 1 &&
433+
simplified_dst1_strides[0] == 1 && simplified_dst2_strides[0] == 1)
434+
{
435+
// Special case of contiguous data
436+
auto contig_fn = contig_dispatch_vector[src_typeid];
437+
438+
if (contig_fn == nullptr) {
439+
throw std::runtime_error(
440+
"Contiguous implementation is missing for src_typeid=" +
441+
std::to_string(src_typeid));
442+
}
443+
444+
int src_elem_size = src.get_elemsize();
445+
int dst1_elem_size = dst1.get_elemsize();
446+
int dst2_elem_size = dst2.get_elemsize();
447+
auto comp_ev =
448+
contig_fn(q, src_nelems, src_data + src_elem_size * src_offset,
449+
dst1_data + dst1_elem_size * dst1_offset,
450+
dst2_data + dst2_elem_size * dst2_offset, depends);
451+
452+
sycl::event ht_ev =
453+
dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, {comp_ev});
454+
455+
return std::make_pair(ht_ev, comp_ev);
456+
}
457+
458+
// Strided implementation
459+
auto strided_fn = strided_dispatch_vector[src_typeid];
460+
461+
if (strided_fn == nullptr) {
462+
throw std::runtime_error(
463+
"Strided implementation is missing for src_typeid=" +
464+
std::to_string(src_typeid));
465+
}
466+
467+
using dpctl::tensor::offset_utils::device_allocate_and_pack;
468+
469+
std::vector<sycl::event> host_tasks{};
470+
host_tasks.reserve(2);
471+
472+
auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t>(
473+
q, host_tasks, simplified_shape, simplified_src_strides,
474+
simplified_dst1_strides, simplified_dst2_strides);
475+
auto shape_strides_owner = std::move(std::get<0>(ptr_size_event_triple_));
476+
const auto &copy_shape_ev = std::get<2>(ptr_size_event_triple_);
477+
const py::ssize_t *shape_strides = shape_strides_owner.get();
478+
479+
sycl::event strided_fn_ev = strided_fn(
480+
q, src_nelems, nd, shape_strides, src_data, src_offset, dst1_data,
481+
dst1_offset, dst2_data, dst2_offset, depends, {copy_shape_ev});
482+
483+
// async free of shape_strides temporary
484+
sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free(
485+
q, {strided_fn_ev}, shape_strides_owner);
486+
487+
host_tasks.push_back(tmp_cleanup_ev);
488+
489+
return std::make_pair(
490+
dpctl::utils::keep_args_alive(q, {src, dst1, dst2}, host_tasks),
491+
strided_fn_ev);
492+
}
493+
494+
/**
495+
* @brief Template implementing Python API for querying of type support by
496+
* a unary elementwise function with two output arrays.
497+
*/
498+
template <typename output_typesT>
499+
std::pair<py::object, py::object>
500+
py_unary_two_outputs_ufunc_result_type(const py::dtype &input_dtype,
501+
const output_typesT &output_types)
502+
{
503+
int tn = input_dtype.num(); // NumPy type numbers are the same as in dpctl
504+
int src_typeid = -1;
505+
506+
auto array_types = td_ns::usm_ndarray_types();
507+
508+
try {
509+
src_typeid = array_types.typenum_to_lookup_id(tn);
510+
} catch (const std::exception &e) {
511+
throw py::value_error(e.what());
512+
}
513+
514+
using type_utils::_result_typeid;
515+
std::pair<int, int> dst_typeids = _result_typeid(src_typeid, output_types);
516+
int dst1_typeid = dst_typeids.first;
517+
int dst2_typeid = dst_typeids.second;
518+
519+
if (dst1_typeid < 0 || dst2_typeid < 0) {
520+
auto res = py::none();
521+
auto py_res = py::cast<py::object>(res);
522+
return std::make_pair(py_res, py_res);
523+
}
524+
else {
525+
using type_utils::_dtype_from_typenum;
526+
527+
auto dst1_typenum_t = static_cast<td_ns::typenum_t>(dst1_typeid);
528+
auto dst2_typenum_t = static_cast<td_ns::typenum_t>(dst2_typeid);
529+
auto dt1 = _dtype_from_typenum(dst1_typenum_t);
530+
auto dt2 = _dtype_from_typenum(dst2_typenum_t);
531+
532+
return std::make_pair(py::cast<py::object>(dt1),
533+
py::cast<py::object>(dt2));
534+
}
535+
}
536+
280537
// ======================== Binary functions ===========================
281538

282539
namespace
@@ -347,10 +604,10 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
347604
const py::ssize_t *src2_shape = src2.get_shape_raw();
348605
const py::ssize_t *dst_shape = dst.get_shape_raw();
349606
bool shapes_equal(true);
350-
size_t src_nelems(1);
607+
std::size_t src_nelems(1);
351608

352609
for (int i = 0; i < dst_nd; ++i) {
353-
src_nelems *= static_cast<size_t>(src1_shape[i]);
610+
src_nelems *= static_cast<std::size_t>(src1_shape[i]);
354611
shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
355612
src2_shape[i] == dst_shape[i]);
356613
}
@@ -456,7 +713,7 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
456713
std::initializer_list<py::ssize_t>{0, 1};
457714
static constexpr auto one_zero_strides =
458715
std::initializer_list<py::ssize_t>{1, 0};
459-
constexpr py::ssize_t one{1};
716+
static constexpr py::ssize_t one{1};
460717
// special case of C-contiguous matrix and a row
461718
if (isEqual(simplified_src2_strides, zero_one_strides) &&
462719
isEqual(simplified_src1_strides, {simplified_shape[1], one}) &&
@@ -477,8 +734,8 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
477734
is_aligned<required_alignment>(
478735
dst_data + dst_offset * dst_itemsize))
479736
{
480-
size_t n0 = simplified_shape[0];
481-
size_t n1 = simplified_shape[1];
737+
std::size_t n0 = simplified_shape[0];
738+
std::size_t n1 = simplified_shape[1];
482739
sycl::event comp_ev = matrix_row_broadcast_fn(
483740
exec_q, host_tasks, n0, n1, src1_data, src1_offset,
484741
src2_data, src2_offset, dst_data, dst_offset,
@@ -511,8 +768,8 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
511768
is_aligned<required_alignment>(
512769
dst_data + dst_offset * dst_itemsize))
513770
{
514-
size_t n0 = simplified_shape[1];
515-
size_t n1 = simplified_shape[0];
771+
std::size_t n0 = simplified_shape[1];
772+
std::size_t n1 = simplified_shape[0];
516773
sycl::event comp_ev = row_matrix_broadcast_fn(
517774
exec_q, host_tasks, n0, n1, src1_data, src1_offset,
518775
src2_data, src2_offset, dst_data, dst_offset,
@@ -655,10 +912,10 @@ std::pair<sycl::event, sycl::event>
655912
const py::ssize_t *rhs_shape = rhs.get_shape_raw();
656913
const py::ssize_t *lhs_shape = lhs.get_shape_raw();
657914
bool shapes_equal(true);
658-
size_t rhs_nelems(1);
915+
std::size_t rhs_nelems(1);
659916

660917
for (int i = 0; i < lhs_nd; ++i) {
661-
rhs_nelems *= static_cast<size_t>(rhs_shape[i]);
918+
rhs_nelems *= static_cast<std::size_t>(rhs_shape[i]);
662919
shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]);
663920
}
664921
if (!shapes_equal) {
@@ -749,7 +1006,7 @@ std::pair<sycl::event, sycl::event>
7491006
if (nd == 2) {
7501007
static constexpr auto one_zero_strides =
7511008
std::initializer_list<py::ssize_t>{1, 0};
752-
constexpr py::ssize_t one{1};
1009+
static constexpr py::ssize_t one{1};
7531010
// special case of C-contiguous matrix and a row
7541011
if (isEqual(simplified_rhs_strides, one_zero_strides) &&
7551012
isEqual(simplified_lhs_strides, {one, simplified_shape[0]}))
@@ -758,8 +1015,8 @@ std::pair<sycl::event, sycl::event>
7581015
contig_row_matrix_broadcast_dispatch_table[rhs_typeid]
7591016
[lhs_typeid];
7601017
if (row_matrix_broadcast_fn != nullptr) {
761-
size_t n0 = simplified_shape[1];
762-
size_t n1 = simplified_shape[0];
1018+
std::size_t n0 = simplified_shape[1];
1019+
std::size_t n1 = simplified_shape[0];
7631020
sycl::event comp_ev = row_matrix_broadcast_fn(
7641021
exec_q, host_tasks, n0, n1, rhs_data, rhs_offset,
7651022
lhs_data, lhs_offset, depends);
@@ -805,5 +1062,4 @@ std::pair<sycl::event, sycl::event>
8051062
dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, host_tasks),
8061063
strided_fn_ev);
8071064
}
808-
8091065
} // namespace dpnp::extensions::py_internal

0 commit comments

Comments
 (0)