Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dace-cartesian = [
'dace>=1.0.2' # refined in [tool.uv.sources]
]
dace-next = [
'dace==43!2026.04.27' # uses custom index at 'https://github.com/GridTools/pypi'
'dace==2.3.5' # uses custom index at 'https://github.com/GridTools/pypi'
]
dev = [
{include-group = 'build'},
Expand Down Expand Up @@ -486,7 +486,7 @@ url = 'https://gridtools.github.io/pypi/'
atlas4py = {index = "test.pypi"}
dace = [
{git = "https://github.com/GridTools/dace", branch = "romanc/stree-v2", group = "dace-cartesian"},
{index = "gridtools", group = "dace-next"}
{git = "https://github.com/philip-paul-mueller/dace", branch = "phimuell__new-gpu-codegen-dev", group = "dace-next"}
]

# -- versioningit --
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,12 +709,7 @@ def _visit_if_branch_arg(
inner_desc = arg_desc.clone()
inner_desc.transient = False
elif isinstance(arg.gt_dtype, ts.ScalarType):
if isinstance(arg, MemletExpr) and len(arg.gt_field.dims) == 1:
# TODO(edopao): we cannot use a scalar because of an issue in gpu codegen,
# which leads to compilation error: cannot convert 'const double' to 'const double*'
inner_desc = dace.data.Array(dtype=arg_desc.dtype, shape=(1,))
else:
inner_desc = dace.data.Scalar(arg_desc.dtype)
inner_desc = dace.data.Scalar(arg_desc.dtype)
else:
# for list of values, we retrieve the local size from the corresponding offset
local_dim = arg.gt_dtype.offset_type
Expand Down Expand Up @@ -837,16 +832,9 @@ def _visit_if_branch_result(
# If the result is currently written to a transient node, inside the nested SDFG,
# we need to allocate a non-transient data node.
result_desc = edge.result.dc_node.desc(sdfg)
if isinstance(sym.type, ts.ScalarType) and isinstance(result_desc, dace.data.Array):
# TODO(edopao): a scalar should not be represented as an array, but
# currently this can happen because of an issue workaround, see the
# todo comment above in `_visit_if_branch_arg()`.
assert len(result_desc.shape) == 1 and result_desc.shape[0] == 1
_, output_desc = sdfg.add_scalar(output_data, result_desc.dtype)
else:
output_desc = result_desc.clone()
output_desc.transient = False
output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True)
output_desc = result_desc.clone()
output_desc.transient = False
output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True)
output_node = state.add_access(output_data)
state.add_nedge(
edge.result.dc_node,
Expand Down
31 changes: 14 additions & 17 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading