2828
2929#pragma once
3030
31+ #include < cstddef>
3132#include < exception>
3233#include < stdexcept>
34+ #include < utility>
35+ #include < vector>
3336
3437#include < sycl/sycl.hpp>
3538
4346
4447// dpctl tensor headers
4548#include " kernels/alignment.hpp"
46- // #include "kernels/dpctl_tensor_types.hpp"
4749#include " utils/memory_overlap.hpp"
4850#include " utils/offset_utils.hpp"
4951#include " utils/output_validation.hpp"
5052#include " utils/sycl_alloc_utils.hpp"
5153#include " utils/type_dispatch.hpp"
5254
53- namespace py = pybind11;
54- namespace td_ns = dpctl::tensor::type_dispatch;
55-
5655static_assert (std::is_same_v<py::ssize_t , dpctl::tensor::ssize_t >);
5756
5857namespace dpnp ::extensions::py_internal
5958{
59+ namespace py = pybind11;
60+ namespace td_ns = dpctl::tensor::type_dispatch;
6061
6162using dpctl::tensor::kernels::alignment_utils::is_aligned;
6263using dpctl::tensor::kernels::alignment_utils::required_alignment;
@@ -108,10 +109,10 @@ std::pair<sycl::event, sycl::event>
108109 const py::ssize_t *src_shape = src.get_shape_raw ();
109110 const py::ssize_t *dst_shape = dst.get_shape_raw ();
110111 bool shapes_equal (true );
111- size_t src_nelems (1 );
112+ std:: size_t src_nelems (1 );
112113
113114 for (int i = 0 ; i < src_nd; ++i) {
114- src_nelems *= static_cast <size_t >(src_shape[i]);
115+ src_nelems *= static_cast <std:: size_t >(src_shape[i]);
115116 shapes_equal = shapes_equal && (src_shape[i] == dst_shape[i]);
116117 }
117118 if (!shapes_equal) {
@@ -277,6 +278,262 @@ py::object py_unary_ufunc_result_type(const py::dtype &input_dtype,
277278 }
278279}
279280
281+ /* *
282+ * @brief Template implementing Python API for a unary elementwise function
283+ * with two output arrays.
284+ */
285+ template <typename output_typesT,
286+ typename contig_dispatchT,
287+ typename strided_dispatchT>
288+ std::pair<sycl::event, sycl::event>
289+ py_unary_two_outputs_ufunc (const dpctl::tensor::usm_ndarray &src,
290+ const dpctl::tensor::usm_ndarray &dst1,
291+ const dpctl::tensor::usm_ndarray &dst2,
292+ sycl::queue &q,
293+ const std::vector<sycl::event> &depends,
294+ //
295+ const output_typesT &output_type_vec,
296+ const contig_dispatchT &contig_dispatch_vector,
297+ const strided_dispatchT &strided_dispatch_vector)
298+ {
299+ int src_typenum = src.get_typenum ();
300+ int dst1_typenum = dst1.get_typenum ();
301+ int dst2_typenum = dst2.get_typenum ();
302+
303+ const auto &array_types = td_ns::usm_ndarray_types ();
304+ int src_typeid = array_types.typenum_to_lookup_id (src_typenum);
305+ int dst1_typeid = array_types.typenum_to_lookup_id (dst1_typenum);
306+ int dst2_typeid = array_types.typenum_to_lookup_id (dst2_typenum);
307+
308+ std::pair<int , int > func_output_typeids = output_type_vec[src_typeid];
309+
310+ // check that types are supported
311+ if (dst1_typeid != func_output_typeids.first ||
312+ dst2_typeid != func_output_typeids.second )
313+ {
314+ throw py::value_error (
315+ " One of destination arrays has unexpected elemental data type." );
316+ }
317+
318+ // check that queues are compatible
319+ if (!dpctl::utils::queues_are_compatible (q, {src, dst1, dst2})) {
320+ throw py::value_error (
321+ " Execution queue is not compatible with allocation queues" );
322+ }
323+
324+ dpctl::tensor::validation::CheckWritable::throw_if_not_writable (dst1);
325+ dpctl::tensor::validation::CheckWritable::throw_if_not_writable (dst2);
326+
327+ // check that dimensions are the same
328+ int src_nd = src.get_ndim ();
329+ if (src_nd != dst1.get_ndim () || src_nd != dst2.get_ndim ()) {
330+ throw py::value_error (" Array dimensions are not the same." );
331+ }
332+
333+ // check that shapes are the same
334+ const py::ssize_t *src_shape = src.get_shape_raw ();
335+ const py::ssize_t *dst1_shape = dst1.get_shape_raw ();
336+ const py::ssize_t *dst2_shape = dst2.get_shape_raw ();
337+ bool shapes_equal (true );
338+ std::size_t src_nelems (1 );
339+
340+ for (int i = 0 ; i < src_nd; ++i) {
341+ src_nelems *= static_cast <std::size_t >(src_shape[i]);
342+ shapes_equal = shapes_equal && (src_shape[i] == dst1_shape[i]) &&
343+ (src_shape[i] == dst2_shape[i]);
344+ }
345+ if (!shapes_equal) {
346+ throw py::value_error (" Array shapes are not the same." );
347+ }
348+
349+ // if nelems is zero, return
350+ if (src_nelems == 0 ) {
351+ return std::make_pair (sycl::event (), sycl::event ());
352+ }
353+
354+ dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (dst1,
355+ src_nelems);
356+ dpctl::tensor::validation::AmpleMemory::throw_if_not_ample (dst2,
357+ src_nelems);
358+
359+ // check memory overlap
360+ auto const &overlap = dpctl::tensor::overlap::MemoryOverlap ();
361+ auto const &same_logical_tensors =
362+ dpctl::tensor::overlap::SameLogicalTensors ();
363+ if ((overlap (src, dst1) && !same_logical_tensors (src, dst1)) ||
364+ (overlap (src, dst2) && !same_logical_tensors (src, dst2)) ||
365+ (overlap (dst1, dst2) && !same_logical_tensors (dst1, dst2)))
366+ {
367+ throw py::value_error (" Arrays index overlapping segments of memory" );
368+ }
369+
370+ const char *src_data = src.get_data ();
371+ char *dst1_data = dst1.get_data ();
372+ char *dst2_data = dst2.get_data ();
373+
374+ // handle contiguous inputs
375+ bool is_src_c_contig = src.is_c_contiguous ();
376+ bool is_src_f_contig = src.is_f_contiguous ();
377+
378+ bool is_dst1_c_contig = dst1.is_c_contiguous ();
379+ bool is_dst1_f_contig = dst1.is_f_contiguous ();
380+
381+ bool is_dst2_c_contig = dst2.is_c_contiguous ();
382+ bool is_dst2_f_contig = dst2.is_f_contiguous ();
383+
384+ bool all_c_contig =
385+ (is_src_c_contig && is_dst1_c_contig && is_dst2_c_contig);
386+ bool all_f_contig =
387+ (is_src_f_contig && is_dst1_f_contig && is_dst2_f_contig);
388+
389+ if (all_c_contig || all_f_contig) {
390+ auto contig_fn = contig_dispatch_vector[src_typeid];
391+
392+ if (contig_fn == nullptr ) {
393+ throw std::runtime_error (
394+ " Contiguous implementation is missing for src_typeid=" +
395+ std::to_string (src_typeid));
396+ }
397+
398+ auto comp_ev =
399+ contig_fn (q, src_nelems, src_data, dst1_data, dst2_data, depends);
400+ sycl::event ht_ev =
401+ dpctl::utils::keep_args_alive (q, {src, dst1, dst2}, {comp_ev});
402+
403+ return std::make_pair (ht_ev, comp_ev);
404+ }
405+
406+ // simplify iteration space
407+ // if 1d with strides 1 - input is contig
408+ // dispatch to strided
409+
410+ auto const &src_strides = src.get_strides_vector ();
411+ auto const &dst1_strides = dst1.get_strides_vector ();
412+ auto const &dst2_strides = dst2.get_strides_vector ();
413+
414+ using shT = std::vector<py::ssize_t >;
415+ shT simplified_shape;
416+ shT simplified_src_strides;
417+ shT simplified_dst1_strides;
418+ shT simplified_dst2_strides;
419+ py::ssize_t src_offset (0 );
420+ py::ssize_t dst1_offset (0 );
421+ py::ssize_t dst2_offset (0 );
422+
423+ int nd = src_nd;
424+ const py::ssize_t *shape = src_shape;
425+
426+ simplify_iteration_space_3 (
427+ nd, shape, src_strides, dst1_strides, dst2_strides,
428+ // output
429+ simplified_shape, simplified_src_strides, simplified_dst1_strides,
430+ simplified_dst2_strides, src_offset, dst1_offset, dst2_offset);
431+
432+ if (nd == 1 && simplified_src_strides[0 ] == 1 &&
433+ simplified_dst1_strides[0 ] == 1 && simplified_dst2_strides[0 ] == 1 )
434+ {
435+ // Special case of contiguous data
436+ auto contig_fn = contig_dispatch_vector[src_typeid];
437+
438+ if (contig_fn == nullptr ) {
439+ throw std::runtime_error (
440+ " Contiguous implementation is missing for src_typeid=" +
441+ std::to_string (src_typeid));
442+ }
443+
444+ int src_elem_size = src.get_elemsize ();
445+ int dst1_elem_size = dst1.get_elemsize ();
446+ int dst2_elem_size = dst2.get_elemsize ();
447+ auto comp_ev =
448+ contig_fn (q, src_nelems, src_data + src_elem_size * src_offset,
449+ dst1_data + dst1_elem_size * dst1_offset,
450+ dst2_data + dst2_elem_size * dst2_offset, depends);
451+
452+ sycl::event ht_ev =
453+ dpctl::utils::keep_args_alive (q, {src, dst1, dst2}, {comp_ev});
454+
455+ return std::make_pair (ht_ev, comp_ev);
456+ }
457+
458+ // Strided implementation
459+ auto strided_fn = strided_dispatch_vector[src_typeid];
460+
461+ if (strided_fn == nullptr ) {
462+ throw std::runtime_error (
463+ " Strided implementation is missing for src_typeid=" +
464+ std::to_string (src_typeid));
465+ }
466+
467+ using dpctl::tensor::offset_utils::device_allocate_and_pack;
468+
469+ std::vector<sycl::event> host_tasks{};
470+ host_tasks.reserve (2 );
471+
472+ auto ptr_size_event_triple_ = device_allocate_and_pack<py::ssize_t >(
473+ q, host_tasks, simplified_shape, simplified_src_strides,
474+ simplified_dst1_strides, simplified_dst2_strides);
475+ auto shape_strides_owner = std::move (std::get<0 >(ptr_size_event_triple_));
476+ const auto ©_shape_ev = std::get<2 >(ptr_size_event_triple_);
477+ const py::ssize_t *shape_strides = shape_strides_owner.get ();
478+
479+ sycl::event strided_fn_ev = strided_fn (
480+ q, src_nelems, nd, shape_strides, src_data, src_offset, dst1_data,
481+ dst1_offset, dst2_data, dst2_offset, depends, {copy_shape_ev});
482+
483+ // async free of shape_strides temporary
484+ sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free (
485+ q, {strided_fn_ev}, shape_strides_owner);
486+
487+ host_tasks.push_back (tmp_cleanup_ev);
488+
489+ return std::make_pair (
490+ dpctl::utils::keep_args_alive (q, {src, dst1, dst2}, host_tasks),
491+ strided_fn_ev);
492+ }
493+
494+ /* *
495+ * @brief Template implementing Python API for querying of type support by
496+ * a unary elementwise function with two output arrays.
497+ */
498+ template <typename output_typesT>
499+ std::pair<py::object, py::object>
500+ py_unary_two_outputs_ufunc_result_type (const py::dtype &input_dtype,
501+ const output_typesT &output_types)
502+ {
503+ int tn = input_dtype.num (); // NumPy type numbers are the same as in dpctl
504+ int src_typeid = -1 ;
505+
506+ auto array_types = td_ns::usm_ndarray_types ();
507+
508+ try {
509+ src_typeid = array_types.typenum_to_lookup_id (tn);
510+ } catch (const std::exception &e) {
511+ throw py::value_error (e.what ());
512+ }
513+
514+ using type_utils::_result_typeid;
515+ std::pair<int , int > dst_typeids = _result_typeid (src_typeid, output_types);
516+ int dst1_typeid = dst_typeids.first ;
517+ int dst2_typeid = dst_typeids.second ;
518+
519+ if (dst1_typeid < 0 || dst2_typeid < 0 ) {
520+ auto res = py::none ();
521+ auto py_res = py::cast<py::object>(res);
522+ return std::make_pair (py_res, py_res);
523+ }
524+ else {
525+ using type_utils::_dtype_from_typenum;
526+
527+ auto dst1_typenum_t = static_cast <td_ns::typenum_t >(dst1_typeid);
528+ auto dst2_typenum_t = static_cast <td_ns::typenum_t >(dst2_typeid);
529+ auto dt1 = _dtype_from_typenum (dst1_typenum_t );
530+ auto dt2 = _dtype_from_typenum (dst2_typenum_t );
531+
532+ return std::make_pair (py::cast<py::object>(dt1),
533+ py::cast<py::object>(dt2));
534+ }
535+ }
536+
280537// ======================== Binary functions ===========================
281538
282539namespace
@@ -347,10 +604,10 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
347604 const py::ssize_t *src2_shape = src2.get_shape_raw ();
348605 const py::ssize_t *dst_shape = dst.get_shape_raw ();
349606 bool shapes_equal (true );
350- size_t src_nelems (1 );
607+ std:: size_t src_nelems (1 );
351608
352609 for (int i = 0 ; i < dst_nd; ++i) {
353- src_nelems *= static_cast <size_t >(src1_shape[i]);
610+ src_nelems *= static_cast <std:: size_t >(src1_shape[i]);
354611 shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] &&
355612 src2_shape[i] == dst_shape[i]);
356613 }
@@ -456,7 +713,7 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
456713 std::initializer_list<py::ssize_t >{0 , 1 };
457714 static constexpr auto one_zero_strides =
458715 std::initializer_list<py::ssize_t >{1 , 0 };
459- constexpr py::ssize_t one{1 };
716+ static constexpr py::ssize_t one{1 };
460717 // special case of C-contiguous matrix and a row
461718 if (isEqual (simplified_src2_strides, zero_one_strides) &&
462719 isEqual (simplified_src1_strides, {simplified_shape[1 ], one}) &&
@@ -477,8 +734,8 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
477734 is_aligned<required_alignment>(
478735 dst_data + dst_offset * dst_itemsize))
479736 {
480- size_t n0 = simplified_shape[0 ];
481- size_t n1 = simplified_shape[1 ];
737+ std:: size_t n0 = simplified_shape[0 ];
738+ std:: size_t n1 = simplified_shape[1 ];
482739 sycl::event comp_ev = matrix_row_broadcast_fn (
483740 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
484741 src2_data, src2_offset, dst_data, dst_offset,
@@ -511,8 +768,8 @@ std::pair<sycl::event, sycl::event> py_binary_ufunc(
511768 is_aligned<required_alignment>(
512769 dst_data + dst_offset * dst_itemsize))
513770 {
514- size_t n0 = simplified_shape[1 ];
515- size_t n1 = simplified_shape[0 ];
771+ std:: size_t n0 = simplified_shape[1 ];
772+ std:: size_t n1 = simplified_shape[0 ];
516773 sycl::event comp_ev = row_matrix_broadcast_fn (
517774 exec_q, host_tasks, n0, n1, src1_data, src1_offset,
518775 src2_data, src2_offset, dst_data, dst_offset,
@@ -655,10 +912,10 @@ std::pair<sycl::event, sycl::event>
655912 const py::ssize_t *rhs_shape = rhs.get_shape_raw ();
656913 const py::ssize_t *lhs_shape = lhs.get_shape_raw ();
657914 bool shapes_equal (true );
658- size_t rhs_nelems (1 );
915+ std:: size_t rhs_nelems (1 );
659916
660917 for (int i = 0 ; i < lhs_nd; ++i) {
661- rhs_nelems *= static_cast <size_t >(rhs_shape[i]);
918+ rhs_nelems *= static_cast <std:: size_t >(rhs_shape[i]);
662919 shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]);
663920 }
664921 if (!shapes_equal) {
@@ -749,7 +1006,7 @@ std::pair<sycl::event, sycl::event>
7491006 if (nd == 2 ) {
7501007 static constexpr auto one_zero_strides =
7511008 std::initializer_list<py::ssize_t >{1 , 0 };
752- constexpr py::ssize_t one{1 };
1009+ static constexpr py::ssize_t one{1 };
7531010 // special case of C-contiguous matrix and a row
7541011 if (isEqual (simplified_rhs_strides, one_zero_strides) &&
7551012 isEqual (simplified_lhs_strides, {one, simplified_shape[0 ]}))
@@ -758,8 +1015,8 @@ std::pair<sycl::event, sycl::event>
7581015 contig_row_matrix_broadcast_dispatch_table[rhs_typeid]
7591016 [lhs_typeid];
7601017 if (row_matrix_broadcast_fn != nullptr ) {
761- size_t n0 = simplified_shape[1 ];
762- size_t n1 = simplified_shape[0 ];
1018+ std:: size_t n0 = simplified_shape[1 ];
1019+ std:: size_t n1 = simplified_shape[0 ];
7631020 sycl::event comp_ev = row_matrix_broadcast_fn (
7641021 exec_q, host_tasks, n0, n1, rhs_data, rhs_offset,
7651022 lhs_data, lhs_offset, depends);
@@ -805,5 +1062,4 @@ std::pair<sycl::event, sycl::event>
8051062 dpctl::utils::keep_args_alive (exec_q, {rhs, lhs}, host_tasks),
8061063 strided_fn_ev);
8071064}
808-
8091065} // namespace dpnp::extensions::py_internal
0 commit comments