From 166c9b7909c45b207fea3b9b8cab8468981dbc19 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Wed, 18 Mar 2026 14:00:36 +0000 Subject: [PATCH 01/22] Some starting to def use chains update --- .../psyir/tools/definition_use_chains.py | 107 +++++++++++------- ...tion_use_chains_forward_dependence_test.py | 28 +++++ 2 files changed, 96 insertions(+), 39 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index d74f782103..06f31faad6 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -67,8 +67,8 @@ class DefinitionUseChain: """The DefinitionUseChain class is used to find nodes in a tree that have data dependencies on the provided reference. - :param reference: The Reference for which the dependencies will be - computed. + :param references: The References for which the dependencies will be + computed. :param control_flow_region: Optional region to search for data dependencies. Default is the parent Routine or the root of the tree's children if no ancestor @@ -79,25 +79,47 @@ class DefinitionUseChain: dependency search. :raises TypeError: If one of the arguments is the wrong type. - + :raises InternalError: If not all of the references have the same parent. """ def __init__( self, - reference: Reference, + references: list[Reference], control_flow_region: Iterable[Node] = (), start_point: Optional[int] = None, stop_point: Optional[int] = None, ): - if not isinstance(reference, Reference): + if not isinstance(references, list): raise TypeError( - f"The 'reference' argument passed into a DefinitionUseChain " - f"must be a Reference but found " - f"'{type(reference).__name__}'." + f"The 'references' argument passed into a DefinitionUseChain " + f"must be a list of References but found " + f"{type(references).__name__}'." ) - self._reference = reference - # Store the absolute position for later. - self._reference_abs_pos = reference.abs_position + for ref in references: + if not isinstance(ref, Reference): + raise TypeError( + f"The 'references' argument passed into a " + f"DefinitionUseChain must be a list of References " + f"but found '{type(ref).__name__}' in the list." + ) + # We need all the references to have the same parent. + parent = references[0].parent + for ref in references: + if ref.parent is not parent: + raise InternalError( + f"All references provided into a DefinitionUseChain " + f"must have the same parent." + ) + self._references = references + # Store the absolute positions and signatures for later. + self._reference_signatures = [] + self._references_abs_pos = {} + self._references[0].compute_cached_abs_positions() + for ref in references: + sig, _ = ref.get_signature_and_indices() + self._reference_signatures.append(sig) + self._references_abs_pos[sig] = ref.abs_position + # To enable loops to work correctly we can set the start/stop point # and not just use base it on the reference's absolute position if start_point and not isinstance(start_point, int): @@ -115,9 +137,9 @@ def __init__( self._start_point = start_point self._stop_point = stop_point if not control_flow_region: - self._scope = [reference.ancestor(Routine)] + self._scope = [references[0].ancestor(Routine)] if self._scope[0] is None: - self._scope = reference.root.children[:] + self._scope = references[0].root.children[:] else: # We need a list of regions for control flow. if not isinstance(control_flow_region, list): @@ -135,34 +157,34 @@ def __init__( self._scope = control_flow_region # The uses, defsout and killed sets as defined for each basic block. - self._uses = [] - self._defsout = [] - self._killed = [] + self._uses = {} + self._defsout = {} + self._killed = {} # The output map, mapping between nodes and the reach of that node. - self._reaches = [] + self._reaches = {} @property - def uses(self) -> list[Node]: + def uses(self) -> dict[list[Node]]: """ - :returns: the list of nodes using the value that the referenced symbol + :returns: the lists of nodes using the value that the referenced symbols has before it is reassigned. """ return self._uses @property - def defsout(self) -> list[Node]: + def defsout(self) -> dict[list[Node]]: """ - :returns: the list of nodes that reach the end of the block without + :returns: the lists of nodes that reach the end of the block without being killed, and therefore can have dependencies outside of this block. """ return self._defsout @property - def killed(self) -> list[Node]: + def killed(self) -> dict[list[Node]]: """ - :returns: the list of nodes that represent the last use of an assigned + :returns: the lists of nodes that represent the last use of an assigned variable. Calling next_access on any of these nodes will find a write that reassigns it's value. """ @@ -208,7 +230,10 @@ def find_forward_accesses(self) -> list[Node]: # If there is no set start point, then we look for all # accesses after the Reference. if self._start_point is None: - self._start_point = self._reference_abs_pos + # Find the highest abs position, as all of these are + # contained in the same parent. + self._start_point = max(self._reference_abs_pos, + key=self._reference_abs_pos.get) # If there is no set stop point, then any Reference after # the start point can potentially be a forward access. if self._stop_point is None: @@ -229,10 +254,10 @@ def find_forward_accesses(self) -> list[Node]: # called but thats hard to otherwise track. if ( isinstance(self._scope[0], Routine) - or self._scope[0] is self._reference.root + or self._scope[0] is self._references[0].root ): # Check if there is an ancestor Loop/WhileLoop. - ancestor = self._reference.ancestor((Loop, WhileLoop)) + ancestor = self._references[0].ancestor((Loop, WhileLoop)) while ancestor is not None: # Create a basic block for the ancestor Loop. body = ancestor.loop_body.children[:] @@ -240,7 +265,7 @@ def find_forward_accesses(self) -> list[Node]: # Find the stop point - this needs to be the node after # the ancestor statement. sub_stop_point = ( - self._reference.ancestor(Statement) + self._references[0].ancestor(Statement) .walk(Node)[-1] .abs_position + 1 @@ -253,7 +278,7 @@ def find_forward_accesses(self) -> list[Node]: # node to avoid handling the special cases based on # the parents of the reference. chain = DefinitionUseChain( - self._reference.copy(), + [ref.copy() for ref in self._references], body, start_point=ancestor.abs_position, stop_point=sub_stop_point, @@ -265,7 +290,7 @@ def find_forward_accesses(self) -> list[Node]: control_flow_nodes.insert(0, None) sub_stop_point = ancestor.loop_body.abs_position chain = DefinitionUseChain( - self._reference.copy(), + [ref.copy() for ref in self._references], [ancestor.condition], start_point=ancestor.abs_position, stop_point=sub_stop_point, @@ -274,10 +299,13 @@ def find_forward_accesses(self) -> list[Node]: ancestor = ancestor.ancestor((Loop, WhileLoop)) # Check if there is an ancestor Assignment. - ancestor = self._reference.ancestor(Assignment) + ancestor = self._references[0].ancestor(Assignment) if ancestor is not None: # If the reference is the lhs then we can ignore the RHS. - if ancestor.lhs is self._reference: + # This can only be the case if we only have a single + # reference input. + if (ancestor.lhs is self._references[0] + and len(self._references) == 1): # Find the last node in the assignment last_node = ancestor.walk(Node)[-1] # Modify the start_point to only include the node after @@ -287,7 +315,7 @@ def find_forward_accesses(self) -> list[Node]: # Add the lhs as a potential basic block with # different start and stop positions. chain = DefinitionUseChain( - self._reference, + [ref.copy() for ref in self._references], [ancestor.lhs], start_point=ancestor.lhs.abs_position - 1, stop_point=ancestor.lhs.abs_position + 1, @@ -306,7 +334,7 @@ def find_forward_accesses(self) -> list[Node]: if len(block) == 0: continue chain = DefinitionUseChain( - self._reference, + [ref.copy() for ref in self._references], block, start_point=self._start_point, stop_point=self._stop_point, @@ -320,12 +348,13 @@ def find_forward_accesses(self) -> list[Node]: if cfn is None: # We're outside a control flow region, updating the reaches # here is to find all the reached nodes. - for ref in chain._reaches: - # Add unique references to reaches. Since we're not - # in a control flow region, we can't have added - # these references into the reaches array yet so - # they're guaranteed to be unique. - self._reaches.append(ref) + for sig in chain._reaches: # TODO + for ref in chain._reaches[sig]: + # Add unique references to reaches. Since we're + # not in a control flow region, we can't have + # added these references into the reaches array + # yet so they're guaranteed to be unique. + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop as we # will never get past the write as its not conditional. if len(chain.defsout) > 0: diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index d4d1d6f3c8..e1952f55a9 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -1181,3 +1181,31 @@ def test_forward_accesses_nested_loop(fortran_reader): chains = DefinitionUseChain(lhs) reaches = chains.find_forward_accesses() assert len(reaches) == 1 + + + +def test_forward_accesses_multiple_elements(fortran_reader): + """Test that if we have multiple inputs we get multiple outputs as + expected.""" + code = """subroutine x + integer :: i,j,k,l,m + + i = j + k + j = l + m = k + k = i + end subroutine x""" + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Routine)[0] + assigns = routine.walk(Assignment) + rhs = assigns[0].rhs + chains = DefinitionUseChains([rhs.children[0], rhs.children[1]], start_point = ???) + reaches = chains.find_forward_accesses() + sig0, _ = rhs.children[0].get_signature_and_indices() + sig1, _ = rhs.children[1].get_signature_and_indices() + + assert len(reaches[sig0]) == 1 + assert reaches[sig0][0] is assigns[1].lhs + assert len(reaches[sig1]) == 2 + assert reaches[sig1][0] is assigns[2].rhs + assert reaches[sig1][1] is assigns[3].lhs From 23ddd688492a03fc3a571633ac06769239ed4b54 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Wed, 18 Mar 2026 14:20:24 +0000 Subject: [PATCH 02/22] [skip-ci] more duc improvements --- .../psyir/tools/definition_use_chains.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index 06f31faad6..21a0940627 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -110,7 +110,8 @@ def __init__( f"All references provided into a DefinitionUseChain " f"must have the same parent." ) - self._references = references + # Make a copy of the list so we can modify it. + self._references = [ref for ref in references] # Store the absolute positions and signatures for later. self._reference_signatures = [] self._references_abs_pos = {} @@ -348,16 +349,28 @@ def find_forward_accesses(self) -> list[Node]: if cfn is None: # We're outside a control flow region, updating the reaches # here is to find all the reached nodes. - for sig in chain._reaches: # TODO + for sig in chain._reaches: for ref in chain._reaches[sig]: - # Add unique references to reaches. Since we're - # not in a control flow region, we can't have - # added these references into the reaches array - # yet so they're guaranteed to be unique. - self._reaches[sig].append(ref) - # If we have a defsout in the chain then we can stop as we - # will never get past the write as its not conditional. - if len(chain.defsout) > 0: + # Add unique references to reaches. Since we could + # have multiple references to the same symbol in + # the input they're not unique, so we need to + # check for uniqueness. As nodes can be == but + # not the same object, this has to be done + # using a loop and `is`. + for ref2 in self._reaches[sig] + if ref2 is ref: + break + else: + self._reaches[sig].append(ref) + # If we have a defsout in the chain then we can stop for + # that reference as we will never get past the write + # as its not conditional. + for i, sig in enumerate(self._reference_signatures): + if len(chains.defsout[sig]) > 0: + self._references.pop(i) + # If we have found an end point for all references then + # we can finish. + if len(self._references) == 0: # Reset the start and stop points before returning # the result. self._start_point = save_start_position From e0655412b0052e004c3c8dde10c1a8a834a6f4ef Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Wed, 18 Mar 2026 14:21:25 +0000 Subject: [PATCH 03/22] [skip-ci] linting --- src/psyclone/psyir/tools/definition_use_chains.py | 2 +- .../tools/definition_use_chains_forward_dependence_test.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index 21a0940627..5d7eb663ef 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -357,7 +357,7 @@ def find_forward_accesses(self) -> list[Node]: # check for uniqueness. As nodes can be == but # not the same object, this has to be done # using a loop and `is`. - for ref2 in self._reaches[sig] + for ref2 in self._reaches[sig]: if ref2 is ref: break else: diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index e1952f55a9..30380e1105 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -1199,7 +1199,9 @@ def test_forward_accesses_multiple_elements(fortran_reader): routine = psyir.walk(Routine)[0] assigns = routine.walk(Assignment) rhs = assigns[0].rhs - chains = DefinitionUseChains([rhs.children[0], rhs.children[1]], start_point = ???) + chains = DefinitionUseChains([rhs.children[0], + rhs.children[1]], + start_point = 0) #FIXME reaches = chains.find_forward_accesses() sig0, _ = rhs.children[0].get_signature_and_indices() sig1, _ = rhs.children[1].get_signature_and_indices() From eddd698ffcbaed6687c3864d7eff6c2e0cbd7e63 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Tue, 7 Apr 2026 15:29:41 +0100 Subject: [PATCH 04/22] [skip-ci] More improvements to DUC to enable multi-reference usage --- .../psyir/tools/definition_use_chains.py | 110 +++++++++++------- 1 file changed, 69 insertions(+), 41 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index 5d7eb663ef..cdb4a5383d 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -161,6 +161,10 @@ def __init__( self._uses = {} self._defsout = {} self._killed = {} + for sig in self.._reference_signatures: + self._uses[sig] = [] + self._defsout[sig] = [] + self._killed[sig] = [] # The output map, mapping between nodes and the reach of that node. self._reaches = {} @@ -368,6 +372,7 @@ def find_forward_accesses(self) -> list[Node]: for i, sig in enumerate(self._reference_signatures): if len(chains.defsout[sig]) > 0: self._references.pop(i) + self._reference_signatures.pop(i) # If we have found an end point for all references then # we can finish. if len(self._references) == 0: @@ -386,28 +391,38 @@ def find_forward_accesses(self) -> list[Node]: # if the variable is the same symbol as the _reference. if isinstance(cfn, Loop): cfn_abs_pos = cfn.abs_position - if ( - cfn.variable == self._reference.symbol - and cfn_abs_pos >= self._start_point - and cfn_abs_pos < self._stop_point - ): - # The loop variable is always written to and so - # we're done if its reached. - self._reaches.append(cfn) - self._start_point = save_start_position - self._stop_point = save_stop_position - return self._reaches + for i, ref in self._references: + if ( + cfn.variable == ref.symbol + and cfn_abs_pos >= self._start_point + and cfn_abs_pos < self._stop_point + ): + # The loop variable is always written to. + sig = self._reference_signatures[i] + self._reaches[sig].append(cfn) + # This reference is killed by this access. + self._references.pop(i) + self._reference_signatures.pop(i) + # If we have found an end point for all + # references then we can finish. + if len(self._references) == 0: + self._start_point = save_start_position + self._stop_point = save_stop_position + return self._reaches - for ref in chain._reaches: - # Add node to the _reaches list if it is not already - # contained. Note that a "not in" check is not - # sufficient as it checks for equality, but this needs - # identity uniqueness. - for ref2 in self._reaches: - if ref2 is ref: - break - else: - self._reaches.append(ref) + for sig in chain._reaches: + for ref in chain._reaches[sig]: + # Add unique references to reaches. Since we could + # have multiple references to the same symbol in + # the input they're not unique, so we need to + # check for uniqueness. As nodes can be == but + # not the same object, this has to be done + # using a loop and `is`. + for ref2 in self._reaches[sig]: + if ref2 is ref: + break + else: + self._reaches[sig].append(ref) else: # Check if there is an ancestor Assignment. ancestor = self._reference.ancestor(Assignment) @@ -438,17 +453,25 @@ def find_forward_accesses(self) -> list[Node]: ) # Find any forward_accesses in the lhs. chain.find_forward_accesses() - for ref in chain._reaches: - self._reaches.append(ref) + for sig in chain._reaches: + for ref in chain._reaches[sig]: + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop as we # will never get past the write as its not conditional. - if len(chain.defsout) > 0: + for i, sig in enumerate(self._reference_signatures): + if len(chains.defsout[sig]) > 0: + self._references.pop(i) + self._reference_signatures.pop(i) + # If we have found an end point for all references then + # we can finish. + if len(self._references) == 0: # Reset the start and stop points before returning # the result. self._start_point = save_start_position self._stop_point = save_stop_position return self._reaches # We can compute the rest of the accesses + # FIXME Here onwards won't work yet. self._compute_forward_uses(self._scope) for ref in self._uses: self._reaches.append(ref) @@ -485,9 +508,10 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): :raises NotImplementedError: If a GOTO statement is found in the code region. """ - sig, _ = self._reference.get_signature_and_indices() - # For a basic block we will only ever have one defsout - defs_out = None + # For a basic block we will only ever have one defsout per reference. + defs_out = {} + for sig in self._reference_signatures: + defs_out[sig] = None for region in basic_block_list: for reference in region.walk((Reference, Call, CodeBlock, Return)): # Store the position instead of computing it twice. @@ -497,8 +521,9 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): if isinstance(reference, Return): # When we find a return statement any following statements # can be ignored so we can return. - if defs_out is not None: - self._defsout.append(defs_out) + for sig in self._reference_signatures: + if defs_out[sig] is not None: + self._defsout[sig].append(defs_out) return # If its parent is an inquiry function then its neither # a read nor write if its the first argument. @@ -520,19 +545,22 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): if isinstance( reference._fp2_nodes[0], (Exit_Stmt, Cycle_Stmt) ): - if defs_out is not None: - self._defsout.append(defs_out) + for sig in self._reference_signatures: + if defs_out[sig] is not None: + self._defsout[sig].append(defs_out) return - if ( - self._reference.symbol.name - in reference.get_symbol_names() - ): - # Assume the worst for a CodeBlock and we count them - # as killed and defsout and uses. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference - continue + for i, ref in self._references: + if ( + ref.symbol.name + in reference.get_symbol_names() + ): + # Assume the worst for a CodeBlock and we count + # them as killed and defsout and uses. + sig = self._reference_signatures[i] + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference + # FIXME from here elif isinstance(reference, Call): # If its a local variable we can ignore it as we'll catch # the Reference later if its passed into the Call. From fdcd6ad252777c9ea57c282c1d4c9c1f21406390 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Wed, 8 Apr 2026 15:37:25 +0100 Subject: [PATCH 05/22] forward accesses first draft completed --- .../psyir/tools/definition_use_chains.py | 98 ++++++++++--------- 1 file changed, 50 insertions(+), 48 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index cdb4a5383d..76dd3d1da1 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -471,20 +471,20 @@ def find_forward_accesses(self) -> list[Node]: self._stop_point = save_stop_position return self._reaches # We can compute the rest of the accesses - # FIXME Here onwards won't work yet. self._compute_forward_uses(self._scope) - for ref in self._uses: - self._reaches.append(ref) - # If this block doesn't kill any accesses, then we add - # the defsout into the reaches array. - if len(self.killed) == 0: - for ref in self._defsout: - self._reaches.append(ref) - else: - # If this block killed any accesses, then the first element - # of the killed writes is the access access that we're - # dependent with. - self._reaches.append(self.killed[0]) + for sig in self._reference_signatures: + for ref in self._uses[sig]: + self._reaches[sig].append(ref) + # If this block doesn't kill any accesses, then we add + # the defsout into the reaches array. + if len(self.killed[sig]) == 0: + for ref in self._defsout[sig]: + self._reaches[sig].append(ref) + else: + # If this block killed any accesses, then the first element + # of the killed writes is the access access that we're + # dependent with. + self._reaches[sig].append(self.killed[sig][0]) # Reset the start and stop points before returning the result. self._start_point = save_start_position @@ -560,30 +560,31 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): if defs_out[sig] is not None: self._killed[sig].append(defs_out[sig]) defs_out[sig] = reference - # FIXME from here elif isinstance(reference, Call): # If its a local variable we can ignore it as we'll catch # the Reference later if its passed into the Call. - if self._reference.symbol.is_automatic: - continue - if isinstance(reference, IntrinsicCall): - # IntrinsicCall can only do stuff to arguments, these - # will be caught by Reference walk already. - # Note that this assumes two symbols are not - # aliases of each other. - continue - if reference.is_pure: - # Pure subroutines only touch their arguments, so we'll - # catch the arguments that are passed into the call - # later as References. - continue - # For now just assume calls are bad if we have a non-local - # variable and we treat them as though they were a write. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference - continue - elif reference.get_signature_and_indices()[0] == sig: + for i, ref in self._references: + if ref.symbol.is_automatic: + continue + if isinstance(reference, IntrinsicCall): + # IntrinsicCall can only do stuff to arguments, these + # will be caught by Reference walk already. + # Note that this assumes two symbols are not + # aliases of each other. + continue + if reference.is_pure: + # Pure subroutines only touch their arguments, so we'll + # catch the arguments that are passed into the call + # later as References. + continue + # For now just assume calls are bad if we have a non-local + # variable and we treat them as though they were a write. + sig = self._reference_signatures[i] + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference + elif reference.get_signature_and_indices()[0] in self._reference_signatures: + sig = reference.get_signature_and_indices()[0] # Work out if its read only or not. assign = reference.ancestor(Assignment) if assign is not None: @@ -591,12 +592,12 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): # This is a write to the reference, so kill the # previous defs_out and set this to be the # defs_out. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference elif ( assign.lhs is defs_out - and len(self._killed) == 0 + and len(self._killed[sig]) == 0 and assign.lhs.get_signature_and_indices()[0] == sig and assign.lhs is not self._reference @@ -605,28 +606,29 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): # a = a + 1. Since the PSyIR tree walk accesses # the lhs of an assignment before the rhs of an # assignment we need to not ignore these accesses. - self._uses.append(reference) + self._uses[sig].append(reference) else: # Read only, so if we've not yet set written to # this variable this is a use. NB. We need to # check the if the write is the LHS of the parent # assignment and if so check if we killed any # previous assignments. - if defs_out is None: - self._uses.append(reference) + if defs_out[sig] is None: + self._uses[sig].append(reference) elif reference.ancestor(Call): # Otherwise we assume read/write access for now. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference else: # Reference outside an Assignment - read only # This could be References inside a While loop # condition for example. - if defs_out is None: - self._uses.append(reference) - if defs_out is not None: - self._defsout.append(defs_out) + if defs_out[sig] is None: + self._uses[sig].append(reference) + for sig in self._reference_signatures: + if defs_out[sig] is not None: + self._defsout[sig].append(defs_out[sig]) def _find_basic_blocks( self, nodelist: list[Node] From 010811a1c90dc8cb0b81c06bb1355336a8f4cd07 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Thu, 9 Apr 2026 14:36:09 +0100 Subject: [PATCH 06/22] Allow DUCs to handle multiple input References --- src/psyclone/psyir/nodes/reference.py | 10 +- .../psyir/tools/definition_use_chains.py | 386 ++++++++++-------- .../psyir/transformations/hoist_trans.py | 4 +- ...ion_use_chains_backward_dependence_test.py | 203 +++++---- ...tion_use_chains_forward_dependence_test.py | 299 ++++++++------ 5 files changed, 519 insertions(+), 383 deletions(-) diff --git a/src/psyclone/psyir/nodes/reference.py b/src/psyclone/psyir/nodes/reference.py index 6d70fdb560..006911655c 100644 --- a/src/psyclone/psyir/nodes/reference.py +++ b/src/psyclone/psyir/nodes/reference.py @@ -245,8 +245,9 @@ def previous_accesses(self): # Avoid circular import # pylint: disable=import-outside-toplevel from psyclone.psyir.tools import DefinitionUseChain - chain = DefinitionUseChain(self) - return chain.find_backward_accesses() + chain = DefinitionUseChain([self]) + sig = self.get_signature_and_indices()[0] + return chain.find_backward_accesses()[sig] def next_accesses(self): ''' @@ -258,8 +259,9 @@ def next_accesses(self): # Avoid circular import # pylint: disable=import-outside-toplevel from psyclone.psyir.tools import DefinitionUseChain - chain = DefinitionUseChain(self) - return chain.find_forward_accesses() + chain = DefinitionUseChain([self]) + sig = self.get_signature_and_indices()[0] + return chain.find_forward_accesses()[sig] def escapes_scope( self, scope: Node, visited_nodes: Optional[set] = None diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index 76dd3d1da1..ee56b3c95a 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -44,6 +44,7 @@ Goto_Stmt, ) +from psyclone.errors import InternalError from psyclone.psyir.nodes import ( Assignment, Call, @@ -93,7 +94,7 @@ def __init__( raise TypeError( f"The 'references' argument passed into a DefinitionUseChain " f"must be a list of References but found " - f"{type(references).__name__}'." + f"'{type(references).__name__}'." ) for ref in references: if not isinstance(ref, Reference): @@ -103,12 +104,13 @@ def __init__( f"but found '{type(ref).__name__}' in the list." ) # We need all the references to have the same parent. + # FIXME Same parent or same ancestor statement? Latter more useful. parent = references[0].parent for ref in references: if ref.parent is not parent: raise InternalError( - f"All references provided into a DefinitionUseChain " - f"must have the same parent." + "All references provided into a DefinitionUseChain " + "must have the same parent." ) # Make a copy of the list so we can modify it. self._references = [ref for ref in references] @@ -161,19 +163,21 @@ def __init__( self._uses = {} self._defsout = {} self._killed = {} - for sig in self.._reference_signatures: - self._uses[sig] = [] - self._defsout[sig] = [] - self._killed[sig] = [] # The output map, mapping between nodes and the reach of that node. self._reaches = {} + # Initialise the maps. + for sig in self._reference_signatures: + self._uses[sig] = [] + self._defsout[sig] = [] + self._killed[sig] = [] + self._reaches[sig] = [] @property def uses(self) -> dict[list[Node]]: """ - :returns: the lists of nodes using the value that the referenced symbols - has before it is reassigned. + :returns: the lists of nodes using the value that the referenced + symbols has before it is reassigned. """ return self._uses @@ -227,7 +231,7 @@ def find_forward_accesses(self) -> list[Node]: # Compute the abs position caches as we'll use these a lot. # The compute_cached_abs_position will only do this if needed # so we don't need to check here. - self._reference.compute_cached_abs_positions() + self._references[0].compute_cached_abs_positions() # Setup the start and stop positions save_start_position = self._start_point @@ -237,8 +241,7 @@ def find_forward_accesses(self) -> list[Node]: if self._start_point is None: # Find the highest abs position, as all of these are # contained in the same parent. - self._start_point = max(self._reference_abs_pos, - key=self._reference_abs_pos.get) + self._start_point = max(list(self._references_abs_pos.values())) # If there is no set stop point, then any Reference after # the start point can potentially be a forward access. if self._stop_point is None: @@ -361,16 +364,16 @@ def find_forward_accesses(self) -> list[Node]: # check for uniqueness. As nodes can be == but # not the same object, this has to be done # using a loop and `is`. - for ref2 in self._reaches[sig]: - if ref2 is ref: - break - else: - self._reaches[sig].append(ref) + for ref2 in self._reaches[sig]: + if ref2 is ref: + break + else: + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop for # that reference as we will never get past the write # as its not conditional. for i, sig in enumerate(self._reference_signatures): - if len(chains.defsout[sig]) > 0: + if len(chain.defsout[sig]) > 0: self._references.pop(i) self._reference_signatures.pop(i) # If we have found an end point for all references then @@ -391,7 +394,7 @@ def find_forward_accesses(self) -> list[Node]: # if the variable is the same symbol as the _reference. if isinstance(cfn, Loop): cfn_abs_pos = cfn.abs_position - for i, ref in self._references: + for i, ref in enumerate(self._references[:]): if ( cfn.variable == ref.symbol and cfn_abs_pos >= self._start_point @@ -418,21 +421,21 @@ def find_forward_accesses(self) -> list[Node]: # check for uniqueness. As nodes can be == but # not the same object, this has to be done # using a loop and `is`. - for ref2 in self._reaches[sig]: - if ref2 is ref: - break - else: - self._reaches[sig].append(ref) + for ref2 in self._reaches[sig]: + if ref2 is ref: + break + else: + self._reaches[sig].append(ref) else: # Check if there is an ancestor Assignment. - ancestor = self._reference.ancestor(Assignment) + ancestor = self._references[0].ancestor(Assignment) if ancestor is not None: - # If we get here to check the start part of a loop we need - # to handle this differently. - if self._start_point != self._reference_abs_pos: - pass - # If the reference is the lhs then we can ignore the RHS. - if ancestor.lhs is self._reference: + # If any of the references is the lhs then we can ignore the + # RHS. + # FIXME original if statement is + # "ancestor.lhs is self._reference" + # FIXME Is that true? + if any([ancestor.lhs is ref for ref in self._references]): # Find the last node in the assignment last_node = ancestor.walk(Node)[-1] # Modify the start_point to only include the node after @@ -446,7 +449,7 @@ def find_forward_accesses(self) -> list[Node]: # Add the lhs as a potential basic block with different # start and stop positions. chain = DefinitionUseChain( - self._reference, + [ref for ref in self._references], [ancestor.lhs], start_point=ancestor.lhs.abs_position - 1, stop_point=ancestor.lhs.abs_position + 1, @@ -459,7 +462,7 @@ def find_forward_accesses(self) -> list[Node]: # If we have a defsout in the chain then we can stop as we # will never get past the write as its not conditional. for i, sig in enumerate(self._reference_signatures): - if len(chains.defsout[sig]) > 0: + if len(chain.defsout[sig]) > 0: self._references.pop(i) self._reference_signatures.pop(i) # If we have found an end point for all references then @@ -521,9 +524,9 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): if isinstance(reference, Return): # When we find a return statement any following statements # can be ignored so we can return. - for sig in self._reference_signatures: + for sig in self._reference_signatures: if defs_out[sig] is not None: - self._defsout[sig].append(defs_out) + self._defsout[sig].append(defs_out[sig]) return # If its parent is an inquiry function then its neither # a read nor write if its the first argument. @@ -545,11 +548,11 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): if isinstance( reference._fp2_nodes[0], (Exit_Stmt, Cycle_Stmt) ): - for sig in self._reference_signatures: + for sig in self._reference_signatures: if defs_out[sig] is not None: - self._defsout[sig].append(defs_out) + self._defsout[sig].append(defs_out[sig]) return - for i, ref in self._references: + for i, ref in enumerate(self._references[:]): if ( ref.symbol.name in reference.get_symbol_names() @@ -563,27 +566,29 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): elif isinstance(reference, Call): # If its a local variable we can ignore it as we'll catch # the Reference later if its passed into the Call. - for i, ref in self._references: + for i, ref in enumerate(self._references[:]): if ref.symbol.is_automatic: continue if isinstance(reference, IntrinsicCall): - # IntrinsicCall can only do stuff to arguments, these - # will be caught by Reference walk already. + # IntrinsicCall can only do stuff to arguments, + # these will be caught by Reference walk already. # Note that this assumes two symbols are not # aliases of each other. continue if reference.is_pure: - # Pure subroutines only touch their arguments, so we'll - # catch the arguments that are passed into the call - # later as References. + # Pure subroutines only touch their arguments, so + # we'll catch the arguments that are passed into + # the call later as References. continue - # For now just assume calls are bad if we have a non-local - # variable and we treat them as though they were a write. + # For now just assume calls are bad if we have a + # non-local variable and we treat them as though they + # were a write. sig = self._reference_signatures[i] if defs_out[sig] is not None: self._killed[sig].append(defs_out[sig]) defs_out[sig] = reference - elif reference.get_signature_and_indices()[0] in self._reference_signatures: + elif (reference.get_signature_and_indices()[0] in + self._reference_signatures): sig = reference.get_signature_and_indices()[0] # Work out if its read only or not. assign = reference.ancestor(Assignment) @@ -596,11 +601,14 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): self._killed[sig].append(defs_out[sig]) defs_out[sig] = reference elif ( - assign.lhs is defs_out + assign.lhs is defs_out[sig] and len(self._killed[sig]) == 0 and assign.lhs.get_signature_and_indices()[0] == sig - and assign.lhs is not self._reference + and any( + [assign.lhs is not ref for + ref in self._references] + ) ): # reference is on the rhs of an assignment such as # a = a + 1. Since the PSyIR tree walk accesses @@ -721,9 +729,12 @@ def _find_basic_blocks( if node.else_body: refs = node.else_body.walk(Reference) for ref in refs: - if ref is self._reference: - # If its in the else_body we don't add the if_body - in_else_body = True + # If its in the else_body we don't add the if_body + for ref2 in self._references: + if ref is ref2: + in_else_body = True + break + if in_else_body: break if not in_else_body: control_flow_nodes.append(node) @@ -732,8 +743,11 @@ def _find_basic_blocks( in_if_body = False refs = node.if_body.walk(Reference) for ref in refs: - if ref is self._reference: - in_if_body = True + for ref2 in self._references: + if ref is ref2: + in_if_body = True + break + if in_if_body: break if node.else_body and not in_if_body: control_flow_nodes.append(node) @@ -787,9 +801,10 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): :raises NotImplementedError: If a GOTO statement is found in the code region. """ - sig, _ = self._reference.get_signature_and_indices() - # For a basic block we will only ever have one defsout - defs_out = None + # For a basic block we will only ever have one defsout per reference. + defs_out = {} + for sig in self._reference_signatures: + defs_out[sig] = None # Working backwards so reverse the basic_block_list basic_block_list.reverse() stop_position = self._stop_point @@ -836,43 +851,48 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): "DefinitionUseChains can't handle code containing" " GOTO statements." ) - if ( - self._reference.symbol.name - in reference.get_symbol_names() - ): - # Assume the worst for a CodeBlock and we count them - # as killed and defsout and uses. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference - continue + for i, ref in enumerate(self._references[:]): + if ( + ref.symbol.name + in reference.get_symbol_names() + ): + # Assume the worst for a CodeBlock and we count + # them as killed and defsout and uses. + sig = self._reference_signatures[i] + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference elif isinstance(reference, Call): # If its a local variable we can ignore it as we'll catch # the Reference later if its passed into the Call. - if self._reference.symbol.is_automatic: - continue - # If the call is an ancestor of the Reference then - # we skip it for backwards accesses. - if self._reference.is_descendant_of(reference): - continue - if isinstance(reference, IntrinsicCall): - # IntrinsicCall can only do stuff to arguments, these - # will be caught by Reference walk already. - # Note that this assumes two symbols are not - # aliases of each other. - continue - if reference.is_pure: - # Pure subroutines only touch their arguments, so we'll - # catch the arguments that are passed into the call - # later as References. - continue - # For now just assume calls are bad if we have a non-local - # variable and we treat them as though they were a write. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference - continue - elif reference.get_signature_and_indices()[0] == sig: + for i, ref in enumerate(self._references): + if ref.symbol.is_automatic: + continue + # If the call is an ancestor of the Reference then + # we skip it for backwards accesses. + if ref.is_descendant_of(reference): + continue + if isinstance(reference, IntrinsicCall): + # IntrinsicCall can only do stuff to arguments, + # these will be caught by Reference walk already. + # Note that this assumes two symbols are not + # aliases of each other. + continue + if reference.is_pure: + # Pure subroutines only touch their arguments, so + # we'll catch the arguments that are passed into + # the call later as References. + continue + # For now just assume calls are bad if we have a + # non-local variable and we treat them as though + # they were a write. + sig = self._reference_signatures[i] + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference + elif (reference.get_signature_and_indices()[0] in + self._reference_signatures): + sig = reference.get_signature_and_indices()[0] # Work out if its read only or not. assign = reference.ancestor(Assignment) # RHS reads occur "before" LHS writes, so if we @@ -880,13 +900,15 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): # a dependency to the value used from the LHS. if assign is not None: if assign.lhs is reference: - # Check if the RHS contains the self._reference. + # Check if the RHS contains the self._references. # Can't use in since equality is not what we want # here. found = False for ref in assign.rhs.walk(Reference): if ( - ref is self._reference + any([ref is ref2 for + ref2 in self._references]) + # FIXME What does this and check? and self._stop_point == ref.abs_position ): found = True @@ -897,12 +919,15 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): # This is a write to the reference, so kill the # previous defs_out and set this to be the # defs_out. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference elif ( assign.lhs.get_signature_and_indices()[0] == sig - and assign.lhs is not self._reference + and any( + [assign.lhs is not ref for + ref in self._references] + ) ): # Reference is on the rhs of an assignment such as # a = a + 1. Since we're looping through the tree @@ -918,21 +943,22 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): # check the if the write is the LHS of the parent # assignment and if so check if we killed any # previous assignments. - if defs_out is None: - self._uses.append(reference) + if defs_out[sig] is None: + self._uses[sig].append(reference) elif reference.ancestor(Call): # Otherwise we assume read/write access for now. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference else: # Reference outside an Assignment - read only # This could be References inside a While loop # condition for example. - if defs_out is None: - self._uses.append(reference) - if defs_out is not None: - self._defsout.append(defs_out) + if defs_out[sig] is None: + self._uses[sig].append(reference) + for sig in self._reference_signatures: + if defs_out[sig] is not None: + self._defsout[sig].append(defs_out[sig]) def find_backward_accesses(self) -> list[Node]: """ @@ -951,7 +977,7 @@ def find_backward_accesses(self) -> list[Node]: # Compute the abs position caches as we'll use these a lot. # The compute_cached_abs_position will only do this if needed # so we don't need to check here. - self._reference.compute_cached_abs_positions() + self._references[0].compute_cached_abs_positions() # Setup the start and stop positions save_start_position = self._start_point @@ -959,7 +985,9 @@ def find_backward_accesses(self) -> list[Node]: # If there is no set start point, then we look for all # accesses after the Reference. if self._stop_point is None: - self._stop_point = self._reference_abs_pos + # Find the max abs position, as all of these are + # contained in the same parent. + self._stop_point = max(list(self._references_abs_pos.values())) # If there is no set stop point, then any Reference after # the start point can potentially be a forward access. if self._start_point is None: @@ -979,8 +1007,10 @@ def find_backward_accesses(self) -> list[Node]: # statement. if len(block) == 0: continue + # Create a copy of the list as it can modify elements + # in the list. chain = DefinitionUseChain( - self._reference, + [ref for ref in self._references], block, start_point=self._start_point, stop_point=self._stop_point, @@ -995,10 +1025,10 @@ def find_backward_accesses(self) -> list[Node]: # called but thats hard to otherwise track. if ( isinstance(self._scope[0], Routine) - or self._scope[0] is self._reference.root + or self._scope[0] is self._references[0].root ): # Check if there is an ancestor Loop/WhileLoop. - ancestor = self._reference.ancestor((Loop, WhileLoop)) + ancestor = self._references[0].ancestor((Loop, WhileLoop)) while ancestor is not None: # Create a basic block for the ancestor Loop. body = ancestor.loop_body.children[:] @@ -1008,18 +1038,20 @@ def find_backward_accesses(self) -> list[Node]: # We make a copy of the reference to have a detached # node to avoid handling the special cases based on # the parents of the reference. - if self._reference.ancestor(Assignment) is not None: - sub_start_point = self._reference.ancestor( + if self._references[0].ancestor(Assignment) is not None: + sub_start_point = self._references[0].ancestor( Assignment ).abs_position else: - sub_start_point = self._reference.abs_position + sub_start_point = max(list( + self._references_abs_pos.values() + )) # If we have a basic block with no children then skip it, # e.g. for an if block with no code before the else # statement. if len(body) > 0: chain = DefinitionUseChain( - self._reference.copy(), + [ref.copy() for ref in self._references], body, start_point=sub_start_point, stop_point=sub_stop_point, @@ -1032,7 +1064,7 @@ def find_backward_accesses(self) -> list[Node]: control_flow_nodes.append(None) sub_stop_point = ancestor.loop_body.abs_position chain = DefinitionUseChain( - self._reference.copy(), + [ref.copy() for ref in self._references], [ancestor.condition], start_point=ancestor.abs_position, stop_point=sub_stop_point, @@ -1041,16 +1073,16 @@ def find_backward_accesses(self) -> list[Node]: ancestor = ancestor.ancestor((Loop, WhileLoop)) # Check if there is an ancestor Assignment. - ancestor = self._reference.ancestor(Assignment) + ancestor = self._references[0].ancestor(Assignment) if ancestor is not None: # If the reference is not the lhs then we can ignore # the RHS. - if ancestor.lhs is self._reference: + if any([ancestor.lhs is ref for ref in self._references]): end = ancestor.walk(Node)[-1] # Add the rhs as a potential basic block with # different start and stop positions. chain = DefinitionUseChain( - self._reference.copy(), + [ref.copy() for ref in self._references], ancestor.rhs.children[:], start_point=ancestor.rhs.abs_position, stop_point=end.abs_position, @@ -1072,21 +1104,28 @@ def find_backward_accesses(self) -> list[Node]: if cfn is None: # We're outside a control flow region, updating the reaches # here is to find all the reached nodes. - for ref in chain._reaches: - # Add unique references to reaches. Since we're not - # in a control flow region, we can't have added - # these references into the reaches array yet so - # they're guaranteed to be unique. - found = False - for ref2 in self._reaches: - if ref is ref2: - found = True - break - if not found: - self._reaches.append(ref) + for sig in chain._reaches: + for ref in chain._reaches[sig]: + # Add unique references to reaches. Since we're not + # in a control flow region, we can't have added + # these references into the reaches array yet so + # they're guaranteed to be unique. + found = False + for ref2 in self._reaches[sig]: + if ref is ref2: + found = True + break + if not found: + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop as we # will never get past the write as its not conditional. - if len(chain.defsout) > 0: + for i, sig in enumerate(self._reference_signatures): + if len(chain.defsout[sig]) > 0: + self._references.pop(i) + self._reference_signatures.pop(i) + # If we have found an end point for all references then + # we can stop. + if len(self._references) == 0: # Reset the start and stop points before returning # the result. self._start_point = save_start_position @@ -1102,34 +1141,41 @@ def find_backward_accesses(self) -> list[Node]: # if the variable is the same symbol as the _reference. if isinstance(cfn, Loop): cfn_abs_pos = cfn.abs_position - if ( - cfn.variable == self._reference.symbol - and cfn_abs_pos >= self._start_point - and cfn_abs_pos < self._stop_point - ): - # The loop variable is always written to and so - # we're done if its reached. - self._reaches.append(cfn) - self._start_point = save_start_position - self._stop_point = save_stop_position - return self._reaches - - for ref in chain._reaches: - found = False - for ref2 in self._reaches: - if ref is ref2: - found = True - break - if not found: - self._reaches.append(ref) + for i, ref in enumerate(self._references[:]): + if ( + cfn.variable == ref.symbol + and cfn_abs_pos >= self._start_point + and cfn_abs_pos < self._stop_point + ): + # The loop variable is always written + sig = self._reference_signatures[i] + self._reaches[sig].append(cfn) + # This reference is killed by this access. + self._references.pop(i) + self._reference_signatures.pop(i) + # If we have found an end point for all + # references then we can finish. + if len(self._references) == 0: + self._start_point = save_start_position + self._stop_point = save_stop_position + return self._reaches + for sig in chain._reaches: + for ref in chain._reaches[sig]: + found = False + for ref2 in self._reaches[sig]: + if ref is ref2: + found = True + break + if not found: + self._reaches[sig].append(ref) else: # Check if there is an ancestor Assignment. - ancestor = self._reference.ancestor(Assignment) + ancestor = self._references[0].ancestor(Assignment) if ancestor is not None: # If we get here to check the start part of a loop we need # to handle this differently. # If the reference is the lhs then we can ignore the RHS. - if ancestor.lhs is not self._reference: + if all([ancestor.lhs is not ref for ref in self._references]): pass elif ancestor.rhs is self._scope[0] and len(self._scope) == 1: # If the ancestor RHS is the scope of this chain then we @@ -1139,30 +1185,32 @@ def find_backward_accesses(self) -> list[Node]: # Add the rhs as a potential basic block with different # start and stop positions. chain = DefinitionUseChain( - self._reference, + [ref for ref in self._references], [ancestor.rhs], start_point=ancestor.rhs.abs_position, stop_point=sys.maxsize, ) # Find any backward_accesses in the rhs. chain.find_backward_accesses() - for ref in chain._reaches: - self._reaches.append(ref) + for sig in chain._reaches: + for ref in chain._reaches[sig]: + self._reaches[sig].append(ref) # We can compute the rest of the accesses self._compute_backward_uses(self._scope) - for ref in self._uses: - self._reaches.append(ref) - # If this block doesn't kill any accesses, then we add - # the defsout into the reaches array. - if len(self.killed) == 0: - for ref in self._defsout: - self._reaches.append(ref) - else: - # If this block killed any accesses, then the first element - # of the killed writes is the access access that we're - # dependent with. - self._reaches.append(self.killed[0]) + for sig in self._reference_signatures: + for ref in self._uses[sig]: + self._reaches[sig].append(ref) + # If this block doesn't kill any accesses, then we add + # the defsout into the reaches array. + if len(self.killed[sig]) == 0: + for ref in self._defsout[sig]: + self._reaches[sig].append(ref) + else: + # If this block killed any accesses, then the first element + # of the killed writes is the access access that we're + # dependent with. + self._reaches[sig].append(self.killed[sig][0]) # Reset the start and stop points before returning the result. self._start_point = save_start_position diff --git a/src/psyclone/psyir/transformations/hoist_trans.py b/src/psyclone/psyir/transformations/hoist_trans.py index bbcd7eb4b8..279f877d7a 100644 --- a/src/psyclone/psyir/transformations/hoist_trans.py +++ b/src/psyclone/psyir/transformations/hoist_trans.py @@ -239,9 +239,9 @@ def _validate_dependencies(self, statement, parent_loop): written_node = accesses_in_statement[0].node accesses_in_loop = all_loop_vars[written_sig] chains = DefinitionUseChain( - written_node, parent_loop.children[:] + [written_node], parent_loop.children[:] ) - if chains.find_backward_accesses(): + if chains.find_backward_accesses()[written_sig]: code = statement.debug_string().strip() raise TransformationError(f"The statement '{code}' can't be " f"hoisted as variable " diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py index d7940c0df5..98c9ffc62b 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py @@ -65,9 +65,9 @@ def test_definition_use_chain_compute_backward_uses(fortran_reader): a_3 = psyir.walk(Reference)[3] # Check this is the lhs of the assignment assert a_3 is psyir.walk(Assignment)[1].rhs - + sig = a_3.get_signature_and_indices()[0] duc = DefinitionUseChain( - a_3, control_flow_region=[routine] + [a_3], control_flow_region=[routine] ) basic_block_list = routine.children[:] # Need to set the start point and stop points similar to what @@ -75,8 +75,9 @@ def test_definition_use_chain_compute_backward_uses(fortran_reader): duc._start_point = routine.children[0].abs_position duc._stop_point = a_3.abs_position-1 duc._compute_backward_uses(basic_block_list) - assert len(duc.defsout) == 1 - assert duc.defsout[0] is psyir.walk(Reference)[0] # The lhs of a = a + 1 + assert len(duc.defsout[sig]) == 1 + # The lhs of a = a + 1 + assert duc.defsout[sig][0] is psyir.walk(Reference)[0] # Next we test a Reference with a write then a read - we should only get # the write, which should be in uses and defsout. @@ -91,9 +92,9 @@ def test_definition_use_chain_compute_backward_uses(fortran_reader): psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] a_3 = psyir.walk(Reference)[4] - + sig = a_3.get_signature_and_indices()[0] duc = DefinitionUseChain( - a_3, control_flow_region=[routine] + [a_3], control_flow_region=[routine] ) basic_block_list = routine.children[:] # Need to set the start point and stop points similar to what @@ -101,10 +102,10 @@ def test_definition_use_chain_compute_backward_uses(fortran_reader): duc._start_point = routine.children[0].abs_position duc._stop_point = a_3.abs_position - 1 duc._compute_backward_uses(basic_block_list) - assert len(duc.uses) == 0 - assert len(duc.defsout) == 1 - assert len(duc.killed) == 0 - assert duc.defsout[0] is psyir.walk(Reference)[2] # The lhs of a = 2 + assert len(duc.uses[sig]) == 0 + assert len(duc.defsout[sig]) == 1 + assert len(duc.killed[sig]) == 0 + assert duc.defsout[sig][0] is psyir.walk(Reference)[2] # The lhs of a = 2 def test_definition_use_chain_find_backward_accesses_basic_example( @@ -140,10 +141,12 @@ def test_definition_use_chain_find_backward_accesses_basic_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Creating use chain for the a in a = 2 + ref = routine.walk(Assignment)[8].lhs + sig = ref.get_signature_and_indices()[0] chains = DefinitionUseChain( - routine.walk(Assignment)[8].lhs, [routine] + [ref], [routine] ) - reaches = chains.find_backward_accesses() + reaches = chains.find_backward_accesses()[sig] # We find 2 results # the a in e = a**3 # The call bar(c, b) as a isn't local and we can't guarantee its behaviour. @@ -152,9 +155,10 @@ def test_definition_use_chain_find_backward_accesses_basic_example( assert reaches[1] is routine.walk(Call)[1] # Create use chain for c in b = c + d - chains = DefinitionUseChain(routine.walk(Assignment)[5].rhs.children[0], - [routine]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[5].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref], [routine]) + reaches = chains.find_backward_accesses()[sig] # We should find 2 results # C = d * a # d = C + 2.0 @@ -177,8 +181,10 @@ def test_definition_use_chain_find_backward_accesses_assignment( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start chain from A = a * a - chains = DefinitionUseChain(routine.walk(Assignment)[1].lhs) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] # We should find 3 results, both 3 references in # a = A * A # and A = 1 @@ -209,8 +215,10 @@ def test_definition_use_chain_find_backward_accesses_ifelse_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from b = A + d. - chains = DefinitionUseChain(routine.walk(Assignment)[4].rhs.children[0]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[4].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] # TODO #2760 For now the if statement doesn't kill the accesses, # even though it will always be written to. assert len(reaches) == 4 @@ -222,8 +230,9 @@ def test_definition_use_chain_find_backward_accesses_ifelse_example( # Also check that a = 4 backward access is not a = 3. a_3 = routine.walk(Assignment)[2].lhs a_4 = routine.walk(Assignment)[3].lhs - chains = DefinitionUseChain(a_4) - reaches = chains.find_backward_accesses() + sig = a_4.get_signature_and_indices()[0] + chains = DefinitionUseChain([a_4]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 2 assert reaches[0] is not a_3 assert reaches[1] is not a_3 @@ -257,8 +266,10 @@ def test_definition_use_chain_find_backward_accesses_psy_data_node_example( # because it needs to deal with the current_block p_trans.apply(routine[1:]) # Start the chain from b = A + d. - chains = DefinitionUseChain(routine.walk(Assignment)[4].rhs.children[0]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[4].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] # Inside the ProfileRegion the DUC has to work as before, the 'a' has # 4 backwards accesses as shown in the previous test. assert len(reaches) == 4 @@ -284,10 +295,12 @@ def test_definition_use_chain_find_backward_accesses_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a + i. + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs + [ref] ) - reaches = chains.find_backward_accesses() + reaches = chains.find_backward_accesses()[sig] # We should have 4? reaches # First b = A + 2 # Second a = A + i @@ -320,10 +333,10 @@ def test_definition_use_chain_find_backward_accesses_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from I = 1231. - chains = DefinitionUseChain( - routine.walk(Assignment)[2].lhs - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[2].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] # We should have 1 reaches # It should be the loop assert len(reaches) == 1 @@ -348,8 +361,10 @@ def test_definition_use_chain_find_backward_accesses_while_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a + 3. - chains = DefinitionUseChain(routine.children[2].loop_body.children[0].lhs) - reaches = chains.find_backward_accesses() + ref = routine.children[2].loop_body.children[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 4 assert reaches[0] is routine.walk(WhileLoop)[0].condition.children[0] @@ -381,8 +396,10 @@ def test_definition_use_chain_backward_accesses_nested_loop_example( routine = psyir.walk(Routine)[0] # Start the chain from b = b + A. loops = routine.walk(WhileLoop) - chains = DefinitionUseChain(loops[1].walk(Assignment)[0].rhs.children[1]) - reaches = chains.find_backward_accesses() + ref = loops[1].walk(Assignment)[0].rhs.children[1] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] # TODO #2760 The backwards accesses should not continue past a = a + 3 as # to reach the b = b + a statement we must have passed through the # a = a + 3 statement. @@ -410,8 +427,10 @@ def test_definition_use_chain_find_backward_accesses_structure_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[2].lhs) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[2].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].lhs @@ -429,8 +448,10 @@ def test_definition_use_chain_find_backward_accesses_no_control_flow_example( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].rhs.children[0] @@ -449,8 +470,10 @@ def test_definition_use_chain_find_backward_accesses_codeblock( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[1].lhs) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -470,8 +493,10 @@ def test_definition_use_chain_find_backward_accesses_codeblock_and_call_nlocal( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].rhs.children[0]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 # Result is the argument of the call assert reaches[0] is routine.walk(Call)[0].children[1] @@ -495,8 +520,10 @@ def test_definition_use_chain_find_backward_accesses_codeblock_and_call_cflow( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].rhs.children[0]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 2 assert reaches[0] is routine.walk(Call)[1].children[1] assert reaches[1] is routine.walk(Call)[0] @@ -518,8 +545,10 @@ def test_definition_use_chain_find_backward_accesses_codeblock_and_call_local( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].rhs.children[0]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -539,8 +568,10 @@ def test_definition_use_chain_find_backward_accesses_call_and_codeblock_nlocal( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Call)[0] @@ -558,7 +589,7 @@ def test_definition_use_chains_goto_statement( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) + chains = DefinitionUseChain([routine.walk(Assignment)[0].lhs]) with pytest.raises(NotImplementedError) as excinfo: chains.find_backward_accesses() assert ("DefinitionUseChains can't handle code containing GOTO statements" @@ -586,10 +617,10 @@ def test_definition_use_chains_exit_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from a = A +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].rhs.children[0] - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] # We should have 2 reaches # First is A = a + i # Second is A = 1 @@ -623,10 +654,10 @@ def test_definition_use_chains_cycle_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from a = A +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].rhs.children[0] - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] # We should have 2 reaches # A = b * 4 # A = 1 @@ -657,10 +688,10 @@ def test_definition_use_chains_return_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from a = A +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].rhs.children[0] - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] # We should have 2 reaches # A = b * 4 # A = 1 @@ -689,10 +720,10 @@ def test_definition_use_chains_backward_accesses_multiple_routines( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[1] - chains = DefinitionUseChain( - routine.walk(Assignment)[0].rhs - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].rhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 0 @@ -716,10 +747,10 @@ def test_definition_use_chains_backward_accesses_nonassign_reference_in_loop( end subroutine x""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain( - routine.walk(Call)[0].children[1] - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Call)[0].children[1] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] # TODO #2760 The backwards accesses should not continue past a = a + i # when searching backwards in the loop, or to a = 1 assert len(reaches) == 3 @@ -749,10 +780,10 @@ def test_definition_use_chains_backward_accesses_empty_schedules( """ psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(Assignment)[1].rhs.children[1] assert reaches[1] is routine.walk(Assignment)[1].rhs.children[0] @@ -776,10 +807,10 @@ def test_definition_use_chains_backward_accesses_inquiry_func( """ psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 0 @@ -806,8 +837,9 @@ def test_definition_use_chain_find_backward_accesses_pure_call( # Find the a in the rhs of the second assignment assign2 = routine.walk(Assignment)[1] rhs_a = assign2.rhs.children[0] - chains = DefinitionUseChain(rhs_a) - reaches = chains.find_backward_accesses() + sig = rhs_a.get_signature_and_indices()[0] + chains = DefinitionUseChain([rhs_a]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 # Result is lhs of the first assignment lhs_assign1 = routine.walk(Assignment)[0].lhs @@ -815,8 +847,9 @@ def test_definition_use_chain_find_backward_accesses_pure_call( # Get the lhs of the b = 1 assignment lhs_assign3 = routine.walk(Assignment)[2].lhs - chains = DefinitionUseChain(lhs_assign3) - reaches = chains.find_backward_accesses() + sig = lhs_assign3.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs_assign3]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 # We should find the argument in the pure subroutine call assert reaches[0] is routine.walk(Call)[0].children[1] @@ -843,8 +876,9 @@ def test_definition_use_chain_find_backward_accesses_ancestor_call( routine = psyir.find_routine_psyir("foo") call = psyir.walk(Call)[0] arg = call.arguments[1] - chain = DefinitionUseChain(arg) - all_prev = chain.find_backward_accesses() + sig = arg.get_signature_and_indices()[0] + chain = DefinitionUseChain([arg]) + all_prev = chain.find_backward_accesses()[sig] # Check that the ancestor call of b isn't a backward access. assert not isinstance(all_prev[0], Call) # The correct previous access should be the Reference to b in @@ -869,6 +903,7 @@ def test_backward_accesses_nested_loop(fortran_reader): psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] lhs = routine.walk(Assignment)[0].lhs - chains = DefinitionUseChain(lhs) - reaches = chains.find_backward_accesses() + sig = lhs.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index 30380e1105..4c25b87589 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -74,37 +74,49 @@ def test_definition_use_chain_init_and_properties(fortran_reader): a_1 = references[0] # Check the basic initialisation with default setup. - duc = DefinitionUseChain(a_1) - assert duc._reference is a_1 + duc = DefinitionUseChain([a_1]) + assert len(duc._references) == 1 + assert duc._references[0] is a_1 + sig = a_1.get_signature_and_indices()[0] + assert duc._reference_signatures[0] == sig assert duc._start_point is None assert duc._stop_point is None - assert duc._reference_abs_pos == a_1.abs_position + assert duc._references_abs_pos[sig] == a_1.abs_position assert len(duc._scope) == 1 assert duc._scope[0] is routine # Test the control_flow_region setter - duc = DefinitionUseChain(a_1, routine.children[0:2]) + duc = DefinitionUseChain([a_1], routine.children[0:2]) assert len(duc._scope) == 2 assert duc._scope[0] is routine.children[0] assert duc._scope[1] is routine.children[1] # Test the start and stop point setters - duc = DefinitionUseChain(a_1, start_point=0, stop_point=1) + duc = DefinitionUseChain([a_1], start_point=0, stop_point=1) assert duc._start_point == 0 assert duc._stop_point == 1 - assert len(duc.uses) == 0 - assert len(duc.defsout) == 0 - assert len(duc.killed) == 0 + assert len(duc.uses) == 1 + assert len(duc.uses[sig]) == 0 + assert len(duc.defsout) == 1 + assert len(duc.defsout[sig]) == 0 + assert len(duc.killed) == 1 + assert len(duc.killed[sig]) == 0 - # Test exception when passed a non_list + # Test exceptions when passed a non_list for various inputs. with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain(a_1, control_flow_region=2) + duc = DefinitionUseChain(a_1) + assert ("The 'references' argument passed into a DefinitionUseChain " + "must be a list of References but found 'Reference'" + in str(excinfo.value)) + + with pytest.raises(TypeError) as excinfo: + duc = DefinitionUseChain([a_1], control_flow_region=2) assert ("The control_flow_region passed into a DefinitionUseChain " "must be a list but found 'int'." in str(excinfo.value)) with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain(a_1, control_flow_region=[2]) + duc = DefinitionUseChain([a_1], control_flow_region=[2]) assert ("Each element of the control_flow_region passed into a " "DefinitionUseChain must be a Node but found a non-Node " "element. Full input is " in str(excinfo.value)) @@ -114,21 +126,17 @@ def test_definition_use_chain_init_and_properties(fortran_reader): r1 = Reference(sym) r2 = Reference(sym) assign = Assignment.create(r1, r2) - duc = DefinitionUseChain(r1) + duc = DefinitionUseChain([r1]) assert duc._scope[0] is assign.lhs assert duc._scope[1] is assign.rhs # Test remaining TypeErrors with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain("123") - assert ("The 'reference' argument passed into a DefinitionUseChain must " - "be a Reference but found 'str'." in str(excinfo.value)) - with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain(r1, start_point="123") + duc = DefinitionUseChain([r1], start_point="123") assert ("The start_point passed into a DefinitionUseChain must be an " "int but found 'str'." in str(excinfo.value)) with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain(r1, stop_point="123") + duc = DefinitionUseChain([r1], stop_point="123") assert ("The stop_point passed into a DefinitionUseChain must be an " "int but found 'str'." in str(excinfo.value)) @@ -166,17 +174,17 @@ def test_definition_use_chain_is_basic_block(fortran_reader): reference = psyir.walk(Reference)[0] # The full routine is not a basic block as it contains control flow. - duc1 = DefinitionUseChain(reference, control_flow_region=block1) + duc1 = DefinitionUseChain([reference], control_flow_region=block1) assert not duc1.is_basic_block # The if_body of the if statement is a basic block as it contains no # control flow. - duc2 = DefinitionUseChain(reference, control_flow_region=block2) + duc2 = DefinitionUseChain([reference], control_flow_region=block2) assert duc2.is_basic_block # The else_body of the if statement is not a basic block as it contains # control flow. - duc3 = DefinitionUseChain(reference, control_flow_region=block3) + duc3 = DefinitionUseChain([reference], control_flow_region=block3) assert not duc3.is_basic_block # Test that regiondirectives (e.g. OMPParallelDirective) don't count. @@ -194,7 +202,7 @@ def test_definition_use_chain_is_basic_block(fortran_reader): par_trans.apply(psyir.walk(Routine)[0].children[:]) reference = psyir.walk(Reference)[0] parallel = psyir.walk(OMPParallelDirective)[0] - duc = DefinitionUseChain(reference, control_flow_region=[parallel]) + duc = DefinitionUseChain([reference], control_flow_region=[parallel]) assert not duc.is_basic_block @@ -215,16 +223,17 @@ def test_definition_use_chain_compute_forward_uses(fortran_reader): assert a_1 is psyir.walk(Assignment)[0].lhs duc = DefinitionUseChain( - a_1, control_flow_region=[routine] + [a_1], control_flow_region=[routine] ) + sig = a_1.get_signature_and_indices()[0] basic_block_list = routine.children[:] # Need to set the start point and stop points similar to what # forward_accesses would do duc._start_point = a_1.ancestor(Assignment).walk(Node)[-1].abs_position duc._stop_point = 100000000 duc._compute_forward_uses(basic_block_list) - assert len(duc.uses) == 1 - assert duc.uses[0] is psyir.walk(Reference)[3] # The rhs of b=a + assert len(duc.uses[sig]) == 1 + assert duc.uses[sig][0] is psyir.walk(Reference)[3] # The rhs of b=a # Next we test a Reference with a write then a read - we should only get # the write, which should be in uses and defsout. @@ -241,18 +250,19 @@ def test_definition_use_chain_compute_forward_uses(fortran_reader): a_1 = psyir.walk(Reference)[1] duc = DefinitionUseChain( - a_1, control_flow_region=[routine] + [a_1], control_flow_region=[routine] ) + sig = a_1.get_signature_and_indices()[0] basic_block_list = routine.children[:] # Need to set the start point and stop points similar to what # forward_accesses would do duc._start_point = a_1.ancestor(Assignment).walk(Node)[-1].abs_position duc._stop_point = 100000000 duc._compute_forward_uses(basic_block_list) - assert len(duc.uses) == 0 - assert len(duc.defsout) == 1 - assert len(duc.killed) == 0 - assert duc.defsout[0] is psyir.walk(Reference)[2] # The lhs of a = 2 + assert len(duc.uses[sig]) == 0 + assert len(duc.defsout[sig]) == 1 + assert len(duc.killed[sig]) == 0 + assert duc.defsout[sig][0] is psyir.walk(Reference)[2] # The lhs of a = 2 # Finally test a Reference with a write then another write. # The defsout should be the final write and the first write should be @@ -270,19 +280,20 @@ def test_definition_use_chain_compute_forward_uses(fortran_reader): a_1 = psyir.walk(Reference)[1] duc = DefinitionUseChain( - a_1, control_flow_region=[routine] + [a_1], control_flow_region=[routine] ) + sig = a_1.get_signature_and_indices()[0] basic_block_list = routine.children[:] # Need to set the start point and stop points similar to what # forward_accesses would do duc._start_point = a_1.ancestor(Assignment).walk(Node)[-1].abs_position duc._stop_point = 100000000 duc._compute_forward_uses(basic_block_list) - assert len(duc.uses) == 0 - assert len(duc.defsout) == 1 - assert len(duc.killed) == 1 - assert duc.defsout[0] is psyir.walk(Reference)[5] # The lhs of a = 3 - assert duc.killed[0] is psyir.walk(Reference)[2] # The lhs of a = 2 + assert len(duc.uses[sig]) == 0 + assert len(duc.defsout[sig]) == 1 + assert len(duc.killed[sig]) == 1 + assert duc.defsout[sig][0] is psyir.walk(Reference)[5] # The lhs of a = 3 + assert duc.killed[sig][0] is psyir.walk(Reference)[2] # The lhs of a = 2 def test_definition_use_chain_find_basic_blocks(fortran_reader): @@ -309,7 +320,7 @@ def test_definition_use_chain_find_basic_blocks(fortran_reader): psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] duc = DefinitionUseChain( - routine.walk(Reference)[0], control_flow_region=[routine] + [routine.walk(Reference)[0]], control_flow_region=[routine] ) # Find the basic blocks. cfn, blocks = duc._find_basic_blocks(routine.children[:]) @@ -442,7 +453,7 @@ def test_definition_use_chain_find_basic_blocks_inside_loops(fortran_reader): aref = psyir.walk(ArrayReference)[0] assert aref.symbol.name == "ztmp" duc = DefinitionUseChain( - aref, control_flow_region=[routine] + [aref], control_flow_region=[routine] ) cfn, blocks = duc._find_basic_blocks(routine.walk(Loop)[0].children[:]) # The ifblock has to be in cfn twice, once with the contents of the if @@ -486,38 +497,44 @@ def test_definition_use_chain_find_forward_accesses_basic_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Creating use chain for the a in c = a + 1.0 + ref = routine.children[0].children[1].children[0] + sig = ref.get_signature_and_indices()[0] chains = DefinitionUseChain( - routine.children[0].children[1].children[0], [routine] + [ref], [routine] ) reaches = chains.find_forward_accesses() # We find 3 results # the a in e = a**2 (assignment 2) # the a in c = d * a (assignment 4) # The call bar(c, b) as a isn't local and we can't guarantee its behaviour. - assert len(reaches) == 3 - assert reaches[0] is routine.walk(Assignment)[1].rhs.children[0] - assert reaches[1] is routine.walk(Assignment)[4].rhs.children[1] - assert reaches[2] is routine.walk(Call)[1] + assert len(reaches[sig]) == 3 + assert reaches[sig][0] is routine.walk(Assignment)[1].rhs.children[0] + assert reaches[sig][1] is routine.walk(Assignment)[4].rhs.children[1] + assert reaches[sig][2] is routine.walk(Call)[1] # Create use chain for d in d = c + 2.0 - chains = DefinitionUseChain(routine.children[3].lhs, [routine]) + ref = routine.children[3].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref], [routine]) reaches = chains.find_forward_accesses() # We should find 2 results # c = D * a (Assignment 5) # b = c + D (Assignment 6) - assert reaches[0] is routine.walk(Assignment)[4].rhs.children[0] - assert reaches[1] is routine.walk(Assignment)[5].rhs.children[1] - assert len(reaches) == 2 + assert reaches[sig][0] is routine.walk(Assignment)[4].rhs.children[0] + assert reaches[sig][1] is routine.walk(Assignment)[5].rhs.children[1] + assert len(reaches[sig]) == 2 # Create use chain for c in c = d * a (Assignment 5) - chains = DefinitionUseChain(routine.walk(Assignment)[4].lhs, [routine]) + ref = routine.walk(Assignment)[4].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref], [routine]) reaches = chains.find_forward_accesses() # 2 results: # b = C + d (Assignment 6) # call bar(c, d) (The second Call) - assert len(reaches) == 2 - assert reaches[0] is routine.walk(Assignment)[5].rhs.children[0] - assert reaches[1] is routine.walk(Call)[1].arguments[0] + assert len(reaches[sig]) == 2 + assert reaches[sig][0] is routine.walk(Assignment)[5].rhs.children[0] + assert reaches[sig][1] is routine.walk(Call)[1].arguments[0] def test_definition_use_chain_find_forward_accesses_assignment( @@ -534,15 +551,17 @@ def test_definition_use_chain_find_forward_accesses_assignment( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start chain from A = 1 - chains = DefinitionUseChain(routine.children[0].lhs) + ref = routine.children[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) reaches = chains.find_forward_accesses() # We should find 3 results, all 3 references in # A = A * A - assert len(reaches) == 3 assignment = routine.walk(Assignment)[1] - assert reaches[0] is assignment.rhs.children[0] - assert reaches[1] is assignment.rhs.children[1] - assert reaches[2] is assignment.lhs + assert len(reaches[sig]) == 3 + assert reaches[sig][0] is assignment.rhs.children[0] + assert reaches[sig][1] is assignment.rhs.children[1] + assert reaches[sig][2] is assignment.lhs def test_definition_use_chain_find_forward_accesses_ifelse_example( @@ -566,23 +585,26 @@ def test_definition_use_chain_find_forward_accesses_ifelse_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from a = 1. - chains = DefinitionUseChain(routine.children[0].lhs) + ref = routine.children[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) reaches = chains.find_forward_accesses() # TODO #2760 For now the if statement doesn't kill the accesses, # even though it will always be written to. - assert len(reaches) == 4 - assert reaches[0] is routine.walk(Assignment)[1].rhs.children[0] - assert reaches[1] is routine.walk(Assignment)[2].lhs - assert reaches[2] is routine.walk(Assignment)[3].lhs - assert reaches[3] is routine.walk(Assignment)[4].rhs.children[0] + assert len(reaches[sig]) == 4 + assert reaches[sig][0] is routine.walk(Assignment)[1].rhs.children[0] + assert reaches[sig][1] is routine.walk(Assignment)[2].lhs + assert reaches[sig][2] is routine.walk(Assignment)[3].lhs + assert reaches[sig][3] is routine.walk(Assignment)[4].rhs.children[0] # Also check that a = 3 forward access is not a = 4. a_3 = routine.children[2].if_body.children[0].lhs + sig = a_3.get_signature_and_indices()[0] a_4 = routine.children[2].else_body.children[0].rhs - chains = DefinitionUseChain(a_3) + chains = DefinitionUseChain([a_3]) reaches = chains.find_forward_accesses() - assert len(reaches) == 1 - assert reaches[0] is not a_4 + assert len(reaches[sig]) == 1 + assert reaches[sig][0] is not a_4 def test_definition_use_chain_find_forward_accesses_loop_example( @@ -605,10 +627,12 @@ def test_definition_use_chain_find_forward_accesses_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from b = A +2. + ref = routine.children[1].loop_body.children[1].rhs.children[0] + sig = ref.get_signature_and_indices()[0] chains = DefinitionUseChain( - routine.children[1].loop_body.children[1].rhs.children[0] + [ref] ) - reaches = chains.find_forward_accesses() + reaches = chains.find_forward_accesses()[sig] # We should have 3 reaches # First two are A = A + i # Second is c = a + b @@ -635,10 +659,12 @@ def test_definition_use_chain_find_forward_accesses_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from I = 1231. + ref = routine.children[0].lhs + sig = ref.get_signature_and_indices()[0] chains = DefinitionUseChain( - routine.children[0].lhs + [ref] ) - reaches = chains.find_forward_accesses() + reaches = chains.find_forward_accesses()[sig] # We should have 1 reaches # It should be the loop assert len(reaches) == 1 @@ -663,8 +689,10 @@ def test_definition_use_chain_find_forward_accesses_while_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a + 3. - chains = DefinitionUseChain(routine.walk(Assignment)[2].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[2].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(WhileLoop)[0].condition.children[0] @@ -695,8 +723,10 @@ def test_definition_use_chain_forward_accesses_nested_loop_example( routine = psyir.walk(Routine)[0] # Start the chain from b = b + A. loops = routine.walk(WhileLoop) - chains = DefinitionUseChain(loops[1].loop_body.children[0].rhs.children[1]) - reaches = chains.find_forward_accesses() + ref = loops[1].loop_body.children[0].rhs.children[1] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] # Results should be A = A + 3 and the a < i condition assert reaches[0] is loops[0].condition.children[0] assert reaches[1] is loops[0].walk(Assignment)[0].rhs.children[0] @@ -720,8 +750,10 @@ def test_definition_use_chain_find_forward_accesses_structure_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[2].lhs @@ -739,8 +771,10 @@ def test_definition_use_chain_find_forward_accesses_no_control_flow_example( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].rhs.children[0]) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].lhs @@ -760,8 +794,10 @@ def test_definition_use_chain_find_forward_accesses_no_control_flow_example2( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].rhs.children[0]) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].lhs @@ -780,8 +816,10 @@ def test_definition_use_chain_find_forward_accesses_codeblock( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -801,8 +839,10 @@ def test_definition_use_chain_find_forward_accesses_codeblock_and_call_nlocal( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -825,8 +865,10 @@ def test_definition_use_chain_find_forward_accesses_codeblock_and_call_cflow( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 2 assert reaches[0] is routine.walk(CodeBlock)[0] assert reaches[1] is routine.walk(Call)[1] @@ -848,8 +890,10 @@ def test_definition_use_chain_find_forward_accesses_codeblock_and_call_local( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -869,8 +913,10 @@ def test_definition_use_chain_find_forward_accesses_call_and_codeblock_nlocal( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Call)[0] @@ -888,7 +934,7 @@ def test_definition_use_chains_goto_statement( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) + chains = DefinitionUseChain([routine.walk(Assignment)[0].lhs]) with pytest.raises(NotImplementedError) as excinfo: chains.find_forward_accesses() assert ("DefinitionUseChains can't handle code containing GOTO statements" @@ -916,10 +962,10 @@ def test_definition_use_chains_exit_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] # We should have 3 reaches # First two are A = A + i # Second is c = a + b @@ -956,10 +1002,10 @@ def test_definition_use_chains_cycle_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] # We should have 4 reaches # First two are A = A + i # Then A = b * 4 @@ -995,10 +1041,10 @@ def test_definition_use_chains_return_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] # We should have 4 reaches # First two are A = A + i # Then A = b + 4 @@ -1032,10 +1078,10 @@ def test_definition_use_chains_forward_accesses_multiple_routines( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain( - routine.walk(Assignment)[0].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 0 @@ -1060,10 +1106,10 @@ def test_definition_use_chains_forward_accesses_empty_schedules( """ psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain( - routine.walk(Assignment)[0].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(Assignment)[1].rhs.children[0] assert reaches[1] is routine.walk(Assignment)[1].rhs.children[1] @@ -1087,8 +1133,10 @@ def test_definition_use_chains_backward_accesses_inquiry_func( """ psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 0 @@ -1113,8 +1161,10 @@ def test_definition_use_chains_multiple_ancestor_loops( end subroutine test""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[1].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain([ref]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(Assignment)[0].lhs assert reaches[1] is routine.walk(Assignment)[1].lhs @@ -1142,8 +1192,9 @@ def test_definition_use_chain_find_forward_accesses_pure_call( routine = psyir.walk(Routine)[1] # Start from the lhs of the first assignment lhs_assign1 = routine.walk(Assignment)[0].lhs - chains = DefinitionUseChain(lhs_assign1) - reaches = chains.find_forward_accesses() + sig = lhs_assign1.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs_assign1]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 2 # Should find the a in the rhs of a = a + 2 and the lhs. rhs_assign3 = routine.walk(Assignment)[2].rhs.children[0] @@ -1153,8 +1204,9 @@ def test_definition_use_chain_find_forward_accesses_pure_call( # Start from lhs of b = 1 lhs_assign2 = routine.walk(Assignment)[1].lhs - chains = DefinitionUseChain(lhs_assign2) - reaches = chains.find_forward_accesses() + sig = lhs_assign2.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs_assign2]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 # result is first argument of the pure subroutine call argument = routine.walk(Call)[0].children[1] @@ -1178,12 +1230,12 @@ def test_forward_accesses_nested_loop(fortran_reader): psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] lhs = routine.walk(Assignment)[0].lhs - chains = DefinitionUseChain(lhs) - reaches = chains.find_forward_accesses() + sig = lhs.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 - def test_forward_accesses_multiple_elements(fortran_reader): """Test that if we have multiple inputs we get multiple outputs as expected.""" @@ -1199,9 +1251,8 @@ def test_forward_accesses_multiple_elements(fortran_reader): routine = psyir.walk(Routine)[0] assigns = routine.walk(Assignment) rhs = assigns[0].rhs - chains = DefinitionUseChains([rhs.children[0], - rhs.children[1]], - start_point = 0) #FIXME + chains = DefinitionUseChain([rhs.children[0], + rhs.children[1]]) reaches = chains.find_forward_accesses() sig0, _ = rhs.children[0].get_signature_and_indices() sig1, _ = rhs.children[1].get_signature_and_indices() From d12e7db8c0ae96d464c092d52211172f60873dae Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Thu, 9 Apr 2026 15:13:28 +0100 Subject: [PATCH 07/22] fix tutorial --- .../training/transformation/3.6-sympy/solution/dataflow.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tutorial/training/transformation/3.6-sympy/solution/dataflow.py b/tutorial/training/transformation/3.6-sympy/solution/dataflow.py index 35511e50b0..4c7ea7ffd1 100755 --- a/tutorial/training/transformation/3.6-sympy/solution/dataflow.py +++ b/tutorial/training/transformation/3.6-sympy/solution/dataflow.py @@ -110,8 +110,9 @@ stop_position = node.ancestor(Statement).abs_position else: stop_position = node.abs_position - chain = DefinitionUseChain(node, stop_point=stop_position) - all_prev = chain.find_backward_accesses() + chain = DefinitionUseChain([node], stop_point=stop_position) + sig = node.get_signature_and_indices()[0] + all_prev = chain.find_backward_accesses()[sig] # Keep track if a write was found (if not, we will add the # variable as a node by itself) From fbdab621d3def72336b464ade0de41c3bd8f15c3 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Thu, 9 Apr 2026 15:42:43 +0100 Subject: [PATCH 08/22] Remove unnecessary uniqueness check --- .../psyir/tools/definition_use_chains.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index ee56b3c95a..1ec8a5eea4 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -358,17 +358,11 @@ def find_forward_accesses(self) -> list[Node]: # here is to find all the reached nodes. for sig in chain._reaches: for ref in chain._reaches[sig]: - # Add unique references to reaches. Since we could - # have multiple references to the same symbol in - # the input they're not unique, so we need to - # check for uniqueness. As nodes can be == but - # not the same object, this has to be done - # using a loop and `is`. - for ref2 in self._reaches[sig]: - if ref2 is ref: - break - else: - self._reaches[sig].append(ref) + # Add unique references to reaches. Since we're not + # in a control flow region, we can't have added + # these references into the reaches array yet so + # they're guaranteed to be unique. + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop for # that reference as we will never get past the write # as its not conditional. From 576210d501c1140e2b6f9b3e36e712ddf1c654b6 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Thu, 9 Apr 2026 15:49:47 +0100 Subject: [PATCH 09/22] Remaining coverage misses --- .../definition_use_chains_forward_dependence_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index 4c25b87589..f479c6aeb2 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -37,6 +37,7 @@ forward_accesses routine.''' import pytest +from psyclone.errors import InternalError from psyclone.psyir.nodes import ( ArrayReference, Assignment, @@ -139,6 +140,15 @@ def test_definition_use_chain_init_and_properties(fortran_reader): duc = DefinitionUseChain([r1], stop_point="123") assert ("The stop_point passed into a DefinitionUseChain must be an " "int but found 'str'." in str(excinfo.value)) + with pytest.raises(TypeError) as excinfo: + duc = DefinitionUseChain(["a", "b"]) + assert ("The 'references' argument passed into a DefinitionUseChain must " + "be a list of References but found 'str' in the list." + in str(excinfo.value)) + with pytest.raises(InternalError) as excinfo: + duc = DefinitionUseChain([r1, Reference(sym)]) + assert ("All references provided into a DefinitionUseChain " + "must have the same parent." in str(excinfo.value)) def test_definition_use_chain_is_basic_block(fortran_reader): From 3f9da44139f9eea9bd06f0ab1681a9b19cf1766d Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Thu, 9 Apr 2026 18:24:41 +0100 Subject: [PATCH 10/22] allow all nodes in one statement --- src/psyclone/psyir/tools/definition_use_chains.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index 1ec8a5eea4..8998216f76 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -103,11 +103,10 @@ def __init__( f"DefinitionUseChain must be a list of References " f"but found '{type(ref).__name__}' in the list." ) - # We need all the references to have the same parent. - # FIXME Same parent or same ancestor statement? Latter more useful. - parent = references[0].parent + # We need all the references to have the same ancestor Statement. + parent = references[0].ancestor(Statement) for ref in references: - if ref.parent is not parent: + if ref.ancestor(Statement) is not parent: raise InternalError( "All references provided into a DefinitionUseChain " "must have the same parent." From 63ef6b27302b4ef60eb756aa80a7190fdb75c579 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Fri, 10 Apr 2026 14:26:59 +0100 Subject: [PATCH 11/22] Update tests and add assignment.previous/next_accesses --- src/psyclone/psyir/nodes/assignment.py | 39 +++++- src/psyclone/psyir/nodes/reference.py | 17 ++- .../psyir/tools/definition_use_chains.py | 37 +++--- .../tests/psyir/nodes/assignment_test.py | 120 ++++++++++++++++++ .../tests/psyir/nodes/reference_test.py | 4 + ...tion_use_chains_forward_dependence_test.py | 12 +- 6 files changed, 207 insertions(+), 22 deletions(-) diff --git a/src/psyclone/psyir/nodes/assignment.py b/src/psyclone/psyir/nodes/assignment.py index a8409034d4..574756f12c 100644 --- a/src/psyclone/psyir/nodes/assignment.py +++ b/src/psyclone/psyir/nodes/assignment.py @@ -39,13 +39,14 @@ ''' This module contains the Assignment node implementation.''' -from psyclone.core import VariablesAccessMap, AccessType +from psyclone.core import VariablesAccessMap, AccessType, Signature from psyclone.errors import InternalError from psyclone.psyir.nodes.literal import Literal from psyclone.psyir.nodes.array_reference import ArrayReference from psyclone.psyir.nodes.datanode import DataNode from psyclone.psyir.nodes.intrinsic_call import ( IntrinsicCall, REDUCTION_INTRINSICS) +from psyclone.psyir.nodes.node import Node from psyclone.psyir.nodes.ranges import Range from psyclone.psyir.nodes.reference import Reference from psyclone.psyir.nodes.statement import Statement @@ -243,3 +244,39 @@ def is_literal_assignment(self): ''' return isinstance(self.rhs, Literal) + + def previous_accesses(self) -> dict[Signature, list[Node]]: + ''' + :returns: the nodes accessing the same symbols directly before this + after this. It can be multiple nodes for each symbol if + the control flow diverges and there are multiple + possible accesses. + ''' + # Find all of the read/write References in this assignment. + refs = [] + for ref in self.walk(Reference): + if ref.is_read or ref.is_write: + refs.append(ref) + # Avoid circular import + # pylint: disable=import-outside-toplevel + from psyclone.psyir.tools import DefinitionUseChain + chain = DefinitionUseChain(refs) + return chain.find_backward_accesses() + + def next_accesses(self) -> dict[Signature, list[Node]]: + ''' + :returns: the nodes accessing the same symbols directly after this + after this. It can be multiple nodes for each symbol if + the control flow diverges and there are multiple + possible accesses. + ''' + # Find all of the read/write References in this assignment. + refs = [] + for ref in self.walk(Reference): + if ref.is_read or ref.is_write: + refs.append(ref) + # Avoid circular import + # pylint: disable=import-outside-toplevel + from psyclone.psyir.tools import DefinitionUseChain + chain = DefinitionUseChain(refs) + return chain.find_forward_accesses() diff --git a/src/psyclone/psyir/nodes/reference.py b/src/psyclone/psyir/nodes/reference.py index 006911655c..793d2574e5 100644 --- a/src/psyclone/psyir/nodes/reference.py +++ b/src/psyclone/psyir/nodes/reference.py @@ -45,7 +45,12 @@ # We cannot import from 'nodes' directly due to circular import from psyclone.psyir.nodes.datanode import DataNode from psyclone.psyir.nodes.node import Node -from psyclone.psyir.symbols import Symbol, AutomaticInterface +from psyclone.psyir.symbols import ( + Symbol, + AutomaticInterface, + RoutineSymbol, + IntrinsicSymbol +) from psyclone.psyir.symbols.datatypes import UnresolvedType @@ -95,6 +100,12 @@ def is_read(self): # pylint: disable=import-outside-toplevel from psyclone.psyir.nodes.assignment import Assignment from psyclone.psyir.nodes.intrinsic_call import IntrinsicCall + + # If the symbol is a RoutineSymbol or IntrinsicSymbol we don't read + # the symbol. + if isinstance(self.symbol, (RoutineSymbol, IntrinsicSymbol)): + return False + parent = self.parent if isinstance(parent, Assignment): if parent.lhs is self: @@ -235,7 +246,7 @@ def datatype(self): return super().datatype return self.symbol.datatype - def previous_accesses(self): + def previous_accesses(self) -> list[Node]: ''' :returns: the nodes accessing the same symbol directly before this reference. It can be multiple nodes if the control flow @@ -249,7 +260,7 @@ def previous_accesses(self): sig = self.get_signature_and_indices()[0] return chain.find_backward_accesses()[sig] - def next_accesses(self): + def next_accesses(self) -> list[Node]: ''' :returns: the nodes accessing the same symbol directly after this reference. It can be multiple nodes if the control flow diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index 8998216f76..b3994fa427 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -44,6 +44,7 @@ Goto_Stmt, ) +from psyclone.core import Signature from psyclone.errors import InternalError from psyclone.psyir.nodes import ( Assignment, @@ -103,14 +104,19 @@ def __init__( f"DefinitionUseChain must be a list of References " f"but found '{type(ref).__name__}' in the list." ) - # We need all the references to have the same ancestor Statement. - parent = references[0].ancestor(Statement) - for ref in references: - if ref.ancestor(Statement) is not parent: - raise InternalError( - "All references provided into a DefinitionUseChain " - "must have the same parent." - ) + # We need all the references to have the same ancestor Schedule and + # belong to the same child of the Schedule. + # Skip this check if we only have 1 input. + if len(references) > 1: + parent = references[0].ancestor(Schedule) + parent_path = references[0].path_from(parent)[0] + for ref in references: + if (ref.ancestor(Schedule) is not parent or + ref.path_from(parent)[0] != parent_path): + raise InternalError( + "All references provided into a DefinitionUseChain " + "must have the same parent in the ancestor Schedule." + ) # Make a copy of the list so we can modify it. self._references = [ref for ref in references] # Store the absolute positions and signatures for later. @@ -213,9 +219,9 @@ def is_basic_block(self) -> bool: return False return True - def find_forward_accesses(self) -> list[Node]: + def find_forward_accesses(self) -> dict[Signature, list[Node]]: """ - Find all the forward accesses for the reference defined in this + Find all the forward accesses for the references defined in this DefinitionUseChain. Forward accesses are all of the References or Calls that read or write to the symbol of the reference up to the point that a @@ -224,7 +230,7 @@ def find_forward_accesses(self) -> list[Node]: that occur inside control flow do not end the forward access chain. - :returns: the forward accesses of the reference given to this + :returns: the forward accesses of the references given to this DefinitionUseChain """ # Compute the abs position caches as we'll use these a lot. @@ -425,9 +431,6 @@ def find_forward_accesses(self) -> list[Node]: if ancestor is not None: # If any of the references is the lhs then we can ignore the # RHS. - # FIXME original if statement is - # "ancestor.lhs is self._reference" - # FIXME Is that true? if any([ancestor.lhs is ref for ref in self._references]): # Find the last node in the assignment last_node = ancestor.walk(Node)[-1] @@ -953,7 +956,7 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): if defs_out[sig] is not None: self._defsout[sig].append(defs_out[sig]) - def find_backward_accesses(self) -> list[Node]: + def find_backward_accesses(self) -> dict[Signature, list[Node]]: """ Find all the backward accesses for the reference defined in this DefinitionUseChain. @@ -1167,8 +1170,8 @@ def find_backward_accesses(self) -> list[Node]: if ancestor is not None: # If we get here to check the start part of a loop we need # to handle this differently. - # If the reference is the lhs then we can ignore the RHS. - if all([ancestor.lhs is not ref for ref in self._references]): + # If any reference is the lhs then we can ignore the RHS. + if any([ancestor.lhs is ref for ref in self._references]): pass elif ancestor.rhs is self._scope[0] and len(self._scope) == 1: # If the ancestor RHS is the scope of this chain then we diff --git a/src/psyclone/tests/psyir/nodes/assignment_test.py b/src/psyclone/tests/psyir/nodes/assignment_test.py index e12650813f..97215aef4c 100644 --- a/src/psyclone/tests/psyir/nodes/assignment_test.py +++ b/src/psyclone/tests/psyir/nodes/assignment_test.py @@ -379,3 +379,123 @@ def test_reference_accesses(fortran_reader): assert accesses[Signature('g')][0].access_type == AccessType.READ assert accesses[Signature('g')][1].node == assigns[1].lhs assert accesses[Signature('g')][1].access_type == AccessType.WRITE + + +def test_next_accesses(fortran_reader): + ''' Test the assignment.next_accesses function.''' + # Start with a basic test where the assignment just has a single + # Reference. + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + a = 1 + a = a + 1 + end program""" + ) + assigns = psyir.walk(Assignment) + + reaches = assigns[0].next_accesses() + sig = assigns[0].lhs.get_signature_and_indices()[0] + assert len(reaches) == 1 + assert len(reaches[sig]) == 2 + assert reaches[sig][0] is assigns[1].rhs.children[0] + assert reaches[sig][1] is assigns[1].lhs + + # Next test multiple References to different symbols in the + # Assignment + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + integer :: b + a = b + a = a + 1 + b = 2 + end program""" + ) + assigns = psyir.walk(Assignment) + reaches = assigns[0].next_accesses() + a_sig = assigns[0].lhs.get_signature_and_indices()[0] + b_sig = assigns[0].rhs.get_signature_and_indices()[0] + assert len(reaches) == 2 + assert len(reaches[a_sig]) == 2 + assert reaches[a_sig][0] is assigns[1].rhs.children[0] + assert reaches[a_sig][1] is assigns[1].lhs + assert len(reaches[b_sig]) == 1 + assert reaches[b_sig][0] is assigns[2].lhs + + # Test References inside an inquiry function are ignored + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + integer :: b + integer, dimension(100) :: c + a = b + size(c) + c = 1 + end program""" + ) + assigns = psyir.walk(Assignment) + reaches = assigns[0].next_accesses() + a_sig = assigns[0].lhs.get_signature_and_indices()[0] + b_sig = assigns[0].rhs.children[0].get_signature_and_indices()[0] + assert len(reaches) == 2 + assert len(reaches[a_sig]) == 0 + assert len(reaches[b_sig]) == 0 + + +def test_previous_accesses(fortran_reader): + ''' Test the assignment.previous_accesses function.''' + # Start with a basic test where the assignment just has a single + # Reference. + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + a = a + 1 + a = 1 + end program""" + ) + assigns = psyir.walk(Assignment) + + reaches = assigns[1].previous_accesses() + sig = assigns[0].lhs.get_signature_and_indices()[0] + assert len(reaches) == 1 + assert len(reaches[sig]) == 1 + assert reaches[sig][0] is assigns[0].lhs + + # Next test multiple References to different symbols in the + # Assignment + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + integer :: b + b = 2 + a = a + 1 + a = b + end program""" + ) + assigns = psyir.walk(Assignment) + reaches = assigns[2].previous_accesses() + a_sig = assigns[2].lhs.get_signature_and_indices()[0] + b_sig = assigns[2].rhs.get_signature_and_indices()[0] + assert len(reaches) == 2 + assert len(reaches[a_sig]) == 1 + assert reaches[a_sig][0] is assigns[1].lhs + assert len(reaches[b_sig]) == 1 + assert reaches[b_sig][0] is assigns[0].lhs + + # Test References inside an inquiry function are ignored + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + integer :: b + integer, dimension(100) :: c + c = 1 + a = b + size(c) + end program""" + ) + assigns = psyir.walk(Assignment) + reaches = assigns[1].previous_accesses() + a_sig = assigns[1].lhs.get_signature_and_indices()[0] + b_sig = assigns[1].rhs.children[0].get_signature_and_indices()[0] + assert len(reaches) == 2 + assert len(reaches[a_sig]) == 0 + assert len(reaches[b_sig]) == 0 diff --git a/src/psyclone/tests/psyir/nodes/reference_test.py b/src/psyclone/tests/psyir/nodes/reference_test.py index 9658bce946..fc77e433c7 100644 --- a/src/psyclone/tests/psyir/nodes/reference_test.py +++ b/src/psyclone/tests/psyir/nodes/reference_test.py @@ -565,6 +565,10 @@ def test_reference_is_read(fortran_reader): assert references[1].is_read assert references[3].symbol.name == "c" assert references[3].is_read + + # Routine or Intrinsic Symbols are not read. + assert references[5].symbol.name == "LBOUND" + assert not references[5].is_read # For the lbound, d should be an inquiry (so not a read) but # x should be a read assert references[6].symbol.name == "d" diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index f479c6aeb2..3074cdda8a 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -145,10 +145,20 @@ def test_definition_use_chain_init_and_properties(fortran_reader): assert ("The 'references' argument passed into a DefinitionUseChain must " "be a list of References but found 'str' in the list." in str(excinfo.value)) + + # Create a containing schedule. + code = """subroutine test + integer :: r1 + r1 = 1 + end subroutine + """ + psyir = fortran_reader.psyir_from_source(code) + r1 = psyir.walk(Assignment)[0].lhs with pytest.raises(InternalError) as excinfo: duc = DefinitionUseChain([r1, Reference(sym)]) assert ("All references provided into a DefinitionUseChain " - "must have the same parent." in str(excinfo.value)) + "must have the same parent in the ancestor Schedule." + in str(excinfo.value)) def test_definition_use_chain_is_basic_block(fortran_reader): From 4e72cac64119246d935b565de0df3da51179b871 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Fri, 10 Apr 2026 15:34:06 +0100 Subject: [PATCH 12/22] Fixed bug in ancestor assignment for searching backwards --- src/psyclone/psyir/tools/definition_use_chains.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index b3994fa427..edd0ac519d 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -1170,8 +1170,9 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: if ancestor is not None: # If we get here to check the start part of a loop we need # to handle this differently. - # If any reference is the lhs then we can ignore the RHS. - if any([ancestor.lhs is ref for ref in self._references]): + # If no reference is the lhs then we can skip the rhs when + # searching backwards.. + if all([ancestor.lhs is not ref for ref in self._references]): pass elif ancestor.rhs is self._scope[0] and len(self._scope) == 1: # If the ancestor RHS is the scope of this chain then we @@ -1179,9 +1180,10 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: pass else: # Add the rhs as a potential basic block with different - # start and stop positions. + # start and stop positions. This only needs to be computed + # for the Reference on the LHS of the ancestor assignment. chain = DefinitionUseChain( - [ref for ref in self._references], + [ancestor.lhs], [ancestor.rhs], start_point=ancestor.rhs.abs_position, stop_point=sys.maxsize, From 0194db814a7e192d7a3b8c552a92d87d6c3daa02 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Fri, 10 Apr 2026 15:57:39 +0100 Subject: [PATCH 13/22] Fix remaining FIXME --- src/psyclone/psyir/tools/definition_use_chains.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index edd0ac519d..c52532cddc 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -899,12 +899,13 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): # Check if the RHS contains the self._references. # Can't use in since equality is not what we want # here. + # We also only stop if the stop_point of the chain + # is in the assignment's rhs. found = False for ref in assign.rhs.walk(Reference): if ( any([ref is ref2 for ref2 in self._references]) - # FIXME What does this and check? and self._stop_point == ref.abs_position ): found = True From 7677af01931cd872372b46940441f2eb83c15d42 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Mon, 13 Apr 2026 15:11:35 +0100 Subject: [PATCH 14/22] First identity test for DUC versions. This is still in development but shows some easy cases for forward accesses --- .../psyir/tools/definition_use_chains.py | 40 ++++--- ...nition_use_chains_multiref_forward_test.py | 102 ++++++++++++++++++ 2 files changed, 128 insertions(+), 14 deletions(-) create mode 100644 src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index c52532cddc..dc7aa6bf39 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -109,14 +109,18 @@ def __init__( # Skip this check if we only have 1 input. if len(references) > 1: parent = references[0].ancestor(Schedule) - parent_path = references[0].path_from(parent)[0] - for ref in references: - if (ref.ancestor(Schedule) is not parent or - ref.path_from(parent)[0] != parent_path): - raise InternalError( - "All references provided into a DefinitionUseChain " - "must have the same parent in the ancestor Schedule." - ) + # Skip this check for detached nodes, since we get copies + # provided to the recursive calls. + if parent: + parent_path = references[0].path_from(parent)[0] + for ref in references: + if (ref.ancestor(Schedule) is not parent or + ref.path_from(parent)[0] != parent_path): + raise InternalError( + "All references provided into a " + "DefinitionUseChain must have the same parent in " + "the ancestor Schedule." + ) # Make a copy of the list so we can modify it. self._references = [ref for ref in references] # Store the absolute positions and signatures for later. @@ -326,14 +330,17 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: self._start_point = last_node.abs_position else: # Add the lhs as a potential basic block with - # different start and stop positions. + # different start and stop positions, but don't + # include the lhs if the lhs is present. chain = DefinitionUseChain( - [ref.copy() for ref in self._references], + [ref.copy() for ref in self._references if + ref != ancestor.lhs], [ancestor.lhs], start_point=ancestor.lhs.abs_position - 1, stop_point=ancestor.lhs.abs_position + 1, ) - control_flow_nodes.append(None) + index = len(chains) + control_flow_nodes.insert(index, None) chains.append(chain) # N.B. For now this assumes that for an expression # b = a * a, that next_access to the first Reference @@ -353,6 +360,7 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: stop_point=self._stop_point, ) chains.append(chain) + for i, chain in enumerate(chains): # Compute the defsout, killed and reaches for the block. chain.find_forward_accesses() @@ -370,9 +378,12 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop for # that reference as we will never get past the write - # as its not conditional. + # as its not conditional. Since we don't always include + # the LHS of an assignment into the chain we skip the + # signature if its not present in the chain defsout dict. for i, sig in enumerate(self._reference_signatures): - if len(chain.defsout[sig]) > 0: + if (sig in chain.defsout and + len(chain.defsout[sig]) > 0): self._references.pop(i) self._reference_signatures.pop(i) # If we have found an end point for all references then @@ -1084,7 +1095,8 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: start_point=ancestor.rhs.abs_position, stop_point=end.abs_position, ) - control_flow_nodes.append(None) + index = len(chains) + control_flow_nodes.insert(index, None) chains.append(chain) # N.B. For now this assumes that for an expression # b = a * a, that next_access to the first Reference diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py new file mode 100644 index 0000000000..a6f4dd9126 --- /dev/null +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py @@ -0,0 +1,102 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2026, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Author: A. B. G. Chalk, STFC Daresbury Lab +# ----------------------------------------------------------------------------- +'''This module contains tests to check that the DefinitionUseChains return +the same result for multiple inputs vs those inputs individually.''' + +import pytest +from psyclone.psyir.nodes import ( + Assignment, + Reference +) +from psyclone.psyir.tools.definition_use_chains import DefinitionUseChain + + +@pytest.mark.parametrize("code", [ + """subroutine test + integer :: a, b + a = b + b = 2 + a + a = 2 + end subroutine test + """, + """subroutine test + integer :: a, b, i + do i = 1, 100 + a = b + b = 3 + end do + end subroutine test + """, + """subroutine test + integer :: a, b + logical :: x + a = b + if (x) then + a = b + a + else + b = a * b + end if + a = 1 + b = 1 + end subroutine test + """ + ]) +def test_duc_forward_equivalence(code, fortran_reader): + '''TODO''' + + psyir = fortran_reader.psyir_from_source(code) + + assign = psyir.walk(Assignment)[0] + + all_refs = assign.walk(Reference) + lhs_ref = assign.lhs + rhs_ref = assign.rhs + duc1 = DefinitionUseChain(all_refs) + duc2 = DefinitionUseChain([lhs_ref]) + duc3 = DefinitionUseChain([rhs_ref]) + res1 = duc1.find_forward_accesses() + res2 = duc2.find_forward_accesses() + res3 = duc3.find_forward_accesses() + + lhs_sig = lhs_ref.get_signature_and_indices()[0] + rhs_sig = rhs_ref.get_signature_and_indices()[0] + + assert len(res2[lhs_sig]) == len(res1[lhs_sig]) + for i, result in enumerate(res2[lhs_sig]): + assert result is res1[lhs_sig][i] + assert len(res3[rhs_sig]) == len(res1[rhs_sig]) + for i, result in enumerate(res3[rhs_sig]): + assert result is res1[rhs_sig][i] From 0fbe5f99a8010a1f5faf2373098388e4d49d87a6 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Mon, 13 Apr 2026 16:11:52 +0100 Subject: [PATCH 15/22] Fixed forward accesses issue. Still need to check it going backwards --- .../psyir/tools/definition_use_chains.py | 54 ++++++++++++------- ...nition_use_chains_multiref_forward_test.py | 11 +++- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index dc7aa6bf39..77f9afa153 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -369,13 +369,18 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: if cfn is None: # We're outside a control flow region, updating the reaches # here is to find all the reached nodes. + # Some signatures may already have been removed by being + # killed, so we only add those if they've not already been + # killed. for sig in chain._reaches: - for ref in chain._reaches[sig]: - # Add unique references to reaches. Since we're not - # in a control flow region, we can't have added - # these references into the reaches array yet so - # they're guaranteed to be unique. - self._reaches[sig].append(ref) + if sig in self._reference_signatures: + for ref in chain._reaches[sig]: + # Add unique references to reaches. Since we're + # not in a control flow region, we can't have + # added these references into the reaches + # array yet so they're guaranteed to be + # unique. + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop for # that reference as we will never get past the write # as its not conditional. Since we don't always include @@ -386,6 +391,10 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: len(chain.defsout[sig]) > 0): self._references.pop(i) self._reference_signatures.pop(i) + # Make sure we propagate the defsout updates + # to ensure we don't go past it for other + # accesses above. + self._defsout[sig].extend(chain.defsout[sig]) # If we have found an end point for all references then # we can finish. if len(self._references) == 0: @@ -424,18 +433,22 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: return self._reaches for sig in chain._reaches: - for ref in chain._reaches[sig]: - # Add unique references to reaches. Since we could - # have multiple references to the same symbol in - # the input they're not unique, so we need to - # check for uniqueness. As nodes can be == but - # not the same object, this has to be done - # using a loop and `is`. - for ref2 in self._reaches[sig]: - if ref2 is ref: - break - else: - self._reaches[sig].append(ref) + # Not all signatures are still being searched for, + # some have already been killed by previous accesses + # so we should skip them. + if sig in self._reference_signatures: + for ref in chain._reaches[sig]: + # Add unique references to reaches. Since we + # could have multiple references to the same + # symbol in the input they're not unique, so + # we need to check for uniqueness. As nodes + # can be == but not the same object, this + # has to be done using a loop and `is`. + for ref2 in self._reaches[sig]: + if ref2 is ref: + break + else: + self._reaches[sig].append(ref) else: # Check if there is an ancestor Assignment. ancestor = self._references[0].ancestor(Assignment) @@ -464,8 +477,9 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: # Find any forward_accesses in the lhs. chain.find_forward_accesses() for sig in chain._reaches: - for ref in chain._reaches[sig]: - self._reaches[sig].append(ref) + if sig in self._reference_signatures: + for ref in chain._reaches[sig]: + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop as we # will never get past the write as its not conditional. for i, sig in enumerate(self._reference_signatures): diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py index a6f4dd9126..2370f9cd72 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py @@ -72,7 +72,16 @@ a = 1 b = 1 end subroutine test - """ + """, + """subroutine test + integer :: a, b + a = b + b = 2 + if (b > 1) then + a = 1 + end if + a = 2 * b + end subroutine test""" ]) def test_duc_forward_equivalence(code, fortran_reader): '''TODO''' From b84bba6dbdf50794bf4db9365dd86703ec32d0a4 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Tue, 14 Apr 2026 12:09:04 +0100 Subject: [PATCH 16/22] Fixed backward accesses and tests added --- .../psyir/tools/definition_use_chains.py | 70 ++++++++------ ...=> definition_use_chains_multiref_test.py} | 96 ++++++++++++++++++- 2 files changed, 131 insertions(+), 35 deletions(-) rename src/psyclone/tests/psyir/tools/{definition_use_chains_multiref_forward_test.py => definition_use_chains_multiref_test.py} (61%) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index 77f9afa153..148aa300be 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -375,12 +375,14 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: for sig in chain._reaches: if sig in self._reference_signatures: for ref in chain._reaches[sig]: - # Add unique references to reaches. Since we're - # not in a control flow region, we can't have - # added these references into the reaches - # array yet so they're guaranteed to be - # unique. - self._reaches[sig].append(ref) + # Add unique references to reaches. We always + # need to check as the input can have multiple + # References to the same symbol. + for ref2 in self._reaches[sig]: + if ref2 is ref: + break + else: + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop for # that reference as we will never get past the write # as its not conditional. Since we don't always include @@ -1009,7 +1011,7 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: if self._stop_point is None: # Find the max abs position, as all of these are # contained in the same parent. - self._stop_point = max(list(self._references_abs_pos.values())) + self._stop_point = min(list(self._references_abs_pos.values())) # If there is no set stop point, then any Reference after # the start point can potentially be a forward access. if self._start_point is None: @@ -1104,7 +1106,7 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: # Add the rhs as a potential basic block with # different start and stop positions. chain = DefinitionUseChain( - [ref.copy() for ref in self._references], + [ancestor.lhs.copy()], ancestor.rhs.children[:], start_point=ancestor.rhs.abs_position, stop_point=end.abs_position, @@ -1128,24 +1130,29 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: # We're outside a control flow region, updating the reaches # here is to find all the reached nodes. for sig in chain._reaches: - for ref in chain._reaches[sig]: - # Add unique references to reaches. Since we're not - # in a control flow region, we can't have added - # these references into the reaches array yet so - # they're guaranteed to be unique. - found = False - for ref2 in self._reaches[sig]: - if ref is ref2: - found = True - break - if not found: - self._reaches[sig].append(ref) + if sig in self._reference_signatures: + for ref in chain._reaches[sig]: + # Add unique references to reaches. Since + # we're not in a control flow region, we + # can't have added these references into the + # reaches array yet so they're guaranteed + # to be unique. + found = False + for ref2 in self._reaches[sig]: + if ref is ref2: + found = True + break + if not found: + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop as we # will never get past the write as its not conditional. for i, sig in enumerate(self._reference_signatures): - if len(chain.defsout[sig]) > 0: - self._references.pop(i) - self._reference_signatures.pop(i) + # Not all references are passed into all subchains. + if sig in chain.defsout: + if len(chain.defsout[sig]) > 0: + self._references.pop(i) + self._reference_signatures.pop(i) + self._defsout[sig].extend(chain.defsout[sig]) # If we have found an end point for all references then # we can stop. if len(self._references) == 0: @@ -1183,14 +1190,15 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: self._stop_point = save_stop_position return self._reaches for sig in chain._reaches: - for ref in chain._reaches[sig]: - found = False - for ref2 in self._reaches[sig]: - if ref is ref2: - found = True - break - if not found: - self._reaches[sig].append(ref) + if sig in self._reference_signatures: + for ref in chain._reaches[sig]: + found = False + for ref2 in self._reaches[sig]: + if ref is ref2: + found = True + break + if not found: + self._reaches[sig].append(ref) else: # Check if there is an ancestor Assignment. ancestor = self._references[0].ancestor(Assignment) diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_test.py similarity index 61% rename from src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py rename to src/psyclone/tests/psyir/tools/definition_use_chains_multiref_test.py index 2370f9cd72..c843b9685e 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_forward_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_test.py @@ -81,18 +81,29 @@ a = 1 end if a = 2 * b - end subroutine test""" + end subroutine test""", + """subroutine test + integer :: a, b + a = b + a + b = 2 + if (b > 1) then + a = 1 + end if + a = 2 * b + end subroutine test""", ]) def test_duc_forward_equivalence(code, fortran_reader): - '''TODO''' + '''Test the DUCs give the same results for multiple inputs as each of + the inputs individually. + ''' psyir = fortran_reader.psyir_from_source(code) assign = psyir.walk(Assignment)[0] all_refs = assign.walk(Reference) - lhs_ref = assign.lhs - rhs_ref = assign.rhs + lhs_ref = all_refs[0] # First Reference is lhs a + rhs_ref = all_refs[1] # Second Reference is always rhs b duc1 = DefinitionUseChain(all_refs) duc2 = DefinitionUseChain([lhs_ref]) duc3 = DefinitionUseChain([rhs_ref]) @@ -109,3 +120,80 @@ def test_duc_forward_equivalence(code, fortran_reader): assert len(res3[rhs_sig]) == len(res1[rhs_sig]) for i, result in enumerate(res3[rhs_sig]): assert result is res1[rhs_sig][i] + + +@pytest.mark.parametrize("code", [ + """subroutine test + integer :: a, b + a = 2 + b = 2 + a + a = b + end subroutine test + """, + """subroutine test + integer :: a, b, i + do i = 1, 100 + b = 3 + a = b + end do + end subroutine test + """, + """subroutine test + integer :: a, b + logical :: x + a = 1 + b = 1 + if (x) then + a = b + a + else + b = a * b + end if + a = b + end subroutine test + """, + """subroutine test + integer :: a, b + a = 2 * b + if (b > 1) then + a = 1 + end if + b = 2 + a = b + end subroutine test""", + """subroutine test + integer :: a, b + a = 2 * b + if (b > 1) then + a = 1 + end if + b = 2 + a = b + a + end subroutine test""", + ]) +def test_duc_backward_equivalence(code, fortran_reader): + '''Test the DUCs give the same results for multiple inputs as each + of the inputs individually. + ''' + psyir = fortran_reader.psyir_from_source(code) + + assign = psyir.walk(Assignment)[-1] + + all_refs = assign.walk(Reference) + lhs_ref = all_refs[0] # First Reference is lhs a + rhs_ref = all_refs[1] # Second Reference is always rhs b + duc1 = DefinitionUseChain(all_refs) + duc2 = DefinitionUseChain([lhs_ref]) + duc3 = DefinitionUseChain([rhs_ref]) + res1 = duc1.find_backward_accesses() + res2 = duc2.find_backward_accesses() + res3 = duc3.find_backward_accesses() + + lhs_sig = lhs_ref.get_signature_and_indices()[0] + rhs_sig = rhs_ref.get_signature_and_indices()[0] + + assert len(res2[lhs_sig]) == len(res1[lhs_sig]) + for i, result in enumerate(res2[lhs_sig]): + assert result is res1[lhs_sig][i] + assert len(res3[rhs_sig]) == len(res1[rhs_sig]) + for i, result in enumerate(res3[rhs_sig]): + assert result is res1[rhs_sig][i] From b566193cdbc1eb399190b952526ccbec00087d15 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Tue, 14 Apr 2026 13:17:02 +0100 Subject: [PATCH 17/22] fixed issue with ITs --- src/psyclone/psyir/tools/definition_use_chains.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index 148aa300be..f0e8d04726 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -328,13 +328,14 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: # Modify the start_point to only include the node after # this assignment. self._start_point = last_node.abs_position - else: + elif (all([ref is not ancestor.lhs for + ref in self._references])): # Add the lhs as a potential basic block with # different start and stop positions, but don't # include the lhs if the lhs is present. chain = DefinitionUseChain( [ref.copy() for ref in self._references if - ref != ancestor.lhs], + ref is not ancestor.lhs], [ancestor.lhs], start_point=ancestor.lhs.abs_position - 1, stop_point=ancestor.lhs.abs_position + 1, From d9c5cb54f2e94e77af280057a4b465f1d5396193 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Tue, 14 Apr 2026 18:07:33 +0100 Subject: [PATCH 18/22] Temporary addition --- ...tion_use_chains_forward_dependence_test.py | 368 ++++++++++++++++++ 1 file changed, 368 insertions(+) diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index 3074cdda8a..ac30499465 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -1282,3 +1282,371 @@ def test_forward_accesses_multiple_elements(fortran_reader): assert len(reaches[sig1]) == 2 assert reaches[sig1][0] is assigns[2].rhs assert reaches[sig1][1] is assigns[3].lhs + + +def test_forward_accesses_if_else(fortran_reader): + """Test that for a reference in an if_body of an IfBlock we don't find + dependencies from the else_body""" + code = """subroutine x + integer :: a + logical :: test + + if(test) then + a = 2 + else + a = 4 + end if + end subroutine x""" + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Routine)[0] + assigns = routine.walk(Assignment) + lhs = assigns[0].lhs + sig = lhs.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs]) + reaches = chains.find_forward_accesses()[sig] + + assert len(reaches) == 0 + + + + from psyclone.psyir.nodes import ( + Loop, Directive, Node, Reference, CodeBlock, Call, + Schedule, IntrinsicCall, StructureReference, IfBlock) + from psyclone.psyir.symbols import DataSymbol + from psyclone.psyir.transformations import ( + ArrayAssignment2LoopsTrans, HoistLoopBoundExprTrans, HoistLocalArraysTrans, + HoistTrans, InlineTrans, Maxval2LoopTrans, ProfileTrans, + OMPMinimiseSyncTrans, Reference2ArrayRangeTrans, + ScalarisationTrans, IncreaseRankLoopArraysTrans, MaximalRegionTrans, + TransformationError) + + def increase_rank_and_reorder_nemov5_loops(routine: Routine): + ''' This method increases the rank of temporary arrays used inside selected + loops (in order to parallelise the outer loop without overlapping them) + and then rearranges the outer loop next to the inner ones (in order to + collapse them), so that more parallelism can be leverage. This is useful + in GPU contexts, but it increases the memory footprint and may not be + beneficial for caching-architectures. + + :param routine: the target routine. + + ''' + irlatrans = IncreaseRankLoopArraysTrans() + + # Map of routines and arrays + selection = { + "dyn_zdf": ['zwd', 'zwi', 'zws'], + "tra_zdf_imp": ['zwd', 'zwi', 'zws', 'zwt'], + "tke_tke": ['zice_fra', 'zd_lw', 'zd_up', 'zdiag', 'zwlc2', 'zpelc', + 'imlc', 'zhlc', 'zus3'], + "tke_avn": ['zmxlm', 'zmxld'] + } + + if routine.name not in selection: + return + + for outer_loop in routine.walk(Loop, stop_type=Loop): + if outer_loop.variable.name == "jj": + # Increase the rank of the temporary arrays in this loop + irlatrans.apply(outer_loop, arrays=selection[routine.name]) + # Now reorder the code + for child in outer_loop.loop_body[:]: + # Move the contents of the jj loop outside it + outer_loop.parent.addchild(child.detach(), + index=outer_loop.position) + # Add a new jj loop around each inner loop that is not 'jn' + target_loop = [] + for inner_loop in child.walk(Loop, stop_type=Loop): + if inner_loop.variable.name != "jn": + target_loop.append(inner_loop) + else: + for next_loop in inner_loop.loop_body.walk( + Loop, stop_type=Loop): + target_loop.append(next_loop) + for inner_loop in target_loop: + if isinstance(inner_loop.loop_body[0], Loop): + inner_loop = inner_loop.loop_body[0] + inner_loop.replace_with( + Loop.create( + outer_loop.variable, + outer_loop.start_expr.copy(), + outer_loop.stop_expr.copy(), + outer_loop.step_expr.copy(), + children=[inner_loop.copy()] + ) + ) + # Remove the now empty jj loop + outer_loop.detach() + + def normalise_loops( + schedule, + hoist_local_arrays: bool = True, + convert_array_notation: bool = True, + loopify_array_intrinsics: bool = True, + convert_range_loops: bool = True, + scalarise_loops: bool = False, + increase_array_ranks: bool = False, + hoist_expressions: bool = True, + ): + ''' Normalise all loops in the given schedule so that they are in an + appropriate form for the Parallelisation transformations to analyse + them. + + :param schedule: the PSyIR Schedule to transform. + :type schedule: :py:class:`psyclone.psyir.nodes.node` + :param bool hoist_local_arrays: whether to hoist local arrays. + :param bool convert_array_notation: whether to convert array notation + to explicit loops. + :param bool loopify_array_intrinsics: whether to convert intrinsics that + operate on arrays to explicit loops (currently only maxval). + :param bool convert_range_loops: whether to convert ranges to explicit + loops. + :param scalarise_loops: whether to attempt to convert arrays to scalars + where possible, default is False. + :param increase_array_ranks: whether to increase the rank of selected + arrays. + :param hoist_expressions: whether to hoist bounds and loop invariant + statements out of the loop nest. + ''' + if hoist_local_arrays and schedule.name not in CONTAINS_STMT_FUNCTIONS: + # Apply the HoistLocalArraysTrans when possible, it cannot be applied + # to files with statement functions because it will attempt to put the + # allocate above it, which is not valid Fortran. + try: + HoistLocalArraysTrans().apply(schedule) + except TransformationError: + pass + + if convert_array_notation: + for reference in schedule.walk(Reference): + try: + Reference2ArrayRangeTrans().apply(reference) + except TransformationError: + pass + + if loopify_array_intrinsics: + for intr in schedule.walk(IntrinsicCall): + if intr.intrinsic.name == "MAXVAL": + try: + Maxval2LoopTrans().apply(intr, verbose=True) + except TransformationError as err: + print(err.value) + + if convert_range_loops: + # Convert all array implicit loops to explicit loops + explicit_loops = ArrayAssignment2LoopsTrans() + for assignment in schedule.walk(Assignment): + try: + explicit_loops.apply( + assignment, options={'verbose': True}) + except TransformationError: + pass + + if scalarise_loops: + # Apply scalarisation to every loop. Execute this in reverse order + # as sometimes we can scalarise earlier loops if following loops + # have already been scalarised. + loops = schedule.walk(Loop) + loops.reverse() + scalartrans = ScalarisationTrans() + for loop in loops: + scalartrans.apply(loop) + + if increase_array_ranks: + increase_rank_and_reorder_nemov5_loops(schedule) + + if hoist_expressions: + # First hoist all possible expressions + for loop in schedule.walk(Loop): + try: + HoistLoopBoundExprTrans().apply(loop) + except TransformationError: + pass + + # Hoist all possible assignments (in reverse order so the inner loop + # constants are hoisted all the way out if possible) + for loop in reversed(schedule.walk(Loop)): + for statement in list(loop.loop_body): + try: + HoistTrans().apply(statement) + except TransformationError: + pass + + # TODO #1928: In order to perform better on the GPU, nested loops with two + # sibling inner loops need to be fused or apply loop fission to the + # top level. This would allow the collapse clause to be applied. + + + + + + + + + + + + code = """ + SUBROUTINE tra_asm_inc( kt, Kbb, Kmm, pts, Krhs ) + !!---------------------------------------------------------------------- + !! *** ROUTINE tra_asm_inc *** + !! + !! ** Purpose : Apply the tracer (T and S) assimilation increments + !! + !! ** Method : Direct initialization or Incremental Analysis Updating + !! + !! ** Action : + !!---------------------------------------------------------------------- + INTEGER , INTENT(in ) :: kt ! Current time step + INTEGER , INTENT(in ) :: Kbb, Kmm, Krhs ! Time level indices + REAL(wp), DIMENSION(jpi,jpj,jpk,jpts,jpt), INTENT(inout) :: pts ! active tracers and RHS of tracer equation + ! + INTEGER :: ji, jj, jk + INTEGER :: it + REAL(wp) :: zincwgt ! IAU weight for current time step + REAL(wp), DIMENSION(:,:), ALLOCATABLE :: zfzptnz, zdep2d ! Freezing point values + REAL(wp), DIMENSION(jpi,jpj,jpk) :: zvalid_bv ! Mask representing Brunt-Vaisala (N2) checks used to reject T/S + ! increments + !!---------------------------------------------------------------------- + ! !-------------------------------------- + IF ( ln_asmiau ) THEN ! Incremental Analysis Updating + ! !-------------------------------------- + ! + IF ( ( kt >= nitiaustr_r ).AND.( kt <= nitiaufin_r ) ) THEN + ! + it = kt - nit000 + 1 + zincwgt = wgtiau(it) / rn_Dt ! IAU weight for the current time step + ! + IF( .NOT. l_istiled .OR. ntile == 1 ) THEN ! Do only on the first tile + IF(lwp) THEN + WRITE(numout,*) + WRITE(numout,*) 'tra_asm_inc : Tracer IAU at time step = ', kt,' with IAU weight = ', wgtiau(it) + WRITE(numout,*) '~~~~~~~~~~~~' + ENDIF + ENDIF + ! + IF( ln_temnofreeze ) ALLOCATE( zfzptnz(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0)), zdep2d(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0)) ) + ! + ! Call Brunt-Vaisala checks to reject T/S increments + zvalid_bv(:,:,:) = 1.0_wp + IF ( ln_bv_check ) CALL verify_incs_bv( wgtiau(it), Kmm, pts, zvalid_bv ) + ! + ! Update the tracer tendencies + DO jk = 1, jpkm1 + IF (ln_temnofreeze) THEN + ! Do not apply negative increments if the temperature will fall below freezing + DO jj = ntsj-( 0), ntej+( 0 ) ; DO ji = ntsi-( 0), ntei+( 0) + zdep2d(ji,jj) = ((gdept_1d(jk) ) *(1._wp+r3t(ji,jj,Kmm))) ! better solution: define an interface for eos_fzp when ((gdept_1d(jk) ) *(1._wp+r3t(ji,jj,Kmm))) is a scalar + END DO ; END DO + CALL eos_fzp( pts(:,:,jk,jp_sal,Kmm), zfzptnz(:,:), zdep2d(:,:), kbnd=0 ) + ! + WHERE(t_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) > 0.0_wp .OR. & + & pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_tem,Kmm) + pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_tem,Krhs) + t_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) * wgtiau(it) > zfzptnz(:,:) ) + pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_tem,Krhs) = pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_tem,Krhs) + t_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) * zvalid_bv(ji,jj,jk) * zincwgt + END WHERE + ELSE + DO jj = ntsj-( 0), ntej+( 0 ) ; DO ji = ntsi-( 0), ntei+( 0) + pts(ji,jj,jk,jp_tem,Krhs) = pts(ji,jj,jk,jp_tem,Krhs) + t_bkginc(ji,jj,jk) * zvalid_bv(ji,jj,jk) * zincwgt + END DO ; END DO + ENDIF + IF (ln_salfix) THEN + ! Do not apply negative increments if the salinity will fall below a specified + ! minimum value rn_salfixmin + WHERE(s_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) > 0.0_wp .OR. & + & pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_sal,Kmm) + pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_sal,Krhs) + s_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) * wgtiau(it) > rn_salfixmin ) + pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_sal,Krhs) = pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_sal,Krhs) + s_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) * zvalid_bv(ji,jj,jk) * zincwgt + END WHERE + ELSE + DO jj = ntsj-( 0), ntej+( 0 ) ; DO ji = ntsi-( 0), ntei+( 0) + pts(ji,jj,jk,jp_sal,Krhs) = pts(ji,jj,jk,jp_sal,Krhs) + s_bkginc(ji,jj,jk) * zvalid_bv(ji,jj,jk) * zincwgt + END DO ; END DO + ENDIF + END DO + ! + IF( ln_temnofreeze ) DEALLOCATE( zfzptnz, zdep2d ) + ! + ENDIF + ! + IF( .NOT. l_istiled .OR. ntile == nijtile ) THEN ! Do only on the last tile + IF ( kt == nitiaufin_r + 1 ) THEN ! For bias crcn to work + IF (ALLOCATED(t_bkginc)) DEALLOCATE( t_bkginc ) + IF (ALLOCATED(s_bkginc)) DEALLOCATE( s_bkginc ) + ENDIF + ENDIF + ! !-------------------------------------- + ELSEIF ( ln_asmdin ) THEN ! Direct Initialization + ! !-------------------------------------- + ! + IF ( kt == nitdin_r ) THEN + ! + l_1st_euler = .TRUE. ! Force Euler forward step + ! + ! Call Brunt-Vaisala checks to reject T/S increments + zvalid_bv(:,:,:) = 1.0_wp + IF ( ln_bv_check ) CALL verify_incs_bv( 1.0_wp, Kmm, pts, zvalid_bv ) + ! + ! Initialize the now fields with the background + increment + IF (ln_temnofreeze) THEN + ! Do not apply negative increments if the temperature will fall below freezing + ALLOCATE( zfzptnz(ntsi-(nn_hls):ntei+(nn_hls),ntsj-(nn_hls):ntej+(nn_hls)), zdep2d(ntsi-(nn_hls):ntei+(nn_hls),ntsj-(nn_hls):ntej+(nn_hls)) ) + ! + DO jk = 1, jpkm1 + DO jj = ntsj-( nn_hls), ntej+( nn_hls ) ; DO ji = ntsi-( nn_hls), ntei+( nn_hls) + zdep2d(ji,jj) = ((gdept_1d(jk) ) *(1._wp+r3t(ji,jj,Kmm))) ! better solution: define an interface for eos_fzp when ((gdept_1d(jk) ) *(1._wp+r3t(ji,jj,Kmm))) is a scalar + END DO ; END DO + CALL eos_fzp( pts(:,:,jk,jp_sal,Kmm), zfzptnz(:,:), zdep2d(:,:) ) + ! + WHERE( t_bkginc(:,:,jk) > 0.0_wp .OR. pts(:,:,jk,jp_tem,Kmm) + t_bkginc(:,:,jk) > zfzptnz(:,:) ) + pts(:,:,jk,jp_tem,Kmm) = t_bkg(:,:,jk) + t_bkginc(:,:,jk) * zvalid_bv(:,:,jk) + END WHERE + END DO + ! + DEALLOCATE( zfzptnz, zdep2d ) + ELSE + pts(:,:,:,jp_tem,Kmm) = t_bkg(:,:,:) + t_bkginc(:,:,:) * zvalid_bv(:,:,:) + ENDIF + IF (ln_salfix) THEN + ! Do not apply negative increments if the salinity will fall below a specified + ! minimum value rn_salfixmin + WHERE( s_bkginc(:,:,:) > 0.0_wp .OR. pts(:,:,:,jp_sal,Kmm) + s_bkginc(:,:,:) > rn_salfixmin ) + pts(:,:,:,jp_sal,Kmm) = s_bkg(:,:,:) + s_bkginc(:,:,:) * zvalid_bv(:,:,:) + END WHERE + ELSE + pts(:,:,:,jp_sal,Kmm) = s_bkg(:,:,:) + s_bkginc(:,:,:) * zvalid_bv(:,:,:) + ENDIF + + pts(:,:,:,:,Kbb) = pts(:,:,:,:,Kmm) ! Update before fields + CALL eos( pts, Kbb, rhd, rhop ) ! Before potential and in situ densities + + DEALLOCATE( t_bkginc ) + DEALLOCATE( s_bkginc ) + DEALLOCATE( t_bkg ) + DEALLOCATE( s_bkg ) + ENDIF + ! + ENDIF + ! Perhaps the following call should be in step + IF ( ln_sicinc ) CALL sic_asm_inc ( kt ) ! apply sea ice concentration increment + IF ( ln_sitinc ) CALL sit_asm_inc ( kt ) ! apply sea ice thickness increment + ! + END SUBROUTINE tra_asm_inc""" + psyir = fortran_reader.psyir_from_source(code) + normalise_loops( + psyir.walk(Routine)[0], + hoist_local_arrays=False, + convert_array_notation=True, + loopify_array_intrinsics=True, + convert_range_loops=True, + increase_array_ranks=True, + hoist_expressions=True + ) + references = psyir.walk(Reference) + res = None + for ref in references: + if "pts" in ref.parent.debug_string(): + print(ref.parent.debug_string()) + if "pts(widx1,widx2,jk,jp_tem,kmm) = t_bkg(LBOUND(t_bkg, dim=1) + widx1 - 1,LBOUND(t_bkg, dim=2) + widx2 - 1,jk) + t_bkginc(LBOUND(t_bkginc, dim=1) + widx1 - 1,LBOUND(t_bkginc, dim=2) + widx2 - 1,jk) * zvalid_bv(widx1,widx2,jk)" in ref.parent.debug_string(): + res = ref + break + print(res) + assert False From 94a45c125c1202f9869e6aba5e2010f16076b669 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Thu, 16 Apr 2026 13:56:29 +0100 Subject: [PATCH 19/22] Fix error in DUCs resulting in nested if/else failures --- .../psyir/tools/definition_use_chains.py | 2 +- ...tion_use_chains_forward_dependence_test.py | 402 ++---------------- 2 files changed, 42 insertions(+), 362 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index f0e8d04726..754af2fd87 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -355,7 +355,7 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: if len(block) == 0: continue chain = DefinitionUseChain( - [ref.copy() for ref in self._references], + [ref for ref in self._references], block, start_point=self._start_point, stop_point=self._stop_point, diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index ac30499465..79266f6334 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -1284,369 +1284,49 @@ def test_forward_accesses_multiple_elements(fortran_reader): assert reaches[sig1][1] is assigns[3].lhs -def test_forward_accesses_if_else(fortran_reader): - """Test that for a reference in an if_body of an IfBlock we don't find - dependencies from the else_body""" - code = """subroutine x - integer :: a - logical :: test - - if(test) then - a = 2 - else - a = 4 - end if - end subroutine x""" - psyir = fortran_reader.psyir_from_source(code) - routine = psyir.walk(Routine)[0] - assigns = routine.walk(Assignment) - lhs = assigns[0].lhs - sig = lhs.get_signature_and_indices()[0] - chains = DefinitionUseChain([lhs]) - reaches = chains.find_forward_accesses()[sig] - - assert len(reaches) == 0 - - - - from psyclone.psyir.nodes import ( - Loop, Directive, Node, Reference, CodeBlock, Call, - Schedule, IntrinsicCall, StructureReference, IfBlock) - from psyclone.psyir.symbols import DataSymbol - from psyclone.psyir.transformations import ( - ArrayAssignment2LoopsTrans, HoistLoopBoundExprTrans, HoistLocalArraysTrans, - HoistTrans, InlineTrans, Maxval2LoopTrans, ProfileTrans, - OMPMinimiseSyncTrans, Reference2ArrayRangeTrans, - ScalarisationTrans, IncreaseRankLoopArraysTrans, MaximalRegionTrans, - TransformationError) - - def increase_rank_and_reorder_nemov5_loops(routine: Routine): - ''' This method increases the rank of temporary arrays used inside selected - loops (in order to parallelise the outer loop without overlapping them) - and then rearranges the outer loop next to the inner ones (in order to - collapse them), so that more parallelism can be leverage. This is useful - in GPU contexts, but it increases the memory footprint and may not be - beneficial for caching-architectures. - - :param routine: the target routine. - - ''' - irlatrans = IncreaseRankLoopArraysTrans() - - # Map of routines and arrays - selection = { - "dyn_zdf": ['zwd', 'zwi', 'zws'], - "tra_zdf_imp": ['zwd', 'zwi', 'zws', 'zwt'], - "tke_tke": ['zice_fra', 'zd_lw', 'zd_up', 'zdiag', 'zwlc2', 'zpelc', - 'imlc', 'zhlc', 'zus3'], - "tke_avn": ['zmxlm', 'zmxld'] - } - - if routine.name not in selection: - return - - for outer_loop in routine.walk(Loop, stop_type=Loop): - if outer_loop.variable.name == "jj": - # Increase the rank of the temporary arrays in this loop - irlatrans.apply(outer_loop, arrays=selection[routine.name]) - # Now reorder the code - for child in outer_loop.loop_body[:]: - # Move the contents of the jj loop outside it - outer_loop.parent.addchild(child.detach(), - index=outer_loop.position) - # Add a new jj loop around each inner loop that is not 'jn' - target_loop = [] - for inner_loop in child.walk(Loop, stop_type=Loop): - if inner_loop.variable.name != "jn": - target_loop.append(inner_loop) - else: - for next_loop in inner_loop.loop_body.walk( - Loop, stop_type=Loop): - target_loop.append(next_loop) - for inner_loop in target_loop: - if isinstance(inner_loop.loop_body[0], Loop): - inner_loop = inner_loop.loop_body[0] - inner_loop.replace_with( - Loop.create( - outer_loop.variable, - outer_loop.start_expr.copy(), - outer_loop.stop_expr.copy(), - outer_loop.step_expr.copy(), - children=[inner_loop.copy()] - ) - ) - # Remove the now empty jj loop - outer_loop.detach() - - def normalise_loops( - schedule, - hoist_local_arrays: bool = True, - convert_array_notation: bool = True, - loopify_array_intrinsics: bool = True, - convert_range_loops: bool = True, - scalarise_loops: bool = False, - increase_array_ranks: bool = False, - hoist_expressions: bool = True, - ): - ''' Normalise all loops in the given schedule so that they are in an - appropriate form for the Parallelisation transformations to analyse - them. - - :param schedule: the PSyIR Schedule to transform. - :type schedule: :py:class:`psyclone.psyir.nodes.node` - :param bool hoist_local_arrays: whether to hoist local arrays. - :param bool convert_array_notation: whether to convert array notation - to explicit loops. - :param bool loopify_array_intrinsics: whether to convert intrinsics that - operate on arrays to explicit loops (currently only maxval). - :param bool convert_range_loops: whether to convert ranges to explicit - loops. - :param scalarise_loops: whether to attempt to convert arrays to scalars - where possible, default is False. - :param increase_array_ranks: whether to increase the rank of selected - arrays. - :param hoist_expressions: whether to hoist bounds and loop invariant - statements out of the loop nest. - ''' - if hoist_local_arrays and schedule.name not in CONTAINS_STMT_FUNCTIONS: - # Apply the HoistLocalArraysTrans when possible, it cannot be applied - # to files with statement functions because it will attempt to put the - # allocate above it, which is not valid Fortran. - try: - HoistLocalArraysTrans().apply(schedule) - except TransformationError: - pass - - if convert_array_notation: - for reference in schedule.walk(Reference): - try: - Reference2ArrayRangeTrans().apply(reference) - except TransformationError: - pass - - if loopify_array_intrinsics: - for intr in schedule.walk(IntrinsicCall): - if intr.intrinsic.name == "MAXVAL": - try: - Maxval2LoopTrans().apply(intr, verbose=True) - except TransformationError as err: - print(err.value) - - if convert_range_loops: - # Convert all array implicit loops to explicit loops - explicit_loops = ArrayAssignment2LoopsTrans() - for assignment in schedule.walk(Assignment): - try: - explicit_loops.apply( - assignment, options={'verbose': True}) - except TransformationError: - pass - - if scalarise_loops: - # Apply scalarisation to every loop. Execute this in reverse order - # as sometimes we can scalarise earlier loops if following loops - # have already been scalarised. - loops = schedule.walk(Loop) - loops.reverse() - scalartrans = ScalarisationTrans() - for loop in loops: - scalartrans.apply(loop) - - if increase_array_ranks: - increase_rank_and_reorder_nemov5_loops(schedule) - - if hoist_expressions: - # First hoist all possible expressions - for loop in schedule.walk(Loop): - try: - HoistLoopBoundExprTrans().apply(loop) - except TransformationError: - pass - - # Hoist all possible assignments (in reverse order so the inner loop - # constants are hoisted all the way out if possible) - for loop in reversed(schedule.walk(Loop)): - for statement in list(loop.loop_body): - try: - HoistTrans().apply(statement) - except TransformationError: - pass - - # TODO #1928: In order to perform better on the GPU, nested loops with two - # sibling inner loops need to be fused or apply loop fission to the - # top level. This would allow the collapse clause to be applied. - - - - - - - - +def test_if_else_behaviour(fortran_reader): + ''' Test to ensure that with nested if/else statements the DUC + don't find incorrect dependencies from else blocks.''' + code = """subroutine test + integer , dimension(100, 100, 100) :: some_array + logical :: tmp, tmp2, tmp3 + integer :: i, j, k + if(tmp) then + if(tmp2) then + do i = 1, 100 + do j = 1, 100 + do k = 1, 100 + if(tmp3 .or. some_array(1,4,9)) then + some_array(1,4,9) = i * k * j + end if + end do + end do + end do + else + do i = 1, 100 + do j = 1, 100 + do k = 1, 100 + some_array(i,k,j) = i + j + k + end do + end do + end do + end if + else if (.not. tmp) then + do i = 1, 100 + do j = 1, 100 + do k = 1, 100 + some_array(i,j,k) = i*k*j + end do + end do + end do + end if + end subroutine""" - code = """ - SUBROUTINE tra_asm_inc( kt, Kbb, Kmm, pts, Krhs ) - !!---------------------------------------------------------------------- - !! *** ROUTINE tra_asm_inc *** - !! - !! ** Purpose : Apply the tracer (T and S) assimilation increments - !! - !! ** Method : Direct initialization or Incremental Analysis Updating - !! - !! ** Action : - !!---------------------------------------------------------------------- - INTEGER , INTENT(in ) :: kt ! Current time step - INTEGER , INTENT(in ) :: Kbb, Kmm, Krhs ! Time level indices - REAL(wp), DIMENSION(jpi,jpj,jpk,jpts,jpt), INTENT(inout) :: pts ! active tracers and RHS of tracer equation - ! - INTEGER :: ji, jj, jk - INTEGER :: it - REAL(wp) :: zincwgt ! IAU weight for current time step - REAL(wp), DIMENSION(:,:), ALLOCATABLE :: zfzptnz, zdep2d ! Freezing point values - REAL(wp), DIMENSION(jpi,jpj,jpk) :: zvalid_bv ! Mask representing Brunt-Vaisala (N2) checks used to reject T/S - ! increments - !!---------------------------------------------------------------------- - ! !-------------------------------------- - IF ( ln_asmiau ) THEN ! Incremental Analysis Updating - ! !-------------------------------------- - ! - IF ( ( kt >= nitiaustr_r ).AND.( kt <= nitiaufin_r ) ) THEN - ! - it = kt - nit000 + 1 - zincwgt = wgtiau(it) / rn_Dt ! IAU weight for the current time step - ! - IF( .NOT. l_istiled .OR. ntile == 1 ) THEN ! Do only on the first tile - IF(lwp) THEN - WRITE(numout,*) - WRITE(numout,*) 'tra_asm_inc : Tracer IAU at time step = ', kt,' with IAU weight = ', wgtiau(it) - WRITE(numout,*) '~~~~~~~~~~~~' - ENDIF - ENDIF - ! - IF( ln_temnofreeze ) ALLOCATE( zfzptnz(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0)), zdep2d(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0)) ) - ! - ! Call Brunt-Vaisala checks to reject T/S increments - zvalid_bv(:,:,:) = 1.0_wp - IF ( ln_bv_check ) CALL verify_incs_bv( wgtiau(it), Kmm, pts, zvalid_bv ) - ! - ! Update the tracer tendencies - DO jk = 1, jpkm1 - IF (ln_temnofreeze) THEN - ! Do not apply negative increments if the temperature will fall below freezing - DO jj = ntsj-( 0), ntej+( 0 ) ; DO ji = ntsi-( 0), ntei+( 0) - zdep2d(ji,jj) = ((gdept_1d(jk) ) *(1._wp+r3t(ji,jj,Kmm))) ! better solution: define an interface for eos_fzp when ((gdept_1d(jk) ) *(1._wp+r3t(ji,jj,Kmm))) is a scalar - END DO ; END DO - CALL eos_fzp( pts(:,:,jk,jp_sal,Kmm), zfzptnz(:,:), zdep2d(:,:), kbnd=0 ) - ! - WHERE(t_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) > 0.0_wp .OR. & - & pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_tem,Kmm) + pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_tem,Krhs) + t_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) * wgtiau(it) > zfzptnz(:,:) ) - pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_tem,Krhs) = pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_tem,Krhs) + t_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) * zvalid_bv(ji,jj,jk) * zincwgt - END WHERE - ELSE - DO jj = ntsj-( 0), ntej+( 0 ) ; DO ji = ntsi-( 0), ntei+( 0) - pts(ji,jj,jk,jp_tem,Krhs) = pts(ji,jj,jk,jp_tem,Krhs) + t_bkginc(ji,jj,jk) * zvalid_bv(ji,jj,jk) * zincwgt - END DO ; END DO - ENDIF - IF (ln_salfix) THEN - ! Do not apply negative increments if the salinity will fall below a specified - ! minimum value rn_salfixmin - WHERE(s_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) > 0.0_wp .OR. & - & pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_sal,Kmm) + pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_sal,Krhs) + s_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) * wgtiau(it) > rn_salfixmin ) - pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_sal,Krhs) = pts(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk,jp_sal,Krhs) + s_bkginc(ntsi-(0):ntei+(0),ntsj-(0):ntej+(0),jk) * zvalid_bv(ji,jj,jk) * zincwgt - END WHERE - ELSE - DO jj = ntsj-( 0), ntej+( 0 ) ; DO ji = ntsi-( 0), ntei+( 0) - pts(ji,jj,jk,jp_sal,Krhs) = pts(ji,jj,jk,jp_sal,Krhs) + s_bkginc(ji,jj,jk) * zvalid_bv(ji,jj,jk) * zincwgt - END DO ; END DO - ENDIF - END DO - ! - IF( ln_temnofreeze ) DEALLOCATE( zfzptnz, zdep2d ) - ! - ENDIF - ! - IF( .NOT. l_istiled .OR. ntile == nijtile ) THEN ! Do only on the last tile - IF ( kt == nitiaufin_r + 1 ) THEN ! For bias crcn to work - IF (ALLOCATED(t_bkginc)) DEALLOCATE( t_bkginc ) - IF (ALLOCATED(s_bkginc)) DEALLOCATE( s_bkginc ) - ENDIF - ENDIF - ! !-------------------------------------- - ELSEIF ( ln_asmdin ) THEN ! Direct Initialization - ! !-------------------------------------- - ! - IF ( kt == nitdin_r ) THEN - ! - l_1st_euler = .TRUE. ! Force Euler forward step - ! - ! Call Brunt-Vaisala checks to reject T/S increments - zvalid_bv(:,:,:) = 1.0_wp - IF ( ln_bv_check ) CALL verify_incs_bv( 1.0_wp, Kmm, pts, zvalid_bv ) - ! - ! Initialize the now fields with the background + increment - IF (ln_temnofreeze) THEN - ! Do not apply negative increments if the temperature will fall below freezing - ALLOCATE( zfzptnz(ntsi-(nn_hls):ntei+(nn_hls),ntsj-(nn_hls):ntej+(nn_hls)), zdep2d(ntsi-(nn_hls):ntei+(nn_hls),ntsj-(nn_hls):ntej+(nn_hls)) ) - ! - DO jk = 1, jpkm1 - DO jj = ntsj-( nn_hls), ntej+( nn_hls ) ; DO ji = ntsi-( nn_hls), ntei+( nn_hls) - zdep2d(ji,jj) = ((gdept_1d(jk) ) *(1._wp+r3t(ji,jj,Kmm))) ! better solution: define an interface for eos_fzp when ((gdept_1d(jk) ) *(1._wp+r3t(ji,jj,Kmm))) is a scalar - END DO ; END DO - CALL eos_fzp( pts(:,:,jk,jp_sal,Kmm), zfzptnz(:,:), zdep2d(:,:) ) - ! - WHERE( t_bkginc(:,:,jk) > 0.0_wp .OR. pts(:,:,jk,jp_tem,Kmm) + t_bkginc(:,:,jk) > zfzptnz(:,:) ) - pts(:,:,jk,jp_tem,Kmm) = t_bkg(:,:,jk) + t_bkginc(:,:,jk) * zvalid_bv(:,:,jk) - END WHERE - END DO - ! - DEALLOCATE( zfzptnz, zdep2d ) - ELSE - pts(:,:,:,jp_tem,Kmm) = t_bkg(:,:,:) + t_bkginc(:,:,:) * zvalid_bv(:,:,:) - ENDIF - IF (ln_salfix) THEN - ! Do not apply negative increments if the salinity will fall below a specified - ! minimum value rn_salfixmin - WHERE( s_bkginc(:,:,:) > 0.0_wp .OR. pts(:,:,:,jp_sal,Kmm) + s_bkginc(:,:,:) > rn_salfixmin ) - pts(:,:,:,jp_sal,Kmm) = s_bkg(:,:,:) + s_bkginc(:,:,:) * zvalid_bv(:,:,:) - END WHERE - ELSE - pts(:,:,:,jp_sal,Kmm) = s_bkg(:,:,:) + s_bkginc(:,:,:) * zvalid_bv(:,:,:) - ENDIF - - pts(:,:,:,:,Kbb) = pts(:,:,:,:,Kmm) ! Update before fields - CALL eos( pts, Kbb, rhd, rhop ) ! Before potential and in situ densities - - DEALLOCATE( t_bkginc ) - DEALLOCATE( s_bkginc ) - DEALLOCATE( t_bkg ) - DEALLOCATE( s_bkg ) - ENDIF - ! - ENDIF - ! Perhaps the following call should be in step - IF ( ln_sicinc ) CALL sic_asm_inc ( kt ) ! apply sea ice concentration increment - IF ( ln_sitinc ) CALL sit_asm_inc ( kt ) ! apply sea ice thickness increment - ! - END SUBROUTINE tra_asm_inc""" psyir = fortran_reader.psyir_from_source(code) - normalise_loops( - psyir.walk(Routine)[0], - hoist_local_arrays=False, - convert_array_notation=True, - loopify_array_intrinsics=True, - convert_range_loops=True, - increase_array_ranks=True, - hoist_expressions=True - ) references = psyir.walk(Reference) - res = None - for ref in references: - if "pts" in ref.parent.debug_string(): - print(ref.parent.debug_string()) - if "pts(widx1,widx2,jk,jp_tem,kmm) = t_bkg(LBOUND(t_bkg, dim=1) + widx1 - 1,LBOUND(t_bkg, dim=2) + widx2 - 1,jk) + t_bkginc(LBOUND(t_bkginc, dim=1) + widx1 - 1,LBOUND(t_bkginc, dim=2) + widx2 - 1,jk) * zvalid_bv(widx1,widx2,jk)" in ref.parent.debug_string(): - res = ref - break - print(res) - assert False + next_accesses = references[4].next_accesses() + assert len(next_accesses) == 2 + assert next_accesses[0] is references[3] + assert next_accesses[1] is references[4] From 533fe1c09a460383d27dafc332b46cdb6f083975 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Mon, 20 Apr 2026 14:58:57 +0100 Subject: [PATCH 20/22] Changes for review --- src/psyclone/psyir/nodes/assignment.py | 16 ++--- src/psyclone/psyir/nodes/reference.py | 6 +- .../psyir/tools/definition_use_chains.py | 58 ++++++++++--------- ...tion_use_chains_forward_dependence_test.py | 6 +- 4 files changed, 44 insertions(+), 42 deletions(-) diff --git a/src/psyclone/psyir/nodes/assignment.py b/src/psyclone/psyir/nodes/assignment.py index 574756f12c..d66312d3ae 100644 --- a/src/psyclone/psyir/nodes/assignment.py +++ b/src/psyclone/psyir/nodes/assignment.py @@ -247,10 +247,10 @@ def is_literal_assignment(self): def previous_accesses(self) -> dict[Signature, list[Node]]: ''' - :returns: the nodes accessing the same symbols directly before this - after this. It can be multiple nodes for each symbol if - the control flow diverges and there are multiple - possible accesses. + :returns: the nodes containing the previous accesses of the symbols + accessed within this node. It can be multiple nodes for + each symbol if the control flow diverges and there are + multiple possible accesses. ''' # Find all of the read/write References in this assignment. refs = [] @@ -265,10 +265,10 @@ def previous_accesses(self) -> dict[Signature, list[Node]]: def next_accesses(self) -> dict[Signature, list[Node]]: ''' - :returns: the nodes accessing the same symbols directly after this - after this. It can be multiple nodes for each symbol if - the control flow diverges and there are multiple - possible accesses. + :returns: the nodes containing the next accesses of the symbols + accessed within this node. It can be multiple nodes for + each symbol if the control flow diverges and there are + multiple possible accesses. ''' # Find all of the read/write References in this assignment. refs = [] diff --git a/src/psyclone/psyir/nodes/reference.py b/src/psyclone/psyir/nodes/reference.py index 793d2574e5..718cae644d 100644 --- a/src/psyclone/psyir/nodes/reference.py +++ b/src/psyclone/psyir/nodes/reference.py @@ -251,12 +251,11 @@ def previous_accesses(self) -> list[Node]: :returns: the nodes accessing the same symbol directly before this reference. It can be multiple nodes if the control flow diverges and there are multiple possible accesses. - :rtype: List[:py:class:`psyclone.psyir.nodes.Node`] ''' # Avoid circular import # pylint: disable=import-outside-toplevel from psyclone.psyir.tools import DefinitionUseChain - chain = DefinitionUseChain([self]) + chain = DefinitionUseChain(self) sig = self.get_signature_and_indices()[0] return chain.find_backward_accesses()[sig] @@ -265,12 +264,11 @@ def next_accesses(self) -> list[Node]: :returns: the nodes accessing the same symbol directly after this reference. It can be multiple nodes if the control flow diverges and there are multiple possible accesses. - :rtype: List[:py:class:`psyclone.psyir.nodes.Node`] ''' # Avoid circular import # pylint: disable=import-outside-toplevel from psyclone.psyir.tools import DefinitionUseChain - chain = DefinitionUseChain([self]) + chain = DefinitionUseChain(self) sig = self.get_signature_and_indices()[0] return chain.find_forward_accesses()[sig] diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index 30e9d76ddc..26131a21b4 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -36,7 +36,7 @@ """This module contains the DefinitionUseChain class""" import sys -from typing import Iterable, Optional +from typing import Iterable, Optional, Union from fparser.two.Fortran2003 import ( Cycle_Stmt, @@ -86,16 +86,18 @@ class DefinitionUseChain: def __init__( self, - references: list[Reference], + references: Union[list[Reference], Reference], control_flow_region: Iterable[Node] = (), start_point: Optional[int] = None, stop_point: Optional[int] = None, ): + if isinstance(references, Reference): + references = [references] if not isinstance(references, list): raise TypeError( f"The 'references' argument passed into a DefinitionUseChain " - f"must be a list of References but found " - f"'{type(references).__name__}'." + f"must be a list of References or a single Reference but " + f"found '{type(references).__name__}'." ) for ref in references: if not isinstance(ref, Reference): @@ -112,17 +114,17 @@ def __init__( # Skip this check for detached nodes, since we get copies # provided to the recursive calls. if parent: - parent_path = references[0].path_from(parent)[0] + parent_idx = references[0].path_from(parent)[0] for ref in references: if (ref.ancestor(Schedule) is not parent or - ref.path_from(parent)[0] != parent_path): + ref.path_from(parent)[0] != parent_idx): raise InternalError( "All references provided into a " "DefinitionUseChain must have the same parent in " "the ancestor Schedule." ) - # Make a copy of the list so we can modify it. - self._references = [ref for ref in references] + # Make a shallow copy of the list so we can modify it. + self._references = references[:] # Store the absolute positions and signatures for later. self._reference_signatures = [] self._references_abs_pos = {} @@ -185,26 +187,28 @@ def __init__( @property def uses(self) -> dict[list[Node]]: """ - :returns: the lists of nodes using the value that the referenced - symbols has before it is reassigned. + :returns: a map holding, for each referenced Symbol, the list of nodes + that use the value that the Symbol had before it is + reassigned. """ return self._uses @property def defsout(self) -> dict[list[Node]]: """ - :returns: the lists of nodes that reach the end of the block without - being killed, and therefore can have dependencies outside - of this block. + :returns: a map holding, for each referenced Symbol, the list of nodes + whose values reach the end of the block without being killed, + and therefore can have dependencies outside of this block. """ return self._defsout @property def killed(self) -> dict[list[Node]]: """ - :returns: the lists of nodes that represent the last use of an assigned - variable. Calling next_access on any of these nodes will find - a write that reassigns it's value. + :returns: a map holding, for each reference Symbol, the list of nodes + that represent the last use of the assigned Symbol. Calling + next_access on any of these nodes will find a write that + reassigns the value of the Symbol inside this block. """ return self._killed @@ -334,8 +338,7 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: # different start and stop positions, but don't # include the lhs if the lhs is present. chain = DefinitionUseChain( - [ref.copy() for ref in self._references if - ref is not ancestor.lhs], + [ref.copy() for ref in self._references], [ancestor.lhs], start_point=ancestor.lhs.abs_position - 1, stop_point=ancestor.lhs.abs_position + 1, @@ -355,7 +358,7 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: if len(block) == 0: continue chain = DefinitionUseChain( - [ref for ref in self._references], + self._references[:], block, start_point=self._start_point, stop_point=self._stop_point, @@ -371,7 +374,7 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: # We're outside a control flow region, updating the reaches # here is to find all the reached nodes. # Some signatures may already have been removed by being - # killed, so we only add those if they've not already been + # killed, so we only add those that haven't already been # killed. for sig in chain._reaches: if sig in self._reference_signatures: @@ -413,7 +416,8 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: # or if block structures to see if we're guaranteed to # write to the symbol. # If the control flow node is a Loop we have to check - # if the variable is the same symbol as the _reference. + # if the variable is the same symbol as any of the + # References in _references. if isinstance(cfn, Loop): cfn_abs_pos = cfn.abs_position for i, ref in enumerate(self._references[:]): @@ -509,7 +513,7 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: self._reaches[sig].append(ref) else: # If this block killed any accesses, then the first element - # of the killed writes is the access access that we're + # of the killed writes is the access that we're # dependent with. self._reaches[sig].append(self.killed[sig][0]) @@ -605,7 +609,7 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): # the call later as References. continue # For now just assume calls are bad if we have a - # non-local variable and we treat them as though they + # non-local variable: we treat them as though they # were a write. sig = self._reference_signatures[i] if defs_out[sig] is not None: @@ -1010,7 +1014,7 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: # If there is no set start point, then we look for all # accesses after the Reference. if self._stop_point is None: - # Find the max abs position, as all of these are + # Find the min abs position, as all of these are # contained in the same parent. self._stop_point = min(list(self._references_abs_pos.values())) # If there is no set stop point, then any Reference after @@ -1035,7 +1039,7 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: # Create a copy of the list as it can modify elements # in the list. chain = DefinitionUseChain( - [ref for ref in self._references], + self._references[:], block, start_point=self._start_point, stop_point=self._stop_point, @@ -1068,7 +1072,7 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: Assignment ).abs_position else: - sub_start_point = max(list( + sub_start_point = min(list( self._references_abs_pos.values() )) # If we have a basic block with no children then skip it, @@ -1242,7 +1246,7 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: self._reaches[sig].append(ref) else: # If this block killed any accesses, then the first element - # of the killed writes is the access access that we're + # of the killed writes is the access that we're # dependent with. self._reaches[sig].append(self.killed[sig][0]) diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index 79266f6334..d1d7c7064f 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -106,10 +106,10 @@ def test_definition_use_chain_init_and_properties(fortran_reader): # Test exceptions when passed a non_list for various inputs. with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain(a_1) + duc = DefinitionUseChain("a") assert ("The 'references' argument passed into a DefinitionUseChain " - "must be a list of References but found 'Reference'" - in str(excinfo.value)) + "must be a list of References or a single Reference but found " + "'str'" in str(excinfo.value)) with pytest.raises(TypeError) as excinfo: duc = DefinitionUseChain([a_1], control_flow_region=2) From dc5901f5ef8c13a6f17e8b732621c495313656aa Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Mon, 20 Apr 2026 15:03:25 +0100 Subject: [PATCH 21/22] linting --- src/psyclone/psyir/nodes/assignment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/psyclone/psyir/nodes/assignment.py b/src/psyclone/psyir/nodes/assignment.py index d66312d3ae..fc76c376d6 100644 --- a/src/psyclone/psyir/nodes/assignment.py +++ b/src/psyclone/psyir/nodes/assignment.py @@ -248,7 +248,7 @@ def is_literal_assignment(self): def previous_accesses(self) -> dict[Signature, list[Node]]: ''' :returns: the nodes containing the previous accesses of the symbols - accessed within this node. It can be multiple nodes for + accessed within this node. It can be multiple nodes for each symbol if the control flow diverges and there are multiple possible accesses. ''' @@ -266,7 +266,7 @@ def previous_accesses(self) -> dict[Signature, list[Node]]: def next_accesses(self) -> dict[Signature, list[Node]]: ''' :returns: the nodes containing the next accesses of the symbols - accessed within this node. It can be multiple nodes for + accessed within this node. It can be multiple nodes for each symbol if the control flow diverges and there are multiple possible accesses. ''' From ac8fb189511278d679d705fbeda92f1aa6db9652 Mon Sep 17 00:00:00 2001 From: LonelyCat124 <3043914+LonelyCat124@users.noreply.github.com.> Date: Tue, 5 May 2026 14:51:50 +0100 Subject: [PATCH 22/22] Changes from review --- .../psyir/tools/definition_use_chains.py | 17 +++- ...ion_use_chains_backward_dependence_test.py | 52 ++++++------- ...tion_use_chains_forward_dependence_test.py | 78 ++++++++++--------- .../definition_use_chains_multiref_test.py | 64 ++++++--------- 4 files changed, 105 insertions(+), 106 deletions(-) diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index 26131a21b4..9270797e5d 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -103,8 +103,9 @@ def __init__( if not isinstance(ref, Reference): raise TypeError( f"The 'references' argument passed into a " - f"DefinitionUseChain must be a list of References " - f"but found '{type(ref).__name__}' in the list." + f"DefinitionUseChain must be a Reference or " + f"list of References but found " + f"'{type(ref).__name__}'." ) # We need all the references to have the same ancestor Schedule and # belong to the same child of the Schedule. @@ -254,6 +255,10 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: if self._start_point is None: # Find the highest abs position, as all of these are # contained in the same parent. + # We start after the last of the provided references, as + # for a statement such as b = a + a we don't want to return + # any of the References to a if the second a Reference is provided + # as an input to the DUC. self._start_point = max(list(self._references_abs_pos.values())) # If there is no set stop point, then any Reference after # the start point can potentially be a forward access. @@ -344,6 +349,10 @@ def find_forward_accesses(self) -> dict[Signature, list[Node]]: stop_point=ancestor.lhs.abs_position + 1, ) index = len(chains) + # This chain is missed by the call to + # _find_basic_blocks, so we nede to add a None + # into the control_flow_nodes list at the correct + # place to keep behaviour correct. control_flow_nodes.insert(index, None) chains.append(chain) # N.B. For now this assumes that for an expression @@ -1016,6 +1025,10 @@ def find_backward_accesses(self) -> dict[Signature, list[Node]]: if self._stop_point is None: # Find the min abs position, as all of these are # contained in the same parent. + # We start before any of the provided references, as + # for a statement such as b = a + a we don't want to return + # any of the References to a if the first a Reference is provided + # as an input to the DUC. self._stop_point = min(list(self._references_abs_pos.values())) # If there is no set stop point, then any Reference after # the start point can potentially be a forward access. diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py index 98c9ffc62b..840de59c47 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py @@ -67,7 +67,7 @@ def test_definition_use_chain_compute_backward_uses(fortran_reader): assert a_3 is psyir.walk(Assignment)[1].rhs sig = a_3.get_signature_and_indices()[0] duc = DefinitionUseChain( - [a_3], control_flow_region=[routine] + a_3, control_flow_region=[routine] ) basic_block_list = routine.children[:] # Need to set the start point and stop points similar to what @@ -94,7 +94,7 @@ def test_definition_use_chain_compute_backward_uses(fortran_reader): a_3 = psyir.walk(Reference)[4] sig = a_3.get_signature_and_indices()[0] duc = DefinitionUseChain( - [a_3], control_flow_region=[routine] + a_3, control_flow_region=[routine] ) basic_block_list = routine.children[:] # Need to set the start point and stop points similar to what @@ -144,7 +144,7 @@ def test_definition_use_chain_find_backward_accesses_basic_example( ref = routine.walk(Assignment)[8].lhs sig = ref.get_signature_and_indices()[0] chains = DefinitionUseChain( - [ref], [routine] + ref, [routine] ) reaches = chains.find_backward_accesses()[sig] # We find 2 results @@ -157,7 +157,7 @@ def test_definition_use_chain_find_backward_accesses_basic_example( # Create use chain for c in b = c + d ref = routine.walk(Assignment)[5].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref], [routine]) + chains = DefinitionUseChain(ref, [routine]) reaches = chains.find_backward_accesses()[sig] # We should find 2 results # C = d * a @@ -183,7 +183,7 @@ def test_definition_use_chain_find_backward_accesses_assignment( # Start chain from A = a * a ref = routine.walk(Assignment)[1].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] # We should find 3 results, both 3 references in # a = A * A @@ -217,7 +217,7 @@ def test_definition_use_chain_find_backward_accesses_ifelse_example( # Start the chain from b = A + d. ref = routine.walk(Assignment)[4].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] # TODO #2760 For now the if statement doesn't kill the accesses, # even though it will always be written to. @@ -268,7 +268,7 @@ def test_definition_use_chain_find_backward_accesses_psy_data_node_example( # Start the chain from b = A + d. ref = routine.walk(Assignment)[4].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] # Inside the ProfileRegion the DUC has to work as before, the 'a' has # 4 backwards accesses as shown in the previous test. @@ -297,9 +297,7 @@ def test_definition_use_chain_find_backward_accesses_loop_example( # Start the chain from A = a + i. ref = routine.walk(Assignment)[1].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain( - [ref] - ) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] # We should have 4? reaches # First b = A + 2 @@ -335,7 +333,7 @@ def test_definition_use_chain_find_backward_accesses_loop_example( # Start the chain from I = 1231. ref = routine.walk(Assignment)[2].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] # We should have 1 reaches # It should be the loop @@ -363,7 +361,7 @@ def test_definition_use_chain_find_backward_accesses_while_loop_example( # Start the chain from A = a + 3. ref = routine.children[2].loop_body.children[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 4 @@ -398,7 +396,7 @@ def test_definition_use_chain_backward_accesses_nested_loop_example( loops = routine.walk(WhileLoop) ref = loops[1].walk(Assignment)[0].rhs.children[1] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] # TODO #2760 The backwards accesses should not continue past a = a + 3 as # to reach the b = b + a statement we must have passed through the @@ -429,7 +427,7 @@ def test_definition_use_chain_find_backward_accesses_structure_example( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[2].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].lhs @@ -450,7 +448,7 @@ def test_definition_use_chain_find_backward_accesses_no_control_flow_example( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].rhs.children[0] @@ -472,7 +470,7 @@ def test_definition_use_chain_find_backward_accesses_codeblock( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[1].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -495,7 +493,7 @@ def test_definition_use_chain_find_backward_accesses_codeblock_and_call_nlocal( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 # Result is the argument of the call @@ -522,7 +520,7 @@ def test_definition_use_chain_find_backward_accesses_codeblock_and_call_cflow( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 2 assert reaches[0] is routine.walk(Call)[1].children[1] @@ -547,7 +545,7 @@ def test_definition_use_chain_find_backward_accesses_codeblock_and_call_local( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -570,7 +568,7 @@ def test_definition_use_chain_find_backward_accesses_call_and_codeblock_nlocal( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Call)[0] @@ -619,7 +617,7 @@ def test_definition_use_chains_exit_statement( # Start the chain from a = A +i. ref = routine.walk(Assignment)[1].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] # We should have 2 reaches # First is A = a + i @@ -656,7 +654,7 @@ def test_definition_use_chains_cycle_statement( # Start the chain from a = A +i. ref = routine.walk(Assignment)[1].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] # We should have 2 reaches # A = b * 4 @@ -690,7 +688,7 @@ def test_definition_use_chains_return_statement( # Start the chain from a = A +i. ref = routine.walk(Assignment)[1].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] # We should have 2 reaches # A = b * 4 @@ -722,7 +720,7 @@ def test_definition_use_chains_backward_accesses_multiple_routines( routine = psyir.walk(Routine)[1] ref = routine.walk(Assignment)[0].rhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 0 @@ -749,7 +747,7 @@ def test_definition_use_chains_backward_accesses_nonassign_reference_in_loop( routine = psyir.walk(Routine)[0] ref = routine.walk(Call)[0].children[1] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] # TODO #2760 The backwards accesses should not continue past a = a + i # when searching backwards in the loop, or to a = 1 @@ -782,7 +780,7 @@ def test_definition_use_chains_backward_accesses_empty_schedules( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[1].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(Assignment)[1].rhs.children[1] @@ -809,7 +807,7 @@ def test_definition_use_chains_backward_accesses_inquiry_func( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[1].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 0 diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index d1d7c7064f..c1108be2c4 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -143,7 +143,7 @@ def test_definition_use_chain_init_and_properties(fortran_reader): with pytest.raises(TypeError) as excinfo: duc = DefinitionUseChain(["a", "b"]) assert ("The 'references' argument passed into a DefinitionUseChain must " - "be a list of References but found 'str' in the list." + "be a Reference or list of References but found 'str'." in str(excinfo.value)) # Create a containing schedule. @@ -473,7 +473,7 @@ def test_definition_use_chain_find_basic_blocks_inside_loops(fortran_reader): aref = psyir.walk(ArrayReference)[0] assert aref.symbol.name == "ztmp" duc = DefinitionUseChain( - [aref], control_flow_region=[routine] + aref, control_flow_region=[routine] ) cfn, blocks = duc._find_basic_blocks(routine.walk(Loop)[0].children[:]) # The ifblock has to be in cfn twice, once with the contents of the if @@ -520,7 +520,7 @@ def test_definition_use_chain_find_forward_accesses_basic_example( ref = routine.children[0].children[1].children[0] sig = ref.get_signature_and_indices()[0] chains = DefinitionUseChain( - [ref], [routine] + ref, [routine] ) reaches = chains.find_forward_accesses() # We find 3 results @@ -535,7 +535,7 @@ def test_definition_use_chain_find_forward_accesses_basic_example( # Create use chain for d in d = c + 2.0 ref = routine.children[3].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref], [routine]) + chains = DefinitionUseChain(ref, [routine]) reaches = chains.find_forward_accesses() # We should find 2 results # c = D * a (Assignment 5) @@ -547,7 +547,7 @@ def test_definition_use_chain_find_forward_accesses_basic_example( # Create use chain for c in c = d * a (Assignment 5) ref = routine.walk(Assignment)[4].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref], [routine]) + chains = DefinitionUseChain(ref, [routine]) reaches = chains.find_forward_accesses() # 2 results: # b = C + d (Assignment 6) @@ -573,7 +573,7 @@ def test_definition_use_chain_find_forward_accesses_assignment( # Start chain from A = 1 ref = routine.children[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses() # We should find 3 results, all 3 references in # A = A * A @@ -607,7 +607,7 @@ def test_definition_use_chain_find_forward_accesses_ifelse_example( # Start the chain from a = 1. ref = routine.children[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses() # TODO #2760 For now the if statement doesn't kill the accesses, # even though it will always be written to. @@ -649,9 +649,7 @@ def test_definition_use_chain_find_forward_accesses_loop_example( # Start the chain from b = A +2. ref = routine.children[1].loop_body.children[1].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain( - [ref] - ) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] # We should have 3 reaches # First two are A = A + i @@ -681,9 +679,7 @@ def test_definition_use_chain_find_forward_accesses_loop_example( # Start the chain from I = 1231. ref = routine.children[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain( - [ref] - ) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] # We should have 1 reaches # It should be the loop @@ -711,7 +707,7 @@ def test_definition_use_chain_find_forward_accesses_while_loop_example( # Start the chain from A = a + 3. ref = routine.walk(Assignment)[2].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 3 @@ -745,7 +741,7 @@ def test_definition_use_chain_forward_accesses_nested_loop_example( loops = routine.walk(WhileLoop) ref = loops[1].loop_body.children[0].rhs.children[1] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] # Results should be A = A + 3 and the a < i condition assert reaches[0] is loops[0].condition.children[0] @@ -772,7 +768,7 @@ def test_definition_use_chain_find_forward_accesses_structure_example( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[2].lhs @@ -793,7 +789,7 @@ def test_definition_use_chain_find_forward_accesses_no_control_flow_example( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].lhs @@ -816,7 +812,7 @@ def test_definition_use_chain_find_forward_accesses_no_control_flow_example2( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].rhs.children[0] sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].lhs @@ -838,7 +834,7 @@ def test_definition_use_chain_find_forward_accesses_codeblock( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -861,7 +857,7 @@ def test_definition_use_chain_find_forward_accesses_codeblock_and_call_nlocal( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -887,7 +883,7 @@ def test_definition_use_chain_find_forward_accesses_codeblock_and_call_cflow( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 2 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -912,7 +908,7 @@ def test_definition_use_chain_find_forward_accesses_codeblock_and_call_local( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -935,7 +931,7 @@ def test_definition_use_chain_find_forward_accesses_call_and_codeblock_nlocal( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Call)[0] @@ -984,7 +980,7 @@ def test_definition_use_chains_exit_statement( # Start the chain from A = a +i. ref = routine.walk(Assignment)[1].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] # We should have 3 reaches # First two are A = A + i @@ -1024,7 +1020,7 @@ def test_definition_use_chains_cycle_statement( # Start the chain from A = a +i. ref = routine.walk(Assignment)[1].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] # We should have 4 reaches # First two are A = A + i @@ -1063,7 +1059,7 @@ def test_definition_use_chains_return_statement( # Start the chain from A = a +i. ref = routine.walk(Assignment)[1].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] # We should have 4 reaches # First two are A = A + i @@ -1100,7 +1096,7 @@ def test_definition_use_chains_forward_accesses_multiple_routines( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 0 @@ -1128,7 +1124,7 @@ def test_definition_use_chains_forward_accesses_empty_schedules( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(Assignment)[1].rhs.children[0] @@ -1155,7 +1151,7 @@ def test_definition_use_chains_backward_accesses_inquiry_func( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[0].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 0 @@ -1183,7 +1179,7 @@ def test_definition_use_chains_multiple_ancestor_loops( routine = psyir.walk(Routine)[0] ref = routine.walk(Assignment)[1].lhs sig = ref.get_signature_and_indices()[0] - chains = DefinitionUseChain([ref]) + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(Assignment)[0].lhs @@ -1274,14 +1270,14 @@ def test_forward_accesses_multiple_elements(fortran_reader): chains = DefinitionUseChain([rhs.children[0], rhs.children[1]]) reaches = chains.find_forward_accesses() - sig0, _ = rhs.children[0].get_signature_and_indices() - sig1, _ = rhs.children[1].get_signature_and_indices() + jsig, _ = rhs.children[0].get_signature_and_indices() + ksig, _ = rhs.children[1].get_signature_and_indices() - assert len(reaches[sig0]) == 1 - assert reaches[sig0][0] is assigns[1].lhs - assert len(reaches[sig1]) == 2 - assert reaches[sig1][0] is assigns[2].rhs - assert reaches[sig1][1] is assigns[3].lhs + assert len(reaches[jsig]) == 1 + assert reaches[jsig][0] is assigns[1].lhs + assert len(reaches[ksig]) == 2 + assert reaches[ksig][0] is assigns[2].rhs + assert reaches[ksig][1] is assigns[3].lhs def test_if_else_behaviour(fortran_reader): @@ -1326,7 +1322,15 @@ def test_if_else_behaviour(fortran_reader): psyir = fortran_reader.psyir_from_source(code) references = psyir.walk(Reference) + # Find the next accesses of some_array(1,4,9) from + # the lhs of the first assignment. next_accesses = references[4].next_accesses() + + # The next accesses should be the some_array(1,4,9) access in + # if(tmp3 .or. some_array(1, 4, 9) and to the same reference. + # Both are due to the containing do k = 1, 100 loop. We musn't + # find any accesses to the some_array writes in the else_bodies + # of the containing IfBlock nodes assert len(next_accesses) == 2 assert next_accesses[0] is references[3] assert next_accesses[1] is references[4] diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_test.py index c843b9685e..a21b252403 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_test.py @@ -45,23 +45,18 @@ @pytest.mark.parametrize("code", [ - """subroutine test - integer :: a, b + """integer :: a, b a = b b = 2 + a a = 2 - end subroutine test """, - """subroutine test - integer :: a, b, i + """integer :: a, b, i do i = 1, 100 a = b b = 3 end do - end subroutine test """, - """subroutine test - integer :: a, b + """integer :: a, b logical :: x a = b if (x) then @@ -71,32 +66,27 @@ end if a = 1 b = 1 - end subroutine test """, - """subroutine test - integer :: a, b + """integer :: a, b a = b b = 2 if (b > 1) then a = 1 end if - a = 2 * b - end subroutine test""", - """subroutine test - integer :: a, b + a = 2 * b""", + """integer :: a, b a = b + a b = 2 if (b > 1) then a = 1 end if - a = 2 * b - end subroutine test""", + a = 2 * b""", ]) def test_duc_forward_equivalence(code, fortran_reader): '''Test the DUCs give the same results for multiple inputs as each of the inputs individually. ''' - + code = f"subroutine test\n{code}\nend subroutine test" psyir = fortran_reader.psyir_from_source(code) assign = psyir.walk(Assignment)[0] @@ -104,9 +94,11 @@ def test_duc_forward_equivalence(code, fortran_reader): all_refs = assign.walk(Reference) lhs_ref = all_refs[0] # First Reference is lhs a rhs_ref = all_refs[1] # Second Reference is always rhs b + assert lhs_ref.symbol.name == "a" + assert rhs_ref.symbol.name == "b" duc1 = DefinitionUseChain(all_refs) - duc2 = DefinitionUseChain([lhs_ref]) - duc3 = DefinitionUseChain([rhs_ref]) + duc2 = DefinitionUseChain(lhs_ref) + duc3 = DefinitionUseChain(rhs_ref) res1 = duc1.find_forward_accesses() res2 = duc2.find_forward_accesses() res3 = duc3.find_forward_accesses() @@ -123,23 +115,18 @@ def test_duc_forward_equivalence(code, fortran_reader): @pytest.mark.parametrize("code", [ - """subroutine test - integer :: a, b + """integer :: a, b a = 2 b = 2 + a a = b - end subroutine test """, - """subroutine test - integer :: a, b, i + """integer :: a, b, i do i = 1, 100 b = 3 a = b end do - end subroutine test """, - """subroutine test - integer :: a, b + """integer :: a, b logical :: x a = 1 b = 1 @@ -149,31 +136,27 @@ def test_duc_forward_equivalence(code, fortran_reader): b = a * b end if a = b - end subroutine test """, - """subroutine test - integer :: a, b + """integer :: a, b a = 2 * b if (b > 1) then a = 1 end if b = 2 - a = b - end subroutine test""", - """subroutine test - integer :: a, b + a = b""", + """integer :: a, b a = 2 * b if (b > 1) then a = 1 end if b = 2 - a = b + a - end subroutine test""", + a = b + a""", ]) def test_duc_backward_equivalence(code, fortran_reader): '''Test the DUCs give the same results for multiple inputs as each of the inputs individually. ''' + code = f"subroutine test\n{code}\nend subroutine test" psyir = fortran_reader.psyir_from_source(code) assign = psyir.walk(Assignment)[-1] @@ -181,16 +164,17 @@ def test_duc_backward_equivalence(code, fortran_reader): all_refs = assign.walk(Reference) lhs_ref = all_refs[0] # First Reference is lhs a rhs_ref = all_refs[1] # Second Reference is always rhs b + assert lhs_ref.symbol.name == "a" + assert rhs_ref.symbol.name == "b" duc1 = DefinitionUseChain(all_refs) - duc2 = DefinitionUseChain([lhs_ref]) - duc3 = DefinitionUseChain([rhs_ref]) + duc2 = DefinitionUseChain(lhs_ref) + duc3 = DefinitionUseChain(rhs_ref) res1 = duc1.find_backward_accesses() res2 = duc2.find_backward_accesses() res3 = duc3.find_backward_accesses() lhs_sig = lhs_ref.get_signature_and_indices()[0] rhs_sig = rhs_ref.get_signature_and_indices()[0] - assert len(res2[lhs_sig]) == len(res1[lhs_sig]) for i, result in enumerate(res2[lhs_sig]): assert result is res1[lhs_sig][i]