diff --git a/src/psyclone/psyir/nodes/assignment.py b/src/psyclone/psyir/nodes/assignment.py index a8409034d4..fc76c376d6 100644 --- a/src/psyclone/psyir/nodes/assignment.py +++ b/src/psyclone/psyir/nodes/assignment.py @@ -39,13 +39,14 @@ ''' This module contains the Assignment node implementation.''' -from psyclone.core import VariablesAccessMap, AccessType +from psyclone.core import VariablesAccessMap, AccessType, Signature from psyclone.errors import InternalError from psyclone.psyir.nodes.literal import Literal from psyclone.psyir.nodes.array_reference import ArrayReference from psyclone.psyir.nodes.datanode import DataNode from psyclone.psyir.nodes.intrinsic_call import ( IntrinsicCall, REDUCTION_INTRINSICS) +from psyclone.psyir.nodes.node import Node from psyclone.psyir.nodes.ranges import Range from psyclone.psyir.nodes.reference import Reference from psyclone.psyir.nodes.statement import Statement @@ -243,3 +244,39 @@ def is_literal_assignment(self): ''' return isinstance(self.rhs, Literal) + + def previous_accesses(self) -> dict[Signature, list[Node]]: + ''' + :returns: the nodes containing the previous accesses of the symbols + accessed within this node. It can be multiple nodes for + each symbol if the control flow diverges and there are + multiple possible accesses. + ''' + # Find all of the read/write References in this assignment. + refs = [] + for ref in self.walk(Reference): + if ref.is_read or ref.is_write: + refs.append(ref) + # Avoid circular import + # pylint: disable=import-outside-toplevel + from psyclone.psyir.tools import DefinitionUseChain + chain = DefinitionUseChain(refs) + return chain.find_backward_accesses() + + def next_accesses(self) -> dict[Signature, list[Node]]: + ''' + :returns: the nodes containing the next accesses of the symbols + accessed within this node. It can be multiple nodes for + each symbol if the control flow diverges and there are + multiple possible accesses. + ''' + # Find all of the read/write References in this assignment. + refs = [] + for ref in self.walk(Reference): + if ref.is_read or ref.is_write: + refs.append(ref) + # Avoid circular import + # pylint: disable=import-outside-toplevel + from psyclone.psyir.tools import DefinitionUseChain + chain = DefinitionUseChain(refs) + return chain.find_forward_accesses() diff --git a/src/psyclone/psyir/nodes/reference.py b/src/psyclone/psyir/nodes/reference.py index 6d70fdb560..718cae644d 100644 --- a/src/psyclone/psyir/nodes/reference.py +++ b/src/psyclone/psyir/nodes/reference.py @@ -45,7 +45,12 @@ # We cannot import from 'nodes' directly due to circular import from psyclone.psyir.nodes.datanode import DataNode from psyclone.psyir.nodes.node import Node -from psyclone.psyir.symbols import Symbol, AutomaticInterface +from psyclone.psyir.symbols import ( + Symbol, + AutomaticInterface, + RoutineSymbol, + IntrinsicSymbol +) from psyclone.psyir.symbols.datatypes import UnresolvedType @@ -95,6 +100,12 @@ def is_read(self): # pylint: disable=import-outside-toplevel from psyclone.psyir.nodes.assignment import Assignment from psyclone.psyir.nodes.intrinsic_call import IntrinsicCall + + # If the symbol is a RoutineSymbol or IntrinsicSymbol we don't read + # the symbol. + if isinstance(self.symbol, (RoutineSymbol, IntrinsicSymbol)): + return False + parent = self.parent if isinstance(parent, Assignment): if parent.lhs is self: @@ -235,31 +246,31 @@ def datatype(self): return super().datatype return self.symbol.datatype - def previous_accesses(self): + def previous_accesses(self) -> list[Node]: ''' :returns: the nodes accessing the same symbol directly before this reference. It can be multiple nodes if the control flow diverges and there are multiple possible accesses. - :rtype: List[:py:class:`psyclone.psyir.nodes.Node`] ''' # Avoid circular import # pylint: disable=import-outside-toplevel from psyclone.psyir.tools import DefinitionUseChain chain = DefinitionUseChain(self) - return chain.find_backward_accesses() + sig = self.get_signature_and_indices()[0] + return chain.find_backward_accesses()[sig] - def next_accesses(self): + def next_accesses(self) -> list[Node]: ''' :returns: the nodes accessing the same symbol directly after this reference. It can be multiple nodes if the control flow diverges and there are multiple possible accesses. - :rtype: List[:py:class:`psyclone.psyir.nodes.Node`] ''' # Avoid circular import # pylint: disable=import-outside-toplevel from psyclone.psyir.tools import DefinitionUseChain chain = DefinitionUseChain(self) - return chain.find_forward_accesses() + sig = self.get_signature_and_indices()[0] + return chain.find_forward_accesses()[sig] def escapes_scope( self, scope: Node, visited_nodes: Optional[set] = None diff --git a/src/psyclone/psyir/tools/definition_use_chains.py b/src/psyclone/psyir/tools/definition_use_chains.py index b30a0c34db..9270797e5d 100644 --- a/src/psyclone/psyir/tools/definition_use_chains.py +++ b/src/psyclone/psyir/tools/definition_use_chains.py @@ -36,7 +36,7 @@ """This module contains the DefinitionUseChain class""" import sys -from typing import Iterable, Optional +from typing import Iterable, Optional, Union from fparser.two.Fortran2003 import ( Cycle_Stmt, @@ -44,6 +44,8 @@ Goto_Stmt, ) +from psyclone.core import Signature +from psyclone.errors import InternalError from psyclone.psyir.nodes import ( Assignment, Call, @@ -67,8 +69,8 @@ class DefinitionUseChain: """The DefinitionUseChain class is used to find nodes in a tree that have data dependencies on the provided reference. - :param reference: The Reference for which the dependencies will be - computed. + :param references: The References for which the dependencies will be + computed. :param control_flow_region: Optional region to search for data dependencies. Default is the parent Routine or the root of the tree's children if no ancestor @@ -79,25 +81,60 @@ class DefinitionUseChain: dependency search. :raises TypeError: If one of the arguments is the wrong type. - + :raises InternalError: If not all of the references have the same parent. """ def __init__( self, - reference: Reference, + references: Union[list[Reference], Reference], control_flow_region: Iterable[Node] = (), start_point: Optional[int] = None, stop_point: Optional[int] = None, ): - if not isinstance(reference, Reference): + if isinstance(references, Reference): + references = [references] + if not isinstance(references, list): raise TypeError( - f"The 'reference' argument passed into a DefinitionUseChain " - f"must be a Reference but found " - f"'{type(reference).__name__}'." + f"The 'references' argument passed into a DefinitionUseChain " + f"must be a list of References or a single Reference but " + f"found '{type(references).__name__}'." ) - self._reference = reference - # Store the absolute position for later. - self._reference_abs_pos = reference.abs_position + for ref in references: + if not isinstance(ref, Reference): + raise TypeError( + f"The 'references' argument passed into a " + f"DefinitionUseChain must be a Reference or " + f"list of References but found " + f"'{type(ref).__name__}'." + ) + # We need all the references to have the same ancestor Schedule and + # belong to the same child of the Schedule. + # Skip this check if we only have 1 input. + if len(references) > 1: + parent = references[0].ancestor(Schedule) + # Skip this check for detached nodes, since we get copies + # provided to the recursive calls. + if parent: + parent_idx = references[0].path_from(parent)[0] + for ref in references: + if (ref.ancestor(Schedule) is not parent or + ref.path_from(parent)[0] != parent_idx): + raise InternalError( + "All references provided into a " + "DefinitionUseChain must have the same parent in " + "the ancestor Schedule." + ) + # Make a shallow copy of the list so we can modify it. + self._references = references[:] + # Store the absolute positions and signatures for later. + self._reference_signatures = [] + self._references_abs_pos = {} + self._references[0].compute_cached_abs_positions() + for ref in references: + sig, _ = ref.get_signature_and_indices() + self._reference_signatures.append(sig) + self._references_abs_pos[sig] = ref.abs_position + # To enable loops to work correctly we can set the start/stop point # and not just use base it on the reference's absolute position if start_point and not isinstance(start_point, int): @@ -115,9 +152,9 @@ def __init__( self._start_point = start_point self._stop_point = stop_point if not control_flow_region: - self._scope = [reference.ancestor(Routine)] + self._scope = [references[0].ancestor(Routine)] if self._scope[0] is None: - self._scope = reference.root.children[:] + self._scope = references[0].root.children[:] else: # We need a list of regions for control flow. if not isinstance(control_flow_region, list): @@ -135,36 +172,44 @@ def __init__( self._scope = control_flow_region # The uses, defsout and killed sets as defined for each basic block. - self._uses = [] - self._defsout = [] - self._killed = [] + self._uses = {} + self._defsout = {} + self._killed = {} # The output map, mapping between nodes and the reach of that node. - self._reaches = [] + self._reaches = {} + # Initialise the maps. + for sig in self._reference_signatures: + self._uses[sig] = [] + self._defsout[sig] = [] + self._killed[sig] = [] + self._reaches[sig] = [] @property - def uses(self) -> list[Node]: + def uses(self) -> dict[list[Node]]: """ - :returns: the list of nodes using the value that the referenced symbol - has before it is reassigned. + :returns: a map holding, for each referenced Symbol, the list of nodes + that use the value that the Symbol had before it is + reassigned. """ return self._uses @property - def defsout(self) -> list[Node]: + def defsout(self) -> dict[list[Node]]: """ - :returns: the list of nodes that reach the end of the block without - being killed, and therefore can have dependencies outside - of this block. + :returns: a map holding, for each referenced Symbol, the list of nodes + whose values reach the end of the block without being killed, + and therefore can have dependencies outside of this block. """ return self._defsout @property - def killed(self) -> list[Node]: + def killed(self) -> dict[list[Node]]: """ - :returns: the list of nodes that represent the last use of an assigned - variable. Calling next_access on any of these nodes will find - a write that reassigns it's value. + :returns: a map holding, for each reference Symbol, the list of nodes + that represent the last use of the assigned Symbol. Calling + next_access on any of these nodes will find a write that + reassigns the value of the Symbol inside this block. """ return self._killed @@ -183,9 +228,9 @@ def is_basic_block(self) -> bool: return False return True - def find_forward_accesses(self) -> list[Node]: + def find_forward_accesses(self) -> dict[Signature, list[Node]]: """ - Find all the forward accesses for the reference defined in this + Find all the forward accesses for the references defined in this DefinitionUseChain. Forward accesses are all of the References or Calls that read or write to the symbol of the reference up to the point that a @@ -194,13 +239,13 @@ def find_forward_accesses(self) -> list[Node]: that occur inside control flow do not end the forward access chain. - :returns: the forward accesses of the reference given to this + :returns: the forward accesses of the references given to this DefinitionUseChain """ # Compute the abs position caches as we'll use these a lot. # The compute_cached_abs_position will only do this if needed # so we don't need to check here. - self._reference.compute_cached_abs_positions() + self._references[0].compute_cached_abs_positions() # Setup the start and stop positions save_start_position = self._start_point @@ -208,7 +253,13 @@ def find_forward_accesses(self) -> list[Node]: # If there is no set start point, then we look for all # accesses after the Reference. if self._start_point is None: - self._start_point = self._reference_abs_pos + # Find the highest abs position, as all of these are + # contained in the same parent. + # We start after the last of the provided references, as + # for a statement such as b = a + a we don't want to return + # any of the References to a if the second a Reference is provided + # as an input to the DUC. + self._start_point = max(list(self._references_abs_pos.values())) # If there is no set stop point, then any Reference after # the start point can potentially be a forward access. if self._stop_point is None: @@ -229,10 +280,10 @@ def find_forward_accesses(self) -> list[Node]: # called but thats hard to otherwise track. if ( isinstance(self._scope[0], Routine) - or self._scope[0] is self._reference.root + or self._scope[0] is self._references[0].root ): # Check if there is an ancestor Loop/WhileLoop. - ancestor = self._reference.ancestor((Loop, WhileLoop)) + ancestor = self._references[0].ancestor((Loop, WhileLoop)) while ancestor is not None: # Create a basic block for the ancestor Loop. body = ancestor.loop_body.children[:] @@ -240,7 +291,7 @@ def find_forward_accesses(self) -> list[Node]: # Find the stop point - this needs to be the node after # the ancestor statement. sub_stop_point = ( - self._reference.ancestor(Statement) + self._references[0].ancestor(Statement) .walk(Node)[-1] .abs_position + 1 @@ -253,7 +304,7 @@ def find_forward_accesses(self) -> list[Node]: # node to avoid handling the special cases based on # the parents of the reference. chain = DefinitionUseChain( - self._reference.copy(), + [ref.copy() for ref in self._references], body, start_point=ancestor.abs_position, stop_point=sub_stop_point, @@ -265,7 +316,7 @@ def find_forward_accesses(self) -> list[Node]: control_flow_nodes.insert(0, None) sub_stop_point = ancestor.loop_body.abs_position chain = DefinitionUseChain( - self._reference.copy(), + [ref.copy() for ref in self._references], [ancestor.condition], start_point=ancestor.abs_position, stop_point=sub_stop_point, @@ -274,25 +325,35 @@ def find_forward_accesses(self) -> list[Node]: ancestor = ancestor.ancestor((Loop, WhileLoop)) # Check if there is an ancestor Assignment. - ancestor = self._reference.ancestor(Assignment) + ancestor = self._references[0].ancestor(Assignment) if ancestor is not None: # If the reference is the lhs then we can ignore the RHS. - if ancestor.lhs is self._reference: + # This can only be the case if we only have a single + # reference input. + if (ancestor.lhs is self._references[0] + and len(self._references) == 1): # Find the last node in the assignment last_node = ancestor.walk(Node)[-1] # Modify the start_point to only include the node after # this assignment. self._start_point = last_node.abs_position - else: + elif (all([ref is not ancestor.lhs for + ref in self._references])): # Add the lhs as a potential basic block with - # different start and stop positions. + # different start and stop positions, but don't + # include the lhs if the lhs is present. chain = DefinitionUseChain( - self._reference, + [ref.copy() for ref in self._references], [ancestor.lhs], start_point=ancestor.lhs.abs_position - 1, stop_point=ancestor.lhs.abs_position + 1, ) - control_flow_nodes.append(None) + index = len(chains) + # This chain is missed by the call to + # _find_basic_blocks, so we nede to add a None + # into the control_flow_nodes list at the correct + # place to keep behaviour correct. + control_flow_nodes.insert(index, None) chains.append(chain) # N.B. For now this assumes that for an expression # b = a * a, that next_access to the first Reference @@ -306,12 +367,13 @@ def find_forward_accesses(self) -> list[Node]: if len(block) == 0: continue chain = DefinitionUseChain( - self._reference, + self._references[:], block, start_point=self._start_point, stop_point=self._stop_point, ) chains.append(chain) + for i, chain in enumerate(chains): # Compute the defsout, killed and reaches for the block. chain.find_forward_accesses() @@ -320,15 +382,37 @@ def find_forward_accesses(self) -> list[Node]: if cfn is None: # We're outside a control flow region, updating the reaches # here is to find all the reached nodes. - for ref in chain._reaches: - # Add unique references to reaches. Since we're not - # in a control flow region, we can't have added - # these references into the reaches array yet so - # they're guaranteed to be unique. - self._reaches.append(ref) - # If we have a defsout in the chain then we can stop as we - # will never get past the write as its not conditional. - if len(chain.defsout) > 0: + # Some signatures may already have been removed by being + # killed, so we only add those that haven't already been + # killed. + for sig in chain._reaches: + if sig in self._reference_signatures: + for ref in chain._reaches[sig]: + # Add unique references to reaches. We always + # need to check as the input can have multiple + # References to the same symbol. + for ref2 in self._reaches[sig]: + if ref2 is ref: + break + else: + self._reaches[sig].append(ref) + # If we have a defsout in the chain then we can stop for + # that reference as we will never get past the write + # as its not conditional. Since we don't always include + # the LHS of an assignment into the chain we skip the + # signature if its not present in the chain defsout dict. + for i, sig in enumerate(self._reference_signatures): + if (sig in chain.defsout and + len(chain.defsout[sig]) > 0): + self._references.pop(i) + self._reference_signatures.pop(i) + # Make sure we propagate the defsout updates + # to ensure we don't go past it for other + # accesses above. + self._defsout[sig].extend(chain.defsout[sig]) + # If we have found an end point for all references then + # we can finish. + if len(self._references) == 0: # Reset the start and stop points before returning # the result. self._start_point = save_start_position @@ -341,41 +425,53 @@ def find_forward_accesses(self) -> list[Node]: # or if block structures to see if we're guaranteed to # write to the symbol. # If the control flow node is a Loop we have to check - # if the variable is the same symbol as the _reference. + # if the variable is the same symbol as any of the + # References in _references. if isinstance(cfn, Loop): cfn_abs_pos = cfn.abs_position - if ( - cfn.variable == self._reference.symbol - and cfn_abs_pos >= self._start_point - and cfn_abs_pos < self._stop_point - ): - # The loop variable is always written to and so - # we're done if its reached. - self._reaches.append(cfn) - self._start_point = save_start_position - self._stop_point = save_stop_position - return self._reaches + for i, ref in enumerate(self._references[:]): + if ( + cfn.variable == ref.symbol + and cfn_abs_pos >= self._start_point + and cfn_abs_pos < self._stop_point + ): + # The loop variable is always written to. + sig = self._reference_signatures[i] + self._reaches[sig].append(cfn) + # This reference is killed by this access. + self._references.pop(i) + self._reference_signatures.pop(i) + # If we have found an end point for all + # references then we can finish. + if len(self._references) == 0: + self._start_point = save_start_position + self._stop_point = save_stop_position + return self._reaches - for ref in chain._reaches: - # Add node to the _reaches list if it is not already - # contained. Note that a "not in" check is not - # sufficient as it checks for equality, but this needs - # identity uniqueness. - for ref2 in self._reaches: - if ref2 is ref: - break - else: - self._reaches.append(ref) + for sig in chain._reaches: + # Not all signatures are still being searched for, + # some have already been killed by previous accesses + # so we should skip them. + if sig in self._reference_signatures: + for ref in chain._reaches[sig]: + # Add unique references to reaches. Since we + # could have multiple references to the same + # symbol in the input they're not unique, so + # we need to check for uniqueness. As nodes + # can be == but not the same object, this + # has to be done using a loop and `is`. + for ref2 in self._reaches[sig]: + if ref2 is ref: + break + else: + self._reaches[sig].append(ref) else: # Check if there is an ancestor Assignment. - ancestor = self._reference.ancestor(Assignment) + ancestor = self._references[0].ancestor(Assignment) if ancestor is not None: - # If we get here to check the start part of a loop we need - # to handle this differently. - if self._start_point != self._reference_abs_pos: - pass - # If the reference is the lhs then we can ignore the RHS. - if ancestor.lhs is self._reference: + # If any of the references is the lhs then we can ignore the + # RHS. + if any([ancestor.lhs is ref for ref in self._references]): # Find the last node in the assignment last_node = ancestor.walk(Node)[-1] # Modify the start_point to only include the node after @@ -389,18 +485,26 @@ def find_forward_accesses(self) -> list[Node]: # Add the lhs as a potential basic block with different # start and stop positions. chain = DefinitionUseChain( - self._reference, + [ref for ref in self._references], [ancestor.lhs], start_point=ancestor.lhs.abs_position - 1, stop_point=ancestor.lhs.abs_position + 1, ) # Find any forward_accesses in the lhs. chain.find_forward_accesses() - for ref in chain._reaches: - self._reaches.append(ref) + for sig in chain._reaches: + if sig in self._reference_signatures: + for ref in chain._reaches[sig]: + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop as we # will never get past the write as its not conditional. - if len(chain.defsout) > 0: + for i, sig in enumerate(self._reference_signatures): + if len(chain.defsout[sig]) > 0: + self._references.pop(i) + self._reference_signatures.pop(i) + # If we have found an end point for all references then + # we can finish. + if len(self._references) == 0: # Reset the start and stop points before returning # the result. self._start_point = save_start_position @@ -408,18 +512,19 @@ def find_forward_accesses(self) -> list[Node]: return self._reaches # We can compute the rest of the accesses self._compute_forward_uses(self._scope) - for ref in self._uses: - self._reaches.append(ref) - # If this block doesn't kill any accesses, then we add - # the defsout into the reaches array. - if len(self.killed) == 0: - for ref in self._defsout: - self._reaches.append(ref) - else: - # If this block killed any accesses, then the first element - # of the killed writes is the access access that we're - # dependent with. - self._reaches.append(self.killed[0]) + for sig in self._reference_signatures: + for ref in self._uses[sig]: + self._reaches[sig].append(ref) + # If this block doesn't kill any accesses, then we add + # the defsout into the reaches array. + if len(self.killed[sig]) == 0: + for ref in self._defsout[sig]: + self._reaches[sig].append(ref) + else: + # If this block killed any accesses, then the first element + # of the killed writes is the access that we're + # dependent with. + self._reaches[sig].append(self.killed[sig][0]) # Reset the start and stop points before returning the result. self._start_point = save_start_position @@ -443,9 +548,10 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): :raises NotImplementedError: If a GOTO statement is found in the code region. """ - sig, _ = self._reference.get_signature_and_indices() - # For a basic block we will only ever have one defsout - defs_out = None + # For a basic block we will only ever have one defsout per reference. + defs_out = {} + for sig in self._reference_signatures: + defs_out[sig] = None for region in basic_block_list: for reference in region.walk((Reference, Call, CodeBlock, Return)): # Store the position instead of computing it twice. @@ -455,8 +561,9 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): if isinstance(reference, Return): # When we find a return statement any following statements # can be ignored so we can return. - if defs_out is not None: - self._defsout.append(defs_out) + for sig in self._reference_signatures: + if defs_out[sig] is not None: + self._defsout[sig].append(defs_out[sig]) return # If its parent is an inquiry function then its neither # a read nor write if its the first argument. @@ -478,42 +585,48 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): if isinstance( reference.parse_tree_nodes[0], (Exit_Stmt, Cycle_Stmt) ): - if defs_out is not None: - self._defsout.append(defs_out) + for sig in self._reference_signatures: + if defs_out[sig] is not None: + self._defsout[sig].append(defs_out[sig]) return - if ( - self._reference.symbol.name - in reference.get_symbol_names() - ): - # Assume the worst for a CodeBlock and we count them - # as killed and defsout and uses. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference - continue + for i, ref in enumerate(self._references[:]): + if ( + ref.symbol.name + in reference.get_symbol_names() + ): + # Assume the worst for a CodeBlock and we count + # them as killed and defsout and uses. + sig = self._reference_signatures[i] + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference elif isinstance(reference, Call): # If its a local variable we can ignore it as we'll catch # the Reference later if its passed into the Call. - if self._reference.symbol.is_automatic: - continue - if isinstance(reference, IntrinsicCall): - # IntrinsicCall can only do stuff to arguments, these - # will be caught by Reference walk already. - # Note that this assumes two symbols are not - # aliases of each other. - continue - if reference.is_pure: - # Pure subroutines only touch their arguments, so we'll - # catch the arguments that are passed into the call - # later as References. - continue - # For now just assume calls are bad if we have a non-local - # variable and we treat them as though they were a write. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference - continue - elif reference.get_signature_and_indices()[0] == sig: + for i, ref in enumerate(self._references[:]): + if ref.symbol.is_automatic: + continue + if isinstance(reference, IntrinsicCall): + # IntrinsicCall can only do stuff to arguments, + # these will be caught by Reference walk already. + # Note that this assumes two symbols are not + # aliases of each other. + continue + if reference.is_pure: + # Pure subroutines only touch their arguments, so + # we'll catch the arguments that are passed into + # the call later as References. + continue + # For now just assume calls are bad if we have a + # non-local variable: we treat them as though they + # were a write. + sig = self._reference_signatures[i] + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference + elif (reference.get_signature_and_indices()[0] in + self._reference_signatures): + sig = reference.get_signature_and_indices()[0] # Work out if its read only or not. assign = reference.ancestor(Assignment) if assign is not None: @@ -521,42 +634,46 @@ def _compute_forward_uses(self, basic_block_list: list[Node]): # This is a write to the reference, so kill the # previous defs_out and set this to be the # defs_out. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference elif ( - assign.lhs is defs_out - and len(self._killed) == 0 + assign.lhs is defs_out[sig] + and len(self._killed[sig]) == 0 and assign.lhs.get_signature_and_indices()[0] == sig - and assign.lhs is not self._reference + and any( + [assign.lhs is not ref for + ref in self._references] + ) ): # reference is on the rhs of an assignment such as # a = a + 1. Since the PSyIR tree walk accesses # the lhs of an assignment before the rhs of an # assignment we need to not ignore these accesses. - self._uses.append(reference) + self._uses[sig].append(reference) else: # Read only, so if we've not yet set written to # this variable this is a use. NB. We need to # check the if the write is the LHS of the parent # assignment and if so check if we killed any # previous assignments. - if defs_out is None: - self._uses.append(reference) + if defs_out[sig] is None: + self._uses[sig].append(reference) elif reference.ancestor(Call): # Otherwise we assume read/write access for now. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference else: # Reference outside an Assignment - read only # This could be References inside a While loop # condition for example. - if defs_out is None: - self._uses.append(reference) - if defs_out is not None: - self._defsout.append(defs_out) + if defs_out[sig] is None: + self._uses[sig].append(reference) + for sig in self._reference_signatures: + if defs_out[sig] is not None: + self._defsout[sig].append(defs_out[sig]) def _find_basic_blocks( self, nodelist: list[Node] @@ -649,9 +766,12 @@ def _find_basic_blocks( if node.else_body: refs = node.else_body.walk(Reference) for ref in refs: - if ref is self._reference: - # If its in the else_body we don't add the if_body - in_else_body = True + # If its in the else_body we don't add the if_body + for ref2 in self._references: + if ref is ref2: + in_else_body = True + break + if in_else_body: break if not in_else_body: control_flow_nodes.append(node) @@ -660,8 +780,11 @@ def _find_basic_blocks( in_if_body = False refs = node.if_body.walk(Reference) for ref in refs: - if ref is self._reference: - in_if_body = True + for ref2 in self._references: + if ref is ref2: + in_if_body = True + break + if in_if_body: break if node.else_body and not in_if_body: control_flow_nodes.append(node) @@ -715,9 +838,10 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): :raises NotImplementedError: If a GOTO statement is found in the code region. """ - sig, _ = self._reference.get_signature_and_indices() - # For a basic block we will only ever have one defsout - defs_out = None + # For a basic block we will only ever have one defsout per reference. + defs_out = {} + for sig in self._reference_signatures: + defs_out[sig] = None # Working backwards so reverse the basic_block_list basic_block_list.reverse() stop_position = self._stop_point @@ -764,43 +888,48 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): "DefinitionUseChains can't handle code containing" " GOTO statements." ) - if ( - self._reference.symbol.name - in reference.get_symbol_names() - ): - # Assume the worst for a CodeBlock and we count them - # as killed and defsout and uses. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference - continue + for i, ref in enumerate(self._references[:]): + if ( + ref.symbol.name + in reference.get_symbol_names() + ): + # Assume the worst for a CodeBlock and we count + # them as killed and defsout and uses. + sig = self._reference_signatures[i] + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference elif isinstance(reference, Call): # If its a local variable we can ignore it as we'll catch # the Reference later if its passed into the Call. - if self._reference.symbol.is_automatic: - continue - # If the call is an ancestor of the Reference then - # we skip it for backwards accesses. - if self._reference.is_descendant_of(reference): - continue - if isinstance(reference, IntrinsicCall): - # IntrinsicCall can only do stuff to arguments, these - # will be caught by Reference walk already. - # Note that this assumes two symbols are not - # aliases of each other. - continue - if reference.is_pure: - # Pure subroutines only touch their arguments, so we'll - # catch the arguments that are passed into the call - # later as References. - continue - # For now just assume calls are bad if we have a non-local - # variable and we treat them as though they were a write. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference - continue - elif reference.get_signature_and_indices()[0] == sig: + for i, ref in enumerate(self._references): + if ref.symbol.is_automatic: + continue + # If the call is an ancestor of the Reference then + # we skip it for backwards accesses. + if ref.is_descendant_of(reference): + continue + if isinstance(reference, IntrinsicCall): + # IntrinsicCall can only do stuff to arguments, + # these will be caught by Reference walk already. + # Note that this assumes two symbols are not + # aliases of each other. + continue + if reference.is_pure: + # Pure subroutines only touch their arguments, so + # we'll catch the arguments that are passed into + # the call later as References. + continue + # For now just assume calls are bad if we have a + # non-local variable and we treat them as though + # they were a write. + sig = self._reference_signatures[i] + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference + elif (reference.get_signature_and_indices()[0] in + self._reference_signatures): + sig = reference.get_signature_and_indices()[0] # Work out if its read only or not. assign = reference.ancestor(Assignment) # RHS reads occur "before" LHS writes, so if we @@ -808,13 +937,16 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): # a dependency to the value used from the LHS. if assign is not None: if assign.lhs is reference: - # Check if the RHS contains the self._reference. + # Check if the RHS contains the self._references. # Can't use in since equality is not what we want # here. + # We also only stop if the stop_point of the chain + # is in the assignment's rhs. found = False for ref in assign.rhs.walk(Reference): if ( - ref is self._reference + any([ref is ref2 for + ref2 in self._references]) and self._stop_point == ref.abs_position ): found = True @@ -825,12 +957,15 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): # This is a write to the reference, so kill the # previous defs_out and set this to be the # defs_out. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference elif ( assign.lhs.get_signature_and_indices()[0] == sig - and assign.lhs is not self._reference + and any( + [assign.lhs is not ref for + ref in self._references] + ) ): # Reference is on the rhs of an assignment such as # a = a + 1. Since we're looping through the tree @@ -846,23 +981,24 @@ def _compute_backward_uses(self, basic_block_list: list[Node]): # check the if the write is the LHS of the parent # assignment and if so check if we killed any # previous assignments. - if defs_out is None: - self._uses.append(reference) + if defs_out[sig] is None: + self._uses[sig].append(reference) elif reference.ancestor(Call): # Otherwise we assume read/write access for now. - if defs_out is not None: - self._killed.append(defs_out) - defs_out = reference + if defs_out[sig] is not None: + self._killed[sig].append(defs_out[sig]) + defs_out[sig] = reference else: # Reference outside an Assignment - read only # This could be References inside a While loop # condition for example. - if defs_out is None: - self._uses.append(reference) - if defs_out is not None: - self._defsout.append(defs_out) + if defs_out[sig] is None: + self._uses[sig].append(reference) + for sig in self._reference_signatures: + if defs_out[sig] is not None: + self._defsout[sig].append(defs_out[sig]) - def find_backward_accesses(self) -> list[Node]: + def find_backward_accesses(self) -> dict[Signature, list[Node]]: """ Find all the backward accesses for the reference defined in this DefinitionUseChain. @@ -879,7 +1015,7 @@ def find_backward_accesses(self) -> list[Node]: # Compute the abs position caches as we'll use these a lot. # The compute_cached_abs_position will only do this if needed # so we don't need to check here. - self._reference.compute_cached_abs_positions() + self._references[0].compute_cached_abs_positions() # Setup the start and stop positions save_start_position = self._start_point @@ -887,7 +1023,13 @@ def find_backward_accesses(self) -> list[Node]: # If there is no set start point, then we look for all # accesses after the Reference. if self._stop_point is None: - self._stop_point = self._reference_abs_pos + # Find the min abs position, as all of these are + # contained in the same parent. + # We start before any of the provided references, as + # for a statement such as b = a + a we don't want to return + # any of the References to a if the first a Reference is provided + # as an input to the DUC. + self._stop_point = min(list(self._references_abs_pos.values())) # If there is no set stop point, then any Reference after # the start point can potentially be a forward access. if self._start_point is None: @@ -907,8 +1049,10 @@ def find_backward_accesses(self) -> list[Node]: # statement. if len(block) == 0: continue + # Create a copy of the list as it can modify elements + # in the list. chain = DefinitionUseChain( - self._reference, + self._references[:], block, start_point=self._start_point, stop_point=self._stop_point, @@ -923,10 +1067,10 @@ def find_backward_accesses(self) -> list[Node]: # called but thats hard to otherwise track. if ( isinstance(self._scope[0], Routine) - or self._scope[0] is self._reference.root + or self._scope[0] is self._references[0].root ): # Check if there is an ancestor Loop/WhileLoop. - ancestor = self._reference.ancestor((Loop, WhileLoop)) + ancestor = self._references[0].ancestor((Loop, WhileLoop)) while ancestor is not None: # Create a basic block for the ancestor Loop. body = ancestor.loop_body.children[:] @@ -936,18 +1080,20 @@ def find_backward_accesses(self) -> list[Node]: # We make a copy of the reference to have a detached # node to avoid handling the special cases based on # the parents of the reference. - if self._reference.ancestor(Assignment) is not None: - sub_start_point = self._reference.ancestor( + if self._references[0].ancestor(Assignment) is not None: + sub_start_point = self._references[0].ancestor( Assignment ).abs_position else: - sub_start_point = self._reference.abs_position + sub_start_point = min(list( + self._references_abs_pos.values() + )) # If we have a basic block with no children then skip it, # e.g. for an if block with no code before the else # statement. if len(body) > 0: chain = DefinitionUseChain( - self._reference.copy(), + [ref.copy() for ref in self._references], body, start_point=sub_start_point, stop_point=sub_stop_point, @@ -960,7 +1106,7 @@ def find_backward_accesses(self) -> list[Node]: control_flow_nodes.append(None) sub_stop_point = ancestor.loop_body.abs_position chain = DefinitionUseChain( - self._reference.copy(), + [ref.copy() for ref in self._references], [ancestor.condition], start_point=ancestor.abs_position, stop_point=sub_stop_point, @@ -969,21 +1115,22 @@ def find_backward_accesses(self) -> list[Node]: ancestor = ancestor.ancestor((Loop, WhileLoop)) # Check if there is an ancestor Assignment. - ancestor = self._reference.ancestor(Assignment) + ancestor = self._references[0].ancestor(Assignment) if ancestor is not None: # If the reference is not the lhs then we can ignore # the RHS. - if ancestor.lhs is self._reference: + if any([ancestor.lhs is ref for ref in self._references]): end = ancestor.walk(Node)[-1] # Add the rhs as a potential basic block with # different start and stop positions. chain = DefinitionUseChain( - self._reference.copy(), + [ancestor.lhs.copy()], ancestor.rhs.children[:], start_point=ancestor.rhs.abs_position, stop_point=end.abs_position, ) - control_flow_nodes.append(None) + index = len(chains) + control_flow_nodes.insert(index, None) chains.append(chain) # N.B. For now this assumes that for an expression # b = a * a, that next_access to the first Reference @@ -1000,21 +1147,33 @@ def find_backward_accesses(self) -> list[Node]: if cfn is None: # We're outside a control flow region, updating the reaches # here is to find all the reached nodes. - for ref in chain._reaches: - # Add unique references to reaches. Since we're not - # in a control flow region, we can't have added - # these references into the reaches array yet so - # they're guaranteed to be unique. - found = False - for ref2 in self._reaches: - if ref is ref2: - found = True - break - if not found: - self._reaches.append(ref) + for sig in chain._reaches: + if sig in self._reference_signatures: + for ref in chain._reaches[sig]: + # Add unique references to reaches. Since + # we're not in a control flow region, we + # can't have added these references into the + # reaches array yet so they're guaranteed + # to be unique. + found = False + for ref2 in self._reaches[sig]: + if ref is ref2: + found = True + break + if not found: + self._reaches[sig].append(ref) # If we have a defsout in the chain then we can stop as we # will never get past the write as its not conditional. - if len(chain.defsout) > 0: + for i, sig in enumerate(self._reference_signatures): + # Not all references are passed into all subchains. + if sig in chain.defsout: + if len(chain.defsout[sig]) > 0: + self._references.pop(i) + self._reference_signatures.pop(i) + self._defsout[sig].extend(chain.defsout[sig]) + # If we have found an end point for all references then + # we can stop. + if len(self._references) == 0: # Reset the start and stop points before returning # the result. self._start_point = save_start_position @@ -1030,34 +1189,43 @@ def find_backward_accesses(self) -> list[Node]: # if the variable is the same symbol as the _reference. if isinstance(cfn, Loop): cfn_abs_pos = cfn.abs_position - if ( - cfn.variable == self._reference.symbol - and cfn_abs_pos >= self._start_point - and cfn_abs_pos < self._stop_point - ): - # The loop variable is always written to and so - # we're done if its reached. - self._reaches.append(cfn) - self._start_point = save_start_position - self._stop_point = save_stop_position - return self._reaches - - for ref in chain._reaches: - found = False - for ref2 in self._reaches: - if ref is ref2: - found = True - break - if not found: - self._reaches.append(ref) + for i, ref in enumerate(self._references[:]): + if ( + cfn.variable == ref.symbol + and cfn_abs_pos >= self._start_point + and cfn_abs_pos < self._stop_point + ): + # The loop variable is always written + sig = self._reference_signatures[i] + self._reaches[sig].append(cfn) + # This reference is killed by this access. + self._references.pop(i) + self._reference_signatures.pop(i) + # If we have found an end point for all + # references then we can finish. + if len(self._references) == 0: + self._start_point = save_start_position + self._stop_point = save_stop_position + return self._reaches + for sig in chain._reaches: + if sig in self._reference_signatures: + for ref in chain._reaches[sig]: + found = False + for ref2 in self._reaches[sig]: + if ref is ref2: + found = True + break + if not found: + self._reaches[sig].append(ref) else: # Check if there is an ancestor Assignment. - ancestor = self._reference.ancestor(Assignment) + ancestor = self._references[0].ancestor(Assignment) if ancestor is not None: # If we get here to check the start part of a loop we need # to handle this differently. - # If the reference is the lhs then we can ignore the RHS. - if ancestor.lhs is not self._reference: + # If no reference is the lhs then we can skip the rhs when + # searching backwards.. + if all([ancestor.lhs is not ref for ref in self._references]): pass elif ancestor.rhs is self._scope[0] and len(self._scope) == 1: # If the ancestor RHS is the scope of this chain then we @@ -1065,32 +1233,35 @@ def find_backward_accesses(self) -> list[Node]: pass else: # Add the rhs as a potential basic block with different - # start and stop positions. + # start and stop positions. This only needs to be computed + # for the Reference on the LHS of the ancestor assignment. chain = DefinitionUseChain( - self._reference, + [ancestor.lhs], [ancestor.rhs], start_point=ancestor.rhs.abs_position, stop_point=sys.maxsize, ) # Find any backward_accesses in the rhs. chain.find_backward_accesses() - for ref in chain._reaches: - self._reaches.append(ref) + for sig in chain._reaches: + for ref in chain._reaches[sig]: + self._reaches[sig].append(ref) # We can compute the rest of the accesses self._compute_backward_uses(self._scope) - for ref in self._uses: - self._reaches.append(ref) - # If this block doesn't kill any accesses, then we add - # the defsout into the reaches array. - if len(self.killed) == 0: - for ref in self._defsout: - self._reaches.append(ref) - else: - # If this block killed any accesses, then the first element - # of the killed writes is the access access that we're - # dependent with. - self._reaches.append(self.killed[0]) + for sig in self._reference_signatures: + for ref in self._uses[sig]: + self._reaches[sig].append(ref) + # If this block doesn't kill any accesses, then we add + # the defsout into the reaches array. + if len(self.killed[sig]) == 0: + for ref in self._defsout[sig]: + self._reaches[sig].append(ref) + else: + # If this block killed any accesses, then the first element + # of the killed writes is the access that we're + # dependent with. + self._reaches[sig].append(self.killed[sig][0]) # Reset the start and stop points before returning the result. self._start_point = save_start_position diff --git a/src/psyclone/psyir/transformations/hoist_trans.py b/src/psyclone/psyir/transformations/hoist_trans.py index bbcd7eb4b8..279f877d7a 100644 --- a/src/psyclone/psyir/transformations/hoist_trans.py +++ b/src/psyclone/psyir/transformations/hoist_trans.py @@ -239,9 +239,9 @@ def _validate_dependencies(self, statement, parent_loop): written_node = accesses_in_statement[0].node accesses_in_loop = all_loop_vars[written_sig] chains = DefinitionUseChain( - written_node, parent_loop.children[:] + [written_node], parent_loop.children[:] ) - if chains.find_backward_accesses(): + if chains.find_backward_accesses()[written_sig]: code = statement.debug_string().strip() raise TransformationError(f"The statement '{code}' can't be " f"hoisted as variable " diff --git a/src/psyclone/tests/psyir/nodes/assignment_test.py b/src/psyclone/tests/psyir/nodes/assignment_test.py index e12650813f..97215aef4c 100644 --- a/src/psyclone/tests/psyir/nodes/assignment_test.py +++ b/src/psyclone/tests/psyir/nodes/assignment_test.py @@ -379,3 +379,123 @@ def test_reference_accesses(fortran_reader): assert accesses[Signature('g')][0].access_type == AccessType.READ assert accesses[Signature('g')][1].node == assigns[1].lhs assert accesses[Signature('g')][1].access_type == AccessType.WRITE + + +def test_next_accesses(fortran_reader): + ''' Test the assignment.next_accesses function.''' + # Start with a basic test where the assignment just has a single + # Reference. + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + a = 1 + a = a + 1 + end program""" + ) + assigns = psyir.walk(Assignment) + + reaches = assigns[0].next_accesses() + sig = assigns[0].lhs.get_signature_and_indices()[0] + assert len(reaches) == 1 + assert len(reaches[sig]) == 2 + assert reaches[sig][0] is assigns[1].rhs.children[0] + assert reaches[sig][1] is assigns[1].lhs + + # Next test multiple References to different symbols in the + # Assignment + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + integer :: b + a = b + a = a + 1 + b = 2 + end program""" + ) + assigns = psyir.walk(Assignment) + reaches = assigns[0].next_accesses() + a_sig = assigns[0].lhs.get_signature_and_indices()[0] + b_sig = assigns[0].rhs.get_signature_and_indices()[0] + assert len(reaches) == 2 + assert len(reaches[a_sig]) == 2 + assert reaches[a_sig][0] is assigns[1].rhs.children[0] + assert reaches[a_sig][1] is assigns[1].lhs + assert len(reaches[b_sig]) == 1 + assert reaches[b_sig][0] is assigns[2].lhs + + # Test References inside an inquiry function are ignored + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + integer :: b + integer, dimension(100) :: c + a = b + size(c) + c = 1 + end program""" + ) + assigns = psyir.walk(Assignment) + reaches = assigns[0].next_accesses() + a_sig = assigns[0].lhs.get_signature_and_indices()[0] + b_sig = assigns[0].rhs.children[0].get_signature_and_indices()[0] + assert len(reaches) == 2 + assert len(reaches[a_sig]) == 0 + assert len(reaches[b_sig]) == 0 + + +def test_previous_accesses(fortran_reader): + ''' Test the assignment.previous_accesses function.''' + # Start with a basic test where the assignment just has a single + # Reference. + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + a = a + 1 + a = 1 + end program""" + ) + assigns = psyir.walk(Assignment) + + reaches = assigns[1].previous_accesses() + sig = assigns[0].lhs.get_signature_and_indices()[0] + assert len(reaches) == 1 + assert len(reaches[sig]) == 1 + assert reaches[sig][0] is assigns[0].lhs + + # Next test multiple References to different symbols in the + # Assignment + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + integer :: b + b = 2 + a = a + 1 + a = b + end program""" + ) + assigns = psyir.walk(Assignment) + reaches = assigns[2].previous_accesses() + a_sig = assigns[2].lhs.get_signature_and_indices()[0] + b_sig = assigns[2].rhs.get_signature_and_indices()[0] + assert len(reaches) == 2 + assert len(reaches[a_sig]) == 1 + assert reaches[a_sig][0] is assigns[1].lhs + assert len(reaches[b_sig]) == 1 + assert reaches[b_sig][0] is assigns[0].lhs + + # Test References inside an inquiry function are ignored + psyir = fortran_reader.psyir_from_source( + """program test + integer :: a + integer :: b + integer, dimension(100) :: c + c = 1 + a = b + size(c) + end program""" + ) + assigns = psyir.walk(Assignment) + reaches = assigns[1].previous_accesses() + a_sig = assigns[1].lhs.get_signature_and_indices()[0] + b_sig = assigns[1].rhs.children[0].get_signature_and_indices()[0] + assert len(reaches) == 2 + assert len(reaches[a_sig]) == 0 + assert len(reaches[b_sig]) == 0 diff --git a/src/psyclone/tests/psyir/nodes/reference_test.py b/src/psyclone/tests/psyir/nodes/reference_test.py index 9658bce946..fc77e433c7 100644 --- a/src/psyclone/tests/psyir/nodes/reference_test.py +++ b/src/psyclone/tests/psyir/nodes/reference_test.py @@ -565,6 +565,10 @@ def test_reference_is_read(fortran_reader): assert references[1].is_read assert references[3].symbol.name == "c" assert references[3].is_read + + # Routine or Intrinsic Symbols are not read. + assert references[5].symbol.name == "LBOUND" + assert not references[5].is_read # For the lbound, d should be an inquiry (so not a read) but # x should be a read assert references[6].symbol.name == "d" diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py index d7940c0df5..840de59c47 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_backward_dependence_test.py @@ -65,7 +65,7 @@ def test_definition_use_chain_compute_backward_uses(fortran_reader): a_3 = psyir.walk(Reference)[3] # Check this is the lhs of the assignment assert a_3 is psyir.walk(Assignment)[1].rhs - + sig = a_3.get_signature_and_indices()[0] duc = DefinitionUseChain( a_3, control_flow_region=[routine] ) @@ -75,8 +75,9 @@ def test_definition_use_chain_compute_backward_uses(fortran_reader): duc._start_point = routine.children[0].abs_position duc._stop_point = a_3.abs_position-1 duc._compute_backward_uses(basic_block_list) - assert len(duc.defsout) == 1 - assert duc.defsout[0] is psyir.walk(Reference)[0] # The lhs of a = a + 1 + assert len(duc.defsout[sig]) == 1 + # The lhs of a = a + 1 + assert duc.defsout[sig][0] is psyir.walk(Reference)[0] # Next we test a Reference with a write then a read - we should only get # the write, which should be in uses and defsout. @@ -91,7 +92,7 @@ def test_definition_use_chain_compute_backward_uses(fortran_reader): psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] a_3 = psyir.walk(Reference)[4] - + sig = a_3.get_signature_and_indices()[0] duc = DefinitionUseChain( a_3, control_flow_region=[routine] ) @@ -101,10 +102,10 @@ def test_definition_use_chain_compute_backward_uses(fortran_reader): duc._start_point = routine.children[0].abs_position duc._stop_point = a_3.abs_position - 1 duc._compute_backward_uses(basic_block_list) - assert len(duc.uses) == 0 - assert len(duc.defsout) == 1 - assert len(duc.killed) == 0 - assert duc.defsout[0] is psyir.walk(Reference)[2] # The lhs of a = 2 + assert len(duc.uses[sig]) == 0 + assert len(duc.defsout[sig]) == 1 + assert len(duc.killed[sig]) == 0 + assert duc.defsout[sig][0] is psyir.walk(Reference)[2] # The lhs of a = 2 def test_definition_use_chain_find_backward_accesses_basic_example( @@ -140,10 +141,12 @@ def test_definition_use_chain_find_backward_accesses_basic_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Creating use chain for the a in a = 2 + ref = routine.walk(Assignment)[8].lhs + sig = ref.get_signature_and_indices()[0] chains = DefinitionUseChain( - routine.walk(Assignment)[8].lhs, [routine] + ref, [routine] ) - reaches = chains.find_backward_accesses() + reaches = chains.find_backward_accesses()[sig] # We find 2 results # the a in e = a**3 # The call bar(c, b) as a isn't local and we can't guarantee its behaviour. @@ -152,9 +155,10 @@ def test_definition_use_chain_find_backward_accesses_basic_example( assert reaches[1] is routine.walk(Call)[1] # Create use chain for c in b = c + d - chains = DefinitionUseChain(routine.walk(Assignment)[5].rhs.children[0], - [routine]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[5].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref, [routine]) + reaches = chains.find_backward_accesses()[sig] # We should find 2 results # C = d * a # d = C + 2.0 @@ -177,8 +181,10 @@ def test_definition_use_chain_find_backward_accesses_assignment( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start chain from A = a * a - chains = DefinitionUseChain(routine.walk(Assignment)[1].lhs) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] # We should find 3 results, both 3 references in # a = A * A # and A = 1 @@ -209,8 +215,10 @@ def test_definition_use_chain_find_backward_accesses_ifelse_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from b = A + d. - chains = DefinitionUseChain(routine.walk(Assignment)[4].rhs.children[0]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[4].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] # TODO #2760 For now the if statement doesn't kill the accesses, # even though it will always be written to. assert len(reaches) == 4 @@ -222,8 +230,9 @@ def test_definition_use_chain_find_backward_accesses_ifelse_example( # Also check that a = 4 backward access is not a = 3. a_3 = routine.walk(Assignment)[2].lhs a_4 = routine.walk(Assignment)[3].lhs - chains = DefinitionUseChain(a_4) - reaches = chains.find_backward_accesses() + sig = a_4.get_signature_and_indices()[0] + chains = DefinitionUseChain([a_4]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 2 assert reaches[0] is not a_3 assert reaches[1] is not a_3 @@ -257,8 +266,10 @@ def test_definition_use_chain_find_backward_accesses_psy_data_node_example( # because it needs to deal with the current_block p_trans.apply(routine[1:]) # Start the chain from b = A + d. - chains = DefinitionUseChain(routine.walk(Assignment)[4].rhs.children[0]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[4].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] # Inside the ProfileRegion the DUC has to work as before, the 'a' has # 4 backwards accesses as shown in the previous test. assert len(reaches) == 4 @@ -284,10 +295,10 @@ def test_definition_use_chain_find_backward_accesses_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a + i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] # We should have 4? reaches # First b = A + 2 # Second a = A + i @@ -320,10 +331,10 @@ def test_definition_use_chain_find_backward_accesses_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from I = 1231. - chains = DefinitionUseChain( - routine.walk(Assignment)[2].lhs - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[2].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] # We should have 1 reaches # It should be the loop assert len(reaches) == 1 @@ -348,8 +359,10 @@ def test_definition_use_chain_find_backward_accesses_while_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a + 3. - chains = DefinitionUseChain(routine.children[2].loop_body.children[0].lhs) - reaches = chains.find_backward_accesses() + ref = routine.children[2].loop_body.children[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 4 assert reaches[0] is routine.walk(WhileLoop)[0].condition.children[0] @@ -381,8 +394,10 @@ def test_definition_use_chain_backward_accesses_nested_loop_example( routine = psyir.walk(Routine)[0] # Start the chain from b = b + A. loops = routine.walk(WhileLoop) - chains = DefinitionUseChain(loops[1].walk(Assignment)[0].rhs.children[1]) - reaches = chains.find_backward_accesses() + ref = loops[1].walk(Assignment)[0].rhs.children[1] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] # TODO #2760 The backwards accesses should not continue past a = a + 3 as # to reach the b = b + a statement we must have passed through the # a = a + 3 statement. @@ -410,8 +425,10 @@ def test_definition_use_chain_find_backward_accesses_structure_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[2].lhs) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[2].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].lhs @@ -429,8 +446,10 @@ def test_definition_use_chain_find_backward_accesses_no_control_flow_example( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].rhs.children[0] @@ -449,8 +468,10 @@ def test_definition_use_chain_find_backward_accesses_codeblock( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[1].lhs) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -470,8 +491,10 @@ def test_definition_use_chain_find_backward_accesses_codeblock_and_call_nlocal( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].rhs.children[0]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 # Result is the argument of the call assert reaches[0] is routine.walk(Call)[0].children[1] @@ -495,8 +518,10 @@ def test_definition_use_chain_find_backward_accesses_codeblock_and_call_cflow( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].rhs.children[0]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 2 assert reaches[0] is routine.walk(Call)[1].children[1] assert reaches[1] is routine.walk(Call)[0] @@ -518,8 +543,10 @@ def test_definition_use_chain_find_backward_accesses_codeblock_and_call_local( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].rhs.children[0]) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -539,8 +566,10 @@ def test_definition_use_chain_find_backward_accesses_call_and_codeblock_nlocal( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Call)[0] @@ -558,7 +587,7 @@ def test_definition_use_chains_goto_statement( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) + chains = DefinitionUseChain([routine.walk(Assignment)[0].lhs]) with pytest.raises(NotImplementedError) as excinfo: chains.find_backward_accesses() assert ("DefinitionUseChains can't handle code containing GOTO statements" @@ -586,10 +615,10 @@ def test_definition_use_chains_exit_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from a = A +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].rhs.children[0] - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] # We should have 2 reaches # First is A = a + i # Second is A = 1 @@ -623,10 +652,10 @@ def test_definition_use_chains_cycle_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from a = A +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].rhs.children[0] - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] # We should have 2 reaches # A = b * 4 # A = 1 @@ -657,10 +686,10 @@ def test_definition_use_chains_return_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from a = A +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].rhs.children[0] - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] # We should have 2 reaches # A = b * 4 # A = 1 @@ -689,10 +718,10 @@ def test_definition_use_chains_backward_accesses_multiple_routines( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[1] - chains = DefinitionUseChain( - routine.walk(Assignment)[0].rhs - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[0].rhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 0 @@ -716,10 +745,10 @@ def test_definition_use_chains_backward_accesses_nonassign_reference_in_loop( end subroutine x""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain( - routine.walk(Call)[0].children[1] - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Call)[0].children[1] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] # TODO #2760 The backwards accesses should not continue past a = a + i # when searching backwards in the loop, or to a = 1 assert len(reaches) == 3 @@ -749,10 +778,10 @@ def test_definition_use_chains_backward_accesses_empty_schedules( """ psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(Assignment)[1].rhs.children[1] assert reaches[1] is routine.walk(Assignment)[1].rhs.children[0] @@ -776,10 +805,10 @@ def test_definition_use_chains_backward_accesses_inquiry_func( """ psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_backward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 0 @@ -806,8 +835,9 @@ def test_definition_use_chain_find_backward_accesses_pure_call( # Find the a in the rhs of the second assignment assign2 = routine.walk(Assignment)[1] rhs_a = assign2.rhs.children[0] - chains = DefinitionUseChain(rhs_a) - reaches = chains.find_backward_accesses() + sig = rhs_a.get_signature_and_indices()[0] + chains = DefinitionUseChain([rhs_a]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 # Result is lhs of the first assignment lhs_assign1 = routine.walk(Assignment)[0].lhs @@ -815,8 +845,9 @@ def test_definition_use_chain_find_backward_accesses_pure_call( # Get the lhs of the b = 1 assignment lhs_assign3 = routine.walk(Assignment)[2].lhs - chains = DefinitionUseChain(lhs_assign3) - reaches = chains.find_backward_accesses() + sig = lhs_assign3.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs_assign3]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 # We should find the argument in the pure subroutine call assert reaches[0] is routine.walk(Call)[0].children[1] @@ -843,8 +874,9 @@ def test_definition_use_chain_find_backward_accesses_ancestor_call( routine = psyir.find_routine_psyir("foo") call = psyir.walk(Call)[0] arg = call.arguments[1] - chain = DefinitionUseChain(arg) - all_prev = chain.find_backward_accesses() + sig = arg.get_signature_and_indices()[0] + chain = DefinitionUseChain([arg]) + all_prev = chain.find_backward_accesses()[sig] # Check that the ancestor call of b isn't a backward access. assert not isinstance(all_prev[0], Call) # The correct previous access should be the Reference to b in @@ -869,6 +901,7 @@ def test_backward_accesses_nested_loop(fortran_reader): psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] lhs = routine.walk(Assignment)[0].lhs - chains = DefinitionUseChain(lhs) - reaches = chains.find_backward_accesses() + sig = lhs.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs]) + reaches = chains.find_backward_accesses()[sig] assert len(reaches) == 1 diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py index d4d1d6f3c8..c1108be2c4 100644 --- a/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_forward_dependence_test.py @@ -37,6 +37,7 @@ forward_accesses routine.''' import pytest +from psyclone.errors import InternalError from psyclone.psyir.nodes import ( ArrayReference, Assignment, @@ -74,37 +75,49 @@ def test_definition_use_chain_init_and_properties(fortran_reader): a_1 = references[0] # Check the basic initialisation with default setup. - duc = DefinitionUseChain(a_1) - assert duc._reference is a_1 + duc = DefinitionUseChain([a_1]) + assert len(duc._references) == 1 + assert duc._references[0] is a_1 + sig = a_1.get_signature_and_indices()[0] + assert duc._reference_signatures[0] == sig assert duc._start_point is None assert duc._stop_point is None - assert duc._reference_abs_pos == a_1.abs_position + assert duc._references_abs_pos[sig] == a_1.abs_position assert len(duc._scope) == 1 assert duc._scope[0] is routine # Test the control_flow_region setter - duc = DefinitionUseChain(a_1, routine.children[0:2]) + duc = DefinitionUseChain([a_1], routine.children[0:2]) assert len(duc._scope) == 2 assert duc._scope[0] is routine.children[0] assert duc._scope[1] is routine.children[1] # Test the start and stop point setters - duc = DefinitionUseChain(a_1, start_point=0, stop_point=1) + duc = DefinitionUseChain([a_1], start_point=0, stop_point=1) assert duc._start_point == 0 assert duc._stop_point == 1 - assert len(duc.uses) == 0 - assert len(duc.defsout) == 0 - assert len(duc.killed) == 0 + assert len(duc.uses) == 1 + assert len(duc.uses[sig]) == 0 + assert len(duc.defsout) == 1 + assert len(duc.defsout[sig]) == 0 + assert len(duc.killed) == 1 + assert len(duc.killed[sig]) == 0 - # Test exception when passed a non_list + # Test exceptions when passed a non_list for various inputs. with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain(a_1, control_flow_region=2) + duc = DefinitionUseChain("a") + assert ("The 'references' argument passed into a DefinitionUseChain " + "must be a list of References or a single Reference but found " + "'str'" in str(excinfo.value)) + + with pytest.raises(TypeError) as excinfo: + duc = DefinitionUseChain([a_1], control_flow_region=2) assert ("The control_flow_region passed into a DefinitionUseChain " "must be a list but found 'int'." in str(excinfo.value)) with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain(a_1, control_flow_region=[2]) + duc = DefinitionUseChain([a_1], control_flow_region=[2]) assert ("Each element of the control_flow_region passed into a " "DefinitionUseChain must be a Node but found a non-Node " "element. Full input is " in str(excinfo.value)) @@ -114,23 +127,38 @@ def test_definition_use_chain_init_and_properties(fortran_reader): r1 = Reference(sym) r2 = Reference(sym) assign = Assignment.create(r1, r2) - duc = DefinitionUseChain(r1) + duc = DefinitionUseChain([r1]) assert duc._scope[0] is assign.lhs assert duc._scope[1] is assign.rhs # Test remaining TypeErrors with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain("123") - assert ("The 'reference' argument passed into a DefinitionUseChain must " - "be a Reference but found 'str'." in str(excinfo.value)) - with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain(r1, start_point="123") + duc = DefinitionUseChain([r1], start_point="123") assert ("The start_point passed into a DefinitionUseChain must be an " "int but found 'str'." in str(excinfo.value)) with pytest.raises(TypeError) as excinfo: - duc = DefinitionUseChain(r1, stop_point="123") + duc = DefinitionUseChain([r1], stop_point="123") assert ("The stop_point passed into a DefinitionUseChain must be an " "int but found 'str'." in str(excinfo.value)) + with pytest.raises(TypeError) as excinfo: + duc = DefinitionUseChain(["a", "b"]) + assert ("The 'references' argument passed into a DefinitionUseChain must " + "be a Reference or list of References but found 'str'." + in str(excinfo.value)) + + # Create a containing schedule. + code = """subroutine test + integer :: r1 + r1 = 1 + end subroutine + """ + psyir = fortran_reader.psyir_from_source(code) + r1 = psyir.walk(Assignment)[0].lhs + with pytest.raises(InternalError) as excinfo: + duc = DefinitionUseChain([r1, Reference(sym)]) + assert ("All references provided into a DefinitionUseChain " + "must have the same parent in the ancestor Schedule." + in str(excinfo.value)) def test_definition_use_chain_is_basic_block(fortran_reader): @@ -166,17 +194,17 @@ def test_definition_use_chain_is_basic_block(fortran_reader): reference = psyir.walk(Reference)[0] # The full routine is not a basic block as it contains control flow. - duc1 = DefinitionUseChain(reference, control_flow_region=block1) + duc1 = DefinitionUseChain([reference], control_flow_region=block1) assert not duc1.is_basic_block # The if_body of the if statement is a basic block as it contains no # control flow. - duc2 = DefinitionUseChain(reference, control_flow_region=block2) + duc2 = DefinitionUseChain([reference], control_flow_region=block2) assert duc2.is_basic_block # The else_body of the if statement is not a basic block as it contains # control flow. - duc3 = DefinitionUseChain(reference, control_flow_region=block3) + duc3 = DefinitionUseChain([reference], control_flow_region=block3) assert not duc3.is_basic_block # Test that regiondirectives (e.g. OMPParallelDirective) don't count. @@ -194,7 +222,7 @@ def test_definition_use_chain_is_basic_block(fortran_reader): par_trans.apply(psyir.walk(Routine)[0].children[:]) reference = psyir.walk(Reference)[0] parallel = psyir.walk(OMPParallelDirective)[0] - duc = DefinitionUseChain(reference, control_flow_region=[parallel]) + duc = DefinitionUseChain([reference], control_flow_region=[parallel]) assert not duc.is_basic_block @@ -215,16 +243,17 @@ def test_definition_use_chain_compute_forward_uses(fortran_reader): assert a_1 is psyir.walk(Assignment)[0].lhs duc = DefinitionUseChain( - a_1, control_flow_region=[routine] + [a_1], control_flow_region=[routine] ) + sig = a_1.get_signature_and_indices()[0] basic_block_list = routine.children[:] # Need to set the start point and stop points similar to what # forward_accesses would do duc._start_point = a_1.ancestor(Assignment).walk(Node)[-1].abs_position duc._stop_point = 100000000 duc._compute_forward_uses(basic_block_list) - assert len(duc.uses) == 1 - assert duc.uses[0] is psyir.walk(Reference)[3] # The rhs of b=a + assert len(duc.uses[sig]) == 1 + assert duc.uses[sig][0] is psyir.walk(Reference)[3] # The rhs of b=a # Next we test a Reference with a write then a read - we should only get # the write, which should be in uses and defsout. @@ -241,18 +270,19 @@ def test_definition_use_chain_compute_forward_uses(fortran_reader): a_1 = psyir.walk(Reference)[1] duc = DefinitionUseChain( - a_1, control_flow_region=[routine] + [a_1], control_flow_region=[routine] ) + sig = a_1.get_signature_and_indices()[0] basic_block_list = routine.children[:] # Need to set the start point and stop points similar to what # forward_accesses would do duc._start_point = a_1.ancestor(Assignment).walk(Node)[-1].abs_position duc._stop_point = 100000000 duc._compute_forward_uses(basic_block_list) - assert len(duc.uses) == 0 - assert len(duc.defsout) == 1 - assert len(duc.killed) == 0 - assert duc.defsout[0] is psyir.walk(Reference)[2] # The lhs of a = 2 + assert len(duc.uses[sig]) == 0 + assert len(duc.defsout[sig]) == 1 + assert len(duc.killed[sig]) == 0 + assert duc.defsout[sig][0] is psyir.walk(Reference)[2] # The lhs of a = 2 # Finally test a Reference with a write then another write. # The defsout should be the final write and the first write should be @@ -270,19 +300,20 @@ def test_definition_use_chain_compute_forward_uses(fortran_reader): a_1 = psyir.walk(Reference)[1] duc = DefinitionUseChain( - a_1, control_flow_region=[routine] + [a_1], control_flow_region=[routine] ) + sig = a_1.get_signature_and_indices()[0] basic_block_list = routine.children[:] # Need to set the start point and stop points similar to what # forward_accesses would do duc._start_point = a_1.ancestor(Assignment).walk(Node)[-1].abs_position duc._stop_point = 100000000 duc._compute_forward_uses(basic_block_list) - assert len(duc.uses) == 0 - assert len(duc.defsout) == 1 - assert len(duc.killed) == 1 - assert duc.defsout[0] is psyir.walk(Reference)[5] # The lhs of a = 3 - assert duc.killed[0] is psyir.walk(Reference)[2] # The lhs of a = 2 + assert len(duc.uses[sig]) == 0 + assert len(duc.defsout[sig]) == 1 + assert len(duc.killed[sig]) == 1 + assert duc.defsout[sig][0] is psyir.walk(Reference)[5] # The lhs of a = 3 + assert duc.killed[sig][0] is psyir.walk(Reference)[2] # The lhs of a = 2 def test_definition_use_chain_find_basic_blocks(fortran_reader): @@ -309,7 +340,7 @@ def test_definition_use_chain_find_basic_blocks(fortran_reader): psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] duc = DefinitionUseChain( - routine.walk(Reference)[0], control_flow_region=[routine] + [routine.walk(Reference)[0]], control_flow_region=[routine] ) # Find the basic blocks. cfn, blocks = duc._find_basic_blocks(routine.children[:]) @@ -486,38 +517,44 @@ def test_definition_use_chain_find_forward_accesses_basic_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Creating use chain for the a in c = a + 1.0 + ref = routine.children[0].children[1].children[0] + sig = ref.get_signature_and_indices()[0] chains = DefinitionUseChain( - routine.children[0].children[1].children[0], [routine] + ref, [routine] ) reaches = chains.find_forward_accesses() # We find 3 results # the a in e = a**2 (assignment 2) # the a in c = d * a (assignment 4) # The call bar(c, b) as a isn't local and we can't guarantee its behaviour. - assert len(reaches) == 3 - assert reaches[0] is routine.walk(Assignment)[1].rhs.children[0] - assert reaches[1] is routine.walk(Assignment)[4].rhs.children[1] - assert reaches[2] is routine.walk(Call)[1] + assert len(reaches[sig]) == 3 + assert reaches[sig][0] is routine.walk(Assignment)[1].rhs.children[0] + assert reaches[sig][1] is routine.walk(Assignment)[4].rhs.children[1] + assert reaches[sig][2] is routine.walk(Call)[1] # Create use chain for d in d = c + 2.0 - chains = DefinitionUseChain(routine.children[3].lhs, [routine]) + ref = routine.children[3].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref, [routine]) reaches = chains.find_forward_accesses() # We should find 2 results # c = D * a (Assignment 5) # b = c + D (Assignment 6) - assert reaches[0] is routine.walk(Assignment)[4].rhs.children[0] - assert reaches[1] is routine.walk(Assignment)[5].rhs.children[1] - assert len(reaches) == 2 + assert reaches[sig][0] is routine.walk(Assignment)[4].rhs.children[0] + assert reaches[sig][1] is routine.walk(Assignment)[5].rhs.children[1] + assert len(reaches[sig]) == 2 # Create use chain for c in c = d * a (Assignment 5) - chains = DefinitionUseChain(routine.walk(Assignment)[4].lhs, [routine]) + ref = routine.walk(Assignment)[4].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref, [routine]) reaches = chains.find_forward_accesses() # 2 results: # b = C + d (Assignment 6) # call bar(c, d) (The second Call) - assert len(reaches) == 2 - assert reaches[0] is routine.walk(Assignment)[5].rhs.children[0] - assert reaches[1] is routine.walk(Call)[1].arguments[0] + assert len(reaches[sig]) == 2 + assert reaches[sig][0] is routine.walk(Assignment)[5].rhs.children[0] + assert reaches[sig][1] is routine.walk(Call)[1].arguments[0] def test_definition_use_chain_find_forward_accesses_assignment( @@ -534,15 +571,17 @@ def test_definition_use_chain_find_forward_accesses_assignment( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start chain from A = 1 - chains = DefinitionUseChain(routine.children[0].lhs) + ref = routine.children[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses() # We should find 3 results, all 3 references in # A = A * A - assert len(reaches) == 3 assignment = routine.walk(Assignment)[1] - assert reaches[0] is assignment.rhs.children[0] - assert reaches[1] is assignment.rhs.children[1] - assert reaches[2] is assignment.lhs + assert len(reaches[sig]) == 3 + assert reaches[sig][0] is assignment.rhs.children[0] + assert reaches[sig][1] is assignment.rhs.children[1] + assert reaches[sig][2] is assignment.lhs def test_definition_use_chain_find_forward_accesses_ifelse_example( @@ -566,23 +605,26 @@ def test_definition_use_chain_find_forward_accesses_ifelse_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from a = 1. - chains = DefinitionUseChain(routine.children[0].lhs) + ref = routine.children[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) reaches = chains.find_forward_accesses() # TODO #2760 For now the if statement doesn't kill the accesses, # even though it will always be written to. - assert len(reaches) == 4 - assert reaches[0] is routine.walk(Assignment)[1].rhs.children[0] - assert reaches[1] is routine.walk(Assignment)[2].lhs - assert reaches[2] is routine.walk(Assignment)[3].lhs - assert reaches[3] is routine.walk(Assignment)[4].rhs.children[0] + assert len(reaches[sig]) == 4 + assert reaches[sig][0] is routine.walk(Assignment)[1].rhs.children[0] + assert reaches[sig][1] is routine.walk(Assignment)[2].lhs + assert reaches[sig][2] is routine.walk(Assignment)[3].lhs + assert reaches[sig][3] is routine.walk(Assignment)[4].rhs.children[0] # Also check that a = 3 forward access is not a = 4. a_3 = routine.children[2].if_body.children[0].lhs + sig = a_3.get_signature_and_indices()[0] a_4 = routine.children[2].else_body.children[0].rhs - chains = DefinitionUseChain(a_3) + chains = DefinitionUseChain([a_3]) reaches = chains.find_forward_accesses() - assert len(reaches) == 1 - assert reaches[0] is not a_4 + assert len(reaches[sig]) == 1 + assert reaches[sig][0] is not a_4 def test_definition_use_chain_find_forward_accesses_loop_example( @@ -605,10 +647,10 @@ def test_definition_use_chain_find_forward_accesses_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from b = A +2. - chains = DefinitionUseChain( - routine.children[1].loop_body.children[1].rhs.children[0] - ) - reaches = chains.find_forward_accesses() + ref = routine.children[1].loop_body.children[1].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] # We should have 3 reaches # First two are A = A + i # Second is c = a + b @@ -635,10 +677,10 @@ def test_definition_use_chain_find_forward_accesses_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from I = 1231. - chains = DefinitionUseChain( - routine.children[0].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.children[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] # We should have 1 reaches # It should be the loop assert len(reaches) == 1 @@ -663,8 +705,10 @@ def test_definition_use_chain_find_forward_accesses_while_loop_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a + 3. - chains = DefinitionUseChain(routine.walk(Assignment)[2].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[2].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(WhileLoop)[0].condition.children[0] @@ -695,8 +739,10 @@ def test_definition_use_chain_forward_accesses_nested_loop_example( routine = psyir.walk(Routine)[0] # Start the chain from b = b + A. loops = routine.walk(WhileLoop) - chains = DefinitionUseChain(loops[1].loop_body.children[0].rhs.children[1]) - reaches = chains.find_forward_accesses() + ref = loops[1].loop_body.children[0].rhs.children[1] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] # Results should be A = A + 3 and the a < i condition assert reaches[0] is loops[0].condition.children[0] assert reaches[1] is loops[0].walk(Assignment)[0].rhs.children[0] @@ -720,8 +766,10 @@ def test_definition_use_chain_find_forward_accesses_structure_example( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[2].lhs @@ -739,8 +787,10 @@ def test_definition_use_chain_find_forward_accesses_no_control_flow_example( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].rhs.children[0]) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].lhs @@ -760,8 +810,10 @@ def test_definition_use_chain_find_forward_accesses_no_control_flow_example2( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].rhs.children[0]) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].rhs.children[0] + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Assignment)[0].lhs @@ -780,8 +832,10 @@ def test_definition_use_chain_find_forward_accesses_codeblock( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -801,8 +855,10 @@ def test_definition_use_chain_find_forward_accesses_codeblock_and_call_nlocal( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -825,8 +881,10 @@ def test_definition_use_chain_find_forward_accesses_codeblock_and_call_cflow( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 2 assert reaches[0] is routine.walk(CodeBlock)[0] assert reaches[1] is routine.walk(Call)[1] @@ -848,8 +906,10 @@ def test_definition_use_chain_find_forward_accesses_codeblock_and_call_local( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(CodeBlock)[0] @@ -869,8 +929,10 @@ def test_definition_use_chain_find_forward_accesses_call_and_codeblock_nlocal( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 assert reaches[0] is routine.walk(Call)[0] @@ -888,7 +950,7 @@ def test_definition_use_chains_goto_statement( end subroutine""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) + chains = DefinitionUseChain([routine.walk(Assignment)[0].lhs]) with pytest.raises(NotImplementedError) as excinfo: chains.find_forward_accesses() assert ("DefinitionUseChains can't handle code containing GOTO statements" @@ -916,10 +978,10 @@ def test_definition_use_chains_exit_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] # We should have 3 reaches # First two are A = A + i # Second is c = a + b @@ -956,10 +1018,10 @@ def test_definition_use_chains_cycle_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] # We should have 4 reaches # First two are A = A + i # Then A = b * 4 @@ -995,10 +1057,10 @@ def test_definition_use_chains_return_statement( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] # Start the chain from A = a +i. - chains = DefinitionUseChain( - routine.walk(Assignment)[1].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] # We should have 4 reaches # First two are A = A + i # Then A = b + 4 @@ -1032,10 +1094,10 @@ def test_definition_use_chains_forward_accesses_multiple_routines( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain( - routine.walk(Assignment)[0].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 0 @@ -1060,10 +1122,10 @@ def test_definition_use_chains_forward_accesses_empty_schedules( """ psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain( - routine.walk(Assignment)[0].lhs - ) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(Assignment)[1].rhs.children[0] assert reaches[1] is routine.walk(Assignment)[1].rhs.children[1] @@ -1087,8 +1149,10 @@ def test_definition_use_chains_backward_accesses_inquiry_func( """ psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[0].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[0].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 0 @@ -1113,8 +1177,10 @@ def test_definition_use_chains_multiple_ancestor_loops( end subroutine test""" psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] - chains = DefinitionUseChain(routine.walk(Assignment)[1].lhs) - reaches = chains.find_forward_accesses() + ref = routine.walk(Assignment)[1].lhs + sig = ref.get_signature_and_indices()[0] + chains = DefinitionUseChain(ref) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 3 assert reaches[0] is routine.walk(Assignment)[0].lhs assert reaches[1] is routine.walk(Assignment)[1].lhs @@ -1142,8 +1208,9 @@ def test_definition_use_chain_find_forward_accesses_pure_call( routine = psyir.walk(Routine)[1] # Start from the lhs of the first assignment lhs_assign1 = routine.walk(Assignment)[0].lhs - chains = DefinitionUseChain(lhs_assign1) - reaches = chains.find_forward_accesses() + sig = lhs_assign1.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs_assign1]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 2 # Should find the a in the rhs of a = a + 2 and the lhs. rhs_assign3 = routine.walk(Assignment)[2].rhs.children[0] @@ -1153,8 +1220,9 @@ def test_definition_use_chain_find_forward_accesses_pure_call( # Start from lhs of b = 1 lhs_assign2 = routine.walk(Assignment)[1].lhs - chains = DefinitionUseChain(lhs_assign2) - reaches = chains.find_forward_accesses() + sig = lhs_assign2.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs_assign2]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 # result is first argument of the pure subroutine call argument = routine.walk(Call)[0].children[1] @@ -1178,6 +1246,91 @@ def test_forward_accesses_nested_loop(fortran_reader): psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Routine)[0] lhs = routine.walk(Assignment)[0].lhs - chains = DefinitionUseChain(lhs) - reaches = chains.find_forward_accesses() + sig = lhs.get_signature_and_indices()[0] + chains = DefinitionUseChain([lhs]) + reaches = chains.find_forward_accesses()[sig] assert len(reaches) == 1 + + +def test_forward_accesses_multiple_elements(fortran_reader): + """Test that if we have multiple inputs we get multiple outputs as + expected.""" + code = """subroutine x + integer :: i,j,k,l,m + + i = j + k + j = l + m = k + k = i + end subroutine x""" + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Routine)[0] + assigns = routine.walk(Assignment) + rhs = assigns[0].rhs + chains = DefinitionUseChain([rhs.children[0], + rhs.children[1]]) + reaches = chains.find_forward_accesses() + jsig, _ = rhs.children[0].get_signature_and_indices() + ksig, _ = rhs.children[1].get_signature_and_indices() + + assert len(reaches[jsig]) == 1 + assert reaches[jsig][0] is assigns[1].lhs + assert len(reaches[ksig]) == 2 + assert reaches[ksig][0] is assigns[2].rhs + assert reaches[ksig][1] is assigns[3].lhs + + +def test_if_else_behaviour(fortran_reader): + ''' Test to ensure that with nested if/else statements the DUC + don't find incorrect dependencies from else blocks.''' + + code = """subroutine test + integer , dimension(100, 100, 100) :: some_array + logical :: tmp, tmp2, tmp3 + integer :: i, j, k + + if(tmp) then + if(tmp2) then + do i = 1, 100 + do j = 1, 100 + do k = 1, 100 + if(tmp3 .or. some_array(1,4,9)) then + some_array(1,4,9) = i * k * j + end if + end do + end do + end do + else + do i = 1, 100 + do j = 1, 100 + do k = 1, 100 + some_array(i,k,j) = i + j + k + end do + end do + end do + end if + else if (.not. tmp) then + do i = 1, 100 + do j = 1, 100 + do k = 1, 100 + some_array(i,j,k) = i*k*j + end do + end do + end do + end if + end subroutine""" + + psyir = fortran_reader.psyir_from_source(code) + references = psyir.walk(Reference) + # Find the next accesses of some_array(1,4,9) from + # the lhs of the first assignment. + next_accesses = references[4].next_accesses() + + # The next accesses should be the some_array(1,4,9) access in + # if(tmp3 .or. some_array(1, 4, 9) and to the same reference. + # Both are due to the containing do k = 1, 100 loop. We musn't + # find any accesses to the some_array writes in the else_bodies + # of the containing IfBlock nodes + assert len(next_accesses) == 2 + assert next_accesses[0] is references[3] + assert next_accesses[1] is references[4] diff --git a/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_test.py b/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_test.py new file mode 100644 index 0000000000..a21b252403 --- /dev/null +++ b/src/psyclone/tests/psyir/tools/definition_use_chains_multiref_test.py @@ -0,0 +1,183 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2026, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Author: A. B. G. Chalk, STFC Daresbury Lab +# ----------------------------------------------------------------------------- +'''This module contains tests to check that the DefinitionUseChains return +the same result for multiple inputs vs those inputs individually.''' + +import pytest +from psyclone.psyir.nodes import ( + Assignment, + Reference +) +from psyclone.psyir.tools.definition_use_chains import DefinitionUseChain + + +@pytest.mark.parametrize("code", [ + """integer :: a, b + a = b + b = 2 + a + a = 2 + """, + """integer :: a, b, i + do i = 1, 100 + a = b + b = 3 + end do + """, + """integer :: a, b + logical :: x + a = b + if (x) then + a = b + a + else + b = a * b + end if + a = 1 + b = 1 + """, + """integer :: a, b + a = b + b = 2 + if (b > 1) then + a = 1 + end if + a = 2 * b""", + """integer :: a, b + a = b + a + b = 2 + if (b > 1) then + a = 1 + end if + a = 2 * b""", + ]) +def test_duc_forward_equivalence(code, fortran_reader): + '''Test the DUCs give the same results for multiple inputs as each of + the inputs individually. + ''' + code = f"subroutine test\n{code}\nend subroutine test" + psyir = fortran_reader.psyir_from_source(code) + + assign = psyir.walk(Assignment)[0] + + all_refs = assign.walk(Reference) + lhs_ref = all_refs[0] # First Reference is lhs a + rhs_ref = all_refs[1] # Second Reference is always rhs b + assert lhs_ref.symbol.name == "a" + assert rhs_ref.symbol.name == "b" + duc1 = DefinitionUseChain(all_refs) + duc2 = DefinitionUseChain(lhs_ref) + duc3 = DefinitionUseChain(rhs_ref) + res1 = duc1.find_forward_accesses() + res2 = duc2.find_forward_accesses() + res3 = duc3.find_forward_accesses() + + lhs_sig = lhs_ref.get_signature_and_indices()[0] + rhs_sig = rhs_ref.get_signature_and_indices()[0] + + assert len(res2[lhs_sig]) == len(res1[lhs_sig]) + for i, result in enumerate(res2[lhs_sig]): + assert result is res1[lhs_sig][i] + assert len(res3[rhs_sig]) == len(res1[rhs_sig]) + for i, result in enumerate(res3[rhs_sig]): + assert result is res1[rhs_sig][i] + + +@pytest.mark.parametrize("code", [ + """integer :: a, b + a = 2 + b = 2 + a + a = b + """, + """integer :: a, b, i + do i = 1, 100 + b = 3 + a = b + end do + """, + """integer :: a, b + logical :: x + a = 1 + b = 1 + if (x) then + a = b + a + else + b = a * b + end if + a = b + """, + """integer :: a, b + a = 2 * b + if (b > 1) then + a = 1 + end if + b = 2 + a = b""", + """integer :: a, b + a = 2 * b + if (b > 1) then + a = 1 + end if + b = 2 + a = b + a""", + ]) +def test_duc_backward_equivalence(code, fortran_reader): + '''Test the DUCs give the same results for multiple inputs as each + of the inputs individually. + ''' + code = f"subroutine test\n{code}\nend subroutine test" + psyir = fortran_reader.psyir_from_source(code) + + assign = psyir.walk(Assignment)[-1] + + all_refs = assign.walk(Reference) + lhs_ref = all_refs[0] # First Reference is lhs a + rhs_ref = all_refs[1] # Second Reference is always rhs b + assert lhs_ref.symbol.name == "a" + assert rhs_ref.symbol.name == "b" + duc1 = DefinitionUseChain(all_refs) + duc2 = DefinitionUseChain(lhs_ref) + duc3 = DefinitionUseChain(rhs_ref) + res1 = duc1.find_backward_accesses() + res2 = duc2.find_backward_accesses() + res3 = duc3.find_backward_accesses() + + lhs_sig = lhs_ref.get_signature_and_indices()[0] + rhs_sig = rhs_ref.get_signature_and_indices()[0] + assert len(res2[lhs_sig]) == len(res1[lhs_sig]) + for i, result in enumerate(res2[lhs_sig]): + assert result is res1[lhs_sig][i] + assert len(res3[rhs_sig]) == len(res1[rhs_sig]) + for i, result in enumerate(res3[rhs_sig]): + assert result is res1[rhs_sig][i] diff --git a/tutorial/training/transformation/3.6-sympy/solution/dataflow.py b/tutorial/training/transformation/3.6-sympy/solution/dataflow.py index 35511e50b0..4c7ea7ffd1 100755 --- a/tutorial/training/transformation/3.6-sympy/solution/dataflow.py +++ b/tutorial/training/transformation/3.6-sympy/solution/dataflow.py @@ -110,8 +110,9 @@ stop_position = node.ancestor(Statement).abs_position else: stop_position = node.abs_position - chain = DefinitionUseChain(node, stop_point=stop_position) - all_prev = chain.find_backward_accesses() + chain = DefinitionUseChain([node], stop_point=stop_position) + sig = node.get_signature_and_indices()[0] + all_prev = chain.find_backward_accesses()[sig] # Keep track if a write was found (if not, we will add the # variable as a node by itself)