From 29dc3fdc77c1a9a746559f1179a03e110bb7e3e5 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sat, 23 Nov 2024 15:19:36 +0100 Subject: [PATCH 01/20] Added changes in 'call' --- src/psyclone/psyir/nodes/call.py | 282 +++++++++++++++--- src/psyclone/psyir/symbols/containersymbol.py | 4 +- .../psyir/transformations/inline_trans.py | 36 +++ 3 files changed, 273 insertions(+), 49 deletions(-) diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index b87fab7285..234c3f69d0 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -52,9 +52,12 @@ SymbolError, UnsupportedFortranType, DataSymbol, + SymbolTable, + ContainerSymbol, ) -from typing import List +from typing import List, Union from psyclone.errors import PSycloneError +from psyclone.psyir.symbols.datatypes import ArrayType class CallMatchingArgumentsNotFound(PSycloneError): @@ -273,7 +276,8 @@ def replace_named_arg(self, existing_name, arg): raise ValueError( f"The value of the existing_name argument ({existing_name}) " f"in 'replace_named_arg' in the 'Call' node was not found " - f"in the existing arguments.") + f"in the existing arguments." + ) # The n'th argument is placed at the n'th+1 children position # because the 1st child is the routine reference self.children[index + 1] = arg @@ -458,18 +462,91 @@ def copy(self): return new_copy - def get_callees(self): - ''' + def _get_container_symbols_rec( + self, + container_symbols_list: List[str], + ignore_missing_modules: bool = False, + _stack_container_name_list: List[str] = [], + _depth: int = 0, + ): + """Return a list of all container symbols that can be found + recursively + + :param container_symbols: List of starting set of container symbols + :type container_symbols: List[ContainerSymbol] + :param _stack_container_list: Stack with already visited Containers + to avoid circular searches, defaults to [] + :type _stack_container_list: List[Container], optional + """ + # + # TODO: This function seems to be extremely slow: + # It takes considerable time to build this list over and over + # for each lookup. + # + # An alternative would be to cache it, but then the cache + # needs to be invalidated once some symbols are, e.g., deleted. + # + ret_container_symbol_list = container_symbols_list[:] + + # Cache the container names from symbols + container_names = [cs.name.lower() for cs in container_symbols_list] + + from psyclone.parse import ModuleManager + + module_manager = ModuleManager.get() + + for container_name in container_names: + try: + module_info = module_manager.get_module_info( + container_name.lower() + ) + if module_info is None: + continue + + except (ModuleNotFoundError, FileNotFoundError) as err: + if ignore_missing_modules: + continue + + raise err + + container: Container = module_info.get_psyir_container_node() + + # Avoid circular connections (which shouldn't + # be allowed, but who knows...) + if container.name.lower() in _stack_container_name_list: + continue + + new_container_symbols = self._get_container_symbols_rec( + container_symbols_list=container.symbol_table.containersymbols, + ignore_missing_modules=ignore_missing_modules, + _stack_container_name_list=_stack_container_name_list + + [container.name.lower()], + _depth=_depth + 1, + ) + + # Add symbol if it's not yet in the list of symbols + for container_symbol in new_container_symbols: + if container_symbol not in ret_container_symbol_list: + ret_container_symbol_list.append(container_symbol) + + return ret_container_symbol_list + + def get_callees(self, ignore_missing_modules: bool = False): + """ Searches for the implementation(s) of all potential target routines for this Call without any arguments check. + :param ignore_missing_modules: If a module wasn't found, return 'None' + instead of throwing an exception 'ModuleNotFound'. + :type ignore_missing_modules: bool + :returns: the Routine(s) that this call targets. :rtype: list[:py:class:`psyclone.psyir.nodes.Routine`] :raises NotImplementedError: if the routine is not local and not found in any containers in scope at the call site. - ''' + """ def _location_txt(node): ''' Utility to generate meaningful location text. @@ -506,14 +583,32 @@ def _location_txt(node): # be used to resolve the symbol. wildcard_names = [] containers_not_found = [] - current_table = self.scope.symbol_table + current_table: SymbolTable = self.scope.symbol_table while current_table: + # TODO: Obtaining all container symbols in this way + # breaks some tests. + # It would be better using the ModuleManager to resolve + # (and cache) all containers to look up for this. + # + # current_containersymbols = self._get_container_symbols_rec( + # current_table.containersymbols, + # ignore_missing_modules=ignore_missing_modules, + # ) + # for container_symbol in current_containersymbols: for container_symbol in current_table.containersymbols: + container_symbol: ContainerSymbol if container_symbol.wildcard_import: wildcard_names.append(container_symbol.name) + try: - container = container_symbol.find_container_psyir( - local_node=self) + container: Container = ( + container_symbol.find_container_psyir( + local_node=self, + ignore_missing_modules=( + ignore_missing_modules + ), + ) + ) except SymbolError: container = None if not container: @@ -522,12 +617,21 @@ def _location_txt(node): continue routines = [] for name in container.resolve_routine(rsym.name): - psyir = container.find_routine_psyir(name) + # Allow private imports if an 'interface' + # was used. Here, we assume the name of the routine + # is different to the call. + allow_private = name != rsym.name + psyir = container.find_routine_psyir( + name, allow_private=allow_private + ) + if psyir: routines.append(psyir) + if routines: return routines current_table = current_table.parent_symbol_table() + if not wildcard_names: wc_text = "there are no wildcard imports" else: @@ -612,53 +716,119 @@ def _location_txt(node): f" is within a CodeBlock.") def _check_argument_type_matches( - self, - call_arg: DataSymbol, - routine_arg: DataSymbol, - ) -> bool: + self, + call_arg: DataSymbol, + routine_arg: DataSymbol, + check_strict_array_datatype: bool = True, + ) -> bool: """Return information whether argument types are matching. This also supports 'optional' arguments by using partial types. - :param call_arg: _description_ + :param call_arg: Argument from the call :type call_arg: DataSymbol - :param routine_arg: _description_ + :param routine_arg: Argument from the routine :type routine_arg: DataSymbol - :raises CallMatchingArgumentsNotFound: _description_ - :raises CallMatchingArgumentsNotFound: _description_ + :param check_strict_array_datatype: Check strictly for matching + array types. If `False`, only checks for ArrayType itself are done. + :type check_strict_array_datatype: bool + :returns: True if arguments match, False otherwise + :rtype: bool + :raises CallMatchingArgumentsNotFound: Raised if no matching arguments + were found. """ - if isinstance( - routine_arg.datatype, UnsupportedFortranType - ): - # This could be an 'optional' argument. - # This has at least a partial data type - if ( - call_arg.datatype - != routine_arg.datatype.partial_datatype + + type_matches = False + if not check_strict_array_datatype: + # No strict array checks have to be performed, just accept it + if isinstance(call_arg.datatype, ArrayType) and isinstance( + routine_arg.datatype, ArrayType ): - raise CallMatchingArgumentsNotFound( - f"Argument partial type mismatch of call " - f"argument '{call_arg}' and routine argument " - f"'{routine_arg}'" - ) - else: - if call_arg.datatype != routine_arg.datatype: - raise CallMatchingArgumentsNotFound( - f"Argument type mismatch of call argument " - f"'{call_arg}' and routine argument " - f"'{routine_arg}'" - ) + type_matches = True + + if not type_matches: + if isinstance(routine_arg.datatype, UnsupportedFortranType): + # This could be an 'optional' argument. + # This has at least a partial data type + if call_arg.datatype != routine_arg.datatype.partial_datatype: + raise CallMatchingArgumentsNotFound( + f"Argument partial type mismatch of call " + f"argument '{call_arg}' and routine argument " + f"'{routine_arg}'" + ) + else: + if call_arg.datatype != routine_arg.datatype: + raise CallMatchingArgumentsNotFound( + f"Argument type mismatch of call argument " + f"'{call_arg}' and routine argument " + f"'{routine_arg}'" + ) return True - def _get_argument_routine_match(self, routine: Routine): - '''Return a list of integers giving for each argument of the call + def _check_matching_types( + call_arg: Symbol, + routine_arg: Symbol, + check_strict_array_datatype: bool = True, + check_matching_arguments: bool = True, + ) -> bool: + routine_arg: DataSymbol + + type_matches = False + if not check_strict_array_datatype: + # No strict array checks have to be performed, just accept it + if isinstance(call_arg.datatype, ArrayType) and isinstance( + routine_arg.datatype, ArrayType + ): + type_matches = True + + if not type_matches: + # Do the types of arguments match? + # + # TODO #759: If optional is used, it's an unsupported + # Fortran type and we need to use the following workaround + # Once this issue is resolved, simply remove this if + # branch. + # Optional arguments are processed further down. + if isinstance(routine_arg.datatype, UnsupportedFortranType): + if call_arg.datatype != routine_arg.datatype.partial_datatype: + raise CallMatchingArgumentsNotFound( + f"Argument partial type mismatch of call " + f"argument '{call_arg}' and routine argument " + f"'{routine_arg}'" + ) + else: + if call_arg.datatype != routine_arg.datatype: + raise CallMatchingArgumentsNotFound( + f"Argument type mismatch of call argument " + f"'{call_arg.datatype}' and routine argument " + f"'{routine_arg.datatype}'" + ) + type_matches = True + + def _get_argument_routine_match( + self, + routine: Routine, + check_strict_array_datatype: bool = True, + check_matching_arguments: bool = True, + ) -> Union[None, List[int]]: + """Return a list of integers giving for each argument of the call the index of the argument in argument_list (typically of a routine) + :param check_strict_array_datatype: Strict datatype check for + array types + :type check_strict_array_datatype: bool + + :param check_matching_arguments: If no match is possible, + return the first routine in the list of potential candidates. + :type check_matching_arguments: bool + :return: None if no match was found, otherwise list of integers referring to matching arguments. :rtype: None|List[int] - ''' + :raises CallMatchingArgumentsNotFound: If there was some problem in + finding matching arguments. + """ # Create a copy of the list of actual arguments to the routine. # Once an argument has been successfully matched, set it to 'None' @@ -684,7 +854,9 @@ def _get_argument_routine_match(self, routine: Routine): routine_arg = routine_argument_list[call_arg_idx] routine_arg: DataSymbol - self._check_argument_type_matches(call_arg, routine_arg) + self._check_argument_type_matches( + call_arg, routine_arg, check_strict_array_datatype + ) ret_arg_idx_list.append(call_arg_idx) routine_argument_list[call_arg_idx] = None @@ -706,7 +878,13 @@ def _get_argument_routine_match(self, routine: Routine): continue if arg_name == routine_arg.name: - self._check_argument_type_matches(call_arg, routine_arg) + self._check_argument_type_matches( + call_arg, + routine_arg, + check_strict_array_datatype=( + check_strict_array_datatype + ), + ) ret_arg_idx_list.append(routine_arg_idx) break @@ -742,6 +920,8 @@ def _get_argument_routine_match(self, routine: Routine): def get_callee( self, check_matching_arguments: bool = True, + ignore_missing_modules: bool = False, + ignore_unresolved_symbol: bool = False, ): ''' Searches for the implementation(s) of the target routine for this Call @@ -763,21 +943,27 @@ def get_callee( in any containers in scope at the call site. ''' - routine_list = self.get_callees() + routine_list = self.get_callees( + ignore_missing_modules=ignore_missing_modules + ) error: Exception = None # Search for the routine matching the right arguments - for routine in routine_list: - routine: Routine + for routine_node in routine_list: + routine_node: Routine try: - arg_match_list = self._get_argument_routine_match(routine) + arg_match_list = self._get_argument_routine_match( + routine_node, + check_strict_array_datatype=False, + check_matching_arguments=check_matching_arguments, + ) except CallMatchingArgumentsNotFound as err: error = err continue - return (routine, arg_match_list) + return (routine_node, arg_match_list) # If we didn't find any routine, return some routine if no matching # arguments have been found. @@ -793,5 +979,5 @@ def get_callee( ) from error else: raise NotImplementedError( - f"No matching routine found for " f"'{self.routine.name}'" + f"No matching routine found for '{self.routine.name}'" ) diff --git a/src/psyclone/psyir/symbols/containersymbol.py b/src/psyclone/psyir/symbols/containersymbol.py index 5b734aae74..013b3a4f53 100644 --- a/src/psyclone/psyir/symbols/containersymbol.py +++ b/src/psyclone/psyir/symbols/containersymbol.py @@ -125,7 +125,9 @@ def copy(self): new_symbol.is_intrinsic = self.is_intrinsic return new_symbol - def find_container_psyir(self, local_node=None): + def find_container_psyir( + self, local_node=None, ignore_missing_modules: bool = False + ): ''' Searches for the Container that this Symbol refers to. If it is not available, use the interface to import the container. If `local_node` is supplied then the PSyIR tree below it is searched for diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index d44df08d74..f353d94a49 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -53,6 +53,12 @@ from psyclone.psyir.transformations.transformation_error import ( TransformationError) +from typing import List + +# from typing import Dict, List +# from psyclone.psyir.symbols import BOOLEAN_TYPE +# from psyclone.psyir.symbols import ScalarType + _ONE = Literal("1", INTEGER_TYPE) @@ -122,6 +128,36 @@ class InlineTrans(Transformation): Some of these restrictions will be lifted by #924. ''' + + def __init__(self): + # List of call-to-subroutine argument indices + self._ret_arg_match_list: List[int] = None + + # Routine to be inlines + self.node_routine: Routine = None + + # Make strict checks for matching arguments of array data types. + # If disabled, it's sufficient that both arguments are of ArrayType. + # Then, no further checks are performed + self.option_check_argument_strict_array_datatype: bool = True + + # If searching for modules, don't trigger Exceptions if module + # wasn't found. + self.ignore_missing_modules: bool = False + + def set_option( + self, + check_argument_strict_array_datatype: bool = None, + ignore_missing_modules: bool = None, + ): + if check_argument_strict_array_datatype is not None: + self.option_check_argument_strict_array_datatype = ( + check_argument_strict_array_datatype + ) + + if ignore_missing_modules is not None: + self.ignore_missing_modules = ignore_missing_modules + def apply(self, node, options=None): ''' Takes the body of the routine that is the target of the supplied From 23592fba195c5d350d2807c59fbc656adcfe7a8a Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sat, 23 Nov 2024 15:35:25 +0100 Subject: [PATCH 02/20] updates for documentation --- doc/Makefile | 12 ++++++++++++ src/psyclone/psyir/nodes/call.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 doc/Makefile diff --git a/doc/Makefile b/doc/Makefile new file mode 100644 index 0000000000..9f3cc35e4e --- /dev/null +++ b/doc/Makefile @@ -0,0 +1,12 @@ + +all: + make -C developer_guide html SPHINXOPTS="-W --keep-going" + make -C developer_guide linkcheck || echo "Ignoring error of link checking" + make -C reference_guide html SPHINXOPTS="-W --keep-going" + make -C user_guide html SPHINXOPTS="-W --keep-going" + +clean: + make -C developer_guide clean + make -C reference_guide allclean + make -C user_guide clean + diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index 234c3f69d0..43e3870073 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -537,7 +537,7 @@ def get_callees(self, ignore_missing_modules: bool = False): for this Call without any arguments check. :param ignore_missing_modules: If a module wasn't found, return 'None' - instead of throwing an exception 'ModuleNotFound'. + instead of throwing an exception 'ModuleNotFound'. :type ignore_missing_modules: bool :returns: the Routine(s) that this call targets. From abef8fd3e1997a900b197c833c600035357751d2 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sun, 24 Nov 2024 00:57:27 +0100 Subject: [PATCH 03/20] u --- src/psyclone/psyir/nodes/__init__.py | 3 +- src/psyclone/psyir/nodes/call.py | 38 +- src/psyclone/psyir/symbols/symbol_table.py | 8 +- .../psyir/transformations/inline_trans.py | 358 ++-- .../tests/my_shortcut_tests/call_test.py | 1 + .../my_shortcut_tests/inline_trans_test.py | 1 + .../transformations/inline_trans_test.py | 1566 ++++++++++------- 7 files changed, 1157 insertions(+), 818 deletions(-) create mode 120000 src/psyclone/tests/my_shortcut_tests/call_test.py create mode 120000 src/psyclone/tests/my_shortcut_tests/inline_trans_test.py diff --git a/src/psyclone/psyir/nodes/__init__.py b/src/psyclone/psyir/nodes/__init__.py index b43f98f751..ed0d94ae35 100644 --- a/src/psyclone/psyir/nodes/__init__.py +++ b/src/psyclone/psyir/nodes/__init__.py @@ -74,7 +74,7 @@ from psyclone.psyir.nodes.statement import Statement from psyclone.psyir.nodes.structure_reference import StructureReference from psyclone.psyir.nodes.structure_member import StructureMember -from psyclone.psyir.nodes.call import Call +from psyclone.psyir.nodes.call import Call, CallMatchingArgumentsNotFound from psyclone.psyir.nodes.file_container import FileContainer from psyclone.psyir.nodes.directive import ( Directive, StandaloneDirective, RegionDirective) @@ -112,6 +112,7 @@ 'Assignment', 'BinaryOperation', 'Call', + "CallMatchingArgumentsNotFound", 'Clause', 'CodeBlock', 'Container', diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index 43e3870073..8633629b3f 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -275,8 +275,8 @@ def replace_named_arg(self, existing_name, arg): else: raise ValueError( f"The value of the existing_name argument ({existing_name}) " - f"in 'replace_named_arg' in the 'Call' node was not found " - f"in the existing arguments." + "in 'replace_named_arg' in the 'Call' node was not found " + "in the existing arguments." ) # The n'th argument is placed at the n'th+1 children position # because the 1st child is the routine reference @@ -759,9 +759,10 @@ def _check_argument_type_matches( else: if call_arg.datatype != routine_arg.datatype: raise CallMatchingArgumentsNotFound( - f"Argument type mismatch of call argument " - f"'{call_arg}' and routine argument " - f"'{routine_arg}'" + "Argument type mismatch of call argument " + f"'{call_arg}' with type '{call_arg.datatype} " + "and routine argument " + f"'{routine_arg}' with type '{routine_arg.datatype}." ) return True @@ -810,7 +811,6 @@ def _get_argument_routine_match( self, routine: Routine, check_strict_array_datatype: bool = True, - check_matching_arguments: bool = True, ) -> Union[None, List[int]]: """Return a list of integers giving for each argument of the call the index of the argument in argument_list (typically of a routine) @@ -920,6 +920,7 @@ def _get_argument_routine_match( def get_callee( self, check_matching_arguments: bool = True, + check_strict_array_datatype: bool = True, ignore_missing_modules: bool = False, ignore_unresolved_symbol: bool = False, ): @@ -947,7 +948,12 @@ def get_callee( ignore_missing_modules=ignore_missing_modules ) - error: Exception = None + if len(routine_list) == 0: + raise NotImplementedError( + f"No routine or interface found for name '{self.routine.name}'" + ) + + err_info = [] # Search for the routine matching the right arguments for routine_node in routine_list: @@ -956,11 +962,10 @@ def get_callee( try: arg_match_list = self._get_argument_routine_match( routine_node, - check_strict_array_datatype=False, - check_matching_arguments=check_matching_arguments, + check_strict_array_datatype=check_strict_array_datatype, ) except CallMatchingArgumentsNotFound as err: - error = err + err_info.append(err.value) continue return (routine_node, arg_match_list) @@ -973,11 +978,8 @@ def get_callee( # Also return a list of dummy argument indices return (routine_list[0], [i for i in range(len(self.arguments))]) - if error is not None: - raise CallMatchingArgumentsNotFound( - f"No matching routine found for '{self.debug_string()}'" - ) from error - else: - raise NotImplementedError( - f"No matching routine found for '{self.routine.name}'" - ) + error_msg = "\n".join(err_info) + raise CallMatchingArgumentsNotFound( + f"No matching routine found for '{self.debug_string()}'" + + error_msg + ) diff --git a/src/psyclone/psyir/symbols/symbol_table.py b/src/psyclone/psyir/symbols/symbol_table.py index 76335ec635..368ad214c8 100644 --- a/src/psyclone/psyir/symbols/symbol_table.py +++ b/src/psyclone/psyir/symbols/symbol_table.py @@ -585,7 +585,9 @@ def add(self, new_symbol, tag=None): self._symbols[key] = new_symbol - def check_for_clashes(self, other_table, symbols_to_skip=()): + def check_for_clashes( + self, other_table, symbols_to_skip=(), check_unresolved_symbols=True + ): ''' Checks the symbols in the supplied table against those in this table. If there is a name clash that cannot be resolved by @@ -648,6 +650,10 @@ def check_for_clashes(self, other_table, symbols_to_skip=()): f"table imports it via '{other_sym.interface}'.") continue + if not check_unresolved_symbols: + # Skip if unresolved symbols shouldn't be checked + continue + if other_sym.is_unresolved and this_sym.is_unresolved: # Both Symbols are unresolved. if shared_wildcard_imports and not unique_wildcard_imports: diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index f353d94a49..1756a78afa 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -53,9 +53,9 @@ from psyclone.psyir.transformations.transformation_error import ( TransformationError) -from typing import List +from psyclone.psyir.nodes import CallMatchingArgumentsNotFound +from typing import Dict, List -# from typing import Dict, List # from psyclone.psyir.symbols import BOOLEAN_TYPE # from psyclone.psyir.symbols import ScalarType @@ -136,50 +136,89 @@ def __init__(self): # Routine to be inlines self.node_routine: Routine = None - # Make strict checks for matching arguments of array data types. + # If 'True', make strict checks for matching arguments of + # array data types. # If disabled, it's sufficient that both arguments are of ArrayType. # Then, no further checks are performed - self.option_check_argument_strict_array_datatype: bool = True + self._option_check_argument_strict_array_datatype: bool = True # If searching for modules, don't trigger Exceptions if module # wasn't found. - self.ignore_missing_modules: bool = False + self._option_ignore_missing_modules: bool = False + + # If 'True', don't inline if a code block is used within the + # Routine. + self._option_check_codeblocks: bool = True + + # If 'True', the callee must have matching arguments. + # The 'matching' criteria can be weakened by other options. + # If 'False', in case no match was found, the first callee is taken. + self._option_check_matching_arguments_of_callee: bool = True + + # check_diff_container_clashes: bool = True, + # check_diff_container_clashes_unresolved_types: bool = True, + # check_resolve_imports: bool = True, + # check_static_interface: bool = True, + # check_array_type: bool = True, + # check_argument_of_unsupported_type: bool = True, + # check_argument_unresolved_symbols: bool = True, def set_option( self, - check_argument_strict_array_datatype: bool = None, ignore_missing_modules: bool = None, + check_argument_strict_array_datatype: bool = None, + check_codeblocks: bool = None, + check_matching_arguments_of_callee: bool = None, ): if check_argument_strict_array_datatype is not None: - self.option_check_argument_strict_array_datatype = ( + self._option_check_argument_strict_array_datatype = ( check_argument_strict_array_datatype ) if ignore_missing_modules is not None: - self.ignore_missing_modules = ignore_missing_modules + self._option_ignore_missing_modules = ignore_missing_modules - def apply(self, node, options=None): - ''' + if check_codeblocks is not None: + self._option_check_codeblocks = check_codeblocks + + if check_matching_arguments_of_callee is not None: + self._option_check_matching_arguments_of_callee = ( + check_matching_arguments_of_callee + ) + + def apply( + self, node_call: Call, node_routine: Routine = None, options=None + ): + """ Takes the body of the routine that is the target of the supplied call and replaces the call with it. - :param node: target PSyIR node. - :type node: :py:class:`psyclone.psyir.nodes.Routine` + :param call_node: target PSyIR node. + :type call_node: :py:class:`psyclone.psyir.nodes.Call` + :param routine: PSyIR subroutine to be inlined. + Default: Automatically determine subroutine (search) + :type routine: :py:class:`psyclone.psyir.nodes.Routine` :param options: a dictionary with options for transformations. :type options: Optional[Dict[str, Any]] :param bool options["force"]: whether or not to permit the inlining of Routines containing CodeBlocks. Default is False. - ''' - self.validate(node, options) + """ + + # Validate that the inlining can also be accomplish. + # This routine will also update + # self.node_routine and self._ret_arg_match_list + # with the routine to be inlined and the relation between the + # arguments and to which routine arguments they are matched to. + self.validate(node_call, node_routine=node_routine, options=options) # The table associated with the scoping region holding the Call. - table = node.scope.symbol_table + table = node_call.scope.symbol_table # Find the routine to be inlined. - orig_routine = node.get_callees()[0] + orig_routine = node_call.get_callees()[0] if not orig_routine.children or isinstance(orig_routine.children[0], Return): # Called routine is empty so just remove the call. - node.detach() + node_call.detach() return # Ensure we don't modify the original Routine by working with a @@ -207,7 +246,7 @@ def apply(self, node, options=None): # as a Reference. ref2arraytrans = Reference2ArrayRangeTrans() - for child in node.arguments: + for child in node_call.arguments: try: # TODO #1858, this won't yet work for arrays inside structures. ref2arraytrans.apply(child) @@ -218,12 +257,12 @@ def apply(self, node, options=None): # actual arguments. formal_args = routine_table.argument_list for ref in refs[:]: - self._replace_formal_arg(ref, node, formal_args) + self._replace_formal_arg(ref, node_call, formal_args) # Store the Routine level symbol table and node's current scope # so we can merge symbol tables later if required. - ancestor_table = node.ancestor(Routine).scope.symbol_table - scope = node.scope + ancestor_table = node_call.ancestor(Routine).scope.symbol_table + scope = node_call.scope # Copy the nodes from the Routine into the call site. # TODO #924 - while doing this we should ensure that any References @@ -236,7 +275,7 @@ def apply(self, node, options=None): if routine.return_symbol: # This is a function - assignment = node.ancestor(Statement) + assignment = node_call.ancestor(Statement) parent = assignment.parent idx = assignment.position-1 for child in new_stmts: @@ -247,12 +286,12 @@ def apply(self, node, options=None): table.rename_symbol( routine.return_symbol, table.next_available_name( f"inlined_{routine.return_symbol.name}")) - node.replace_with(Reference(routine.return_symbol)) + node_call.replace_with(Reference(routine.return_symbol)) else: # This is a call - parent = node.parent - idx = node.position - node.replace_with(new_stmts[0]) + parent = node_call.parent + idx = node_call.position + node_call.replace_with(new_stmts[0]) for child in new_stmts[1:]: idx += 1 parent.addchild(child, idx) @@ -615,12 +654,20 @@ def _replace_formal_struc_arg(self, actual_arg, ref, call_node, # Just an array reference. return ArrayReference.create(actual_arg.symbol, members[0][1]) - def validate(self, node, options=None): - ''' + def validate( + self, + node_call: Call, + node_routine: Routine = None, + options: Dict[str, str] = None, + ): + """ Checks that the supplied node is a valid target for inlining. - :param node: target PSyIR node. - :type node: subclass of :py:class:`psyclone.psyir.nodes.Call` + :param call_node: target PSyIR node. + :type call_node: subclass of :py:class:`psyclone.psyir.nodes.Call` + :param routine_node: Routine to inline. + Default is to search for it. + :type routine_node: subclass of :py:class:`Routine` :param options: a dictionary with options for transformations. :type options: Optional[Dict[str, Any]] :param bool options["force"]: whether or not to ignore any CodeBlocks @@ -652,65 +699,99 @@ def validate(self, node, options=None): :raises TransformationError: if the shape of an array formal argument does not match that of the corresponding actual argument. - ''' - super().validate(node, options=options) + """ + super().validate(node_call, options=options) - options = {} if options is None else options - forced = options.get("force", False) + self.node_routine: Routine = node_routine # The node should be a Call. - if not isinstance(node, Call): + if not isinstance(node_call, Call): raise TransformationError( - f"The target of the InlineTrans transformation " - f"should be a Call but found '{type(node).__name__}'.") + "The target of the InlineTrans transformation " + f"should be a Call but found '{type(node_call).__name__}'." + ) - if isinstance(node, IntrinsicCall): + if isinstance(node_call, IntrinsicCall): raise TransformationError( - f"Cannot inline an IntrinsicCall ('{node.routine.name}')") - name = node.routine.name + f"Cannot inline an IntrinsicCall ('{node_call.routine.name}')" + ) + name = node_call.routine.name - # Check that we can find the source of the routine being inlined. - # TODO #924 allow for multiple routines (interfaces). - try: - routine = node.get_callees()[0] - except (NotImplementedError, FileNotFoundError, SymbolError) as err: - raise TransformationError( - f"Cannot inline routine '{name}' because its source cannot be " - f"found: {err}") from err + # List of indices relating the call's arguments to the subroutine + # arguments. This can be different due to + # - optional arguments + # - named arguments + + if self.node_routine is None: + # Check that we can find the source of the routine being inlined. + # TODO #924 allow for multiple routines (interfaces). + try: + self.node_routine = node_call.get_callees()[0] + except ( + CallMatchingArgumentsNotFound, + NotImplementedError, + FileNotFoundError, + SymbolError, + ) as err: + raise TransformationError( + f"Cannot inline routine '{name}' because its source cannot" + f" be found:\n{str(err)}" + ) from err + + else: + # A routine has been provided. + # We'll now determine the matching argument list + try: + self._ret_arg_match_list = ( + node_call._get_argument_routine_match( + self.node_routine, + check_strict_array_datatype=False, + ) + ) + except CallMatchingArgumentsNotFound as err: + raise TransformationError( + "Routine's arguments doesn't match subroutine" + ) from err - if not routine.children or isinstance(routine.children[0], Return): + if not self.node_routine.children or isinstance( + self.node_routine.children[0], Return + ): # An empty routine is fine. return - return_stmts = routine.walk(Return) + return_stmts = self.node_routine.walk(Return) if return_stmts: - if len(return_stmts) > 1 or not isinstance(routine.children[-1], - Return): + if len(return_stmts) > 1 or not isinstance( + self.node_routine.children[-1], Return + ): # Either there is more than one Return statement or there is # just one but it isn't the last statement of the Routine. raise TransformationError( f"Routine '{name}' contains one or more " f"Return statements and therefore cannot be inlined.") - if routine.walk(CodeBlock) and not forced: - # N.B. we permit the user to specify the "force" option to allow - # CodeBlocks to be included. - raise TransformationError( - f"Routine '{name}' contains one or more CodeBlocks and " - "therefore cannot be inlined. (If you are confident that " - "the code may safely be inlined despite this then use " - "`options={'force': True}` to override.)") + if self._option_check_codeblocks: + if self.node_routine.walk(CodeBlock): + # N.B. we permit the user to specify the "force" option to + # allow CodeBlocks to be included. + raise TransformationError( + f"Routine '{name}' contains one or more CodeBlocks and " + "therefore cannot be inlined. (If you are confident that " + "the code may safely be inlined despite this then use " + "`check_codeblocks=False` to override.)" + ) # Support for routines with named arguments is not yet implemented. # TODO #924. - for arg in node.argument_names: + for arg in node_call.argument_names: if arg: raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because it " - f"has a named argument '{arg}' (TODO #924).") + f"Routine '{self.node_routine.name}' cannot be inlined" + f" because it has a named argument '{arg}' (TODO #924)." + ) - table = node.scope.symbol_table - routine_table = routine.symbol_table + table = node_call.scope.symbol_table + routine_table = self.node_routine.symbol_table for sym in routine_table.datasymbols: # We don't inline symbols that have an UnsupportedType and are @@ -719,25 +800,28 @@ def validate(self, node, options=None): if isinstance(sym.interface, ArgumentInterface): if isinstance(sym.datatype, UnsupportedType): raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because " - f"it contains a Symbol '{sym.name}' which is an " - f"Argument of UnsupportedType: " - f"'{sym.datatype.declaration}'") + f"Routine '{self.node_routine.name}' cannot be inlined" + f" because it contains a Symbol '{sym.name}' which is" + " an Argument of UnsupportedType:" + f" '{sym.datatype.declaration}'" + ) # We don't inline symbols that have an UnknownInterface, as we # don't know how they are brought into this scope. if isinstance(sym.interface, UnknownInterface): raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because it " - f"contains a Symbol '{sym.name}' with an UnknownInterface:" - f" '{sym.datatype.declaration}'") + f"Routine '{self.node_routine.name}' cannot be inlined" + f" because it contains a Symbol '{sym.name}' with an" + f" UnknownInterface: '{sym.datatype.declaration}'" + ) # Check that there are no static variables in the routine (because # we don't know whether the routine is called from other places). if (isinstance(sym.interface, StaticInterface) and not sym.is_constant): raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because it " - f"has a static (Fortran SAVE) interface for Symbol " - f"'{sym.name}'.") + f"Routine '{self.node_routine.name}' cannot be inlined" + " because it has a static (Fortran SAVE) interface for" + f" Symbol '{sym.name}'." + ) # We can't handle a clash between (apparently) different symbols that # share a name but are imported from different containers. @@ -747,8 +831,9 @@ def validate(self, node, options=None): symbols_to_skip=routine_table.argument_list[:]) except SymbolError as err: raise TransformationError( - f"One or more symbols from routine '{routine.name}' cannot be " - f"added to the table at the call site.") from err + f"One or more symbols from routine '{self.node_routine.name}'" + " cannot be added to the table at the call site." + ) from err # Check for unresolved symbols or for any accessed from the Container # containing the target routine. @@ -758,7 +843,7 @@ def validate(self, node, options=None): # that are used to define the precision of other Symbols in the same # table. If a precision symbol is only used within Statements then we # don't currently capture the fact that it is a precision symbol. - ref_or_lits = routine.walk((Reference, Literal)) + ref_or_lits = self.node_routine.walk((Reference, Literal)) # Check for symbols in any initial-value expressions # (including Fortran parameters) or array dimensions. for sym in routine_table.datasymbols: @@ -799,29 +884,37 @@ def validate(self, node, options=None): # table local to the routine. # pylint: disable=raise-missing-from raise TransformationError( - f"Routine '{routine.name}' cannot be inlined " - f"because it accesses variable '{sym.name}' and this " - f"cannot be found in any of the containers directly " - f"imported into its symbol table.") + f"Routine '{self.node_routine.name}' cannot be inlined" + f" because it accesses variable '{sym.name}' and this" + " cannot be found in any of the containers directly" + " imported into its symbol table." + ) else: if sym.name not in routine_table: raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because " - f"it accesses variable '{sym.name}' from its " - f"parent container.") + f"Routine '{self.node_routine.name}' cannot be inlined" + f" because it accesses variable '{sym.name}' from its" + " parent container." + ) # Check that the shapes of any formal array arguments are the same as # those at the call site. - if len(routine_table.argument_list) != len(node.arguments): - raise TransformationError(LazyString( - lambda: f"Cannot inline '{node.debug_string().strip()}' " - f"because the number of arguments supplied to the call " - f"({len(node.arguments)}) does not match the number of " - f"arguments the routine is declared to have " - f"({len(routine_table.argument_list)}).")) - - for formal_arg, actual_arg in zip(routine_table.argument_list, - node.arguments): + if len(routine_table.argument_list) != len(node_call.arguments): + raise TransformationError( + LazyString( + lambda: ( + f"Cannot inline '{node_call.debug_string().strip()}'" + " because the number of arguments supplied to the" + f" call ({len(node_call.arguments)}) does not match" + " the number of arguments the routine is declared to" + f" have ({len(routine_table.argument_list)})." + ) + ) + ) + + for formal_arg, actual_arg in zip( + routine_table.argument_list, node_call.arguments + ): # If the formal argument is an array with non-default bounds then # we also need to know the bounds of that array at the call site. if not isinstance(formal_arg.datatype, ArrayType): @@ -837,12 +930,17 @@ def validate(self, node, options=None): # Reference or a Literal as we don't know whether the result # of any general expression is or is not an array. # pylint: disable=cell-var-from-loop - raise TransformationError(LazyString( - lambda: f"The call '{node.debug_string()}' cannot be " - f"inlined because actual argument " + raise TransformationError( + LazyString( + lambda: ( + f"The call '{node_call.debug_string()}' cannot be " + "inlined because actual argument " f"'{actual_arg.debug_string()}' corresponds to a " - f"formal argument with array type but is not a " - f"Reference or a Literal.")) + "formal argument with array type but is not a " + "Reference or a Literal." + ) + ) + ) # We have an array argument. We are only able to check that the # argument is not re-shaped in the called routine if we have full @@ -856,10 +954,11 @@ def validate(self, node, options=None): isinstance(actual_arg.datatype.intrinsic, (UnresolvedType, UnsupportedType)))): raise TransformationError( - f"Routine '{routine.name}' cannot be inlined because " - f"the type of the actual argument " - f"'{actual_arg.symbol.name}' corresponding to an array" - f" formal argument ('{formal_arg.name}') is unknown.") + f"Routine '{self.node_routine.name}' cannot be inlined" + " because the type of the actual argument" + f" '{actual_arg.symbol.name}' corresponding to an array" + f" formal argument ('{formal_arg.name}') is unknown." + ) formal_rank = 0 actual_rank = 0 @@ -872,12 +971,18 @@ def validate(self, node, options=None): # because if we get to this point then we're going to quit # the loop. # pylint: disable=cell-var-from-loop - raise TransformationError(LazyString( - lambda: f"Cannot inline routine '{routine.name}' " - f"because it reshapes an argument: actual argument " - f"'{actual_arg.debug_string()}' has rank {actual_rank}" - f" but the corresponding formal argument, " - f"'{formal_arg.name}', has rank {formal_rank}")) + raise TransformationError( + LazyString( + lambda: ( + f"Cannot inline routine '{self.node_routine.name}'" + " because it reshapes an argument: actual" + f" argument '{actual_arg.debug_string()}' has rank" + f" {actual_rank} but the corresponding formal" + f" argument, '{formal_arg.name}', has rank" + f" {formal_rank}" + ) + ) + ) if actual_rank: ranges = actual_arg.walk(Range) for rge in ranges: @@ -885,20 +990,33 @@ def validate(self, node, options=None): if ancestor_ref is not actual_arg: # Have a range in an indirect access. # pylint: disable=cell-var-from-loop - raise TransformationError(LazyString( - lambda: f"Cannot inline routine '{routine.name}' " - f"because argument '{actual_arg.debug_string()}' " - f"has an array range in an indirect access (TODO " - f"#924).")) + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self.node_routine.name}' because" + f" argument '{actual_arg.debug_string()}'" + " has an array range in an indirect" + " access (TODO #924)." + ) + ) + ) if rge.step != _ONE: # TODO #1646. We could resolve this problem by making # a new array and copying the necessary values into it. # pylint: disable=cell-var-from-loop - raise TransformationError(LazyString( - lambda: f"Cannot inline routine '{routine.name}' " - f"because one of its arguments is an array slice " - f"with a non-unit stride: " - f"'{actual_arg.debug_string()}' (TODO #1646)")) + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self.node_routine.name}' because one" + " of its arguments is an array slice with" + " a non-unit stride:" + f" '{actual_arg.debug_string()}' (TODO" + " #1646)" + ) + ) + ) # For AutoAPI auto-documentation generation. diff --git a/src/psyclone/tests/my_shortcut_tests/call_test.py b/src/psyclone/tests/my_shortcut_tests/call_test.py new file mode 120000 index 0000000000..61527f250c --- /dev/null +++ b/src/psyclone/tests/my_shortcut_tests/call_test.py @@ -0,0 +1 @@ +../psyir/nodes/call_test.py \ No newline at end of file diff --git a/src/psyclone/tests/my_shortcut_tests/inline_trans_test.py b/src/psyclone/tests/my_shortcut_tests/inline_trans_test.py new file mode 120000 index 0000000000..9f34f83693 --- /dev/null +++ b/src/psyclone/tests/my_shortcut_tests/inline_trans_test.py @@ -0,0 +1 @@ +../psyir/transformations/inline_trans_test.py \ No newline at end of file diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index e7efd1c172..259e581935 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -34,8 +34,7 @@ # Author: A. R. Porter, STFC Daresbury Lab # Modified: R. W. Ford and S. Siso, STFC Daresbury Lab -'''This module tests the inlining transformation. -''' +"""This module tests the inlining transformation.""" import os import pytest @@ -43,42 +42,48 @@ from psyclone.configuration import Config from psyclone.psyir.nodes import Call, IntrinsicCall, Reference, Routine, Loop from psyclone.psyir.symbols import ( - AutomaticInterface, DataSymbol, UnresolvedType) -from psyclone.psyir.transformations import ( - InlineTrans, TransformationError) + AutomaticInterface, + DataSymbol, + UnresolvedType, +) +from psyclone.psyir.transformations import InlineTrans, TransformationError from psyclone.tests.utilities import Compile -MY_TYPE = (" integer, parameter :: ngrids = 10\n" - " type other_type\n" - " real, dimension(10) :: data\n" - " integer :: nx\n" - " end type other_type\n" - " type my_type\n" - " integer :: idx\n" - " real, dimension(10) :: data\n" - " real, dimension(5,10) :: data2d\n" - " type(other_type) :: local\n" - " end type my_type\n" - " type big_type\n" - " type(my_type) :: region\n" - " end type big_type\n" - " type vbig_type\n" - " type(big_type), dimension(ngrids) :: grids\n" - " end type vbig_type\n") +MY_TYPE = ( + " integer, parameter :: ngrids = 10\n" + " type other_type\n" + " real, dimension(10) :: data\n" + " integer :: nx\n" + " end type other_type\n" + " type my_type\n" + " integer :: idx\n" + " real, dimension(10) :: data\n" + " real, dimension(5,10) :: data2d\n" + " type(other_type) :: local\n" + " end type my_type\n" + " type big_type\n" + " type(my_type) :: region\n" + " end type big_type\n" + " type vbig_type\n" + " type(big_type), dimension(ngrids) :: grids\n" + " end type vbig_type\n" +) # init + def test_init(): - '''Test an InlineTrans transformation can be successfully created.''' + """Test an InlineTrans transformation can be successfully created.""" inline_trans = InlineTrans() assert isinstance(inline_trans, InlineTrans) # apply + def test_apply_empty_routine(fortran_reader, fortran_writer, tmpdir): - '''Check that a call to an empty routine is simply removed.''' + """Check that a call to an empty routine is simply removed.""" code = ( "module test_mod\n" "contains\n" @@ -90,20 +95,20 @@ def test_apply_empty_routine(fortran_reader, fortran_writer, tmpdir): " subroutine sub(idx)\n" " integer :: idx\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert (" i = 10\n\n" - " end subroutine run_it\n" in output) + assert " i = 10\n\n end subroutine run_it\n" in output assert Compile(tmpdir).string_compiles(output) def test_apply_single_return(fortran_reader, fortran_writer, tmpdir): - '''Check that a call to a routine containing only a return statement - is removed. ''' + """Check that a call to a routine containing only a return statement + is removed.""" code = ( "module test_mod\n" "contains\n" @@ -116,20 +121,20 @@ def test_apply_single_return(fortran_reader, fortran_writer, tmpdir): " integer :: idx\n" " return\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert (" i = 10\n\n" - " end subroutine run_it\n" in output) + assert " i = 10\n\n end subroutine run_it\n" in output assert Compile(tmpdir).string_compiles(output) def test_apply_return_then_cb(fortran_reader, fortran_writer, tmpdir): - '''Check that a call to a routine containing a return statement followed - by a CodeBlock is removed.''' + """Check that a call to a routine containing a return statement followed + by a CodeBlock is removed.""" code = ( "module test_mod\n" "contains\n" @@ -143,20 +148,20 @@ def test_apply_return_then_cb(fortran_reader, fortran_writer, tmpdir): " return\n" " write(*,*) idx\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert (" i = 10\n\n" - " end subroutine run_it\n" in output) + assert " i = 10\n\n end subroutine run_it\n" in output assert Compile(tmpdir).string_compiles(output) def test_apply_array_arg(fortran_reader, fortran_writer, tmpdir): - ''' Check that the apply() method works correctly for a very simple - call to a routine with an array reference as argument. ''' + """Check that the apply() method works correctly for a very simple + call to a routine with an array reference as argument.""" code = ( "module test_mod\n" "contains\n" @@ -172,25 +177,29 @@ def test_apply_array_arg(fortran_reader, fortran_writer, tmpdir): " real, intent(inout) :: x\n" " x = 2.0*x\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert (" do i = 1, 10, 1\n" - " a(i) = 1.0\n" - " a(i) = 2.0 * a(i)\n" - " enddo\n" in output) + assert ( + " do i = 1, 10, 1\n" + " a(i) = 1.0\n" + " a(i) = 2.0 * a(i)\n" + " enddo\n" + in output + ) assert Compile(tmpdir).string_compiles(output) def test_apply_array_access(fortran_reader, fortran_writer, tmpdir): - ''' + """ Check that the apply method works correctly when an array is passed into the routine and then indexed within it. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -209,27 +218,31 @@ def test_apply_array_access(fortran_reader, fortran_writer, tmpdir): " x(i) = 2.0*ivar\n" " end do\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert (" do i = 1, 10, 1\n" - " do i_1 = 1, 10, 1\n" - " a(i_1) = 2.0 * i\n" - " enddo\n" in output) + assert ( + " do i = 1, 10, 1\n" + " do i_1 = 1, 10, 1\n" + " a(i_1) = 2.0 * i\n" + " enddo\n" + in output + ) assert Compile(tmpdir).string_compiles(output) def test_apply_gocean_kern(fortran_reader, fortran_writer, monkeypatch): - ''' + """ Test the apply method with a typical GOcean kernel. TODO #924 - currently this xfails because we don't resolve the type of the actual argument. - ''' + """ code = ( "module psy_single_invoke_test\n" " use field_mod, only: r2d_field\n" @@ -258,102 +271,114 @@ def test_apply_gocean_kern(fortran_reader, fortran_writer, monkeypatch): # Set up include_path to import the proper module src_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), - "../../../external/dl_esm_inf/finite_difference/src") - monkeypatch.setattr(Config.get(), '_include_paths', [str(src_dir)]) + "../../../external/dl_esm_inf/finite_difference/src", + ) + monkeypatch.setattr(Config.get(), "_include_paths", [str(src_dir)]) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.apply(psyir.walk(Call)[0]) - if ("actual argument 'cu_fld' corresponding to an array formal " - "argument ('cu') is unknown" in str(err.value)): + if ( + "actual argument 'cu_fld' corresponding to an array formal " + "argument ('cu') is unknown" + in str(err.value) + ): pytest.xfail( "TODO #924 - extend validation to attempt to resolve type of " - "actual argument.") + "actual argument." + ) output = fortran_writer(psyir) - assert (" do j = cu_fld%internal%ystart, cu_fld%internal%ystop, 1\n" - " do i = cu_fld%internal%xstart, cu_fld%internal%xstop, 1\n" - " cu_fld%data(i,j) = 0.5d0 * (pf%data(i,j) + " - "pf%data(i - 1,j)) * u_fld%data(i,j)\n" - " enddo\n" - " enddo\n" in output) + assert ( + " do j = cu_fld%internal%ystart, cu_fld%internal%ystop, 1\n" + " do i = cu_fld%internal%xstart, cu_fld%internal%xstop, 1\n" + " cu_fld%data(i,j) = 0.5d0 * (pf%data(i,j) + " + "pf%data(i - 1,j)) * u_fld%data(i,j)\n" + " enddo\n" + " enddo\n" + in output + ) def test_apply_struct_arg(fortran_reader, fortran_writer, tmpdir): - ''' + """ Check that the apply() method works correctly when the routine argument is a StructureReference containing an ArrayMember which is accessed inside the routine. - ''' + """ code = ( - f"module test_mod\n" + "module test_mod\n" f"{MY_TYPE}" - f"contains\n" - f" subroutine run_it()\n" - f" integer :: i\n" - f" type(my_type) :: var\n" - f" type(my_type) :: var_list(10)\n" - f" type(big_type) :: var2(5)\n" - f" do i=1,5\n" - f" call sub(var, i)\n" - f" call sub(var_list(i), i)\n" - f" call sub(var2(i)%region, i)\n" - f" call sub2(var2)\n" - f" end do\n" - f" end subroutine run_it\n" - f" subroutine sub(x, ivar)\n" - f" type(my_type), intent(inout) :: x\n" - f" integer, intent(in) :: ivar\n" - f" integer :: i\n" - f" do i = 1, 10\n" - f" x%data(i) = 2.0*ivar\n" - f" end do\n" - f" x%data(:) = -1.0\n" - f" x%data = -5.0\n" - f" x%data(1:2) = 0.0\n" - f" end subroutine sub\n" - f" subroutine sub2(x)\n" - f" type(big_type), dimension(:), intent(inout) :: x\n" - f" x(:)%region%local%nx = 0\n" - f" end subroutine sub2\n" - f"end module test_mod\n") + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " type(my_type) :: var\n" + " type(my_type) :: var_list(10)\n" + " type(big_type) :: var2(5)\n" + " do i=1,5\n" + " call sub(var, i)\n" + " call sub(var_list(i), i)\n" + " call sub(var2(i)%region, i)\n" + " call sub2(var2)\n" + " end do\n" + " end subroutine run_it\n" + " subroutine sub(x, ivar)\n" + " type(my_type), intent(inout) :: x\n" + " integer, intent(in) :: ivar\n" + " integer :: i\n" + " do i = 1, 10\n" + " x%data(i) = 2.0*ivar\n" + " end do\n" + " x%data(:) = -1.0\n" + " x%data = -5.0\n" + " x%data(1:2) = 0.0\n" + " end subroutine sub\n" + " subroutine sub2(x)\n" + " type(big_type), dimension(:), intent(inout) :: x\n" + " x(:)%region%local%nx = 0\n" + " end subroutine sub2\n" + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) - assert (" do i = 1, 5, 1\n" - " do i_1 = 1, 10, 1\n" - " var%data(i_1) = 2.0 * i\n" - " enddo\n" - " var%data(:) = -1.0\n" - " var%data = -5.0\n" - " var%data(1:2) = 0.0\n" - " do i_2 = 1, 10, 1\n" - " var_list(i)%data(i_2) = 2.0 * i\n" - " enddo\n" - " var_list(i)%data(:) = -1.0\n" - " var_list(i)%data = -5.0\n" - " var_list(i)%data(1:2) = 0.0\n" - " do i_3 = 1, 10, 1\n" - " var2(i)%region%data(i_3) = 2.0 * i\n" - " enddo\n" - " var2(i)%region%data(:) = -1.0\n" - " var2(i)%region%data = -5.0\n" - " var2(i)%region%data(1:2) = 0.0\n" - " var2(1:5)%region%local%nx = 0\n" - " enddo\n" in output) + assert ( + " do i = 1, 5, 1\n" + " do i_1 = 1, 10, 1\n" + " var%data(i_1) = 2.0 * i\n" + " enddo\n" + " var%data(:) = -1.0\n" + " var%data = -5.0\n" + " var%data(1:2) = 0.0\n" + " do i_2 = 1, 10, 1\n" + " var_list(i)%data(i_2) = 2.0 * i\n" + " enddo\n" + " var_list(i)%data(:) = -1.0\n" + " var_list(i)%data = -5.0\n" + " var_list(i)%data(1:2) = 0.0\n" + " do i_3 = 1, 10, 1\n" + " var2(i)%region%data(i_3) = 2.0 * i\n" + " enddo\n" + " var2(i)%region%data(:) = -1.0\n" + " var2(i)%region%data = -5.0\n" + " var2(i)%region%data(1:2) = 0.0\n" + " var2(1:5)%region%local%nx = 0\n" + " enddo\n" + in output + ) assert Compile(tmpdir).string_compiles(output) def test_apply_unresolved_struct_arg(fortran_reader, fortran_writer): - ''' + """ Check that we handle acceptable cases of the type of an argument being unresolved but that we reject the case where we can't be sure of the array indexing. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -388,7 +413,8 @@ def test_apply_unresolved_struct_arg(fortran_reader, fortran_writer): " type(mystery_type), dimension(3:5), intent(inout) :: x\n" " x(:)%region%local%nx = 0\n" " end subroutine sub4\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() calls = psyir.walk(Call) @@ -397,60 +423,70 @@ def test_apply_unresolved_struct_arg(fortran_reader, fortran_writer): # Second one should fail. with pytest.raises(TransformationError) as err: inline_trans.apply(calls[1]) - assert ("Routine 'sub3' cannot be inlined because the type of the actual " - "argument 'mystery' corresponding to an array formal argument " - "('x') is unknown" in str(err.value)) + assert ( + "Routine 'sub3' cannot be inlined because the type of the actual " + "argument 'mystery' corresponding to an array formal argument " + "('x') is unknown" + in str(err.value) + ) # Third one should be fine because it is a scalar argument. inline_trans.apply(calls[2]) # We can't do the fourth one. with pytest.raises(TransformationError) as err: inline_trans.apply(calls[3]) - assert ("Routine 'sub4' cannot be inlined because the type of the actual " - "argument 'mystery' corresponding to an array formal argument " - "('x') is unknown." in str(err.value)) + assert ( + "Routine 'sub4' cannot be inlined because the type of the actual " + "argument 'mystery' corresponding to an array formal argument " + "('x') is unknown." + in str(err.value) + ) output = fortran_writer(psyir) - assert (" varr(1:5)%region%local%nx = 0\n" - " call sub3(mystery)\n" - " mystery%flag = 1\n" - " call sub4(mystery)\n" in output) + assert ( + " varr(1:5)%region%local%nx = 0\n" + " call sub3(mystery)\n" + " mystery%flag = 1\n" + " call sub4(mystery)\n" + in output + ) def test_apply_struct_slice_arg(fortran_reader, fortran_writer, tmpdir): - ''' + """ Check that the apply() method works correctly when there are slices in structure accesses in both the actual and formal arguments. - ''' + """ code = ( - f"module test_mod\n" + "module test_mod\n" f"{MY_TYPE}" - f"contains\n" - f" subroutine run_it()\n" - f" integer :: i\n" - f" type(my_type) :: var_list(10)\n" - f" type(vbig_type), dimension(5) :: cvar\n" - f" call sub(var_list(:)%local%nx, i)\n" - f" call sub2(var_list(:), 1, 1)\n" - f" call sub2(var_list(:), i, i+2)\n" - f" call sub3(cvar)\n" - f" end subroutine run_it\n" - f" subroutine sub(ix, indx)\n" - f" integer, dimension(:) :: ix\n" - f" integer, intent(in) :: indx\n" - f" ix(:) = ix(:) + 1\n" - f" end subroutine sub\n" - f" subroutine sub2(x, start, stop)\n" - f" type(my_type), dimension(:) :: x\n" - f" integer :: start, stop\n" - f" x(:)%data(2) = 0.0\n" - f" x(:)%local%nx = 4\n" - f" x(start:stop+1)%local%nx = -2\n" - f" end subroutine sub2\n" - f" subroutine sub3(y)\n" - f" type(vbig_type), dimension(:) :: y\n" - f" y(2)%grids(2)%region%data(:) = 0.0\n" - f" end subroutine sub3\n" - f"end module test_mod\n") + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " type(my_type) :: var_list(10)\n" + " type(vbig_type), dimension(5) :: cvar\n" + " call sub(var_list(:)%local%nx, i)\n" + " call sub2(var_list(:), 1, 1)\n" + " call sub2(var_list(:), i, i+2)\n" + " call sub3(cvar)\n" + " end subroutine run_it\n" + " subroutine sub(ix, indx)\n" + " integer, dimension(:) :: ix\n" + " integer, intent(in) :: indx\n" + " ix(:) = ix(:) + 1\n" + " end subroutine sub\n" + " subroutine sub2(x, start, stop)\n" + " type(my_type), dimension(:) :: x\n" + " integer :: start, stop\n" + " x(:)%data(2) = 0.0\n" + " x(:)%local%nx = 4\n" + " x(start:stop+1)%local%nx = -2\n" + " end subroutine sub2\n" + " subroutine sub3(y)\n" + " type(vbig_type), dimension(:) :: y\n" + " y(2)%grids(2)%region%data(:) = 0.0\n" + " end subroutine sub3\n" + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): @@ -464,30 +500,32 @@ def test_apply_struct_slice_arg(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) -def test_apply_struct_local_limits_caller(fortran_reader, fortran_writer, - tmpdir): - ''' +def test_apply_struct_local_limits_caller( + fortran_reader, fortran_writer, tmpdir +): + """ Test the apply() method when there are array bounds specified in the caller. - ''' + """ code = ( - f"module test_mod\n" + "module test_mod\n" f"{MY_TYPE}" - f"contains\n" - f" subroutine run_it()\n" - f" integer :: i\n" - f" type(my_type) :: var_list(10)\n" - f" call sub2(var_list(3:7), 5, 6)\n" - f" end subroutine run_it\n" - f" subroutine sub2(x, start, stop)\n" - f" type(my_type), dimension(:) :: x\n" - f" integer :: start, stop\n" - f" x(:)%data(2) = 1.0\n" - f" x(:)%local%nx = 3\n" - f" x(start:stop+1)%local%nx = -2\n" - f" end subroutine sub2\n" - f"end module test_mod\n") + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " type(my_type) :: var_list(10)\n" + " call sub2(var_list(3:7), 5, 6)\n" + " end subroutine run_it\n" + " subroutine sub2(x, start, stop)\n" + " type(my_type), dimension(:) :: x\n" + " integer :: start, stop\n" + " x(:)%data(2) = 1.0\n" + " x(:)%local%nx = 3\n" + " x(start:stop+1)%local%nx = -2\n" + " end subroutine sub2\n" + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): @@ -500,38 +538,40 @@ def test_apply_struct_local_limits_caller(fortran_reader, fortran_writer, assert Compile(tmpdir).string_compiles(output) -def test_apply_struct_local_limits_caller_decln(fortran_reader, fortran_writer, - tmpdir): - ''' +def test_apply_struct_local_limits_caller_decln( + fortran_reader, fortran_writer, tmpdir +): + """ Test the apply() method when there are non-default array bounds specified in the declaration at the call site. - ''' + """ code = ( - f"module test_mod\n" + "module test_mod\n" f"{MY_TYPE}" - f"contains\n" - f" subroutine run_it()\n" - f" integer :: i\n" - f" type(my_type), dimension(2:9) :: varat2\n" - f" real, dimension(4:8) :: varat3\n" - f" call sub2(varat2(:), 5, 6)\n" - f" call sub2(varat2(3:8), 5, 6)\n" - f" call sub3(varat3(5:6))\n" - f" call sub3(varat3)\n" - f" end subroutine run_it\n" - f" subroutine sub2(x, start, stop)\n" - f" type(my_type), dimension(:) :: x\n" - f" integer :: start, stop\n" - f" x(:)%data(2) = 1.0\n" - f" x(:)%local%nx = 3\n" - f" x(start:stop+1)%local%nx = -2\n" - f" end subroutine sub2\n" - f" subroutine sub3(x)\n" - f" real, dimension(:) :: x\n" - f" x(1:2) = 4.0\n" - f" end subroutine sub3\n" - f"end module test_mod\n") + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " type(my_type), dimension(2:9) :: varat2\n" + " real, dimension(4:8) :: varat3\n" + " call sub2(varat2(:), 5, 6)\n" + " call sub2(varat2(3:8), 5, 6)\n" + " call sub3(varat3(5:6))\n" + " call sub3(varat3)\n" + " end subroutine run_it\n" + " subroutine sub2(x, start, stop)\n" + " type(my_type), dimension(:) :: x\n" + " integer :: start, stop\n" + " x(:)%data(2) = 1.0\n" + " x(:)%local%nx = 3\n" + " x(start:stop+1)%local%nx = -2\n" + " end subroutine sub2\n" + " subroutine sub3(x)\n" + " real, dimension(:) :: x\n" + " x(1:2) = 4.0\n" + " end subroutine sub3\n" + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): @@ -553,13 +593,14 @@ def test_apply_struct_local_limits_caller_decln(fortran_reader, fortran_writer, assert Compile(tmpdir).string_compiles(output) -def test_apply_struct_local_limits_routine(fortran_reader, fortran_writer, - tmpdir): - ''' +def test_apply_struct_local_limits_routine( + fortran_reader, fortran_writer, tmpdir +): + """ Test the apply() method when there are non-default array bounds specified in the declaration within the called routine. - ''' + """ code = ( f"module test_mod\n" f"{MY_TYPE}" @@ -584,7 +625,8 @@ def test_apply_struct_local_limits_routine(fortran_reader, fortran_writer, f" y(start:stop+1)%local%nx = -3\n" f" z(start+1) = 8.0\n" f" end subroutine sub3\n" - f"end module test_mod\n") + f"end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): @@ -616,13 +658,13 @@ def test_apply_struct_local_limits_routine(fortran_reader, fortran_writer, def test_apply_array_limits_are_formal_args(fortran_reader, fortran_writer): - ''' + """ Check that apply() correctly handles the case where the start/stop values of an array formal argument are given in terms of other formal arguments. - ''' - code = ''' + """ + code = """ module test_mod implicit none contains @@ -638,7 +680,7 @@ def test_apply_array_limits_are_formal_args(fortran_reader, fortran_writer): var(start+1) = 5.0 end subroutine end module test_mod -''' +""" psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() acall = psyir.walk(Call, stop_type=Call)[0] @@ -648,12 +690,12 @@ def test_apply_array_limits_are_formal_args(fortran_reader, fortran_writer): def test_apply_allocatable_array_arg(fortran_reader, fortran_writer): - ''' + """ Check that apply() works correctly when a formal argument is given the ALLOCATABLE attribute (meaning that the bounds of the formal argument are those of the actual argument). - ''' + """ code = ( "module test_mod\n" " type my_type\n" @@ -684,7 +726,7 @@ def test_apply_allocatable_array_arg(fortran_reader, fortran_writer): " x(ji+2,jj+1) = -1.0\n" " end subroutine sub1\n" "end module test_mod\n" - ) + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): @@ -702,11 +744,11 @@ def test_apply_allocatable_array_arg(fortran_reader, fortran_writer): def test_apply_array_slice_arg(fortran_reader, fortran_writer, tmpdir): - ''' + """ Check that the apply() method works correctly when an array slice is passed to a routine and then accessed within it. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -744,56 +786,61 @@ def test_apply_array_slice_arg(fortran_reader, fortran_writer, tmpdir): " x(i,:) = 2.0 * x(i,:)\n" " end do\n" " end subroutine sub2a\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() for call in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(call) output = fortran_writer(psyir) - assert (" do i = 1, 10, 1\n" - " do i_1 = 1, 10, 1\n" - " a(1,i_1,i) = 2.0 * i_1\n" - " enddo\n" - " enddo\n" - " a(1,1,:) = 3.0 * a(1,1,:)\n" - " a(:,1,:) = 2.0 * a(:,1,:)\n" - " b(:,:) = 2.0 * b(:,:)\n" - " do i_4 = 1, 10, 1\n" - " b(i_4,:5) = 2.0 * b(i_4,:5)\n" in output) + assert ( + " do i = 1, 10, 1\n" + " do i_1 = 1, 10, 1\n" + " a(1,i_1,i) = 2.0 * i_1\n" + " enddo\n" + " enddo\n" + " a(1,1,:) = 3.0 * a(1,1,:)\n" + " a(:,1,:) = 2.0 * a(:,1,:)\n" + " b(:,:) = 2.0 * b(:,:)\n" + " do i_4 = 1, 10, 1\n" + " b(i_4,:5) = 2.0 * b(i_4,:5)\n" + in output + ) assert Compile(tmpdir).string_compiles(output) def test_apply_struct_array_arg(fortran_reader, fortran_writer, tmpdir): - '''Check that apply works correctly when the actual argument is an - array element within a structure.''' + """Check that apply works correctly when the actual argument is an + array element within a structure.""" code = ( - f"module test_mod\n" + "module test_mod\n" f"{MY_TYPE}" - f"contains\n" - f" subroutine run_it()\n" - f" integer :: i, ig\n" - f" real :: a(10)\n" - f" type(my_type) :: grid\n" - f" type(my_type), dimension(5) :: grid_list\n" - f" grid%data(:) = 1.0\n" - f" do i=1,10\n" - f" a(i) = 1.0\n" - f" call sub(grid%data(i))\n" - f" end do\n" - f" do i=1,10\n" - f" ig = min(i, 5)\n" - f" call sub(grid_list(ig)%data(i))\n" - f" end do\n" - f" do i=1,10\n" - f" ig = min(i, 5)\n" - f" call sub(grid_list(ig)%local%data(i))\n" - f" end do\n" - f" end subroutine run_it\n" - f" subroutine sub(x)\n" - f" real, intent(inout) :: x\n" - f" x = 2.0*x\n" - f" end subroutine sub\n" - f"end module test_mod\n") + "contains\n" + " subroutine run_it()\n" + " integer :: i, ig\n" + " real :: a(10)\n" + " type(my_type) :: grid\n" + " type(my_type), dimension(5) :: grid_list\n" + " grid%data(:) = 1.0\n" + " do i=1,10\n" + " a(i) = 1.0\n" + " call sub(grid%data(i))\n" + " end do\n" + " do i=1,10\n" + " ig = min(i, 5)\n" + " call sub(grid_list(ig)%data(i))\n" + " end do\n" + " do i=1,10\n" + " ig = min(i, 5)\n" + " call sub(grid_list(ig)%local%data(i))\n" + " end do\n" + " end subroutine run_it\n" + " subroutine sub(x)\n" + " real, intent(inout) :: x\n" + " x = 2.0*x\n" + " end subroutine sub\n" + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) loops = psyir.walk(Loop) inline_trans = InlineTrans() @@ -801,55 +848,65 @@ def test_apply_struct_array_arg(fortran_reader, fortran_writer, tmpdir): inline_trans.apply(loops[1].loop_body.children[1]) inline_trans.apply(loops[2].loop_body.children[1]) output = fortran_writer(psyir).lower() - assert (" do i = 1, 10, 1\n" - " a(i) = 1.0\n" - " grid%data(i) = 2.0 * grid%data(i)\n" - " enddo\n" in output) - assert (" do i = 1, 10, 1\n" - " ig = min(i, 5)\n" - " grid_list(ig)%data(i) = 2.0 * grid_list(ig)%data(i)\n" - " enddo\n" in output) - assert (" do i = 1, 10, 1\n" - " ig = min(i, 5)\n" - " grid_list(ig)%local%data(i) = 2.0 * " - "grid_list(ig)%local%data(i)\n" - " enddo\n" in output) + assert ( + " do i = 1, 10, 1\n" + " a(i) = 1.0\n" + " grid%data(i) = 2.0 * grid%data(i)\n" + " enddo\n" + in output + ) + assert ( + " do i = 1, 10, 1\n" + " ig = min(i, 5)\n" + " grid_list(ig)%data(i) = 2.0 * grid_list(ig)%data(i)\n" + " enddo\n" + in output + ) + assert ( + " do i = 1, 10, 1\n" + " ig = min(i, 5)\n" + " grid_list(ig)%local%data(i) = 2.0 * " + "grid_list(ig)%local%data(i)\n" + " enddo\n" + in output + ) assert Compile(tmpdir).string_compiles(output) def test_apply_struct_array_slice_arg(fortran_reader, fortran_writer, tmpdir): - '''Check that apply works correctly when the actual argument is an - array slice within a structure.''' + """Check that apply works correctly when the actual argument is an + array slice within a structure.""" code = ( - f"module test_mod\n" + "module test_mod\n" f"{MY_TYPE}" - f"contains\n" - f" subroutine run_it()\n" - f" integer :: i\n" - f" real :: a(10)\n" - f" type(my_type) :: grid\n" - f" type(vbig_type) :: micah\n" - f" grid%data(:) = 1.0\n" - f" grid%data2d(:,:) = 1.0\n" - f" do i=1,10\n" - f" a(i) = 1.0\n" - f" call sub(micah%grids(3)%region%data(:))\n" - f" call sub(grid%data2d(:,i))\n" - f" call sub(grid%data2d(1:5,i))\n" - f" call sub(grid%local%data)\n" - f" end do\n" - f" end subroutine run_it\n" - f" subroutine sub(x)\n" - f" real, dimension(:), intent(inout) :: x\n" - f" integer ji\n" - f" do ji = 1, 5\n" - f" x(ji) = 2.0*x(ji)\n" - f" end do\n" - f" x(1:2) = 0.0\n" - f" x(:) = 3.0\n" - f" x = 5.0\n" - f" end subroutine sub\n" - f"end module test_mod\n") + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " real :: a(10)\n" + " type(my_type) :: grid\n" + " type(vbig_type) :: micah\n" + " grid%data(:) = 1.0\n" + " grid%data2d(:,:) = 1.0\n" + " do i=1,10\n" + " a(i) = 1.0\n" + " call sub(micah%grids(3)%region%data(:))\n" + " call sub(grid%data2d(:,i))\n" + " call sub(grid%data2d(1:5,i))\n" + " call sub(grid%local%data)\n" + " end do\n" + " end subroutine run_it\n" + " subroutine sub(x)\n" + " real, dimension(:), intent(inout) :: x\n" + " integer ji\n" + " do ji = 1, 5\n" + " x(ji) = 2.0*x(ji)\n" + " end do\n" + " x(1:2) = 0.0\n" + " x(:) = 3.0\n" + " x = 5.0\n" + " end subroutine sub\n" + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() for call in psyir.walk(Call): @@ -860,91 +917,101 @@ def test_apply_struct_array_slice_arg(fortran_reader, fortran_writer, tmpdir): continue inline_trans.apply(call) output = fortran_writer(psyir) - assert (" do i = 1, 10, 1\n" - " a(i) = 1.0\n" - " do ji = 1, 5, 1\n" - " micah%grids(3)%region%data(ji) = 2.0 * " - "micah%grids(3)%region%data(ji)\n" - " enddo\n" - " micah%grids(3)%region%data(1:2) = 0.0\n" - " micah%grids(3)%region%data(:) = 3.0\n" - " micah%grids(3)%region%data(:) = 5.0\n" - " do ji_1 = 1, 5, 1\n" - " grid%data2d(ji_1,i) = 2.0 * grid%data2d(ji_1,i)\n" - " enddo\n" - " grid%data2d(1:2,i) = 0.0\n" - " grid%data2d(:,i) = 3.0\n" - " grid%data2d(:,i) = 5.0\n" - " do ji_2 = 1, 5, 1\n" - " grid%data2d(ji_2,i) = 2.0 * grid%data2d(ji_2,i)\n" - " enddo\n" - " grid%data2d(1:2,i) = 0.0\n" - " grid%data2d(1:5,i) = 3.0\n" - " grid%data2d(1:5,i) = 5.0\n" - # TODO #1858: replace the following line with the commented-out - # lines below. - " call sub(grid%local%data)\n" - # " do ji_3 = 1, 5, 1\n" - # " grid%local%data(ji_3) = 2.0 * grid%local%data(ji_3)\n" - # " enddo\n" - # " grid%local%data(1:2) = 0.0\n" - # " grid%local%data(:) = 3.0\n" - # " grid%local%data = 5.0\n" - " enddo\n" in output) + assert ( + " do i = 1, 10, 1\n" + " a(i) = 1.0\n" + " do ji = 1, 5, 1\n" + " micah%grids(3)%region%data(ji) = 2.0 * " + "micah%grids(3)%region%data(ji)\n" + " enddo\n" + " micah%grids(3)%region%data(1:2) = 0.0\n" + " micah%grids(3)%region%data(:) = 3.0\n" + " micah%grids(3)%region%data(:) = 5.0\n" + " do ji_1 = 1, 5, 1\n" + " grid%data2d(ji_1,i) = 2.0 * grid%data2d(ji_1,i)\n" + " enddo\n" + " grid%data2d(1:2,i) = 0.0\n" + " grid%data2d(:,i) = 3.0\n" + " grid%data2d(:,i) = 5.0\n" + " do ji_2 = 1, 5, 1\n" + " grid%data2d(ji_2,i) = 2.0 * grid%data2d(ji_2,i)\n" + " enddo\n" + " grid%data2d(1:2,i) = 0.0\n" + " grid%data2d(1:5,i) = 3.0\n" + " grid%data2d(1:5,i) = 5.0\n" + # TODO #1858: replace the following line with the commented-out + # lines below. + " call sub(grid%local%data)\n" + # " do ji_3 = 1, 5, 1\n" + # " grid%local%data(ji_3) = 2.0 * grid%local%data(ji_3)\n" + # " enddo\n" + # " grid%local%data(1:2) = 0.0\n" + # " grid%local%data(:) = 3.0\n" + # " grid%local%data = 5.0\n" + " enddo\n" in output + ) assert Compile(tmpdir).string_compiles(output) @pytest.mark.parametrize("type_decln", [MY_TYPE, " use some_mod\n"]) -def test_apply_struct_array(fortran_reader, fortran_writer, tmpdir, - type_decln): - '''Test that apply works correctly when the formal argument is an +def test_apply_struct_array( + fortran_reader, fortran_writer, tmpdir, type_decln +): + """Test that apply works correctly when the formal argument is an array of structures. We test both when the type of the structure is resolved and when it isn't. In the latter case we cannot perform inlining because we don't know the array bounds at the call site. - ''' + """ code = ( - f"module test_mod\n" + "module test_mod\n" f"{type_decln}" - f"contains\n" - f" subroutine run_it()\n" - f" integer :: i\n" - f" real :: a(10)\n" - f" type(my_type) :: grid\n" - f" type(vbig_type) :: micah\n" - f" call sub(micah%grids(:))\n" - f" end subroutine run_it\n" - f" subroutine sub(x)\n" - f" type(big_type), dimension(2:4) :: x\n" - f" integer ji\n" - f" ji = 2\n" - f" x(:)%region%idx = 3.0\n" - f" x(ji)%region%idx = 2.0\n" - f" end subroutine sub\n" - f"end module test_mod\n") + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " real :: a(10)\n" + " type(my_type) :: grid\n" + " type(vbig_type) :: micah\n" + " call sub(micah%grids(:))\n" + " end subroutine run_it\n" + " subroutine sub(x)\n" + " type(big_type), dimension(2:4) :: x\n" + " integer ji\n" + " ji = 2\n" + " x(:)%region%idx = 3.0\n" + " x(ji)%region%idx = 2.0\n" + " end subroutine sub\n" + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() if "use some_mod" in type_decln: with pytest.raises(TransformationError) as err: inline_trans.apply(psyir.walk(Call)[0]) - assert ("Routine 'sub' cannot be inlined because the type of the " - "actual argument 'micah' corresponding to an array formal " - "argument ('x') is unknown." in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because the type of the " + "actual argument 'micah' corresponding to an array formal " + "argument ('x') is unknown." + in str(err.value) + ) else: inline_trans.apply(psyir.walk(Call)[0]) output = fortran_writer(psyir) - assert (" ji = 2\n" - " micah%grids(2 - 2 + 1:4 - 2 + 1)%region%idx = 3.0\n" - " micah%grids(ji - 2 + 1)%region%idx = 2.0\n" in output) + assert ( + " ji = 2\n" + " micah%grids(2 - 2 + 1:4 - 2 + 1)%region%idx = 3.0\n" + " micah%grids(ji - 2 + 1)%region%idx = 2.0\n" + in output + ) assert Compile(tmpdir).string_compiles(output) def test_apply_repeated_module_use(fortran_reader, fortran_writer): - ''' + """ Check that any module use statements are not duplicated when multiple calls are inlined. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -966,7 +1033,8 @@ def test_apply_repeated_module_use(fortran_reader, fortran_writer): " real, intent(inout), dimension(10) :: x\n" " x(:) = 4*radius\n" " end subroutine sub2\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() for call in psyir.walk(Routine)[0].walk(Call, stop_type=Call): @@ -974,18 +1042,24 @@ def test_apply_repeated_module_use(fortran_reader, fortran_writer): output = fortran_writer(psyir) # Check container symbol has not been renamed. assert "use model_mod_1" not in output - assert (" subroutine run_it()\n" - " use model_mod, only : radius\n" - " integer :: i\n" in output) - assert (" do i = 1, 10, 1\n" - " a(:,i) = 4 * radius\n" - " enddo\n" - " b(:,2) = radius\n" in output) + assert ( + " subroutine run_it()\n" + " use model_mod, only : radius\n" + " integer :: i\n" + in output + ) + assert ( + " do i = 1, 10, 1\n" + " a(:,i) = 4 * radius\n" + " enddo\n" + " b(:,2) = radius\n" + in output + ) def test_apply_name_clash(fortran_reader, fortran_writer, tmpdir): - ''' Check that apply() correctly handles the case where a symbol - in the routine to be in-lined clashes with an existing symbol. ''' + """Check that apply() correctly handles the case where a symbol + in the routine to be in-lined clashes with an existing symbol.""" code = ( "module test_mod\n" "contains\n" @@ -1002,22 +1076,23 @@ def test_apply_name_clash(fortran_reader, fortran_writer, tmpdir): " i = 3.0\n" " x = 2.0*x + i\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert (" i = 10\n" - " y = 1.0\n" - " i_1 = 3.0\n" - " y = 2.0 * y + i_1\n" in output) + assert ( + " i = 10\n y = 1.0\n i_1 = 3.0\n y = 2.0 * y + i_1\n" + in output + ) assert Compile(tmpdir).string_compiles(output) def test_apply_imported_symbols(fortran_reader, fortran_writer): - '''Test that the apply method correctly handles imported symbols in the - routine being inlined. ''' + """Test that the apply method correctly handles imported symbols in the + routine being inlined.""" code = ( "module test_mod\n" "contains\n" @@ -1031,23 +1106,27 @@ def test_apply_imported_symbols(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = 3*var2\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert (" subroutine run_it()\n" - " use some_mod, only : var2\n" - " integer :: i\n\n" - " i = 10\n" - " i = 3 * var2\n" in output) + assert ( + " subroutine run_it()\n" + " use some_mod, only : var2\n" + " integer :: i\n\n" + " i = 10\n" + " i = 3 * var2\n" + in output + ) # We can't check this with compilation because of the import of some_mod. def test_apply_last_stmt_is_return(fortran_reader, fortran_writer, tmpdir): - '''Test that the apply method correctly omits any final 'return' - statement that may be present in the routine to be inlined.''' + """Test that the apply method correctly omits any final 'return' + statement that may be present in the routine to be inlined.""" code = ( "module test_mod\n" "contains\n" @@ -1061,21 +1140,20 @@ def test_apply_last_stmt_is_return(fortran_reader, fortran_writer, tmpdir): " idx = idx + 3\n" " return\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert (" i = 10\n" - " i = i + 3\n\n" - " end subroutine run_it\n" in output) + assert " i = 10\n i = i + 3\n\n end subroutine run_it\n" in output assert Compile(tmpdir).string_compiles(output) def test_apply_call_args(fortran_reader, fortran_writer): - '''Check that apply works correctly if any of the actual - arguments are not simple references.''' + """Check that apply works correctly if any of the actual + arguments are not simple references.""" code = ( "module test_mod\n" " use kinds_mod, only: i_def\n" @@ -1091,22 +1169,24 @@ def test_apply_call_args(fortran_reader, fortran_writer): " integer(kind=i_def), intent(in) :: incr2\n" " idx = idx + incr1 * incr2\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert (" i = 10\n" - " i = i + 2 * i * 5_i_def\n\n" - " end subroutine run_it\n" in output) + assert ( + " i = 10\n i = i + 2 * i * 5_i_def\n\n end subroutine run_it\n" + in output + ) # Cannot test for compilation because of 'kinds_mod'. def test_apply_duplicate_imports(fortran_reader, fortran_writer): - '''Check that apply works correctly when the routine to be inlined + """Check that apply works correctly when the routine to be inlined imports symbols from a container that is also accessed in the - calling routine.''' + calling routine.""" code = ( "module test_mod\n" "contains\n" @@ -1121,24 +1201,29 @@ def test_apply_duplicate_imports(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = idx + 5_i_def\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert (" subroutine run_it()\n" - " use kinds_mod, only : i_def\n" - " integer :: i\n\n" in output) - assert (" i = 10_i_def\n" - " i = i + 5_i_def\n\n" - " end subroutine run_it\n" in output) + assert ( + " subroutine run_it()\n" + " use kinds_mod, only : i_def\n" + " integer :: i\n\n" + in output + ) + assert ( + " i = 10_i_def\n i = i + 5_i_def\n\n end subroutine run_it\n" + in output + ) # Cannot test for compilation because of 'kinds_mod'. def test_apply_wildcard_import(fortran_reader, fortran_writer): - '''Check that apply works correctly when a wildcard import is present - in the routine to be inlined.''' + """Check that apply works correctly when a wildcard import is present + in the routine to be inlined.""" code = ( "module test_mod\n" "contains\n" @@ -1153,22 +1238,24 @@ def test_apply_wildcard_import(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = idx + 5\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert (" subroutine run_it()\n" - " use kinds_mod\n" - " integer :: i\n\n" in output) + assert ( + " subroutine run_it()\n use kinds_mod\n integer :: i\n\n" + in output + ) # Cannot test for compilation because of 'kinds_mod'. def test_apply_import_union(fortran_reader, fortran_writer): - '''Test that the apply method works correctly when the set of symbols + """Test that the apply method works correctly when the set of symbols imported from a given container is not the same as that imported into - the scope of the call site.''' + the scope of the call site.""" code = ( "module test_mod\n" "contains\n" @@ -1183,23 +1270,26 @@ def test_apply_import_union(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = idx + 5_i_def\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert (" subroutine run_it()\n" - " use kinds_mod, only : i_def, r_def\n" - " integer :: i\n\n" in output) - assert (" i = 10.0_r_def\n" - " i = i + 5_i_def\n" in output) + assert ( + " subroutine run_it()\n" + " use kinds_mod, only : i_def, r_def\n" + " integer :: i\n\n" + in output + ) + assert " i = 10.0_r_def\n i = i + 5_i_def\n" in output # Cannot test for compilation because of 'kinds_mod'. def test_apply_callsite_rename(fortran_reader, fortran_writer): - '''Check that a symbol import in the routine causes a - rename of a symbol that is local to the *calling* scope.''' + """Check that a symbol import in the routine causes a + rename of a symbol that is local to the *calling* scope.""" code = ( "module test_mod\n" "contains\n" @@ -1217,26 +1307,30 @@ def test_apply_callsite_rename(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = idx + 5_i_def + a_clash\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert (" subroutine run_it()\n" - " use kinds_mod, only : i_def, r_def\n" - " use a_mod, only : a_clash\n" - " integer :: i\n" - " integer :: a_clash_1\n\n" - " a_clash_1 = 2\n" - " i = 10.0_r_def\n" - " i = i + 5_i_def + a_clash\n" - " i = i * a_clash_1\n" in output) + assert ( + " subroutine run_it()\n" + " use kinds_mod, only : i_def, r_def\n" + " use a_mod, only : a_clash\n" + " integer :: i\n" + " integer :: a_clash_1\n\n" + " a_clash_1 = 2\n" + " i = 10.0_r_def\n" + " i = i + 5_i_def + a_clash\n" + " i = i * a_clash_1\n" + in output + ) def test_apply_callsite_rename_container(fortran_reader, fortran_writer): - '''Check that an import from a container in the routine causes a - rename of a symbol that is local to the *calling* scope.''' + """Check that an import from a container in the routine causes a + rename of a symbol that is local to the *calling* scope.""" code = ( "module test_mod\n" "contains\n" @@ -1254,26 +1348,30 @@ def test_apply_callsite_rename_container(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = idx + 5_i_def + a_clash\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert (" subroutine run_it()\n" - " use kinds_mod, only : i_def, r_def\n" - " use a_mod, only : a_clash\n" - " integer :: i\n" - " integer :: a_mod_1\n\n" - " a_mod_1 = 2\n" - " i = 10.0_r_def\n" - " i = i + 5_i_def + a_clash\n" - " i = i * a_mod_1\n" in output) + assert ( + " subroutine run_it()\n" + " use kinds_mod, only : i_def, r_def\n" + " use a_mod, only : a_clash\n" + " integer :: i\n" + " integer :: a_mod_1\n\n" + " a_mod_1 = 2\n" + " i = 10.0_r_def\n" + " i = i + 5_i_def + a_clash\n" + " i = i * a_mod_1\n" + in output + ) def test_validate_non_local_import(fortran_reader): - '''Test that we reject the case where the routine to be - inlined accesses a symbol from an import in its parent container.''' + """Test that we reject the case where the routine to be + inlined accesses a symbol from an import in its parent container.""" code = ( "module test_mod\n" " use some_mod, only: trouble\n" @@ -1287,22 +1385,26 @@ def test_validate_non_local_import(fortran_reader): " integer :: idx\n" " idx = idx + trouble\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Routine 'sub' cannot be inlined because it accesses variable " - "'trouble' from its parent container." in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it accesses variable " + "'trouble' from its parent container." + in str(err.value) + ) def test_apply_shared_routine_call(fortran_reader): - ''' + """ Test the inlining of a routine that itself calls another routine that is also called from within the scope of the call site. - ''' - code = '''\ + """ + code = """\ module my_mod implicit none contains @@ -1315,7 +1417,7 @@ def test_apply_shared_routine_call(fortran_reader): use slartibartfast, only: norway call norway() end subroutine fijord - end module my_mod''' + end module my_mod""" psyir = fortran_reader.psyir_from_source(code) calls = psyir.walk(Call) inline_trans = InlineTrans() @@ -1329,15 +1431,16 @@ def test_apply_shared_routine_call(fortran_reader): nsym = routines[0].symbol_table.lookup("norway") for call in calls: if call.routine is not nsym: - pytest.xfail("#924 cannot reliably update references in inlined " - "code.") + pytest.xfail( + "#924 cannot reliably update references in inlined code." + ) def test_apply_function(fortran_reader, fortran_writer, tmpdir): - '''Check that the apply() method works correctly for a simple call to + """Check that the apply() method works correctly for a simple call to a function. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -1349,7 +1452,8 @@ def test_apply_function(fortran_reader, fortran_writer, tmpdir): " real :: b\n" " func = 2.0\n" " end function\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1361,34 +1465,41 @@ def test_apply_function(fortran_reader, fortran_writer, tmpdir): " real :: b\n" " real :: inlined_func\n\n" " inlined_func = 2.0\n" - " a = inlined_func") + " a = inlined_func" + ) assert expected in output assert Compile(tmpdir).string_compiles(output) # Try two different forms of function declaration. -@pytest.mark.parametrize("function_header", [ - " function func(b) result(x)\n real :: x\n", - " real function func(b) result(x)\n"]) +@pytest.mark.parametrize( + "function_header", + [ + " function func(b) result(x)\n real :: x\n", + " real function func(b) result(x)\n", + ], +) def test_apply_function_declare_name( - fortran_reader, fortran_writer, tmpdir, function_header): - '''Check that the apply() method works correctly for a simple call to + fortran_reader, fortran_writer, tmpdir, function_header +): + """Check that the apply() method works correctly for a simple call to a function where the name of the return name differs from the function name. - ''' + """ code = ( - f"module test_mod\n" - f"contains\n" - f" subroutine run_it()\n" - f" real :: a,b\n" - f" a = func(b)\n" - f" end subroutine run_it\n" + "module test_mod\n" + "contains\n" + " subroutine run_it()\n" + " real :: a,b\n" + " a = func(b)\n" + " end subroutine run_it\n" f"{function_header}" - f" real :: b\n" - f" x = 2.0\n" - f" end function\n" - f"end module test_mod\n") + " real :: b\n" + " x = 2.0\n" + " end function\n" + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1401,16 +1512,17 @@ def test_apply_function_declare_name( " real :: b\n" " real :: inlined_x\n\n" " inlined_x = 2.0\n" - " a = inlined_x") + " a = inlined_x" + ) assert expected in output assert Compile(tmpdir).string_compiles(output) def test_apply_function_expression(fortran_reader, fortran_writer, tmpdir): - '''Check that the apply() method works correctly for a call to a + """Check that the apply() method works correctly for a call to a function that is within an expression. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -1423,7 +1535,8 @@ def test_apply_function_expression(fortran_reader, fortran_writer, tmpdir): " b = b + 3.0\n" " x = b * 2.0\n" " end function\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1434,15 +1547,17 @@ def test_apply_function_expression(fortran_reader, fortran_writer, tmpdir): " real :: inlined_x\n\n" " b = b + 3.0\n" " inlined_x = b * 2.0\n" - " a = (a * inlined_x + 2.0) / a\n" in output) + " a = (a * inlined_x + 2.0) / a\n" + in output + ) assert Compile(tmpdir).string_compiles(output) def test_apply_multi_function(fortran_reader, fortran_writer, tmpdir): - '''Check that the apply() method works correctly when a function is + """Check that the apply() method works correctly when a function is called twice but only one of these function calls is inlined. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -1455,7 +1570,8 @@ def test_apply_multi_function(fortran_reader, fortran_writer, tmpdir): " real :: b\n" " func = 2.0\n" " end function\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1469,7 +1585,8 @@ def test_apply_multi_function(fortran_reader, fortran_writer, tmpdir): " real :: inlined_func\n\n" " inlined_func = 2.0\n" " a = inlined_func\n" - " c = func(a)") + " c = func(a)" + ) assert expected in output assert Compile(tmpdir).string_compiles(output) @@ -1483,35 +1600,45 @@ def test_apply_multi_function(fortran_reader, fortran_writer, tmpdir): " inlined_func = 2.0\n" " a = inlined_func\n" " inlined_func_1 = 2.0\n" - " c = inlined_func_1") + " c = inlined_func_1" + ) assert expected in output -@pytest.mark.parametrize("start, end, indent", [ - ("", "", ""), - ("module test_mod\ncontains\n", "end module test_mod\n", " "), - ("module test_mod\nuse formal\ncontains\n", "end module test_mod\n", - " ")]) +@pytest.mark.parametrize( + "start, end, indent", + [ + ("", "", ""), + ("module test_mod\ncontains\n", "end module test_mod\n", " "), + ( + "module test_mod\nuse formal\ncontains\n", + "end module test_mod\n", + " ", + ), + ], +) def test_apply_raw_subroutine( - fortran_reader, fortran_writer, tmpdir, start, end, indent): - '''Test the apply method works correctly when the routine to be + fortran_reader, fortran_writer, tmpdir, start, end, indent +): + """Test the apply method works correctly when the routine to be inlined is a raw subroutine and is called directly from another raw subroutine, a subroutine within a module but without a use statement and a subroutine within a module with a wildcard use statement. - ''' + """ code = ( f"{start}" - f" subroutine run_it()\n" - f" real :: a\n" - f" call sub(a)\n" - f" end subroutine run_it\n" + " subroutine run_it()\n" + " real :: a\n" + " call sub(a)\n" + " end subroutine run_it\n" f"{end}" - f"subroutine sub(x)\n" - f" real, intent(inout) :: x\n" - f" x = 2.0*x\n" - f"end subroutine sub\n") + "subroutine sub(x)\n" + " real, intent(inout) :: x\n" + " x = 2.0*x\n" + "end subroutine sub\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1521,7 +1648,8 @@ def test_apply_raw_subroutine( f"{indent}subroutine run_it()\n" f"{indent} real :: a\n\n" f"{indent} a = 2.0 * a\n\n" - f"{indent}end subroutine run_it\n") + f"{indent}end subroutine run_it\n" + ) assert expected in output if "use formal" not in output: # Compilation will not work with "use formal" as there is no @@ -1529,78 +1657,93 @@ def test_apply_raw_subroutine( assert Compile(tmpdir).string_compiles(output) -@pytest.mark.parametrize("use1, use2", [ - ("use inline_mod, only : sub\n", ""), ("use inline_mod\n", ""), - ("", "use inline_mod, only : sub\n"), ("", "use inline_mod\n")]) +@pytest.mark.parametrize( + "use1, use2", + [ + ("use inline_mod, only : sub\n", ""), + ("use inline_mod\n", ""), + ("", "use inline_mod, only : sub\n"), + ("", "use inline_mod\n"), + ], +) def test_apply_container_subroutine( - fortran_reader, fortran_writer, tmpdir, use1, use2): - '''Test the apply method works correctly when the routine to be + fortran_reader, fortran_writer, tmpdir, use1, use2 +): + """Test the apply method works correctly when the routine to be inlined is in a different container and is within a module (so a use statement is required). - ''' + """ code = ( - f"module inline_mod\n" - f"contains\n" - f" subroutine sub(x)\n" - f" real, intent(inout) :: x\n" - f" x = 2.0*x\n" - f" end subroutine sub\n" - f"end module inline_mod\n" - f"module test_mod\n" + "module inline_mod\n" + "contains\n" + " subroutine sub(x)\n" + " real, intent(inout) :: x\n" + " x = 2.0*x\n" + " end subroutine sub\n" + "end module inline_mod\n" + "module test_mod\n" f"{use1}" - f"contains\n" - f" subroutine run_it()\n" + "contains\n" + " subroutine run_it()\n" f" {use2}" - f" real :: a\n" - f" call sub(a)\n" - f" end subroutine run_it\n" - f"end module test_mod\n") + " real :: a\n" + " call sub(a)\n" + " end subroutine run_it\n" + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) assert ( - " real :: a\n\n" - " a = 2.0 * a\n\n" - " end subroutine run_it" in output) + " real :: a\n\n a = 2.0 * a\n\n end subroutine run_it" in output + ) assert Compile(tmpdir).string_compiles(output) def test_apply_validate(): - '''Test the apply method calls the validate method.''' + """Test the apply method calls the validate method.""" inline_trans = InlineTrans() with pytest.raises(TransformationError) as info: inline_trans.apply(None) - assert ("The target of the InlineTrans transformation should be " - "a Call but found 'NoneType'." in str(info.value)) + assert ( + "The target of the InlineTrans transformation should be " + "a Call but found 'NoneType'." + in str(info.value) + ) # validate + def test_validate_node(): - ''' Test the expected exception is raised if an invalid node is - supplied to the transformation. ''' + """Test the expected exception is raised if an invalid node is + supplied to the transformation.""" inline_trans = InlineTrans() with pytest.raises(TransformationError) as info: inline_trans.validate(None) - assert ("The target of the InlineTrans transformation should be " - "a Call but found 'NoneType'." in str(info.value)) - call = IntrinsicCall.create(IntrinsicCall.Intrinsic.ALLOCATE, - [Reference(DataSymbol("array", - UnresolvedType()))]) + assert ( + "The target of the InlineTrans transformation should be " + "a Call but found 'NoneType'." + in str(info.value) + ) + call = IntrinsicCall.create( + IntrinsicCall.Intrinsic.ALLOCATE, + [Reference(DataSymbol("array", UnresolvedType()))], + ) with pytest.raises(TransformationError) as info: inline_trans.validate(call) assert "Cannot inline an IntrinsicCall ('ALLOCATE')" in str(info.value) def test_validate_calls_find_routine(fortran_reader): - '''Test that validate() calls the _find_routine method. Use an example + """Test that validate() calls the _find_routine method. Use an example where an exception is raised as the source of the routine to be inlined cannot be found. - ''' + """ code = ( "module test_mod\n" " use some_mod\n" @@ -1610,23 +1753,28 @@ def test_validate_calls_find_routine(fortran_reader): " i = 10\n" " call sub(i)\n" " end subroutine run_it\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Cannot inline routine 'sub' because its source cannot be found: " - "Failed to find the source code of the unresolved routine 'sub' - " - "looked at any routines in the same source file and attempted to " - "resolve the wildcard imports from ['some_mod']. However, failed " - "to find the source for ['some_mod']" in str(err.value)) + print(err.value) + assert ( + "Cannot inline routine 'sub' because its source cannot be found:\n" + "Failed to find the source code of the unresolved routine 'sub' - " + "looked at any routines in the same source file and attempted to " + "resolve the wildcard imports from ['some_mod']. However, failed " + "to find the source for ['some_mod']" + in str(err.value) + ) def test_validate_return_stmt(fortran_reader): - '''Test that validate() raises the expected error if the target routine + """Test that validate() raises the expected error if the target routine contains one or more Returns which that aren't either the very first - statement or very last statement.''' + statement or very last statement.""" code = ( "module test_mod\n" "contains\n" @@ -1641,20 +1789,24 @@ def test_validate_return_stmt(fortran_reader): " return\n" " idx = idx + 3\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Routine 'sub' contains one or more Return statements and " - "therefore cannot be inlined" in str(err.value)) + assert ( + "Routine 'sub' contains one or more Return statements and " + "therefore cannot be inlined" + in str(err.value) + ) def test_validate_codeblock(fortran_reader): - '''Test that validate() raises the expected error for a routine that + """Test that validate() raises the expected error for a routine that contains a CodeBlock. Also test that using the "force" option overrides - this check.''' + this check.""" code = ( "module test_mod\n" "contains\n" @@ -1667,22 +1819,27 @@ def test_validate_codeblock(fortran_reader): " integer :: idx\n" " write(*,*) idx\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Routine 'sub' contains one or more CodeBlocks and therefore " - "cannot be inlined. (If you are confident " in str(err.value)) - inline_trans.validate(call, options={"force": True}) + assert ( + "Routine 'sub' contains one or more CodeBlocks and therefore " + "cannot be inlined. (If you are confident " + in str(err.value) + ) + inline_trans.set_option(check_codeblocks=False) + inline_trans.validate(call) def test_validate_unsupportedtype_argument(fortran_reader): - ''' + """ Test that validate rejects a subroutine with arguments of UnsupportedType. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -1703,17 +1860,20 @@ def test_validate_unsupportedtype_argument(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(routine) - assert ("Routine 'sub' cannot be inlined because it contains a Symbol 'x' " - "which is an Argument of UnsupportedType: 'REAL, POINTER, " - "INTENT(INOUT) :: x'" in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it contains a Symbol 'x' " + "which is an Argument of UnsupportedType: 'REAL, POINTER, " + "INTENT(INOUT) :: x'" + in str(err.value) + ) def test_validate_unknowninterface(fortran_reader, fortran_writer, tmpdir): - ''' + """ Test that validate rejects a subroutine containing variables with UnknownInterface. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -1731,15 +1891,19 @@ def test_validate_unknowninterface(fortran_reader, fortran_writer, tmpdir): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(routine) - assert (" Routine 'sub' cannot be inlined because it contains a Symbol " - "'x' with an UnknownInterface: 'REAL, POINTER :: x'" - in str(err.value)) + assert ( + " Routine 'sub' cannot be inlined because it contains a Symbol " + "'x' with an UnknownInterface: 'REAL, POINTER :: x'" + in str(err.value) + ) # But if the interface is known, it has no problem inlining it xvar = psyir.walk(Routine)[1].symbol_table.lookup("x") xvar.interface = AutomaticInterface() inline_trans.apply(routine) - assert fortran_writer(psyir.walk(Routine)[0]) == """\ + assert ( + fortran_writer(psyir.walk(Routine)[0]) + == """\ subroutine main() REAL, POINTER :: x @@ -1747,14 +1911,15 @@ def test_validate_unknowninterface(fortran_reader, fortran_writer, tmpdir): end subroutine main """ + ) assert Compile(tmpdir).string_compiles(fortran_writer(psyir)) def test_validate_static_var(fortran_reader): - ''' + """ Test that validate rejects a subroutine with StaticInterface variables. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -1768,37 +1933,43 @@ def test_validate_static_var(fortran_reader): " state = state + x\n" " x = 2.0*x + state\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(routine) - assert ("Routine 'sub' cannot be inlined because it has a static (Fortran " - "SAVE) interface for Symbol 'state'." in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it has a static (Fortran " + "SAVE) interface for Symbol 'state'." + in str(err.value) + ) -@pytest.mark.parametrize("code_body", ["idx = idx + 5_i_def", - "real, parameter :: pi = 3_wp\n" - "idx = idx + 1\n"]) +@pytest.mark.parametrize( + "code_body", + ["idx = idx + 5_i_def", "real, parameter :: pi = 3_wp\nidx = idx + 1\n"], +) def test_validate_unresolved_precision_sym(fortran_reader, code_body): - '''Test that a routine that uses an unresolved precision symbol is + """Test that a routine that uses an unresolved precision symbol is rejected. We test when the precision symbol appears in an executable - statement and when it appears in a constant initialisation.''' + statement and when it appears in a constant initialisation.""" code = ( - f"module test_mod\n" - f" use kinds_mod\n" - f"contains\n" - f" subroutine run_it()\n" - f" integer :: i\n" - f" i = 10_i_def\n" - f" call sub(i)\n" - f" end subroutine run_it\n" - f" subroutine sub(idx)\n" - f" integer, intent(inout) :: idx\n" + "module test_mod\n" + " use kinds_mod\n" + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " i = 10_i_def\n" + " call sub(i)\n" + " end subroutine run_it\n" + " subroutine sub(idx)\n" + " integer, intent(inout) :: idx\n" f" {code_body}\n" - f" end subroutine sub\n" - f"end module test_mod\n") + " end subroutine sub\n" + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() call = psyir.walk(Call)[0] @@ -1808,15 +1979,17 @@ def test_validate_unresolved_precision_sym(fortran_reader, code_body): var_name = "wp" else: var_name = "i_def" - assert (f"Routine 'sub' cannot be inlined because it accesses variable " - f"'{var_name}' and this cannot be found in any of the containers " - f"directly imported into its symbol table" in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it accesses variable " + f"'{var_name}' and this cannot be found in any of the containers " + "directly imported into its symbol table" + in str(err.value) + ) -def test_validate_resolved_precision_sym(fortran_reader, monkeypatch, - tmpdir): - '''Test that a routine that uses a resolved precision symbol from its - parent Container is rejected.''' +def test_validate_resolved_precision_sym(fortran_reader, monkeypatch, tmpdir): + """Test that a routine that uses a resolved precision symbol from its + parent Container is rejected.""" code = ( "module test_mod\n" " use kinds_mod\n" @@ -1836,34 +2009,40 @@ def test_validate_resolved_precision_sym(fortran_reader, monkeypatch, " integer, intent(inout) :: idx\n" " idx = idx + 5_i_def\n" " end subroutine sub2\n" - "end module test_mod\n") + "end module test_mod\n" + ) # Set up include_path to import the proper module - monkeypatch.setattr(Config.get(), '_include_paths', [str(tmpdir)]) + monkeypatch.setattr(Config.get(), "_include_paths", [str(tmpdir)]) filename = os.path.join(str(tmpdir), "kinds_mod.f90") - with open(filename, "w", encoding='UTF-8') as module: - module.write(''' + with open(filename, "w", encoding="UTF-8") as module: + module.write( + """ module kinds_mod integer, parameter :: i_def = kind(1) end module kinds_mod - ''') + """ + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() # First subroutine accesses i_def from parent Container. calls = psyir.walk(Call) with pytest.raises(TransformationError) as err: inline_trans.validate(calls[0]) - assert ("Routine 'sub' cannot be inlined because it accesses variable " - "'i_def' and this cannot be found in any of the containers " - "directly imported into its symbol table." in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it accesses variable " + "'i_def' and this cannot be found in any of the containers " + "directly imported into its symbol table." + in str(err.value) + ) # Second subroutine imports i_def directly into its own SymbolTable and # so is OK to inline. inline_trans.validate(calls[1]) def test_validate_import_clash(fortran_reader): - '''Test that validate() raises the expected error when two symbols of the + """Test that validate() raises the expected error when two symbols of the same name are imported from different containers at the call site and - within the routine.''' + within the routine.""" code = ( "module test_mod\n" "contains\n" @@ -1878,19 +2057,23 @@ def test_validate_import_clash(fortran_reader): " integer :: idx\n" " idx = idx + trouble\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("One or more symbols from routine 'sub' cannot be added to the " - "table at the call site." in str(err.value)) + assert ( + "One or more symbols from routine 'sub' cannot be added to the " + "table at the call site." + in str(err.value) + ) def test_validate_non_local_symbol(fortran_reader): - '''Test that validate() raises the expected error when the routine to be - inlined accesses a symbol from its parent container.''' + """Test that validate() raises the expected error when the routine to be + inlined accesses a symbol from its parent container.""" code = ( "module test_mod\n" " integer :: trouble\n" @@ -1904,20 +2087,24 @@ def test_validate_non_local_symbol(fortran_reader): " integer :: idx\n" " idx = idx + trouble\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Routine 'sub' cannot be inlined because it accesses variable " - "'trouble' from its parent container" in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it accesses variable " + "'trouble' from its parent container" + in str(err.value) + ) def test_validate_wrong_number_args(fortran_reader): - ''' Test that validate rejects inlining routines with different number + """Test that validate rejects inlining routines with different number of arguments. - ''' + """ code = ( "module test_mod\n" " integer :: trouble\n" @@ -1931,20 +2118,24 @@ def test_validate_wrong_number_args(fortran_reader): " integer :: idx\n" " idx = idx + 1\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Cannot inline 'call sub(i, trouble)' because the number of " - "arguments supplied to the call (2) does not match the number of " - "arguments the routine is declared to have (1)" in str(err.value)) + assert ( + "Cannot inline 'call sub(i, trouble)' because the number of " + "arguments supplied to the call (2) does not match the number of " + "arguments the routine is declared to have (1)" + in str(err.value) + ) def test_validate_unresolved_import(fortran_reader): - '''Test that validate rejects a routine that accesses a symbol which - is unresolved.''' + """Test that validate rejects a routine that accesses a symbol which + is unresolved.""" code = ( "module test_mod\n" " use some_mod\n" @@ -1958,23 +2149,27 @@ def test_validate_unresolved_import(fortran_reader): " integer :: idx\n" " idx = idx + trouble\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Routine 'sub' cannot be inlined because it accesses variable " - "'trouble' and this cannot be found in any of the containers " - "directly imported into its symbol table." in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it accesses variable " + "'trouble' and this cannot be found in any of the containers " + "directly imported into its symbol table." + in str(err.value) + ) def test_validate_unresolved_array_dim(fortran_reader): - ''' + """ Check that validate rejects a routine if it uses an unresolved Symbol when defining an array dimension. - ''' + """ code = ( "module test_mod\n" " use some_mod\n" @@ -1989,21 +2184,25 @@ def test_validate_unresolved_array_dim(fortran_reader): " integer, dimension(some_size) :: var\n" " idx = idx + 2\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Routine 'sub' cannot be inlined because it accesses variable " - "'some_size' and this cannot be found in any of the containers " - "directly imported into its symbol table" in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it accesses variable " + "'some_size' and this cannot be found in any of the containers " + "directly imported into its symbol table" + in str(err.value) + ) def test_validate_array_reshape(fortran_reader): - '''Test that the validate method rejects an attempt to inline a routine + """Test that the validate method rejects an attempt to inline a routine if any of its formal arguments are declared to be a different shape from - those at the call site.''' + those at the call site.""" code = ( "module test_mod\n" "contains\n" @@ -2019,23 +2218,27 @@ def test_validate_array_reshape(fortran_reader): " x(i) = x(i) + m\n" " enddo\n" "end subroutine\n" - "end module\n") + "end module\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Cannot inline routine 's' because it reshapes an argument: actual" - " argument 'a(:,:)' has rank 2 but the corresponding formal " - "argument, 'x', has rank 1" in str(err.value)) + assert ( + "Cannot inline routine 's' because it reshapes an argument: actual" + " argument 'a(:,:)' has rank 2 but the corresponding formal " + "argument, 'x', has rank 1" + in str(err.value) + ) def test_validate_array_arg_expression(fortran_reader): - ''' + """ Check that validate rejects a call if an argument corresponding to a formal array argument is not a simple Reference or Literal. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -2051,20 +2254,24 @@ def test_validate_array_arg_expression(fortran_reader): " x(i) = x(i) + m\n" " enddo\n" "end subroutine\n" - "end module\n") + "end module\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("The call 'call s(a + b, 10)\n' cannot be inlined because actual " - "argument 'a + b' corresponds to a formal argument with array " - "type but is not a Reference or a Literal" in str(err.value)) + assert ( + "The call 'call s(a + b, 10)\n' cannot be inlined because actual " + "argument 'a + b' corresponds to a formal argument with array " + "type but is not a Reference or a Literal" + in str(err.value) + ) def test_validate_indirect_range(fortran_reader): - '''Test that validate rejects an attempt to inline a call to a routine - with an argument constructed using an indirect slice.''' + """Test that validate rejects an attempt to inline a call to a routine + with an argument constructed using an indirect slice.""" code = ( "module test_mod\n" " integer, dimension(10) :: indices\n" @@ -2077,19 +2284,23 @@ def test_validate_indirect_range(fortran_reader): " real, dimension(:), intent(inout) :: x\n" " x(:) = 0.0\n" "end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Cannot inline routine 'sub' because argument 'var(indices(:))' " - "has an array range in an indirect access" in str(err.value)) + assert ( + "Cannot inline routine 'sub' because argument 'var(indices(:))' " + "has an array range in an indirect access" + in str(err.value) + ) def test_validate_non_unit_stride_slice(fortran_reader): - '''Test that validate rejects an attempt to inline a call to a routine - with an argument constructed using an array slice with non-unit stride.''' + """Test that validate rejects an attempt to inline a call to a routine + with an argument constructed using an array slice with non-unit stride.""" code = ( "module test_mod\n" "contains\n" @@ -2101,20 +2312,23 @@ def test_validate_non_unit_stride_slice(fortran_reader): " real, dimension(:), intent(inout) :: x\n" " x(:) = 0.0\n" "end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Cannot inline routine 'sub' because one of its arguments is an " - "array slice with a non-unit stride: 'var(::2)' (TODO #1646)" in - str(err.value)) + assert ( + "Cannot inline routine 'sub' because one of its arguments is an " + "array slice with a non-unit stride: 'var(::2)' (TODO #1646)" + in str(err.value) + ) def test_validate_named_arg(fortran_reader): - '''Test that the validate method rejects an attempt to inline a routine - that has a named argument.''' + """Test that the validate method rejects an attempt to inline a routine + that has a named argument.""" # In reality, the routine with a named argument would almost certainly # use the 'present' intrinsic but, since that gives a CodeBlock that itself # prevents inlining, our test example omits it. @@ -2140,8 +2354,11 @@ def test_validate_named_arg(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ("Routine 'sub' cannot be inlined because it has a named argument " - "'opt' (TODO #924)" in str(err.value)) + assert ( + "Routine 'sub' cannot be inlined because it has a named argument " + "'opt' (TODO #924)" + in str(err.value) + ) CALL_IN_SUB_USE = ( @@ -2149,26 +2366,18 @@ def test_validate_named_arg(fortran_reader): " use inline_mod, only : sub\n" " real :: a\n" " call sub(a)\n" - "end subroutine run_it\n") -CALL_IN_SUB = CALL_IN_SUB_USE.replace( - " use inline_mod, only : sub\n", "") -SUB = ( - "subroutine sub(x)\n" - " real :: x\n" - " x = 1.0\n" - "end subroutine sub\n") -SUB_IN_MODULE = ( - f"module inline_mod\n" - f"contains\n" - f"{SUB}" - f"end module inline_mod\n") + "end subroutine run_it\n" +) +CALL_IN_SUB = CALL_IN_SUB_USE.replace(" use inline_mod, only : sub\n", "") +SUB = "subroutine sub(x)\n real :: x\n x = 1.0\nend subroutine sub\n" +SUB_IN_MODULE = f"module inline_mod\ncontains\n{SUB}end module inline_mod\n" def test_apply_merges_symbol_table_with_routine(fortran_reader): - ''' + """ Check that the apply method merges the inlined function's symbol table to the containing Routine when the call node is inside a child ScopingNode. - ''' + """ code = ( "module test_mod\n" "contains\n" @@ -2187,21 +2396,22 @@ def test_apply_merges_symbol_table_with_routine(fortran_reader): " x(i) = 2.0*ivar\n" " end do\n" " end subroutine sub\n" - "end module test_mod\n") + "end module test_mod\n" + ) psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) # The i_1 symbol is the renamed i from the inlined call. - assert psyir.walk(Routine)[0].symbol_table.get_symbols()['i_1'] is not None + assert psyir.walk(Routine)[0].symbol_table.get_symbols()["i_1"] is not None def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): - ''' + """ Check that the formal arguments to the inlined routine are not included when checking for clashes (since they will be replaced by the actual arguments to the call). - ''' + """ code_clash = """ subroutine sub(Istr) @@ -2223,7 +2433,7 @@ def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) - expected = '''\ + expected = """\ subroutine sub(istr) integer :: istr real :: x @@ -2234,7 +2444,7 @@ def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): b(istr:) = 1.0 end subroutine sub -''' +""" output = fortran_writer(psyir) assert expected in output assert Compile(tmpdir).string_compiles(output) From a521faf473a078d1d6820128fbb427309194d6a6 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sun, 24 Nov 2024 01:54:36 +0100 Subject: [PATCH 04/20] u --- src/psyclone/psyir/nodes/call.py | 7 +- .../psyir/transformations/inline_trans.py | 350 +++++----- .../psyir/transformations/omp_task_trans.py | 27 + src/psyclone/tests/psyir/nodes/call_test.py | 630 ++++++++++-------- .../transformations/inline_trans_test.py | 54 +- .../omp_task_transformations_test.py | 3 +- 6 files changed, 633 insertions(+), 438 deletions(-) diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index 8633629b3f..c24e852da0 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -837,8 +837,9 @@ def _get_argument_routine_match( ) if len(self.arguments) > len(routine.symbol_table.argument_list): + call_str = self.debug_string().replace("\n", "") raise CallMatchingArgumentsNotFound( - f"More arguments in call ('{self.debug_string()}')" + f"More arguments in call ('{call_str}')" f" than callee (routine '{routine.name}')" ) @@ -979,7 +980,9 @@ def get_callee( return (routine_list[0], [i for i in range(len(self.arguments))]) error_msg = "\n".join(err_info) + raise CallMatchingArgumentsNotFound( - f"No matching routine found for '{self.debug_string()}'" + "Found routines, but no routine with matching arguments found " + f"for '{self.routine.name}':\n" + error_msg ) diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 1756a78afa..3532fa9dc5 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -212,25 +212,24 @@ def apply( self.validate(node_call, node_routine=node_routine, options=options) # The table associated with the scoping region holding the Call. table = node_call.scope.symbol_table - # Find the routine to be inlined. - orig_routine = node_call.get_callees()[0] - if not orig_routine.children or isinstance(orig_routine.children[0], - Return): + if not self.node_routine.children or isinstance( + self.node_routine.children[0], Return + ): # Called routine is empty so just remove the call. node_call.detach() return # Ensure we don't modify the original Routine by working with a # copy of it. - routine = orig_routine.copy() - routine_table = routine.symbol_table + self.node_routine = self.node_routine.copy() + routine_table = self.node_routine.symbol_table # Construct lists of the nodes that will be inserted and all of the # References that they contain. new_stmts = [] refs = [] - for child in routine.children: + for child in self.node_routine.children: new_stmts.append(child.copy()) refs.extend(new_stmts[-1].walk(Reference)) @@ -273,7 +272,7 @@ def apply( # remove it from the list. del new_stmts[-1] - if routine.return_symbol: + if self.node_routine.return_symbol: # This is a function assignment = node_call.ancestor(Statement) parent = assignment.parent @@ -284,9 +283,12 @@ def apply( table = parent.scope.symbol_table # Avoid a potential name clash with the original function table.rename_symbol( - routine.return_symbol, table.next_available_name( - f"inlined_{routine.return_symbol.name}")) - node_call.replace_with(Reference(routine.return_symbol)) + self.node_routine.return_symbol, + table.next_available_name( + f"inlined_{self.node_routine.return_symbol.name}" + ), + ) + node_call.replace_with(Reference(self.node_routine.return_symbol)) else: # This is a call parent = node_call.parent @@ -726,7 +728,19 @@ def validate( # Check that we can find the source of the routine being inlined. # TODO #924 allow for multiple routines (interfaces). try: - self.node_routine = node_call.get_callees()[0] + (self.node_routine, self._ret_arg_match_list) = ( + node_call.get_callee( + check_matching_arguments=( + self._option_check_matching_arguments_of_callee + ), + check_strict_array_datatype=( + self._option_check_argument_strict_array_datatype + ), + ignore_missing_modules=( + self._option_ignore_missing_modules + ), + ) + ) except ( CallMatchingArgumentsNotFound, NotImplementedError, @@ -781,59 +795,73 @@ def validate( "`check_codeblocks=False` to override.)" ) - # Support for routines with named arguments is not yet implemented. - # TODO #924. - for arg in node_call.argument_names: - if arg: - raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be inlined" - f" because it has a named argument '{arg}' (TODO #924)." - ) - table = node_call.scope.symbol_table routine_table = self.node_routine.symbol_table + # TODO: Maybe move me to options + check_argument_unsupported_type = True + check_static_interface = True + check_diff_container_clashes = True + check_diff_container_clashes_unresolved_types = True + check_resolve_imports = True + check_array_type = True + for sym in routine_table.datasymbols: # We don't inline symbols that have an UnsupportedType and are # arguments since we don't know if a simple assignment if # enough (e.g. pointers) - if isinstance(sym.interface, ArgumentInterface): - if isinstance(sym.datatype, UnsupportedType): + if check_argument_unsupported_type: + if isinstance(sym.interface, ArgumentInterface): + if isinstance(sym.datatype, UnsupportedType): + if ", OPTIONAL" not in sym.datatype.declaration: + raise TransformationError( + f"Routine '{self.node_routine.name}' cannot be" + " inlined because it contains a Symbol" + f" '{sym.name}' which is an Argument of" + " UnsupportedType:" + f" '{sym.datatype.declaration}'" + ) + # We don't inline symbols that have an UnknownInterface, as we + # don't know how they are brought into this scope. + if isinstance(sym.interface, UnknownInterface): raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be inlined" - f" because it contains a Symbol '{sym.name}' which is" - " an Argument of UnsupportedType:" - f" '{sym.datatype.declaration}'" + f"Routine '{self.node_routine.name}' cannot be " + "inlined because it contains a Symbol " + f"'{sym.name}' with an UnknownInterface: " + f"'{sym.datatype.declaration}'" ) - # We don't inline symbols that have an UnknownInterface, as we - # don't know how they are brought into this scope. - if isinstance(sym.interface, UnknownInterface): - raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be inlined" - f" because it contains a Symbol '{sym.name}' with an" - f" UnknownInterface: '{sym.datatype.declaration}'" + + if check_static_interface: + # Check that there are no static variables in the routine + # (because we don't know whether the routine is called from + # other places). + if ( + isinstance(sym.interface, StaticInterface) + and not sym.is_constant + ): + raise TransformationError( + f"Routine '{self.node_routine.name}' cannot be " + "inlined because it has a static (Fortran SAVE) " + f"interface for Symbol '{sym.name}'." + ) + + if check_diff_container_clashes: + # We can't handle a clash between (apparently) different symbols + # that share a name but are imported from different containers. + try: + table.check_for_clashes( + routine_table, + symbols_to_skip=routine_table.argument_list[:], + check_unresolved_symbols=( + check_diff_container_clashes_unresolved_types + ), ) - # Check that there are no static variables in the routine (because - # we don't know whether the routine is called from other places). - if (isinstance(sym.interface, StaticInterface) and - not sym.is_constant): + except SymbolError as err: raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be inlined" - " because it has a static (Fortran SAVE) interface for" - f" Symbol '{sym.name}'." - ) - - # We can't handle a clash between (apparently) different symbols that - # share a name but are imported from different containers. - try: - table.check_for_clashes( - routine_table, - symbols_to_skip=routine_table.argument_list[:]) - except SymbolError as err: - raise TransformationError( - f"One or more symbols from routine '{self.node_routine.name}'" - " cannot be added to the table at the call site." - ) from err + "One or more symbols from routine " + f"'{self.node_routine.name}' cannot be added to the " + "table at the call site." + ) from err # Check for unresolved symbols or for any accessed from the Container # containing the target routine. @@ -849,16 +877,19 @@ def validate( for sym in routine_table.datasymbols: if sym.initial_value: ref_or_lits.extend( - sym.initial_value.walk((Reference, Literal))) + sym.initial_value.walk((Reference, Literal)) + ) if isinstance(sym.datatype, ArrayType): for dim in sym.shape: if isinstance(dim, ArrayType.ArrayBounds): if isinstance(dim.lower, Node): - ref_or_lits.extend(dim.lower.walk(Reference, - Literal)) + ref_or_lits.extend( + dim.lower.walk(Reference, Literal) + ) if isinstance(dim.upper, Node): - ref_or_lits.extend(dim.upper.walk(Reference, - Literal)) + ref_or_lits.extend( + dim.upper.walk(Reference, Literal) + ) # Keep a reference to each Symbol that we check so that we can avoid # repeatedly checking the same Symbol. _symbol_cache = set() @@ -875,45 +906,38 @@ def validate( _symbol_cache.add(sym) if isinstance(sym, IntrinsicSymbol): continue - # We haven't seen this Symbol before. - if sym.is_unresolved: - try: - routine_table.resolve_imports(symbol_target=sym) - except KeyError: - # The symbol is not (directly) imported into the symbol - # table local to the routine. - # pylint: disable=raise-missing-from - raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be inlined" - f" because it accesses variable '{sym.name}' and this" - " cannot be found in any of the containers directly" - " imported into its symbol table." - ) - else: - if sym.name not in routine_table: - raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be inlined" - f" because it accesses variable '{sym.name}' from its" - " parent container." - ) - # Check that the shapes of any formal array arguments are the same as - # those at the call site. - if len(routine_table.argument_list) != len(node_call.arguments): - raise TransformationError( - LazyString( - lambda: ( - f"Cannot inline '{node_call.debug_string().strip()}'" - " because the number of arguments supplied to the" - f" call ({len(node_call.arguments)}) does not match" - " the number of arguments the routine is declared to" - f" have ({len(routine_table.argument_list)})." - ) - ) - ) + if check_resolve_imports: + # We haven't seen this Symbol before. + if sym.is_unresolved: + try: + routine_table.resolve_imports(symbol_target=sym) + except KeyError: + # The symbol is not (directly) imported into the symbol + # table local to the routine. + # pylint: disable=raise-missing-from + raise TransformationError( + f"Routine '{self.node_routine.name}' cannot be " + "inlined because it accesses variable " + f"'{sym.name}' and this cannot be found in any " + "of the containers directly imported into its " + "symbol table." + ) + else: + if sym.name not in routine_table: + raise TransformationError( + f"Routine '{self.node_routine.name}' cannot be " + "inlined because it accesses variable " + f"'{sym.name}' from its parent container." + ) + + # Create a list of routine arguments that is actually used + routine_arg_list = [ + routine_table.argument_list[i] for i in self._ret_arg_match_list + ] for formal_arg, actual_arg in zip( - routine_table.argument_list, node_call.arguments + routine_arg_list, node_call.arguments ): # If the formal argument is an array with non-default bounds then # we also need to know the bounds of that array at the call site. @@ -933,8 +957,8 @@ def validate( raise TransformationError( LazyString( lambda: ( - f"The call '{node_call.debug_string()}' cannot be " - "inlined because actual argument " + f"The call '{node_call.debug_string()}' " + "cannot be inlined because actual argument " f"'{actual_arg.debug_string()}' corresponds to a " "formal argument with array type but is not a " "Reference or a Literal." @@ -948,75 +972,83 @@ def validate( # TODO #924. It would be useful if the `datatype` property was # a method that took an optional 'resolve' argument to indicate # that it should attempt to resolve any UnresolvedTypes. - if (isinstance(actual_arg.datatype, - (UnresolvedType, UnsupportedType)) or - (isinstance(actual_arg.datatype, ArrayType) and - isinstance(actual_arg.datatype.intrinsic, - (UnresolvedType, UnsupportedType)))): - raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be inlined" - " because the type of the actual argument" - f" '{actual_arg.symbol.name}' corresponding to an array" - f" formal argument ('{formal_arg.name}') is unknown." - ) + if check_array_type: + if isinstance( + actual_arg.datatype, (UnresolvedType, UnsupportedType) + ) or ( + isinstance(actual_arg.datatype, ArrayType) + and isinstance( + actual_arg.datatype.intrinsic, + (UnresolvedType, UnsupportedType), + ) + ): + raise TransformationError( + f"Routine '{self.node_routine.name}' cannot be " + "inlined because the type of the actual argument " + f"'{actual_arg.symbol.name}' corresponding to an array" + f" formal argument ('{formal_arg.name}') is unknown." + ) - formal_rank = 0 - actual_rank = 0 - if isinstance(formal_arg.datatype, ArrayType): - formal_rank = len(formal_arg.datatype.shape) - if isinstance(actual_arg.datatype, ArrayType): - actual_rank = len(actual_arg.datatype.shape) - if formal_rank != actual_rank: - # It's OK to use the loop variable in the lambda definition - # because if we get to this point then we're going to quit - # the loop. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - f"Cannot inline routine '{self.node_routine.name}'" - " because it reshapes an argument: actual" - f" argument '{actual_arg.debug_string()}' has rank" - f" {actual_rank} but the corresponding formal" - f" argument, '{formal_arg.name}', has rank" - f" {formal_rank}" + formal_rank = 0 + actual_rank = 0 + if isinstance(formal_arg.datatype, ArrayType): + formal_rank = len(formal_arg.datatype.shape) + if isinstance(actual_arg.datatype, ArrayType): + actual_rank = len(actual_arg.datatype.shape) + if formal_rank != actual_rank: + # It's OK to use the loop variable in the lambda definition + # because if we get to this point then we're going to quit + # the loop. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self.node_routine.name}' because it" + " reshapes an argument: actual argument" + f" '{actual_arg.debug_string()}' has rank" + f" {actual_rank} but the corresponding formal" + f" argument, '{formal_arg.name}', has rank" + f" {formal_rank}" + ) ) ) - ) - if actual_rank: - ranges = actual_arg.walk(Range) - for rge in ranges: - ancestor_ref = rge.ancestor(Reference) - if ancestor_ref is not actual_arg: - # Have a range in an indirect access. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{self.node_routine.name}' because" - f" argument '{actual_arg.debug_string()}'" - " has an array range in an indirect" - " access (TODO #924)." + if actual_rank: + ranges = actual_arg.walk(Range) + for rge in ranges: + ancestor_ref = rge.ancestor(Reference) + if ancestor_ref is not actual_arg: + # Have a range in an indirect access. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self.node_routine.name}' because" + " argument" + f" '{actual_arg.debug_string()}' has" + " an array range in an indirect" + " access #(TODO 924)." + ) ) ) - ) - if rge.step != _ONE: - # TODO #1646. We could resolve this problem by making - # a new array and copying the necessary values into it. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{self.node_routine.name}' because one" - " of its arguments is an array slice with" - " a non-unit stride:" - f" '{actual_arg.debug_string()}' (TODO" - " #1646)" + if rge.step != _ONE: + # TODO #1646. We could resolve this problem by + # making a new array and copying the necessary + # values into it. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self.node_routine.name}' because" + " one of its arguments is an array" + " slice with a non-unit stride:" + f" '{actual_arg.debug_string()}' (TODO" + " #1646)" + ) ) ) - ) # For AutoAPI auto-documentation generation. diff --git a/src/psyclone/psyir/transformations/omp_task_trans.py b/src/psyclone/psyir/transformations/omp_task_trans.py index f11e6fa975..bd096a9c5c 100644 --- a/src/psyclone/psyir/transformations/omp_task_trans.py +++ b/src/psyclone/psyir/transformations/omp_task_trans.py @@ -58,6 +58,23 @@ class OMPTaskTrans(ParallelLoopTrans): implementation. ''' + def __init__(self): + super().__init__() + + # If 'True', the callee must have matching arguments. + # The 'matching' criteria can be weakened by other options. + # If 'False', in case no match was found, the first callee is taken. + self._option_check_matching_arguments_of_callee: bool = True + + def set_option( + self, + check_matching_arguments_of_callee: bool = None, + ): + if check_matching_arguments_of_callee is not None: + self._option_check_matching_arguments_of_callee = ( + check_matching_arguments_of_callee + ) + def __str__(self): return "Adds an 'OMP TASK' directive to a statement" @@ -98,6 +115,11 @@ def validate(self, node, options=None): kintrans = KernelModuleInlineTrans() cond_trans = FoldConditionalReturnExpressionsTrans() intrans = InlineTrans() + intrans.set_option( + check_matching_arguments_of_callee=( + self._option_check_matching_arguments_of_callee + ) + ) for kern in kerns: kintrans.validate(kern) cond_trans.validate(kern.get_kernel_schedule()) @@ -157,6 +179,11 @@ def _inline_kernels(self, node): kintrans = KernelModuleInlineTrans() cond_trans = FoldConditionalReturnExpressionsTrans() intrans = InlineTrans() + intrans.set_option( + check_matching_arguments_of_callee=( + self._option_check_matching_arguments_of_callee + ) + ) for kern in kerns: kintrans.apply(kern) cond_trans.apply(kern.get_kernel_schedule()) diff --git a/src/psyclone/tests/psyir/nodes/call_test.py b/src/psyclone/tests/psyir/nodes/call_test.py index 0fc67bf187..853fcead1e 100644 --- a/src/psyclone/tests/psyir/nodes/call_test.py +++ b/src/psyclone/tests/psyir/nodes/call_test.py @@ -34,7 +34,7 @@ # Authors: R. W. Ford and A. R. Porter, STFC Daresbury Lab # ----------------------------------------------------------------------------- -''' Performs py.test tests on the Call PSyIR node. ''' +"""Performs py.test tests on the Call PSyIR node.""" import os import pytest @@ -42,26 +42,42 @@ from psyclone.core import Signature, VariablesAccessInfo from psyclone.parse import ModuleManager from psyclone.psyir.nodes import ( - ArrayReference, Assignment, BinaryOperation, Call, CodeBlock, Literal, - Node, Reference, Routine, Schedule) + ArrayReference, + Assignment, + BinaryOperation, + Call, + CodeBlock, + Literal, + Node, + Reference, + Routine, + Schedule, +) from psyclone.psyir.nodes.node import colored from psyclone.psyir.symbols import ( - ArrayType, INTEGER_TYPE, DataSymbol, NoType, RoutineSymbol, REAL_TYPE, - SymbolError, UnsupportedFortranType) + ArrayType, + INTEGER_TYPE, + DataSymbol, + NoType, + RoutineSymbol, + REAL_TYPE, + SymbolError, + UnsupportedFortranType, +) from psyclone.errors import GenerationError from psyclone.psyir.nodes.call import CallMatchingArgumentsNotFound class SpecialCall(Call): - '''Test Class specialising the Call class''' + """Test Class specialising the Call class""" def test_call_init(): - '''Test that a Call can be created as expected. Also test the routine + """Test that a Call can be created as expected. Also test the routine property. - ''' + """ # Initialise without a RoutineSymbol call = Call() # By default everything is None @@ -76,15 +92,15 @@ def test_call_init(): routine = RoutineSymbol("jo", NoType()) call = Call(parent=parent) call.addchild(Reference(routine)) - call.addchild(Literal('3', INTEGER_TYPE)) + call.addchild(Literal("3", INTEGER_TYPE)) assert call.routine.symbol is routine assert call.parent is parent - assert call.arguments == [Literal('3', INTEGER_TYPE)] + assert call.arguments == [Literal("3", INTEGER_TYPE)] def test_call_is_elemental(): - '''Test the is_elemental property of a Call is set correctly and can be - queried.''' + """Test the is_elemental property of a Call is set correctly and can be + queried.""" routine = RoutineSymbol("zaphod", NoType()) call = Call.create(routine) assert call.is_elemental is None @@ -94,8 +110,8 @@ def test_call_is_elemental(): def test_call_is_pure(): - '''Test the is_pure property of a Call is set correctly and can be - queried.''' + """Test the is_pure property of a Call is set correctly and can be + queried.""" routine = RoutineSymbol("zaphod", NoType()) call = Call.create(routine) assert call.is_pure is None @@ -105,15 +121,15 @@ def test_call_is_pure(): def test_call_is_available_on_device(): - '''Test the is_available_on_device() method of a Call (currently always - returns False). ''' + """Test the is_available_on_device() method of a Call (currently always + returns False).""" routine = RoutineSymbol("zaphod", NoType()) call = Call.create(routine) assert call.is_available_on_device() is False def test_call_equality(): - '''Test the __eq__ method of the Call class. ''' + """Test the __eq__ method of the Call class.""" # routine arguments routine = RoutineSymbol("j", NoType()) routine2 = RoutineSymbol("k", NoType()) @@ -138,92 +154,114 @@ def test_call_equality(): assert call4 != call7 # Check when a Reference (to the same RoutineSymbol) is provided. - call8 = Call.create(Reference(routine), - [("new_name", Literal("1.0", REAL_TYPE))]) + call8 = Call.create( + Reference(routine), [("new_name", Literal("1.0", REAL_TYPE))] + ) assert call8 == call7 @pytest.mark.parametrize("cls", [Call, SpecialCall]) def test_call_create(cls): - '''Test that the create method creates a valid call with arguments, + """Test that the create method creates a valid call with arguments, some of which are named. Also checks the routine and argument_names properties. - ''' + """ routine = RoutineSymbol("ellie", INTEGER_TYPE) array_type = ArrayType(INTEGER_TYPE, shape=[10, 20]) - arguments = [Reference(DataSymbol("arg1", INTEGER_TYPE)), - ArrayReference(DataSymbol("arg2", array_type))] + arguments = [ + Reference(DataSymbol("arg1", INTEGER_TYPE)), + ArrayReference(DataSymbol("arg2", array_type)), + ] call = cls.create(routine, [arguments[0], ("name", arguments[1])]) # pylint: disable=unidiomatic-typecheck assert type(call) is cls assert call.routine.symbol is routine assert call.argument_names == [None, "name"] - for idx, child, in enumerate(call.arguments): + for ( + idx, + child, + ) in enumerate(call.arguments): assert child is arguments[idx] assert child.parent is call def test_call_create_error1(): - '''Test that the appropriate exception is raised if the routine + """Test that the appropriate exception is raised if the routine argument to the create method is not a RoutineSymbol. - ''' + """ with pytest.raises(TypeError) as info: _ = Call.create(None, []) - assert ("The Call routine argument should be a Reference to a " - "RoutineSymbol or a RoutineSymbol, but found " - "'NoneType'." in str(info.value)) + assert ( + "The Call routine argument should be a Reference to a " + "RoutineSymbol or a RoutineSymbol, but found " + "'NoneType'." + in str(info.value) + ) def test_call_create_error2(): - '''Test that the appropriate exception is raised if the arguments - argument to the create method is not a list''' + """Test that the appropriate exception is raised if the arguments + argument to the create method is not a list""" routine = RoutineSymbol("isaac", NoType()) with pytest.raises(GenerationError) as info: _ = Call.create(routine, None) - assert ("Call.create 'arguments' argument should be an Iterable but found " - "'NoneType'." in str(info.value)) + assert ( + "Call.create 'arguments' argument should be an Iterable but found " + "'NoneType'." + in str(info.value) + ) def test_call_create_error3(): - '''Test that the appropriate exception is raised if one or more of the - argument names is not valid.''' + """Test that the appropriate exception is raised if one or more of the + argument names is not valid.""" routine = RoutineSymbol("roo", INTEGER_TYPE) with pytest.raises(ValueError) as info: _ = Call.create( - routine, [Reference(DataSymbol( - "arg1", INTEGER_TYPE)), (" a", None)]) + routine, + [Reference(DataSymbol("arg1", INTEGER_TYPE)), (" a", None)], + ) assert "Invalid Fortran name ' a' found." in str(info.value) def test_call_create_error4(): - '''Test that the appropriate exception is raised if one or more of the + """Test that the appropriate exception is raised if one or more of the arguments argument list entries to the create method is not a DataNode. - ''' + """ routine = RoutineSymbol("roo", INTEGER_TYPE) with pytest.raises(GenerationError) as info: _ = Call.create( - routine, [Reference(DataSymbol( - "arg1", INTEGER_TYPE)), ("name", None)]) - assert ("Item 'NoneType' can't be child 2 of 'Call'. The valid format " - "is: 'Reference, [DataNode]*'." in str(info.value)) + routine, + [Reference(DataSymbol("arg1", INTEGER_TYPE)), ("name", None)], + ) + assert ( + "Item 'NoneType' can't be child 2 of 'Call'. The valid format " + "is: 'Reference, [DataNode]*'." + in str(info.value) + ) def test_call_add_args(): - '''Test the _add_args method in the Call class.''' + """Test the _add_args method in the Call class.""" routine = RoutineSymbol("myeloma", INTEGER_TYPE) call = Call.create(routine) array_type = ArrayType(INTEGER_TYPE, shape=[10, 20]) - arguments = [Reference(DataSymbol("arg1", INTEGER_TYPE)), - ArrayReference(DataSymbol("arg2", array_type))] + arguments = [ + Reference(DataSymbol("arg1", INTEGER_TYPE)), + ArrayReference(DataSymbol("arg2", array_type)), + ] Call._add_args(call, [arguments[0], ("name", arguments[1])]) assert call.routine.symbol is routine assert call.argument_names == [None, "name"] - for idx, child, in enumerate(call.arguments): + for ( + idx, + child, + ) in enumerate(call.arguments): assert child is arguments[idx] assert child.parent is call # For some reason pylint thinks that call.children[0,1] are of @@ -235,37 +273,42 @@ def test_call_add_args(): def test_call_add_args_error1(): - '''Test that the appropriate exception is raised if an entry in the + """Test that the appropriate exception is raised if an entry in the arguments argument to the _add_args method is a tuple that does not have two elements. - ''' + """ routine = RoutineSymbol("isaac", NoType()) with pytest.raises(GenerationError) as info: _ = Call._add_args(routine, [(1, 2, 3)]) - assert ("If a child of the children argument in create method of Call " - "class is a tuple, it's length should be 2, but found 3." - in str(info.value)) + assert ( + "If a child of the children argument in create method of Call " + "class is a tuple, it's length should be 2, but found 3." + in str(info.value) + ) def test_call_add_args_error2(): - '''Test that the appropriate exception is raised if an entry in the + """Test that the appropriate exception is raised if an entry in the arguments argument to the _add_args method is is a tuple with two - elements and the first element is not a string.''' + elements and the first element is not a string.""" routine = RoutineSymbol("isaac", NoType()) with pytest.raises(GenerationError) as info: _ = Call._add_args(routine, [(1, 2)]) - assert ("If a child of the children argument in create method of Call " - "class is a tuple, its first argument should be a str, but " - "found int." in str(info.value)) + assert ( + "If a child of the children argument in create method of Call " + "class is a tuple, its first argument should be a str, but " + "found int." + in str(info.value) + ) def test_call_appendnamedarg(): - '''Test the append_named_arg method in the Call class. Check + """Test the append_named_arg method in the Call class. Check it raises the expected exceptions if arguments are invalid and that it works as expected when the input is valid. - ''' + """ op1 = Literal("1", INTEGER_TYPE) op2 = Literal("2", INTEGER_TYPE) op3 = Literal("3", INTEGER_TYPE) @@ -273,8 +316,7 @@ def test_call_appendnamedarg(): # name arg wrong type with pytest.raises(TypeError) as info: call.append_named_arg(1, op1) - assert ("A name should be a string, but found 'int'." - in str(info.value)) + assert "A name should be a string, but found 'int'." in str(info.value) # invalid name with pytest.raises(ValueError) as info: call.append_named_arg("_", op1) @@ -283,9 +325,11 @@ def test_call_appendnamedarg(): call.append_named_arg("name1", op1) with pytest.raises(ValueError) as info: call.append_named_arg("name1", op2) - assert ("The value of the name argument (name1) in 'append_named_arg' in " - "the 'Call' node is already used for a named argument." - in str(info.value)) + assert ( + "The value of the name argument (name1) in 'append_named_arg' in " + "the 'Call' node is already used for a named argument." + in str(info.value) + ) # ok call.append_named_arg("name2", op2) call.append_named_arg(None, op3) @@ -294,11 +338,11 @@ def test_call_appendnamedarg(): def test_call_insertnamedarg(): - '''Test the insert_named_arg method in the Call class. Check + """Test the insert_named_arg method in the Call class. Check it raises the expected exceptions if arguments are invalid and that it works as expected when the input is valid. - ''' + """ op1 = Literal("1", INTEGER_TYPE) op2 = Literal("2", INTEGER_TYPE) op3 = Literal("3", INTEGER_TYPE) @@ -306,8 +350,7 @@ def test_call_insertnamedarg(): # name arg wrong type with pytest.raises(TypeError) as info: call.insert_named_arg(1, op1, 0) - assert ("A name should be a string, but found 'int'." - in str(info.value)) + assert "A name should be a string, but found 'int'." in str(info.value) # invalid name with pytest.raises(ValueError) as info: call.insert_named_arg("1", op1, 0) @@ -316,14 +359,19 @@ def test_call_insertnamedarg(): call.insert_named_arg("name1", op1, 0) with pytest.raises(ValueError) as info: call.insert_named_arg("name1", op2, 0) - assert ("The value of the name argument (name1) in 'insert_named_arg' in " - "the 'Call' node is already used for a named argument." - in str(info.value)) + assert ( + "The value of the name argument (name1) in 'insert_named_arg' in " + "the 'Call' node is already used for a named argument." + in str(info.value) + ) # invalid index type with pytest.raises(TypeError) as info: call.insert_named_arg("name2", op2, "hello") - assert ("The 'index' argument in 'insert_named_arg' in the 'Call' node " - "should be an int but found str." in str(info.value)) + assert ( + "The 'index' argument in 'insert_named_arg' in the 'Call' node " + "should be an int but found str." + in str(info.value) + ) # ok assert call.arguments == [op1] assert call.argument_names == ["name1"] @@ -336,28 +384,35 @@ def test_call_insertnamedarg(): def test_call_replacenamedarg(): - '''Test the replace_named_arg method in the Call class. Check + """Test the replace_named_arg method in the Call class. Check it raises the expected exceptions if arguments are invalid and that it works as expected when the input is valid. - ''' + """ op1 = Literal("1", INTEGER_TYPE) op2 = Literal("2", INTEGER_TYPE) op3 = Literal("3", INTEGER_TYPE) - call = Call.create(RoutineSymbol("hello"), - [("name1", op1), ("name2", op2)]) + call = Call.create( + RoutineSymbol("hello"), [("name1", op1), ("name2", op2)] + ) # name arg wrong type with pytest.raises(TypeError) as info: call.replace_named_arg(1, op3) - assert ("The 'name' argument in 'replace_named_arg' in the 'Call' " - "node should be a string, but found int." in str(info.value)) + assert ( + "The 'name' argument in 'replace_named_arg' in the 'Call' " + "node should be a string, but found int." + in str(info.value) + ) # name arg is not found with pytest.raises(ValueError) as info: call.replace_named_arg("new_name", op3) - assert ("The value of the existing_name argument (new_name) in " - "'replace_named_arg' in the 'Call' node was not found in the " - "existing arguments." in str(info.value)) + assert ( + "The value of the existing_name argument (new_name) in " + "'replace_named_arg' in the 'Call' node was not found in the " + "existing arguments." + in str(info.value) + ) # ok assert call.arguments == [op1, op2] assert call.argument_names == ["name1", "name2"] @@ -371,7 +426,7 @@ def test_call_replacenamedarg(): def test_call_reference_accesses(): - '''Test the reference_accesses() method.''' + """Test the reference_accesses() method.""" rsym = RoutineSymbol("trillian") # A call with an argument passed by value. call1 = Call.create(rsym, [Literal("1", INTEGER_TYPE)]) @@ -396,16 +451,20 @@ def test_call_reference_accesses(): assert var_info.has_read_write(Signature("gamma")) assert var_info.is_read(Signature("ji")) # Argument is a temporary so any inputs to it are READ only. - expr = BinaryOperation.create(BinaryOperation.Operator.MUL, - Literal("2", INTEGER_TYPE), Reference(dsym)) + expr = BinaryOperation.create( + BinaryOperation.Operator.MUL, + Literal("2", INTEGER_TYPE), + Reference(dsym), + ) call4 = Call.create(rsym, [expr]) var_info = VariablesAccessInfo() call4.reference_accesses(var_info) assert var_info.is_read(Signature("beta")) # Argument is itself a function call: call trillian(some_func(gamma(ji))) fsym = RoutineSymbol("some_func") - fcall = Call.create(fsym, - [ArrayReference.create(asym, [Reference(idx_sym)])]) + fcall = Call.create( + fsym, [ArrayReference.create(asym, [Reference(idx_sym)])] + ) call5 = Call.create(rsym, [fcall]) call5.reference_accesses(var_info) assert var_info.has_read_write(Signature("gamma")) @@ -420,11 +479,11 @@ def test_call_reference_accesses(): def test_call_argumentnames_after_removearg(): - '''Test the argument_names property makes things consistent if a child + """Test the argument_names property makes things consistent if a child argument is removed. This is used transparently by the class to keep things consistent. - ''' + """ op1 = Literal("1", INTEGER_TYPE) op2 = Literal("1", INTEGER_TYPE) call = Call.create(RoutineSymbol("name"), [("name1", op1), ("name2", op2)]) @@ -440,11 +499,11 @@ def test_call_argumentnames_after_removearg(): def test_call_argumentnames_after_addarg(): - '''Test the argument_names property makes things consistent if a child + """Test the argument_names property makes things consistent if a child argument is added. This is used transparently by the class to keep things consistent. - ''' + """ op1 = Literal("1", INTEGER_TYPE) op2 = Literal("1", INTEGER_TYPE) op3 = Literal("1", INTEGER_TYPE) @@ -461,11 +520,11 @@ def test_call_argumentnames_after_addarg(): def test_call_argumentnames_after_replacearg(): - '''Test the argument_names property makes things consistent if a child + """Test the argument_names property makes things consistent if a child argument is replaced. This is used transparently by the class to keep things consistent. - ''' + """ op1 = Literal("1", INTEGER_TYPE) op2 = Literal("1", INTEGER_TYPE) op3 = Literal("1", INTEGER_TYPE) @@ -484,11 +543,11 @@ def test_call_argumentnames_after_replacearg(): def test_call_argumentnames_after_reorderarg(): - '''Test the argument_names property makes things consistent if a child + """Test the argument_names property makes things consistent if a child argument is replaced. This is used transparently by the class to keep things consistent. - ''' + """ op1 = Literal("1", INTEGER_TYPE) op2 = Literal("1", INTEGER_TYPE) op3 = Literal("1", INTEGER_TYPE) @@ -505,10 +564,10 @@ def test_call_argumentnames_after_reorderarg(): def test_call_node_reconcile_add(): - '''Test that the reconcile method behaves as expected. Use an example + """Test that the reconcile method behaves as expected. Use an example where we add a new arg. - ''' + """ op1 = Literal("1", INTEGER_TYPE) op2 = Literal("1", INTEGER_TYPE) op3 = Literal("1", INTEGER_TYPE) @@ -531,10 +590,10 @@ def test_call_node_reconcile_add(): def test_call_node_reconcile_reorder(): - '''Test that the reconcile method behaves as expected. Use an example + """Test that the reconcile method behaves as expected. Use an example where we reorder the arguments. - ''' + """ op1 = Literal("1", INTEGER_TYPE) op2 = Literal("2", INTEGER_TYPE) call = Call.create(RoutineSymbol("name"), [("name1", op1), ("name2", op2)]) @@ -558,22 +617,22 @@ def test_call_node_reconcile_reorder(): def test_call_node_str(): - ''' Test that the node_str method behaves as expected ''' + """Test that the node_str method behaves as expected""" routine = RoutineSymbol("isaac", NoType()) call = Call.create(routine) colouredtext = colored("Call", Call._colour) - assert call.node_str() == colouredtext+"[name='isaac']" + assert call.node_str() == colouredtext + "[name='isaac']" def test_call_str(): - ''' Test that the str method behaves as expected ''' + """Test that the str method behaves as expected""" routine = RoutineSymbol("roo", NoType()) call = Call.create(routine) assert str(call) == "Call[name='roo']" def test_copy(): - ''' Test that the copy() method behaves as expected. ''' + """Test that the copy() method behaves as expected.""" op1 = Literal("1", INTEGER_TYPE) op2 = Literal("2", INTEGER_TYPE) call = Call.create(RoutineSymbol("name"), [("name1", op1), ("name2", op2)]) @@ -606,11 +665,11 @@ def test_copy(): def test_call_get_callees_local(fortran_reader): - ''' + """ Check that get_callees() works as expected when the target of the Call exists in the same Container as the call site. - ''' - code = ''' + """ + code = """ module some_mod implicit none integer :: luggage @@ -623,7 +682,7 @@ def test_call_get_callees_local(fortran_reader): subroutine bottom() luggage = luggage + 1 end subroutine bottom -end module some_mod''' +end module some_mod""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] result = call.get_callees() @@ -631,11 +690,11 @@ def test_call_get_callees_local(fortran_reader): def test_call_get_callee_1_simple_match(fortran_reader): - ''' + """ Check that the right routine has been found for a single routine implementation. - ''' - code = ''' + """ + code = """ module some_mod implicit none contains @@ -650,7 +709,7 @@ def test_call_get_callee_1_simple_match(fortran_reader): integer :: a, b, c end subroutine -end module some_mod''' +end module some_mod""" psyir = fortran_reader.psyir_from_source(code) @@ -666,10 +725,10 @@ def test_call_get_callee_1_simple_match(fortran_reader): def test_call_get_callee_2_optional_args(fortran_reader): - ''' + """ Check that optional arguments have been correlated correctly. - ''' - code = ''' + """ + code = """ module some_mod implicit none contains @@ -685,7 +744,7 @@ def test_call_get_callee_2_optional_args(fortran_reader): integer, optional :: c end subroutine -end module some_mod''' +end module some_mod""" root_node: Node = fortran_reader.psyir_from_source(code) @@ -709,10 +768,10 @@ def test_call_get_callee_2_optional_args(fortran_reader): def test_call_get_callee_3_trigger_error(fortran_reader): - ''' + """ Test which is supposed to trigger an error. - ''' - code = ''' + """ + code = """ module some_mod implicit none contains @@ -727,7 +786,7 @@ def test_call_get_callee_3_trigger_error(fortran_reader): integer :: a, b end subroutine -end module some_mod''' +end module some_mod""" root_node: Node = fortran_reader.psyir_from_source(code) @@ -743,14 +802,17 @@ def test_call_get_callee_3_trigger_error(fortran_reader): with pytest.raises(CallMatchingArgumentsNotFound) as err: call_foo.get_callee() - assert "No matching routine found for" in str(err.value) + assert ( + "Found routines, but no routine with matching arguments found" + in str(err.value) + ) def test_call_get_callee_4_named_arguments(fortran_reader): - ''' + """ Check that named arguments have been correlated correctly - ''' - code = ''' + """ + code = """ module some_mod implicit none contains @@ -765,7 +827,7 @@ def test_call_get_callee_4_named_arguments(fortran_reader): integer :: a, b, c end subroutine -end module some_mod''' +end module some_mod""" root_node: Node = fortran_reader.psyir_from_source(code) @@ -790,11 +852,11 @@ def test_call_get_callee_4_named_arguments(fortran_reader): def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): - ''' + """ Check that optional and named arguments have been correlated correctly when the call is to a generic interface. - ''' - code = ''' + """ + code = """ module some_mod implicit none contains @@ -810,7 +872,7 @@ def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): integer, optional :: c end subroutine -end module some_mod''' +end module some_mod""" root_node: Node = fortran_reader.psyir_from_source(code) @@ -833,7 +895,7 @@ def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): assert result is routine_match -_code_test_get_callee_6 = ''' +_code_test_get_callee_6 = """ module some_mod implicit none @@ -890,13 +952,13 @@ def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): integer, optional :: c end subroutine -end module some_mod''' +end module some_mod""" def test_call_get_callee_6_interfaces_0_0(fortran_reader): - ''' + """ Check that optional and named arguments have been correlated correctly - ''' + """ root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -920,9 +982,9 @@ def test_call_get_callee_6_interfaces_0_0(fortran_reader): def test_call_get_callee_6_interfaces_0_1(fortran_reader): - ''' + """ Check that optional and named arguments have been correlated correctly - ''' + """ root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -947,9 +1009,9 @@ def test_call_get_callee_6_interfaces_0_1(fortran_reader): def test_call_get_callee_6_interfaces_1_0(fortran_reader): - ''' + """ Check that optional and named arguments have been correlated correctly - ''' + """ root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -973,9 +1035,9 @@ def test_call_get_callee_6_interfaces_1_0(fortran_reader): def test_call_get_callee_6_interfaces_1_1(fortran_reader): - ''' + """ Check that optional and named arguments have been correlated correctly - ''' + """ root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1000,9 +1062,9 @@ def test_call_get_callee_6_interfaces_1_1(fortran_reader): def test_call_get_callee_6_interfaces_1_2(fortran_reader): - ''' + """ Check that optional and named arguments have been correlated correctly - ''' + """ root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1027,9 +1089,9 @@ def test_call_get_callee_6_interfaces_1_2(fortran_reader): def test_call_get_callee_6_interfaces_2_0(fortran_reader): - ''' + """ Check that optional and named arguments have been correlated correctly - ''' + """ root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1059,9 +1121,9 @@ def test_call_get_callee_6_interfaces_2_0(fortran_reader): def test_call_get_callee_6_interfaces_2_1(fortran_reader): - ''' + """ Check that optional and named arguments have been correlated correctly - ''' + """ root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1085,9 +1147,9 @@ def test_call_get_callee_6_interfaces_2_1(fortran_reader): def test_call_get_callee_6_interfaces_2_2(fortran_reader): - ''' + """ Check that optional and named arguments have been correlated correctly - ''' + """ root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1112,10 +1174,10 @@ def test_call_get_callee_6_interfaces_2_2(fortran_reader): def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): - ''' + """ Trigger error that matching arguments were not found - ''' - code = ''' + """ + code = """ module some_mod implicit none contains @@ -1132,7 +1194,7 @@ def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): integer :: a, b, c end subroutine -end module some_mod''' +end module some_mod""" psyir = fortran_reader.psyir_from_source(code) @@ -1144,16 +1206,17 @@ def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): with pytest.raises(CallMatchingArgumentsNotFound) as err: call_foo.get_callee() - assert "No matching routine found for 'call foo(e, f, d=g)" in str( - err.value + assert ( + "Found routines, but no routine with matching arguments found" + in str(err.value) ) def test_call_get_callee_8_arguments_not_handled(fortran_reader): - ''' + """ Trigger error that matching arguments were not found - ''' - code = ''' + """ + code = """ module some_mod implicit none contains @@ -1169,7 +1232,7 @@ def test_call_get_callee_8_arguments_not_handled(fortran_reader): integer :: a, b, c end subroutine -end module some_mod''' +end module some_mod""" psyir = fortran_reader.psyir_from_source(code) @@ -1181,80 +1244,98 @@ def test_call_get_callee_8_arguments_not_handled(fortran_reader): with pytest.raises(CallMatchingArgumentsNotFound) as err: call_foo.get_callee() - assert "No matching routine found for 'call foo(e, f)" in str(err.value) + assert ( + "Found routines, but no routine with matching arguments found" + in str(err.value) + ) @pytest.mark.usefixtures("clear_module_manager_instance") def test_call_get_callees_unresolved(fortran_reader, tmpdir, monkeypatch): - ''' + """ Test that get_callees() raises the expected error if the called routine is unresolved. - ''' - code = ''' + """ + code = """ subroutine top() call bottom() -end subroutine top''' +end subroutine top""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and there are " - "no wildcard imports." in str(err.value)) + assert ( + "Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and there are " + "no wildcard imports." + in str(err.value) + ) # Repeat but in the presence of a wildcard import. - code = ''' + code = """ subroutine top() use some_mod_somewhere call bottom() -end subroutine top''' +end subroutine top""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and attempted " - "to resolve the wildcard imports from ['some_mod_somewhere']. " - "However, failed to find the source for ['some_mod_somewhere']. " - "The module search path is set to []" in str(err.value)) + assert ( + "Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and attempted " + "to resolve the wildcard imports from ['some_mod_somewhere']. " + "However, failed to find the source for ['some_mod_somewhere']. " + "The module search path is set to []" + in str(err.value) + ) # Repeat but when some_mod_somewhere *is* resolved but doesn't help us # find the routine we're looking for. mod_manager = ModuleManager.get() monkeypatch.setattr(mod_manager, "_instance", None) path = str(tmpdir) - monkeypatch.setattr(Config.get(), '_include_paths', [path]) - with open(os.path.join(path, "some_mod_somewhere.f90"), "w", - encoding="utf-8") as ofile: - ofile.write('''\ + monkeypatch.setattr(Config.get(), "_include_paths", [path]) + with open( + os.path.join(path, "some_mod_somewhere.f90"), "w", encoding="utf-8" + ) as ofile: + ofile.write( + """\ module some_mod_somewhere end module some_mod_somewhere -''') +""" + ) with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and wildcard " - "imports from ['some_mod_somewhere']." in str(err.value)) + assert ( + "Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and wildcard " + "imports from ['some_mod_somewhere']." + in str(err.value) + ) mod_manager = ModuleManager.get() monkeypatch.setattr(mod_manager, "_instance", None) - code = ''' + code = """ subroutine top() use another_mod, only: this_one call this_one() -end subroutine top''' +end subroutine top""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ("RoutineSymbol 'this_one' is imported from Container 'another_mod'" - " but the source defining that container could not be found. The " - "module search path is set to [" in str(err.value)) + assert ( + "RoutineSymbol 'this_one' is imported from Container 'another_mod'" + " but the source defining that container could not be found. The " + "module search path is set to [" + in str(err.value) + ) def test_call_get_callees_interface(fortran_reader): - ''' + """ Check that get_callees() works correctly when the target of a call is actually a generic interface. - ''' - code = ''' + """ + code = """ module my_mod interface bottom @@ -1277,7 +1358,7 @@ def test_call_get_callees_interface(fortran_reader): luggage = luggage + 1.0 end subroutine rbottom end module my_mod -''' +""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] callees = call.get_callees() @@ -1289,13 +1370,13 @@ def test_call_get_callees_interface(fortran_reader): def test_call_get_callees_unsupported_type(fortran_reader): - ''' + """ Check that get_callees() raises the expected error when the called routine is of UnsupportedFortranType. This is hard to achieve so we have to manually construct some aspects of the test case. - ''' - code = ''' + """ + code = """ module my_mod integer, target :: value contains @@ -1308,7 +1389,7 @@ def test_call_get_callees_unsupported_type(fortran_reader): fval => value end function bottom end module my_mod -''' +""" psyir = fortran_reader.psyir_from_source(code) container = psyir.children[0] routine = container.find_routine_psyir("bottom") @@ -1325,16 +1406,19 @@ def test_call_get_callees_unsupported_type(fortran_reader): call = psyir.walk(Call)[0] with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ("RoutineSymbol 'bottom' exists in Container 'my_mod' but is of " - "UnsupportedFortranType" in str(err.value)) + assert ( + "RoutineSymbol 'bottom' exists in Container 'my_mod' but is of " + "UnsupportedFortranType" + in str(err.value) + ) def test_call_get_callees_file_container(fortran_reader): - ''' + """ Check that get_callees works if the called routine happens to be in file scope, even when there's no Container. - ''' - code = ''' + """ + code = """ subroutine top() integer :: luggage luggage = 0 @@ -1345,7 +1429,7 @@ def test_call_get_callees_file_container(fortran_reader): integer :: luggage luggage = luggage + 1 end subroutine bottom -''' +""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] result = call.get_callees() @@ -1355,13 +1439,13 @@ def test_call_get_callees_file_container(fortran_reader): def test_call_get_callees_no_container(fortran_reader): - ''' + """ Check that get_callees() raises the expected error when the Call is not within a Container and the target routine cannot be found. - ''' + """ # To avoid having the routine symbol immediately dismissed as # unresolved, the code that we initially process *does* have a Container. - code = ''' + code = """ module my_mod contains @@ -1376,7 +1460,7 @@ def test_call_get_callees_no_container(fortran_reader): luggage = luggage + 1 end subroutine bottom end module my_mod -''' +""" psyir = fortran_reader.psyir_from_source(code) top_routine = psyir.walk(Routine)[0] # Deliberately make the Routine node an orphan so there's no Container. @@ -1384,16 +1468,18 @@ def test_call_get_callees_no_container(fortran_reader): call = top_routine.walk(Call)[0] with pytest.raises(SymbolError) as err: _ = call.get_callees() - assert ("Failed to find a Routine named 'bottom' in code:\n'subroutine " - "top()" in str(err.value)) + assert ( + "Failed to find a Routine named 'bottom' in code:\n'subroutine top()" + in str(err.value) + ) def test_call_get_callees_wildcard_import_local_container(fortran_reader): - ''' + """ Check that get_callees() works successfully for a routine accessed via a wildcard import from another module in the same file. - ''' - code = ''' + """ + code = """ module some_mod contains subroutine just_do_it() @@ -1407,7 +1493,7 @@ def test_call_get_callees_wildcard_import_local_container(fortran_reader): call just_do_it() end subroutine run_it end module other_mod -''' +""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] routines = call.get_callees() @@ -1417,11 +1503,11 @@ def test_call_get_callees_wildcard_import_local_container(fortran_reader): def test_call_get_callees_import_local_container(fortran_reader): - ''' + """ Check that get_callees() works successfully for a routine accessed via a specific import from another module in the same file. - ''' - code = ''' + """ + code = """ module some_mod contains subroutine just_do_it() @@ -1435,7 +1521,7 @@ def test_call_get_callees_import_local_container(fortran_reader): call just_do_it() end subroutine run_it end module other_mod -''' +""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] routines = call.get_callees() @@ -1445,13 +1531,14 @@ def test_call_get_callees_import_local_container(fortran_reader): @pytest.mark.usefixtures("clear_module_manager_instance") -def test_call_get_callees_wildcard_import_container(fortran_reader, - tmpdir, monkeypatch): - ''' +def test_call_get_callees_wildcard_import_container( + fortran_reader, tmpdir, monkeypatch +): + """ Check that get_callees() works successfully for a routine accessed via a wildcard import from a module in another file. - ''' - code = ''' + """ + code = """ module other_mod use some_mod contains @@ -1459,29 +1546,34 @@ def test_call_get_callees_wildcard_import_container(fortran_reader, call just_do_it() end subroutine run_it end module other_mod -''' +""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] # This should fail as it can't find the module. with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine " - "'just_do_it' - looked at any routines in the same source file" - in str(err.value)) + assert ( + "Failed to find the source code of the unresolved routine " + "'just_do_it' - looked at any routines in the same source file" + in str(err.value) + ) # Create the module containing the subroutine definition, # write it to file and set the search path so that PSyclone can find it. path = str(tmpdir) - monkeypatch.setattr(Config.get(), '_include_paths', [path]) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) - with open(os.path.join(path, "some_mod.f90"), - "w", encoding="utf-8") as mfile: - mfile.write('''\ + with open( + os.path.join(path, "some_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ module some_mod contains subroutine just_do_it() write(*,*) "hello" end subroutine just_do_it -end module some_mod''') +end module some_mod""" + ) routines = call.get_callees() assert len(routines) == 1 assert isinstance(routines[0], Routine) @@ -1489,10 +1581,10 @@ def test_call_get_callees_wildcard_import_container(fortran_reader, def test_fn_call_get_callees(fortran_reader): - ''' + """ Test that get_callees() works for a function call. - ''' - code = ''' + """ + code = """ module some_mod implicit none integer :: luggage @@ -1507,7 +1599,7 @@ def test_fn_call_get_callees(fortran_reader): integer :: my_func my_func = 1 + val end function my_func -end module some_mod''' +end module some_mod""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] result = call.get_callees() @@ -1515,9 +1607,9 @@ def test_fn_call_get_callees(fortran_reader): def test_get_callees_code_block(fortran_reader): - '''Test that get_callees() raises the expected error when the called - routine is in a CodeBlock.''' - code = ''' + """Test that get_callees() raises the expected error when the called + routine is in a CodeBlock.""" + code = """ module some_mod implicit none integer :: luggage @@ -1531,22 +1623,24 @@ def test_get_callees_code_block(fortran_reader): integer, intent(in) :: val my_func = CMPLX(1 + val, 1.0) end function my_func -end module some_mod''' +end module some_mod""" psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[1] with pytest.raises(SymbolError) as err: _ = call.get_callees() - assert ("Failed to find a Routine named 'my_func' in Container " - "'some_mod'" in str(err.value)) + assert ( + "Failed to find a Routine named 'my_func' in Container 'some_mod'" + in str(err.value) + ) @pytest.mark.usefixtures("clear_module_manager_instance") def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): - ''' + """ Test that get_callees() follows imports to find the definition of the called routine. - ''' - code = ''' + """ + code = """ module some_mod use other_mod, only: pack_it implicit none @@ -1555,24 +1649,29 @@ def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): integer :: luggage = 0 call pack_it(luggage) end subroutine top -end module some_mod''' +end module some_mod""" # Create the module containing an import of the subroutine definition, # write it to file and set the search path so that PSyclone can find it. path = str(tmpdir) - monkeypatch.setattr(Config.get(), '_include_paths', [path]) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) - with open(os.path.join(path, "other_mod.f90"), - "w", encoding="utf-8") as mfile: - mfile.write('''\ + with open( + os.path.join(path, "other_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ module other_mod use another_mod, only: pack_it contains end module other_mod - ''') + """ + ) # Finally, create the module containing the routine definition. - with open(os.path.join(path, "another_mod.f90"), - "w", encoding="utf-8") as mfile: - mfile.write('''\ + with open( + os.path.join(path, "another_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ module another_mod contains subroutine pack_it(arg) @@ -1580,7 +1679,8 @@ def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): arg = arg + 2 end subroutine pack_it end module another_mod - ''') + """ + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] result = call.get_callees() @@ -1591,12 +1691,12 @@ def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): @pytest.mark.usefixtures("clear_module_manager_instance") def test_get_callees_import_private_clash(fortran_reader, tmpdir, monkeypatch): - ''' + """ Test that get_callees() raises the expected error if a module from which a routine is imported has a private shadow of that routine (and thus we don't know where to look for the target routine). - ''' - code = ''' + """ + code = """ module some_mod use other_mod, only: pack_it implicit none @@ -1605,16 +1705,18 @@ def test_get_callees_import_private_clash(fortran_reader, tmpdir, monkeypatch): integer :: luggage = 0 call pack_it(luggage) end subroutine top -end module some_mod''' +end module some_mod""" # Create the module containing a private routine with the name we are # searching for, write it to file and set the search path so that PSyclone # can find it. path = str(tmpdir) - monkeypatch.setattr(Config.get(), '_include_paths', [path]) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) - with open(os.path.join(path, "other_mod.f90"), - "w", encoding="utf-8") as mfile: - mfile.write('''\ + with open( + os.path.join(path, "other_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ module other_mod use another_mod private pack_it @@ -1624,12 +1726,16 @@ def test_get_callees_import_private_clash(fortran_reader, tmpdir, monkeypatch): integer :: pack_it end function pack_it end module other_mod - ''') + """ + ) psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ("RoutineSymbol 'pack_it' is imported from Container 'other_mod' " - "but that Container defines a private Symbol of the same name. " - "Searching for the Container that defines a public Routine with " - "that name is not yet supported - TODO #924" in str(err.value)) + assert ( + "RoutineSymbol 'pack_it' is imported from Container 'other_mod' " + "but that Container defines a private Symbol of the same name. " + "Searching for the Container that defines a public Routine with " + "that name is not yet supported - TODO #924" + in str(err.value) + ) diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index 259e581935..9c50465d57 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -276,6 +276,7 @@ def test_apply_gocean_kern(fortran_reader, fortran_writer, monkeypatch): monkeypatch.setattr(Config.get(), "_include_paths", [str(src_dir)]) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) with pytest.raises(TransformationError) as err: inline_trans.apply(psyir.walk(Call)[0]) if ( @@ -341,6 +342,7 @@ def test_apply_struct_arg(fortran_reader, fortran_writer, tmpdir): ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) @@ -417,6 +419,7 @@ def test_apply_unresolved_struct_arg(fortran_reader, fortran_writer): ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) calls = psyir.walk(Call) # First one should be fine. inline_trans.apply(calls[0]) @@ -489,6 +492,7 @@ def test_apply_struct_slice_arg(fortran_reader, fortran_writer, tmpdir): ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -528,6 +532,7 @@ def test_apply_struct_local_limits_caller( ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -574,6 +579,7 @@ def test_apply_struct_local_limits_caller_decln( ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -629,6 +635,7 @@ def test_apply_struct_local_limits_routine( ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -683,6 +690,7 @@ def test_apply_array_limits_are_formal_args(fortran_reader, fortran_writer): """ psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) acall = psyir.walk(Call, stop_type=Call)[0] inline_trans.apply(acall) output = fortran_writer(psyir) @@ -729,6 +737,7 @@ def test_apply_allocatable_array_arg(fortran_reader, fortran_writer): ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): if not isinstance(routine, IntrinsicCall): inline_trans.apply(routine) @@ -790,6 +799,7 @@ def test_apply_array_slice_arg(fortran_reader, fortran_writer, tmpdir): ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) for call in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(call) output = fortran_writer(psyir) @@ -844,6 +854,7 @@ def test_apply_struct_array_arg(fortran_reader, fortran_writer, tmpdir): psyir = fortran_reader.psyir_from_source(code) loops = psyir.walk(Loop) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) inline_trans.apply(loops[0].loop_body.children[1]) inline_trans.apply(loops[1].loop_body.children[1]) inline_trans.apply(loops[2].loop_body.children[1]) @@ -909,6 +920,8 @@ def test_apply_struct_array_slice_arg(fortran_reader, fortran_writer, tmpdir): ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_matching_arguments_of_callee=False) for call in psyir.walk(Call): if not isinstance(call, IntrinsicCall): if call.arguments[0].debug_string() == "grid%local%data": @@ -985,6 +998,7 @@ def test_apply_struct_array( ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) if "use some_mod" in type_decln: with pytest.raises(TransformationError) as err: inline_trans.apply(psyir.walk(Call)[0]) @@ -1037,6 +1051,7 @@ def test_apply_repeated_module_use(fortran_reader, fortran_writer): ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) for call in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(call) output = fortran_writer(psyir) @@ -1760,7 +1775,7 @@ def test_validate_calls_find_routine(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - print(err.value) + print assert ( "Cannot inline routine 'sub' because its source cannot be found:\n" "Failed to find the source code of the unresolved routine 'sub' - " @@ -1860,10 +1875,17 @@ def test_validate_unsupportedtype_argument(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(routine) + print(err.value) + assert ( + "Found routines, but no routine with matching arguments found for" + " 'sub'" + in str(err.value) + ) assert ( - "Routine 'sub' cannot be inlined because it contains a Symbol 'x' " - "which is an Argument of UnsupportedType: 'REAL, POINTER, " - "INTENT(INOUT) :: x'" + "Argument partial type mismatch of call argument" + " 'Reference[name:'ptr']' and routine argument 'x:" + " DataSymbol'" in str(err.value) ) @@ -2126,9 +2148,14 @@ def test_validate_wrong_number_args(fortran_reader): with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ( - "Cannot inline 'call sub(i, trouble)' because the number of " - "arguments supplied to the call (2) does not match the number of " - "arguments the routine is declared to have (1)" + "Found routines, but no routine with matching arguments found for" + " 'sub':" + in str(err.value) + ) + + assert ( + "More arguments in call ('call sub(i, trouble)') than callee (routine" + " 'sub')" in str(err.value) ) @@ -2223,6 +2250,7 @@ def test_validate_array_reshape(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ( @@ -2259,6 +2287,7 @@ def test_validate_array_arg_expression(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ( @@ -2289,6 +2318,7 @@ def test_validate_indirect_range(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ( @@ -2317,8 +2347,10 @@ def test_validate_non_unit_stride_slice(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() + inline_trans.set_option(check_matching_arguments_of_callee=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) + print(err.value) assert ( "Cannot inline routine 'sub' because one of its arguments is an " "array slice with a non-unit stride: 'var(::2)' (TODO #1646)" @@ -2352,13 +2384,7 @@ def test_validate_named_arg(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() - with pytest.raises(TransformationError) as err: - inline_trans.validate(call) - assert ( - "Routine 'sub' cannot be inlined because it has a named argument " - "'opt' (TODO #924)" - in str(err.value) - ) + inline_trans.validate(call) CALL_IN_SUB_USE = ( diff --git a/src/psyclone/tests/psyir/transformations/omp_task_transformations_test.py b/src/psyclone/tests/psyir/transformations/omp_task_transformations_test.py index f033fae6d6..6bd1b675ed 100644 --- a/src/psyclone/tests/psyir/transformations/omp_task_transformations_test.py +++ b/src/psyclone/tests/psyir/transformations/omp_task_transformations_test.py @@ -178,7 +178,8 @@ def test_omptask_apply_kern(fortran_reader, fortran_writer): new_container.addchild(my_test) sym = my_test.symbol_table.lookup("test_kernel") sym.interface.container_symbol._reference = test_kernel_mod - trans = OMPTaskTrans() + trans: OMPTaskTrans = OMPTaskTrans() + trans.set_option(check_matching_arguments_of_callee=False) master = OMPSingleTrans() parallel = OMPParallelTrans() calls = my_test.walk(Call) From aa6f2e6cacf04746d0ea9fcca1b20dc2f56bd203 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sun, 24 Nov 2024 02:23:43 +0100 Subject: [PATCH 05/20] updates --- src/psyclone/psyir/nodes/call.py | 194 ++++++++++++++---- .../psyir/transformations/inline_trans.py | 182 +++++++++++++--- .../transformations/inline_trans_test.py | 162 ++++++++++++++- 3 files changed, 457 insertions(+), 81 deletions(-) diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index c24e852da0..75754de8b0 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -715,6 +715,152 @@ def _location_txt(node): f"{_location_txt(root_node)}. This is normally because the routine" f" is within a CodeBlock.") + def _check_inline_types( + self, + call_arg: DataSymbol, + routine_arg: DataSymbol, + check_array_type: bool = True, + ): + """This function performs tests to see whether the + inlining can cope with it. + + :param call_arg: The argument of a call + :type call_arg: DataSymbol + :param routine_arg: The argument of a routine + :type routine_arg: DataSymbol + :param check_array_type: Perform strong checks on array types, + defaults to `True` + :type check_array_type: bool, optional + + :raises TransformationError: Raised if transformation can't be done + + :return: 'True' if checks are successful + :rtype: bool + """ + from psyclone.psyir.transformations.transformation_error import ( + TransformationError, + ) + from psyclone.errors import LazyString + from psyclone.psyir.nodes import Literal, Range + from psyclone.psyir.symbols import ( + UnresolvedType, + UnsupportedType, + INTEGER_TYPE, + ) + + _ONE = Literal("1", INTEGER_TYPE) + + # If the formal argument is an array with non-default bounds then + # we also need to know the bounds of that array at the call site. + if not isinstance(routine_arg.datatype, ArrayType): + # Formal argument is not an array so we don't need to do any + # further checks. + return True + + if not isinstance(call_arg, (Reference, Literal)): + # TODO #1799 this really needs the `datatype` method to be + # extended to support all nodes. For now we have to abort + # if we encounter an argument that is not a scalar (according + # to the corresponding formal argument) but is not a + # Reference or a Literal as we don't know whether the result + # of any general expression is or is not an array. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + f"The call '{self.debug_string()}' " + "cannot be inlined because actual argument " + f"'{call_arg.debug_string()}' corresponds to a " + "formal argument with array type but is not a " + "Reference or a Literal." + ) + ) + ) + + # We have an array argument. We are only able to check that the + # argument is not re-shaped in the called routine if we have full + # type information on the actual argument. + # TODO #924. It would be useful if the `datatype` property was + # a method that took an optional 'resolve' argument to indicate + # that it should attempt to resolve any UnresolvedTypes. + if check_array_type: + if isinstance( + call_arg.datatype, (UnresolvedType, UnsupportedType) + ) or ( + isinstance(call_arg.datatype, ArrayType) + and isinstance( + call_arg.datatype.intrinsic, + (UnresolvedType, UnsupportedType), + ) + ): + raise TransformationError( + f"Routine '{self.routine.name}' cannot be " + "inlined because the type of the actual argument " + f"'{call_arg.symbol.name}' corresponding to an array" + f" formal argument ('{routine_arg.name}') is unknown." + ) + + formal_rank = 0 + actual_rank = 0 + if isinstance(routine_arg.datatype, ArrayType): + formal_rank = len(routine_arg.datatype.shape) + if isinstance(call_arg.datatype, ArrayType): + actual_rank = len(call_arg.datatype.shape) + if formal_rank != actual_rank: + # It's OK to use the loop variable in the lambda definition + # because if we get to this point then we're going to quit + # the loop. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self.routine.name}' because it" + " reshapes an argument: actual argument" + f" '{call_arg.debug_string()}' has rank" + f" {actual_rank} but the corresponding formal" + f" argument, '{routine_arg.name}', has rank" + f" {formal_rank}" + ) + ) + ) + if actual_rank: + ranges = call_arg.walk(Range) + for rge in ranges: + ancestor_ref = rge.ancestor(Reference) + if ancestor_ref is not call_arg: + # Have a range in an indirect access. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self.routine.name}' because" + " argument" + f" '{call_arg.debug_string()}' has" + " an array range in an indirect" + " access #(TODO 924)." + ) + ) + ) + if rge.step != _ONE: + # TODO #1646. We could resolve this problem by + # making a new array and copying the necessary + # values into it. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self.routine.name}' because" + " one of its arguments is an array" + " slice with a non-unit stride:" + f" '{call_arg.debug_string()}' (TODO" + " #1646)" + ) + ) + ) + def _check_argument_type_matches( self, call_arg: DataSymbol, @@ -738,6 +884,8 @@ def _check_argument_type_matches( were found. """ + self._check_inline_types(call_arg, routine_arg) + type_matches = False if not check_strict_array_datatype: # No strict array checks have to be performed, just accept it @@ -767,46 +915,6 @@ def _check_argument_type_matches( return True - def _check_matching_types( - call_arg: Symbol, - routine_arg: Symbol, - check_strict_array_datatype: bool = True, - check_matching_arguments: bool = True, - ) -> bool: - routine_arg: DataSymbol - - type_matches = False - if not check_strict_array_datatype: - # No strict array checks have to be performed, just accept it - if isinstance(call_arg.datatype, ArrayType) and isinstance( - routine_arg.datatype, ArrayType - ): - type_matches = True - - if not type_matches: - # Do the types of arguments match? - # - # TODO #759: If optional is used, it's an unsupported - # Fortran type and we need to use the following workaround - # Once this issue is resolved, simply remove this if - # branch. - # Optional arguments are processed further down. - if isinstance(routine_arg.datatype, UnsupportedFortranType): - if call_arg.datatype != routine_arg.datatype.partial_datatype: - raise CallMatchingArgumentsNotFound( - f"Argument partial type mismatch of call " - f"argument '{call_arg}' and routine argument " - f"'{routine_arg}'" - ) - else: - if call_arg.datatype != routine_arg.datatype: - raise CallMatchingArgumentsNotFound( - f"Argument type mismatch of call argument " - f"'{call_arg.datatype}' and routine argument " - f"'{routine_arg.datatype}'" - ) - type_matches = True - def _get_argument_routine_match( self, routine: Routine, @@ -954,7 +1062,7 @@ def get_callee( f"No routine or interface found for name '{self.routine.name}'" ) - err_info = [] + err_info_list = [] # Search for the routine matching the right arguments for routine_node in routine_list: @@ -966,7 +1074,7 @@ def get_callee( check_strict_array_datatype=check_strict_array_datatype, ) except CallMatchingArgumentsNotFound as err: - err_info.append(err.value) + err_info_list.append(err.value) continue return (routine_node, arg_match_list) @@ -979,7 +1087,7 @@ def get_callee( # Also return a list of dummy argument indices return (routine_list[0], [i for i in range(len(self.arguments))]) - error_msg = "\n".join(err_info) + error_msg = "\n".join(err_info_list) raise CallMatchingArgumentsNotFound( "Found routines, but no routine with matching arguments found " diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 3532fa9dc5..0a397037f5 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -45,9 +45,18 @@ Return, Literal, Statement, StructureMember, StructureReference) from psyclone.psyir.nodes.array_mixin import ArrayMixin from psyclone.psyir.symbols import ( - ArgumentInterface, ArrayType, DataSymbol, UnresolvedType, INTEGER_TYPE, - StaticInterface, SymbolError, UnknownInterface, - UnsupportedType, IntrinsicSymbol) + ArgumentInterface, + ArrayType, + DataSymbol, + UnresolvedType, + INTEGER_TYPE, + StaticInterface, + SymbolError, + UnknownInterface, + UnsupportedType, + UnsupportedFortranType, + IntrinsicSymbol, +) from psyclone.psyir.transformations.reference2arrayrange_trans import ( Reference2ArrayRangeTrans) from psyclone.psyir.transformations.transformation_error import ( @@ -56,8 +65,8 @@ from psyclone.psyir.nodes import CallMatchingArgumentsNotFound from typing import Dict, List -# from psyclone.psyir.symbols import BOOLEAN_TYPE -# from psyclone.psyir.symbols import ScalarType +from psyclone.psyir.symbols import BOOLEAN_TYPE +from psyclone.psyir.symbols import ScalarType _ONE = Literal("1", INTEGER_TYPE) @@ -224,19 +233,24 @@ def apply( # copy of it. self.node_routine = self.node_routine.copy() routine_table = self.node_routine.symbol_table + self._remove_unused_optional_arguments() # Construct lists of the nodes that will be inserted and all of the # References that they contain. new_stmts = [] refs = [] for child in self.node_routine.children: + child: Node new_stmts.append(child.copy()) refs.extend(new_stmts[-1].walk(Reference)) # Shallow copy the symbols from the routine into the table at the # call site. - table.merge(routine_table, - symbols_to_skip=routine_table.argument_list[:]) + table.merge( + routine_table, + symbols_to_skip=routine_table.argument_list[:], + # check_unresolved_symbols=self._option_check_argument_unresolved_symbols, + ) # When constructing new references to replace references to formal # args, we need to know whether any of the actual arguments are array @@ -308,6 +322,111 @@ def apply( scope.symbol_table.detach() replacement.attach(scope) + def _remove_ifblock_if_const_args(self, node: Node): + + def if_else_replace(main_schedule, if_block, if_body_schedule): + """Little helper routine to eliminate one branch of an IfBlock + + :param main_schedule: Schedule where if-branch is used + :type main_schedule: Schedule + :param if_block: If-else block itself + :type if_block: IfBlock + :param if_body_schedule: The body of the if or else block + :type if_body_schedule: Schedule + """ + + from psyclone.psyir.nodes import Schedule + + assert isinstance(main_schedule, Schedule) + assert isinstance(if_body_schedule, Schedule) + + # Obtain index in main schedule + idx = main_schedule.children.index(if_block) + + # Detach it + if_block.detach() + + # Insert childreen of if-body schedule + for child in if_body_schedule.children: + main_schedule.addchild(child.copy(), idx) + idx += 1 + + from psyclone.psyir.nodes import IfBlock + + for if_block in node.walk(IfBlock): + if_block: IfBlock + + condition = if_block.condition + + # Check if the condition is a BooleanLiteral + if not isinstance(condition, Literal): + continue + + # Check for right datatype + if ( + condition.datatype.intrinsic + is not ScalarType.Intrinsic.BOOLEAN + ): + continue + + if condition.value == "true": + # Only keep if_block + + if_else_replace(if_block.parent, if_block, if_block.if_body) + + else: + # If there's an else block, replace if-condition with + # else-block + if not if_block.else_body: + if_block.detach() + continue + + if_else_replace(if_block.parent, if_block, if_block.else_body) + + def _remove_unused_optional_arguments(self): + # We first build a lookup table of all optional arguments + # to see whether it's present or not. + optional_sym_present_dict: Dict[str, bool] = dict() + for optional_arg_idx, datasymbol in enumerate( + self.node_routine.symbol_table.datasymbols + ): + if not isinstance(datasymbol.datatype, UnsupportedFortranType): + continue + + if ", OPTIONAL" not in str(datasymbol.datatype): + continue + + sym_name = datasymbol.name.lower() + + if optional_arg_idx not in self._ret_arg_match_list: + optional_sym_present_dict[sym_name] = False + else: + optional_sym_present_dict[sym_name] = True + + # Check if we have any optional arguments at all and if not, return + if len(optional_sym_present_dict) == 0: + return + + # Find all "PRESENT()" calls + for intrinsic_call in self.node_routine.walk(IntrinsicCall): + intrinsic_call: IntrinsicCall + if intrinsic_call.routine.name.lower() == "present": + + # The argument is in the 2nd child + present_arg: Reference = intrinsic_call.children[1] + present_arg_name = present_arg.name.lower() + + assert present_arg_name in optional_sym_present_dict + + if optional_sym_present_dict[present_arg_name]: + # The argument is present. + intrinsic_call.replace_with(Literal("true", BOOLEAN_TYPE)) + else: + intrinsic_call.replace_with(Literal("false", BOOLEAN_TYPE)) + + # Evaluate all if-blocks with constant booleans + self._remove_ifblock_if_const_args(self.node_routine) + def _replace_formal_arg(self, ref, call_node, formal_args): ''' Recursively combines any References to formal arguments in the supplied @@ -337,8 +456,16 @@ def _replace_formal_arg(self, ref, call_node, formal_args): # The supplied reference is not to a formal argument. return ref + # Lookup index in routine argument + routine_arg_idx = formal_args.index(ref.symbol) + + # Lookup index of actual argument + # If this is an optional argument, but not used, this index lookup + # shouldn't fail + actual_arg_idx = self._ret_arg_match_list.index(routine_arg_idx) + # Lookup the actual argument that corresponds to this formal argument. - actual_arg = call_node.arguments[formal_args.index(ref.symbol)] + actual_arg = call_node.arguments[actual_arg_idx] # If the local reference is a simple Reference then we can just # replace it with a copy of the actual argument, e.g. @@ -746,6 +873,7 @@ def validate( NotImplementedError, FileNotFoundError, SymbolError, + TransformationError, ) as err: raise TransformationError( f"Cannot inline routine '{name}' because its source cannot" @@ -936,17 +1064,17 @@ def validate( routine_table.argument_list[i] for i in self._ret_arg_match_list ] - for formal_arg, actual_arg in zip( + for routine_arg, call_arg in zip( routine_arg_list, node_call.arguments ): # If the formal argument is an array with non-default bounds then # we also need to know the bounds of that array at the call site. - if not isinstance(formal_arg.datatype, ArrayType): + if not isinstance(routine_arg.datatype, ArrayType): # Formal argument is not an array so we don't need to do any # further checks. continue - if not isinstance(actual_arg, (Reference, Literal)): + if not isinstance(call_arg, (Reference, Literal)): # TODO #1799 this really needs the `datatype` method to be # extended to support all nodes. For now we have to abort # if we encounter an argument that is not a scalar (according @@ -959,7 +1087,7 @@ def validate( lambda: ( f"The call '{node_call.debug_string()}' " "cannot be inlined because actual argument " - f"'{actual_arg.debug_string()}' corresponds to a " + f"'{call_arg.debug_string()}' corresponds to a " "formal argument with array type but is not a " "Reference or a Literal." ) @@ -974,27 +1102,27 @@ def validate( # that it should attempt to resolve any UnresolvedTypes. if check_array_type: if isinstance( - actual_arg.datatype, (UnresolvedType, UnsupportedType) + call_arg.datatype, (UnresolvedType, UnsupportedType) ) or ( - isinstance(actual_arg.datatype, ArrayType) + isinstance(call_arg.datatype, ArrayType) and isinstance( - actual_arg.datatype.intrinsic, + call_arg.datatype.intrinsic, (UnresolvedType, UnsupportedType), ) ): raise TransformationError( f"Routine '{self.node_routine.name}' cannot be " "inlined because the type of the actual argument " - f"'{actual_arg.symbol.name}' corresponding to an array" - f" formal argument ('{formal_arg.name}') is unknown." + f"'{call_arg.symbol.name}' corresponding to an array" + f" formal argument ('{routine_arg.name}') is unknown." ) formal_rank = 0 actual_rank = 0 - if isinstance(formal_arg.datatype, ArrayType): - formal_rank = len(formal_arg.datatype.shape) - if isinstance(actual_arg.datatype, ArrayType): - actual_rank = len(actual_arg.datatype.shape) + if isinstance(routine_arg.datatype, ArrayType): + formal_rank = len(routine_arg.datatype.shape) + if isinstance(call_arg.datatype, ArrayType): + actual_rank = len(call_arg.datatype.shape) if formal_rank != actual_rank: # It's OK to use the loop variable in the lambda definition # because if we get to this point then we're going to quit @@ -1006,18 +1134,18 @@ def validate( "Cannot inline routine" f" '{self.node_routine.name}' because it" " reshapes an argument: actual argument" - f" '{actual_arg.debug_string()}' has rank" + f" '{call_arg.debug_string()}' has rank" f" {actual_rank} but the corresponding formal" - f" argument, '{formal_arg.name}', has rank" + f" argument, '{routine_arg.name}', has rank" f" {formal_rank}" ) ) ) if actual_rank: - ranges = actual_arg.walk(Range) + ranges = call_arg.walk(Range) for rge in ranges: ancestor_ref = rge.ancestor(Reference) - if ancestor_ref is not actual_arg: + if ancestor_ref is not call_arg: # Have a range in an indirect access. # pylint: disable=cell-var-from-loop raise TransformationError( @@ -1026,7 +1154,7 @@ def validate( "Cannot inline routine" f" '{self.node_routine.name}' because" " argument" - f" '{actual_arg.debug_string()}' has" + f" '{call_arg.debug_string()}' has" " an array range in an indirect" " access #(TODO 924)." ) @@ -1044,7 +1172,7 @@ def validate( f" '{self.node_routine.name}' because" " one of its arguments is an array" " slice with a non-unit stride:" - f" '{actual_arg.debug_string()}' (TODO" + f" '{call_arg.debug_string()}' (TODO" " #1646)" ) ) diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index 9c50465d57..3401ad13b8 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -40,7 +40,14 @@ import pytest from psyclone.configuration import Config -from psyclone.psyir.nodes import Call, IntrinsicCall, Reference, Routine, Loop +from psyclone.psyir.nodes import ( + Call, + IntrinsicCall, + Loop, + Node, + Reference, + Routine, +) from psyclone.psyir.symbols import ( AutomaticInterface, DataSymbol, @@ -2358,12 +2365,10 @@ def test_validate_non_unit_stride_slice(fortran_reader): ) -def test_validate_named_arg(fortran_reader): - """Test that the validate method rejects an attempt to inline a routine - that has a named argument.""" - # In reality, the routine with a named argument would almost certainly - # use the 'present' intrinsic but, since that gives a CodeBlock that itself - # prevents inlining, our test example omits it. +def test_apply_named_arg(fortran_reader): + """Test that the validate method inlines a routine that has a named + argument.""" + code = ( "module test_mod\n" "contains\n" @@ -2373,10 +2378,35 @@ def test_validate_named_arg(fortran_reader): "end subroutine main\n" "subroutine sub(x, opt)\n" " real, intent(inout) :: x\n" + " real :: opt\n" + " x = x + 1.0\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + + inline_trans.apply(call) + + +def test_validate_optional_arg(fortran_reader): + """Test that the validate method inlines a routine + that has an optional argument.""" + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var)\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" " real, optional :: opt\n" - " !if( present(opt) )then\n" - " ! x = x + opt\n" - " !end if\n" + " if( present(opt) )then\n" + " x = x + opt\n" + " end if\n" " x = x + 1.0\n" "end subroutine sub\n" "end module test_mod\n" @@ -2384,7 +2414,117 @@ def test_validate_named_arg(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() - inline_trans.validate(call) + inline_trans.apply(call) + + +def test_validate_optional_and_named_arg(fortran_reader): + """Test that the validate method inlines a routine + that has an optional argument.""" + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var, named=1.0)\n" + " ! Result:\n" + " ! var = var + 1.0 + 1.0\n" + " call sub(var, 2.0, named=1.0)\n" + " ! Result:\n" + " ! var = var + 2.0\n" + " ! var = var + 1.0 + 1.0\n" + "end subroutine main\n" + "subroutine sub(x, opt, named)\n" + " real, intent(inout) :: x\n" + " real, optional :: opt\n" + " real :: named\n" + " if( present(opt) )then\n" + " x = x + opt\n" + " end if\n" + " x = x + 1.0 + named\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir: Node = fortran_reader.psyir_from_source(code) + + inline_trans = InlineTrans() + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + for call in psyir.walk(Call, stop_type=Call): + call: Call + if call.routine.name != "sub": + continue + + inline_trans.apply(call) + + assert ( + """var = var + 1.0 + 1.0 + var = var + 2.0 + var = var + 1.0 + 1.0""" + in routine_main.debug_string() + ) + + +def test_validate_optional_and_named_arg_2(fortran_reader): + """Test that the validate method inlines a routine + that has an optional argument.""" + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var, 1.0)\n" + " ! Result:\n" + " ! var = var + 2.0 + 1.0\n" + " ! var = var + 4.0 + 1.0\n" + " ! var = var + 5.0 + 1.0\n" + " call sub(var)\n" + " ! Result:\n" + " ! var = var + 3.0\n" + " ! var = var + 6.0\n" + " ! var = var + 7.0\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" + " real, optional :: opt\n" + " if( present(opt) )then\n" + " x = x + 2.0 + opt\n" + " else\n" + " x = x + 3.0\n" + " end if\n" + " if( present(opt) )then\n" + " x = x + 4.0 + opt\n" + " x = x + 5.0 + opt\n" + " else\n" + " x = x + 6.0\n" + " x = x + 7.0\n" + " end if\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir: Node = fortran_reader.psyir_from_source(code) + + inline_trans = InlineTrans() + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + for call in psyir.walk(Call, stop_type=Call): + call: Call + if call.routine.name != "sub": + continue + + inline_trans.apply(call) + + print(routine_main.debug_string()) + assert ( + """var = var + 2.0 + 1.0 + var = var + 4.0 + 1.0 + var = var + 5.0 + 1.0 + var = var + 3.0 + var = var + 6.0 + var = var + 7.0""" + in routine_main.debug_string() + ) CALL_IN_SUB_USE = ( From 87f5172a77983885680af4adaadbe5f71b3f56e4 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sun, 24 Nov 2024 11:25:50 +0100 Subject: [PATCH 06/20] cleanups --- src/psyclone/psyir/nodes/__init__.py | 190 +-- src/psyclone/psyir/nodes/call.py | 28 +- src/psyclone/psyir/symbols/containersymbol.py | 13 +- src/psyclone/psyir/symbols/symbol_table.py | 7 +- .../psyir/transformations/inline_trans.py | 6 +- .../tests/my_shortcut_tests/call_test.py | 1 - .../my_shortcut_tests/inline_trans_test.py | 1 - src/psyclone/tests/psyir/nodes/call_test.py | 599 +++---- .../transformations/inline_trans_test.py | 1516 +++++++---------- 9 files changed, 1050 insertions(+), 1311 deletions(-) delete mode 120000 src/psyclone/tests/my_shortcut_tests/call_test.py delete mode 120000 src/psyclone/tests/my_shortcut_tests/inline_trans_test.py diff --git a/src/psyclone/psyir/nodes/__init__.py b/src/psyclone/psyir/nodes/__init__.py index ed0d94ae35..7efe06b811 100644 --- a/src/psyclone/psyir/nodes/__init__.py +++ b/src/psyclone/psyir/nodes/__init__.py @@ -74,7 +74,7 @@ from psyclone.psyir.nodes.statement import Statement from psyclone.psyir.nodes.structure_reference import StructureReference from psyclone.psyir.nodes.structure_member import StructureMember -from psyclone.psyir.nodes.call import Call, CallMatchingArgumentsNotFound +from psyclone.psyir.nodes.call import Call, CallMatchingArgumentsNotFoundError from psyclone.psyir.nodes.file_container import FileContainer from psyclone.psyir.nodes.directive import ( Directive, StandaloneDirective, RegionDirective) @@ -104,97 +104,97 @@ # The entities in the __all__ list are made available to import directly from # this package e.g. 'from psyclone.psyir.nodes import Literal' __all__ = [ - 'colored', - 'ArrayMember', - 'ArrayReference', - 'ArrayOfStructuresMember', - 'ArrayOfStructuresReference', - 'Assignment', - 'BinaryOperation', - 'Call', - "CallMatchingArgumentsNotFound", - 'Clause', - 'CodeBlock', - 'Container', - 'DataNode', - 'FileContainer', - 'IfBlock', - 'IntrinsicCall', - 'Literal', - 'Loop', - 'Member', - 'Node', - 'OperandClause', - 'Operation', - 'Range', - 'Reference', - 'Return', - 'Routine', - 'Schedule', - 'Statement', - 'StructureMember', - 'StructureReference', - 'UnaryOperation', - 'ScopingNode', - 'WhileLoop', - # PSyclone-specific nodes - 'KernelSchedule', - # PSyData Nodes - 'PSyDataNode', - 'ExtractNode', - 'ProfileNode', - 'ReadOnlyVerifyNode', - 'ValueRangeCheckNode', - # Directive Nodes - 'Directive', - 'RegionDirective', - 'StandaloneDirective', - # OpenACC Directive Nodes - 'ACCAtomicDirective', - 'ACCDirective', - 'ACCRegionDirective', - 'ACCStandaloneDirective', - 'ACCDataDirective', - 'ACCEnterDataDirective', - 'ACCParallelDirective', - 'ACCLoopDirective', - 'ACCKernelsDirective', - 'ACCUpdateDirective', - 'ACCRoutineDirective', - # OpenACC Clause Nodes - 'ACCCopyClause', - 'ACCCopyInClause', - 'ACCCopyOutClause', - # OpenMP Directive Nodes - 'OMPAtomicDirective', - 'OMPDirective', - 'OMPRegionDirective', - 'OMPStandaloneDirective', - 'OMPParallelDirective', - 'OMPSerialDirective', - 'OMPSingleDirective', - 'OMPMasterDirective', - 'OMPTaskloopDirective', - 'OMPTaskDirective', - 'DynamicOMPTaskDirective', - 'OMPDoDirective', - 'OMPParallelDoDirective', - 'OMPTaskwaitDirective', - 'OMPTargetDirective', - 'OMPLoopDirective', - 'OMPDeclareTargetDirective', - 'OMPSimdDirective', - 'OMPTeamsDistributeParallelDoDirective', - # OMP Clause Nodes - 'OMPGrainsizeClause', - 'OMPNogroupClause', - 'OMPNowaitClause', - 'OMPNumTasksClause', - 'OMPPrivateClause', - 'OMPDefaultClause', - 'OMPReductionClause', - 'OMPScheduleClause', - 'OMPFirstprivateClause', - 'OMPSharedClause', - 'OMPDependClause' - ] + "colored", + "ArrayMember", + "ArrayReference", + "ArrayOfStructuresMember", + "ArrayOfStructuresReference", + "Assignment", + "BinaryOperation", + "Call", + "CallMatchingArgumentsNotFoundError", + "Clause", + "CodeBlock", + "Container", + "DataNode", + "FileContainer", + "IfBlock", + "IntrinsicCall", + "Literal", + "Loop", + "Member", + "Node", + "OperandClause", + "Operation", + "Range", + "Reference", + "Return", + "Routine", + "Schedule", + "Statement", + "StructureMember", + "StructureReference", + "UnaryOperation", + "ScopingNode", + "WhileLoop", + # PSyclone-specific nodes + "KernelSchedule", + # PSyData Nodes + "PSyDataNode", + "ExtractNode", + "ProfileNode", + "ReadOnlyVerifyNode", + "ValueRangeCheckNode", + # Directive Nodes + "Directive", + "RegionDirective", + "StandaloneDirective", + # OpenACC Directive Nodes + "ACCAtomicDirective", + "ACCDirective", + "ACCRegionDirective", + "ACCStandaloneDirective", + "ACCDataDirective", + "ACCEnterDataDirective", + "ACCParallelDirective", + "ACCLoopDirective", + "ACCKernelsDirective", + "ACCUpdateDirective", + "ACCRoutineDirective", + # OpenACC Clause Nodes + "ACCCopyClause", + "ACCCopyInClause", + "ACCCopyOutClause", + # OpenMP Directive Nodes + "OMPAtomicDirective", + "OMPDirective", + "OMPRegionDirective", + "OMPStandaloneDirective", + "OMPParallelDirective", + "OMPSerialDirective", + "OMPSingleDirective", + "OMPMasterDirective", + "OMPTaskloopDirective", + "OMPTaskDirective", + "DynamicOMPTaskDirective", + "OMPDoDirective", + "OMPParallelDoDirective", + "OMPTaskwaitDirective", + "OMPTargetDirective", + "OMPLoopDirective", + "OMPDeclareTargetDirective", + "OMPSimdDirective", + "OMPTeamsDistributeParallelDoDirective", + # OMP Clause Nodes + "OMPGrainsizeClause", + "OMPNogroupClause", + "OMPNowaitClause", + "OMPNumTasksClause", + "OMPPrivateClause", + "OMPDefaultClause", + "OMPReductionClause", + "OMPScheduleClause", + "OMPFirstprivateClause", + "OMPSharedClause", + "OMPDependClause", +] diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index 75754de8b0..bb61beac91 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -60,7 +60,7 @@ from psyclone.psyir.symbols.datatypes import ArrayType -class CallMatchingArgumentsNotFound(PSycloneError): +class CallMatchingArgumentsNotFoundError(PSycloneError): '''Exception to signal that matching arguments have not been found for this routine ''' @@ -477,11 +477,15 @@ def _get_container_symbols_rec( :param _stack_container_list: Stack with already visited Containers to avoid circular searches, defaults to [] :type _stack_container_list: List[Container], optional + :param _depth: Depth of recursive search + :type _depth: int """ # - # TODO: This function seems to be extremely slow: - # It takes considerable time to build this list over and over - # for each lookup. + # TODO: + # - This function seems to be extremely slow: + # It takes considerable time to build this list over and over + # for each lookup. + # - This function can also be written in a non-resursive way # # An alternative would be to cache it, but then the cache # needs to be invalidated once some symbols are, e.g., deleted. @@ -899,14 +903,14 @@ def _check_argument_type_matches( # This could be an 'optional' argument. # This has at least a partial data type if call_arg.datatype != routine_arg.datatype.partial_datatype: - raise CallMatchingArgumentsNotFound( - f"Argument partial type mismatch of call " + raise CallMatchingArgumentsNotFoundError( + "Argument partial type mismatch of call " f"argument '{call_arg}' and routine argument " f"'{routine_arg}'" ) else: if call_arg.datatype != routine_arg.datatype: - raise CallMatchingArgumentsNotFound( + raise CallMatchingArgumentsNotFoundError( "Argument type mismatch of call argument " f"'{call_arg}' with type '{call_arg.datatype} " "and routine argument " @@ -946,7 +950,7 @@ def _get_argument_routine_match( if len(self.arguments) > len(routine.symbol_table.argument_list): call_str = self.debug_string().replace("\n", "") - raise CallMatchingArgumentsNotFound( + raise CallMatchingArgumentsNotFoundError( f"More arguments in call ('{call_str}')" f" than callee (routine '{routine.name}')" ) @@ -999,7 +1003,7 @@ def _get_argument_routine_match( else: # It doesn't match => Raise exception - raise CallMatchingArgumentsNotFound( + raise CallMatchingArgumentsNotFoundError( f"Named argument '{arg_name}' not found" ) @@ -1019,7 +1023,7 @@ def _get_argument_routine_match( if ", OPTIONAL" in str(routine_arg.datatype): continue - raise CallMatchingArgumentsNotFound( + raise CallMatchingArgumentsNotFoundError( f"Argument '{routine_arg.name}' in subroutine" f" '{routine.name}' not handled" ) @@ -1073,7 +1077,7 @@ def get_callee( routine_node, check_strict_array_datatype=check_strict_array_datatype, ) - except CallMatchingArgumentsNotFound as err: + except CallMatchingArgumentsNotFoundError as err: err_info_list.append(err.value) continue @@ -1089,7 +1093,7 @@ def get_callee( error_msg = "\n".join(err_info_list) - raise CallMatchingArgumentsNotFound( + raise CallMatchingArgumentsNotFoundError( "Found routines, but no routine with matching arguments found " f"for '{self.routine.name}':\n" + error_msg diff --git a/src/psyclone/psyir/symbols/containersymbol.py b/src/psyclone/psyir/symbols/containersymbol.py index 013b3a4f53..02d211de45 100644 --- a/src/psyclone/psyir/symbols/containersymbol.py +++ b/src/psyclone/psyir/symbols/containersymbol.py @@ -128,7 +128,7 @@ def copy(self): def find_container_psyir( self, local_node=None, ignore_missing_modules: bool = False ): - ''' Searches for the Container that this Symbol refers to. If it is + """Searches for the Container that this Symbol refers to. If it is not available, use the interface to import the container. If `local_node` is supplied then the PSyIR tree below it is searched for the container first. @@ -137,10 +137,14 @@ def find_container_psyir( the container. :type local_node: Optional[:py:class:`psyclone.psyir.nodes.Node`] + :param ignore_missing_modules: If 'True', no ModuleNotFound exception= + is raised in case in case the module wasn't found. + :type ignore_missing_modules: bool + :returns: referenced container. :rtype: :py:class:`psyclone.psyir.nodes.Container` - ''' + """ if not self._reference: # First check in the current PSyIR tree (if supplied). if local_node: @@ -153,7 +157,10 @@ def find_container_psyir( self._reference = local return self._reference # We didn't find it so now attempt to import the container. - self._reference = self._interface.get_container(self._name) + try: + self._reference = self._interface.get_container(self._name) + except ModuleNotFoundError: + return None return self._reference def __str__(self): diff --git a/src/psyclone/psyir/symbols/symbol_table.py b/src/psyclone/psyir/symbols/symbol_table.py index 368ad214c8..3afa5aea5f 100644 --- a/src/psyclone/psyir/symbols/symbol_table.py +++ b/src/psyclone/psyir/symbols/symbol_table.py @@ -588,7 +588,7 @@ def add(self, new_symbol, tag=None): def check_for_clashes( self, other_table, symbols_to_skip=(), check_unresolved_symbols=True ): - ''' + """ Checks the symbols in the supplied table against those in this table. If there is a name clash that cannot be resolved by renaming then a SymbolError is raised. Any symbols appearing @@ -600,13 +600,16 @@ def check_for_clashes( the check. :type symbols_to_skip: Iterable[ :py:class:`psyclone.psyir.symbols.Symbol`] + :param check_unresolved_symbols: If 'True', also check unresolved + symbols + :type check_unresolved_symbols: bool :raises TypeError: if symbols_to_skip is supplied but is not an instance of Iterable. :raises SymbolError: if there would be an unresolvable name clash when importing symbols from `other_table` into this table. - ''' + """ # pylint: disable-next=import-outside-toplevel from psyclone.psyir.nodes import IntrinsicCall diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 0a397037f5..183675aaad 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -62,7 +62,7 @@ from psyclone.psyir.transformations.transformation_error import ( TransformationError) -from psyclone.psyir.nodes import CallMatchingArgumentsNotFound +from psyclone.psyir.nodes import CallMatchingArgumentsNotFoundError from typing import Dict, List from psyclone.psyir.symbols import BOOLEAN_TYPE @@ -869,7 +869,7 @@ def validate( ) ) except ( - CallMatchingArgumentsNotFound, + CallMatchingArgumentsNotFoundError, NotImplementedError, FileNotFoundError, SymbolError, @@ -890,7 +890,7 @@ def validate( check_strict_array_datatype=False, ) ) - except CallMatchingArgumentsNotFound as err: + except CallMatchingArgumentsNotFoundError as err: raise TransformationError( "Routine's arguments doesn't match subroutine" ) from err diff --git a/src/psyclone/tests/my_shortcut_tests/call_test.py b/src/psyclone/tests/my_shortcut_tests/call_test.py deleted file mode 120000 index 61527f250c..0000000000 --- a/src/psyclone/tests/my_shortcut_tests/call_test.py +++ /dev/null @@ -1 +0,0 @@ -../psyir/nodes/call_test.py \ No newline at end of file diff --git a/src/psyclone/tests/my_shortcut_tests/inline_trans_test.py b/src/psyclone/tests/my_shortcut_tests/inline_trans_test.py deleted file mode 120000 index 9f34f83693..0000000000 --- a/src/psyclone/tests/my_shortcut_tests/inline_trans_test.py +++ /dev/null @@ -1 +0,0 @@ -../psyir/transformations/inline_trans_test.py \ No newline at end of file diff --git a/src/psyclone/tests/psyir/nodes/call_test.py b/src/psyclone/tests/psyir/nodes/call_test.py index 853fcead1e..4fc30e443a 100644 --- a/src/psyclone/tests/psyir/nodes/call_test.py +++ b/src/psyclone/tests/psyir/nodes/call_test.py @@ -34,7 +34,7 @@ # Authors: R. W. Ford and A. R. Porter, STFC Daresbury Lab # ----------------------------------------------------------------------------- -"""Performs py.test tests on the Call PSyIR node.""" +''' Performs py.test tests on the Call PSyIR node. ''' import os import pytest @@ -66,18 +66,18 @@ ) from psyclone.errors import GenerationError -from psyclone.psyir.nodes.call import CallMatchingArgumentsNotFound +from psyclone.psyir.nodes.call import CallMatchingArgumentsNotFoundError class SpecialCall(Call): - """Test Class specialising the Call class""" + '''Test Class specialising the Call class''' def test_call_init(): - """Test that a Call can be created as expected. Also test the routine + '''Test that a Call can be created as expected. Also test the routine property. - """ + ''' # Initialise without a RoutineSymbol call = Call() # By default everything is None @@ -92,15 +92,15 @@ def test_call_init(): routine = RoutineSymbol("jo", NoType()) call = Call(parent=parent) call.addchild(Reference(routine)) - call.addchild(Literal("3", INTEGER_TYPE)) + call.addchild(Literal('3', INTEGER_TYPE)) assert call.routine.symbol is routine assert call.parent is parent - assert call.arguments == [Literal("3", INTEGER_TYPE)] + assert call.arguments == [Literal('3', INTEGER_TYPE)] def test_call_is_elemental(): - """Test the is_elemental property of a Call is set correctly and can be - queried.""" + '''Test the is_elemental property of a Call is set correctly and can be + queried.''' routine = RoutineSymbol("zaphod", NoType()) call = Call.create(routine) assert call.is_elemental is None @@ -110,8 +110,8 @@ def test_call_is_elemental(): def test_call_is_pure(): - """Test the is_pure property of a Call is set correctly and can be - queried.""" + '''Test the is_pure property of a Call is set correctly and can be + queried.''' routine = RoutineSymbol("zaphod", NoType()) call = Call.create(routine) assert call.is_pure is None @@ -121,15 +121,15 @@ def test_call_is_pure(): def test_call_is_available_on_device(): - """Test the is_available_on_device() method of a Call (currently always - returns False).""" + '''Test the is_available_on_device() method of a Call (currently always + returns False). ''' routine = RoutineSymbol("zaphod", NoType()) call = Call.create(routine) assert call.is_available_on_device() is False def test_call_equality(): - """Test the __eq__ method of the Call class.""" + '''Test the __eq__ method of the Call class. ''' # routine arguments routine = RoutineSymbol("j", NoType()) routine2 = RoutineSymbol("k", NoType()) @@ -154,114 +154,92 @@ def test_call_equality(): assert call4 != call7 # Check when a Reference (to the same RoutineSymbol) is provided. - call8 = Call.create( - Reference(routine), [("new_name", Literal("1.0", REAL_TYPE))] - ) + call8 = Call.create(Reference(routine), + [("new_name", Literal("1.0", REAL_TYPE))]) assert call8 == call7 @pytest.mark.parametrize("cls", [Call, SpecialCall]) def test_call_create(cls): - """Test that the create method creates a valid call with arguments, + '''Test that the create method creates a valid call with arguments, some of which are named. Also checks the routine and argument_names properties. - """ + ''' routine = RoutineSymbol("ellie", INTEGER_TYPE) array_type = ArrayType(INTEGER_TYPE, shape=[10, 20]) - arguments = [ - Reference(DataSymbol("arg1", INTEGER_TYPE)), - ArrayReference(DataSymbol("arg2", array_type)), - ] + arguments = [Reference(DataSymbol("arg1", INTEGER_TYPE)), + ArrayReference(DataSymbol("arg2", array_type))] call = cls.create(routine, [arguments[0], ("name", arguments[1])]) # pylint: disable=unidiomatic-typecheck assert type(call) is cls assert call.routine.symbol is routine assert call.argument_names == [None, "name"] - for ( - idx, - child, - ) in enumerate(call.arguments): + for idx, child, in enumerate(call.arguments): assert child is arguments[idx] assert child.parent is call def test_call_create_error1(): - """Test that the appropriate exception is raised if the routine + '''Test that the appropriate exception is raised if the routine argument to the create method is not a RoutineSymbol. - """ + ''' with pytest.raises(TypeError) as info: _ = Call.create(None, []) - assert ( - "The Call routine argument should be a Reference to a " - "RoutineSymbol or a RoutineSymbol, but found " - "'NoneType'." - in str(info.value) - ) + assert ("The Call routine argument should be a Reference to a " + "RoutineSymbol or a RoutineSymbol, but found " + "'NoneType'." in str(info.value)) def test_call_create_error2(): - """Test that the appropriate exception is raised if the arguments - argument to the create method is not a list""" + '''Test that the appropriate exception is raised if the arguments + argument to the create method is not a list''' routine = RoutineSymbol("isaac", NoType()) with pytest.raises(GenerationError) as info: _ = Call.create(routine, None) - assert ( - "Call.create 'arguments' argument should be an Iterable but found " - "'NoneType'." - in str(info.value) - ) + assert ("Call.create 'arguments' argument should be an Iterable but found " + "'NoneType'." in str(info.value)) def test_call_create_error3(): - """Test that the appropriate exception is raised if one or more of the - argument names is not valid.""" + '''Test that the appropriate exception is raised if one or more of the + argument names is not valid.''' routine = RoutineSymbol("roo", INTEGER_TYPE) with pytest.raises(ValueError) as info: _ = Call.create( - routine, - [Reference(DataSymbol("arg1", INTEGER_TYPE)), (" a", None)], - ) + routine, [Reference(DataSymbol( + "arg1", INTEGER_TYPE)), (" a", None)]) assert "Invalid Fortran name ' a' found." in str(info.value) def test_call_create_error4(): - """Test that the appropriate exception is raised if one or more of the + '''Test that the appropriate exception is raised if one or more of the arguments argument list entries to the create method is not a DataNode. - """ + ''' routine = RoutineSymbol("roo", INTEGER_TYPE) with pytest.raises(GenerationError) as info: _ = Call.create( - routine, - [Reference(DataSymbol("arg1", INTEGER_TYPE)), ("name", None)], - ) - assert ( - "Item 'NoneType' can't be child 2 of 'Call'. The valid format " - "is: 'Reference, [DataNode]*'." - in str(info.value) - ) + routine, [Reference(DataSymbol( + "arg1", INTEGER_TYPE)), ("name", None)]) + assert ("Item 'NoneType' can't be child 2 of 'Call'. The valid format " + "is: 'Reference, [DataNode]*'." in str(info.value)) def test_call_add_args(): - """Test the _add_args method in the Call class.""" + '''Test the _add_args method in the Call class.''' routine = RoutineSymbol("myeloma", INTEGER_TYPE) call = Call.create(routine) array_type = ArrayType(INTEGER_TYPE, shape=[10, 20]) - arguments = [ - Reference(DataSymbol("arg1", INTEGER_TYPE)), - ArrayReference(DataSymbol("arg2", array_type)), - ] + arguments = [Reference(DataSymbol("arg1", INTEGER_TYPE)), + ArrayReference(DataSymbol("arg2", array_type))] Call._add_args(call, [arguments[0], ("name", arguments[1])]) assert call.routine.symbol is routine assert call.argument_names == [None, "name"] - for ( - idx, - child, - ) in enumerate(call.arguments): + for idx, child, in enumerate(call.arguments): assert child is arguments[idx] assert child.parent is call # For some reason pylint thinks that call.children[0,1] are of @@ -273,42 +251,37 @@ def test_call_add_args(): def test_call_add_args_error1(): - """Test that the appropriate exception is raised if an entry in the + '''Test that the appropriate exception is raised if an entry in the arguments argument to the _add_args method is a tuple that does not have two elements. - """ + ''' routine = RoutineSymbol("isaac", NoType()) with pytest.raises(GenerationError) as info: _ = Call._add_args(routine, [(1, 2, 3)]) - assert ( - "If a child of the children argument in create method of Call " - "class is a tuple, it's length should be 2, but found 3." - in str(info.value) - ) + assert ("If a child of the children argument in create method of Call " + "class is a tuple, it's length should be 2, but found 3." + in str(info.value)) def test_call_add_args_error2(): - """Test that the appropriate exception is raised if an entry in the + '''Test that the appropriate exception is raised if an entry in the arguments argument to the _add_args method is is a tuple with two - elements and the first element is not a string.""" + elements and the first element is not a string.''' routine = RoutineSymbol("isaac", NoType()) with pytest.raises(GenerationError) as info: _ = Call._add_args(routine, [(1, 2)]) - assert ( - "If a child of the children argument in create method of Call " - "class is a tuple, its first argument should be a str, but " - "found int." - in str(info.value) - ) + assert ("If a child of the children argument in create method of Call " + "class is a tuple, its first argument should be a str, but " + "found int." in str(info.value)) def test_call_appendnamedarg(): - """Test the append_named_arg method in the Call class. Check + '''Test the append_named_arg method in the Call class. Check it raises the expected exceptions if arguments are invalid and that it works as expected when the input is valid. - """ + ''' op1 = Literal("1", INTEGER_TYPE) op2 = Literal("2", INTEGER_TYPE) op3 = Literal("3", INTEGER_TYPE) @@ -316,7 +289,8 @@ def test_call_appendnamedarg(): # name arg wrong type with pytest.raises(TypeError) as info: call.append_named_arg(1, op1) - assert "A name should be a string, but found 'int'." in str(info.value) + assert ("A name should be a string, but found 'int'." + in str(info.value)) # invalid name with pytest.raises(ValueError) as info: call.append_named_arg("_", op1) @@ -325,11 +299,9 @@ def test_call_appendnamedarg(): call.append_named_arg("name1", op1) with pytest.raises(ValueError) as info: call.append_named_arg("name1", op2) - assert ( - "The value of the name argument (name1) in 'append_named_arg' in " - "the 'Call' node is already used for a named argument." - in str(info.value) - ) + assert ("The value of the name argument (name1) in 'append_named_arg' in " + "the 'Call' node is already used for a named argument." + in str(info.value)) # ok call.append_named_arg("name2", op2) call.append_named_arg(None, op3) @@ -338,11 +310,11 @@ def test_call_appendnamedarg(): def test_call_insertnamedarg(): - """Test the insert_named_arg method in the Call class. Check + '''Test the insert_named_arg method in the Call class. Check it raises the expected exceptions if arguments are invalid and that it works as expected when the input is valid. - """ + ''' op1 = Literal("1", INTEGER_TYPE) op2 = Literal("2", INTEGER_TYPE) op3 = Literal("3", INTEGER_TYPE) @@ -350,7 +322,8 @@ def test_call_insertnamedarg(): # name arg wrong type with pytest.raises(TypeError) as info: call.insert_named_arg(1, op1, 0) - assert "A name should be a string, but found 'int'." in str(info.value) + assert ("A name should be a string, but found 'int'." + in str(info.value)) # invalid name with pytest.raises(ValueError) as info: call.insert_named_arg("1", op1, 0) @@ -359,19 +332,14 @@ def test_call_insertnamedarg(): call.insert_named_arg("name1", op1, 0) with pytest.raises(ValueError) as info: call.insert_named_arg("name1", op2, 0) - assert ( - "The value of the name argument (name1) in 'insert_named_arg' in " - "the 'Call' node is already used for a named argument." - in str(info.value) - ) + assert ("The value of the name argument (name1) in 'insert_named_arg' in " + "the 'Call' node is already used for a named argument." + in str(info.value)) # invalid index type with pytest.raises(TypeError) as info: call.insert_named_arg("name2", op2, "hello") - assert ( - "The 'index' argument in 'insert_named_arg' in the 'Call' node " - "should be an int but found str." - in str(info.value) - ) + assert ("The 'index' argument in 'insert_named_arg' in the 'Call' node " + "should be an int but found str." in str(info.value)) # ok assert call.arguments == [op1] assert call.argument_names == ["name1"] @@ -384,35 +352,28 @@ def test_call_insertnamedarg(): def test_call_replacenamedarg(): - """Test the replace_named_arg method in the Call class. Check + '''Test the replace_named_arg method in the Call class. Check it raises the expected exceptions if arguments are invalid and that it works as expected when the input is valid. - """ + ''' op1 = Literal("1", INTEGER_TYPE) op2 = Literal("2", INTEGER_TYPE) op3 = Literal("3", INTEGER_TYPE) - call = Call.create( - RoutineSymbol("hello"), [("name1", op1), ("name2", op2)] - ) + call = Call.create(RoutineSymbol("hello"), + [("name1", op1), ("name2", op2)]) # name arg wrong type with pytest.raises(TypeError) as info: call.replace_named_arg(1, op3) - assert ( - "The 'name' argument in 'replace_named_arg' in the 'Call' " - "node should be a string, but found int." - in str(info.value) - ) + assert ("The 'name' argument in 'replace_named_arg' in the 'Call' " + "node should be a string, but found int." in str(info.value)) # name arg is not found with pytest.raises(ValueError) as info: call.replace_named_arg("new_name", op3) - assert ( - "The value of the existing_name argument (new_name) in " - "'replace_named_arg' in the 'Call' node was not found in the " - "existing arguments." - in str(info.value) - ) + assert ("The value of the existing_name argument (new_name) in " + "'replace_named_arg' in the 'Call' node was not found in the " + "existing arguments." in str(info.value)) # ok assert call.arguments == [op1, op2] assert call.argument_names == ["name1", "name2"] @@ -426,7 +387,7 @@ def test_call_replacenamedarg(): def test_call_reference_accesses(): - """Test the reference_accesses() method.""" + '''Test the reference_accesses() method.''' rsym = RoutineSymbol("trillian") # A call with an argument passed by value. call1 = Call.create(rsym, [Literal("1", INTEGER_TYPE)]) @@ -451,20 +412,16 @@ def test_call_reference_accesses(): assert var_info.has_read_write(Signature("gamma")) assert var_info.is_read(Signature("ji")) # Argument is a temporary so any inputs to it are READ only. - expr = BinaryOperation.create( - BinaryOperation.Operator.MUL, - Literal("2", INTEGER_TYPE), - Reference(dsym), - ) + expr = BinaryOperation.create(BinaryOperation.Operator.MUL, + Literal("2", INTEGER_TYPE), Reference(dsym)) call4 = Call.create(rsym, [expr]) var_info = VariablesAccessInfo() call4.reference_accesses(var_info) assert var_info.is_read(Signature("beta")) # Argument is itself a function call: call trillian(some_func(gamma(ji))) fsym = RoutineSymbol("some_func") - fcall = Call.create( - fsym, [ArrayReference.create(asym, [Reference(idx_sym)])] - ) + fcall = Call.create(fsym, + [ArrayReference.create(asym, [Reference(idx_sym)])]) call5 = Call.create(rsym, [fcall]) call5.reference_accesses(var_info) assert var_info.has_read_write(Signature("gamma")) @@ -479,11 +436,11 @@ def test_call_reference_accesses(): def test_call_argumentnames_after_removearg(): - """Test the argument_names property makes things consistent if a child + '''Test the argument_names property makes things consistent if a child argument is removed. This is used transparently by the class to keep things consistent. - """ + ''' op1 = Literal("1", INTEGER_TYPE) op2 = Literal("1", INTEGER_TYPE) call = Call.create(RoutineSymbol("name"), [("name1", op1), ("name2", op2)]) @@ -499,11 +456,11 @@ def test_call_argumentnames_after_removearg(): def test_call_argumentnames_after_addarg(): - """Test the argument_names property makes things consistent if a child + '''Test the argument_names property makes things consistent if a child argument is added. This is used transparently by the class to keep things consistent. - """ + ''' op1 = Literal("1", INTEGER_TYPE) op2 = Literal("1", INTEGER_TYPE) op3 = Literal("1", INTEGER_TYPE) @@ -520,11 +477,11 @@ def test_call_argumentnames_after_addarg(): def test_call_argumentnames_after_replacearg(): - """Test the argument_names property makes things consistent if a child + '''Test the argument_names property makes things consistent if a child argument is replaced. This is used transparently by the class to keep things consistent. - """ + ''' op1 = Literal("1", INTEGER_TYPE) op2 = Literal("1", INTEGER_TYPE) op3 = Literal("1", INTEGER_TYPE) @@ -543,11 +500,11 @@ def test_call_argumentnames_after_replacearg(): def test_call_argumentnames_after_reorderarg(): - """Test the argument_names property makes things consistent if a child + '''Test the argument_names property makes things consistent if a child argument is replaced. This is used transparently by the class to keep things consistent. - """ + ''' op1 = Literal("1", INTEGER_TYPE) op2 = Literal("1", INTEGER_TYPE) op3 = Literal("1", INTEGER_TYPE) @@ -564,10 +521,10 @@ def test_call_argumentnames_after_reorderarg(): def test_call_node_reconcile_add(): - """Test that the reconcile method behaves as expected. Use an example + '''Test that the reconcile method behaves as expected. Use an example where we add a new arg. - """ + ''' op1 = Literal("1", INTEGER_TYPE) op2 = Literal("1", INTEGER_TYPE) op3 = Literal("1", INTEGER_TYPE) @@ -590,10 +547,10 @@ def test_call_node_reconcile_add(): def test_call_node_reconcile_reorder(): - """Test that the reconcile method behaves as expected. Use an example + '''Test that the reconcile method behaves as expected. Use an example where we reorder the arguments. - """ + ''' op1 = Literal("1", INTEGER_TYPE) op2 = Literal("2", INTEGER_TYPE) call = Call.create(RoutineSymbol("name"), [("name1", op1), ("name2", op2)]) @@ -617,22 +574,22 @@ def test_call_node_reconcile_reorder(): def test_call_node_str(): - """Test that the node_str method behaves as expected""" + ''' Test that the node_str method behaves as expected ''' routine = RoutineSymbol("isaac", NoType()) call = Call.create(routine) colouredtext = colored("Call", Call._colour) - assert call.node_str() == colouredtext + "[name='isaac']" + assert call.node_str() == colouredtext+"[name='isaac']" def test_call_str(): - """Test that the str method behaves as expected""" + ''' Test that the str method behaves as expected ''' routine = RoutineSymbol("roo", NoType()) call = Call.create(routine) assert str(call) == "Call[name='roo']" def test_copy(): - """Test that the copy() method behaves as expected.""" + ''' Test that the copy() method behaves as expected. ''' op1 = Literal("1", INTEGER_TYPE) op2 = Literal("2", INTEGER_TYPE) call = Call.create(RoutineSymbol("name"), [("name1", op1), ("name2", op2)]) @@ -665,11 +622,11 @@ def test_copy(): def test_call_get_callees_local(fortran_reader): - """ + ''' Check that get_callees() works as expected when the target of the Call exists in the same Container as the call site. - """ - code = """ + ''' + code = ''' module some_mod implicit none integer :: luggage @@ -682,7 +639,7 @@ def test_call_get_callees_local(fortran_reader): subroutine bottom() luggage = luggage + 1 end subroutine bottom -end module some_mod""" +end module some_mod''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] result = call.get_callees() @@ -690,11 +647,11 @@ def test_call_get_callees_local(fortran_reader): def test_call_get_callee_1_simple_match(fortran_reader): - """ + ''' Check that the right routine has been found for a single routine implementation. - """ - code = """ + ''' + code = ''' module some_mod implicit none contains @@ -709,7 +666,7 @@ def test_call_get_callee_1_simple_match(fortran_reader): integer :: a, b, c end subroutine -end module some_mod""" +end module some_mod''' psyir = fortran_reader.psyir_from_source(code) @@ -725,10 +682,10 @@ def test_call_get_callee_1_simple_match(fortran_reader): def test_call_get_callee_2_optional_args(fortran_reader): - """ + ''' Check that optional arguments have been correlated correctly. - """ - code = """ + ''' + code = ''' module some_mod implicit none contains @@ -744,7 +701,7 @@ def test_call_get_callee_2_optional_args(fortran_reader): integer, optional :: c end subroutine -end module some_mod""" +end module some_mod''' root_node: Node = fortran_reader.psyir_from_source(code) @@ -768,10 +725,10 @@ def test_call_get_callee_2_optional_args(fortran_reader): def test_call_get_callee_3_trigger_error(fortran_reader): - """ + ''' Test which is supposed to trigger an error. - """ - code = """ + ''' + code = ''' module some_mod implicit none contains @@ -786,7 +743,7 @@ def test_call_get_callee_3_trigger_error(fortran_reader): integer :: a, b end subroutine -end module some_mod""" +end module some_mod''' root_node: Node = fortran_reader.psyir_from_source(code) @@ -799,7 +756,7 @@ def test_call_get_callee_3_trigger_error(fortran_reader): call_foo: Call = routine_main.walk(Call)[0] assert call_foo.routine.name == "foo" - with pytest.raises(CallMatchingArgumentsNotFound) as err: + with pytest.raises(CallMatchingArgumentsNotFoundError) as err: call_foo.get_callee() assert ( @@ -809,10 +766,10 @@ def test_call_get_callee_3_trigger_error(fortran_reader): def test_call_get_callee_4_named_arguments(fortran_reader): - """ + ''' Check that named arguments have been correlated correctly - """ - code = """ + ''' + code = ''' module some_mod implicit none contains @@ -827,7 +784,7 @@ def test_call_get_callee_4_named_arguments(fortran_reader): integer :: a, b, c end subroutine -end module some_mod""" +end module some_mod''' root_node: Node = fortran_reader.psyir_from_source(code) @@ -852,11 +809,11 @@ def test_call_get_callee_4_named_arguments(fortran_reader): def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): - """ + ''' Check that optional and named arguments have been correlated correctly when the call is to a generic interface. - """ - code = """ + ''' + code = ''' module some_mod implicit none contains @@ -872,7 +829,7 @@ def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): integer, optional :: c end subroutine -end module some_mod""" +end module some_mod''' root_node: Node = fortran_reader.psyir_from_source(code) @@ -895,7 +852,7 @@ def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): assert result is routine_match -_code_test_get_callee_6 = """ +_code_test_get_callee_6 = ''' module some_mod implicit none @@ -952,13 +909,13 @@ def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): integer, optional :: c end subroutine -end module some_mod""" +end module some_mod''' def test_call_get_callee_6_interfaces_0_0(fortran_reader): - """ + ''' Check that optional and named arguments have been correlated correctly - """ + ''' root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -982,9 +939,9 @@ def test_call_get_callee_6_interfaces_0_0(fortran_reader): def test_call_get_callee_6_interfaces_0_1(fortran_reader): - """ + ''' Check that optional and named arguments have been correlated correctly - """ + ''' root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1009,9 +966,9 @@ def test_call_get_callee_6_interfaces_0_1(fortran_reader): def test_call_get_callee_6_interfaces_1_0(fortran_reader): - """ + ''' Check that optional and named arguments have been correlated correctly - """ + ''' root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1035,9 +992,9 @@ def test_call_get_callee_6_interfaces_1_0(fortran_reader): def test_call_get_callee_6_interfaces_1_1(fortran_reader): - """ + ''' Check that optional and named arguments have been correlated correctly - """ + ''' root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1062,9 +1019,9 @@ def test_call_get_callee_6_interfaces_1_1(fortran_reader): def test_call_get_callee_6_interfaces_1_2(fortran_reader): - """ + ''' Check that optional and named arguments have been correlated correctly - """ + ''' root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1089,9 +1046,9 @@ def test_call_get_callee_6_interfaces_1_2(fortran_reader): def test_call_get_callee_6_interfaces_2_0(fortran_reader): - """ + ''' Check that optional and named arguments have been correlated correctly - """ + ''' root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1121,9 +1078,9 @@ def test_call_get_callee_6_interfaces_2_0(fortran_reader): def test_call_get_callee_6_interfaces_2_1(fortran_reader): - """ + ''' Check that optional and named arguments have been correlated correctly - """ + ''' root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1147,9 +1104,9 @@ def test_call_get_callee_6_interfaces_2_1(fortran_reader): def test_call_get_callee_6_interfaces_2_2(fortran_reader): - """ + ''' Check that optional and named arguments have been correlated correctly - """ + ''' root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) @@ -1174,10 +1131,10 @@ def test_call_get_callee_6_interfaces_2_2(fortran_reader): def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): - """ + ''' Trigger error that matching arguments were not found - """ - code = """ + ''' + code = ''' module some_mod implicit none contains @@ -1194,7 +1151,7 @@ def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): integer :: a, b, c end subroutine -end module some_mod""" +end module some_mod''' psyir = fortran_reader.psyir_from_source(code) @@ -1203,7 +1160,7 @@ def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): call_foo: Call = routine_main.walk(Call)[0] - with pytest.raises(CallMatchingArgumentsNotFound) as err: + with pytest.raises(CallMatchingArgumentsNotFoundError) as err: call_foo.get_callee() assert ( @@ -1213,10 +1170,10 @@ def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): def test_call_get_callee_8_arguments_not_handled(fortran_reader): - """ + ''' Trigger error that matching arguments were not found - """ - code = """ + ''' + code = ''' module some_mod implicit none contains @@ -1232,7 +1189,7 @@ def test_call_get_callee_8_arguments_not_handled(fortran_reader): integer :: a, b, c end subroutine -end module some_mod""" +end module some_mod''' psyir = fortran_reader.psyir_from_source(code) @@ -1241,7 +1198,7 @@ def test_call_get_callee_8_arguments_not_handled(fortran_reader): call_foo: Call = routine_main.walk(Call)[0] - with pytest.raises(CallMatchingArgumentsNotFound) as err: + with pytest.raises(CallMatchingArgumentsNotFoundError) as err: call_foo.get_callee() assert ( @@ -1252,90 +1209,75 @@ def test_call_get_callee_8_arguments_not_handled(fortran_reader): @pytest.mark.usefixtures("clear_module_manager_instance") def test_call_get_callees_unresolved(fortran_reader, tmpdir, monkeypatch): - """ + ''' Test that get_callees() raises the expected error if the called routine is unresolved. - """ - code = """ + ''' + code = ''' subroutine top() call bottom() -end subroutine top""" +end subroutine top''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ( - "Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and there are " - "no wildcard imports." - in str(err.value) - ) + assert ("Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and there are " + "no wildcard imports." in str(err.value)) # Repeat but in the presence of a wildcard import. - code = """ + code = ''' subroutine top() use some_mod_somewhere call bottom() -end subroutine top""" +end subroutine top''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ( - "Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and attempted " - "to resolve the wildcard imports from ['some_mod_somewhere']. " - "However, failed to find the source for ['some_mod_somewhere']. " - "The module search path is set to []" - in str(err.value) - ) + assert ("Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and attempted " + "to resolve the wildcard imports from ['some_mod_somewhere']. " + "However, failed to find the source for ['some_mod_somewhere']. " + "The module search path is set to []" in str(err.value)) # Repeat but when some_mod_somewhere *is* resolved but doesn't help us # find the routine we're looking for. mod_manager = ModuleManager.get() monkeypatch.setattr(mod_manager, "_instance", None) path = str(tmpdir) - monkeypatch.setattr(Config.get(), "_include_paths", [path]) - with open( - os.path.join(path, "some_mod_somewhere.f90"), "w", encoding="utf-8" - ) as ofile: - ofile.write( - """\ + monkeypatch.setattr(Config.get(), '_include_paths', [path]) + with open(os.path.join(path, "some_mod_somewhere.f90"), "w", + encoding="utf-8") as ofile: + ofile.write('''\ module some_mod_somewhere end module some_mod_somewhere -""" - ) +''') with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ( - "Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and wildcard " - "imports from ['some_mod_somewhere']." - in str(err.value) - ) + assert ("Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and wildcard " + "imports from ['some_mod_somewhere']." in str(err.value)) mod_manager = ModuleManager.get() monkeypatch.setattr(mod_manager, "_instance", None) - code = """ + code = ''' subroutine top() use another_mod, only: this_one call this_one() -end subroutine top""" +end subroutine top''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ( - "RoutineSymbol 'this_one' is imported from Container 'another_mod'" - " but the source defining that container could not be found. The " - "module search path is set to [" - in str(err.value) - ) + assert ("RoutineSymbol 'this_one' is imported from Container 'another_mod'" + " but the source defining that container could not be found. The " + "module search path is set to [" in str(err.value)) def test_call_get_callees_interface(fortran_reader): - """ + ''' Check that get_callees() works correctly when the target of a call is actually a generic interface. - """ - code = """ + ''' + code = ''' module my_mod interface bottom @@ -1358,7 +1300,7 @@ def test_call_get_callees_interface(fortran_reader): luggage = luggage + 1.0 end subroutine rbottom end module my_mod -""" +''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] callees = call.get_callees() @@ -1370,13 +1312,13 @@ def test_call_get_callees_interface(fortran_reader): def test_call_get_callees_unsupported_type(fortran_reader): - """ + ''' Check that get_callees() raises the expected error when the called routine is of UnsupportedFortranType. This is hard to achieve so we have to manually construct some aspects of the test case. - """ - code = """ + ''' + code = ''' module my_mod integer, target :: value contains @@ -1389,7 +1331,7 @@ def test_call_get_callees_unsupported_type(fortran_reader): fval => value end function bottom end module my_mod -""" +''' psyir = fortran_reader.psyir_from_source(code) container = psyir.children[0] routine = container.find_routine_psyir("bottom") @@ -1406,19 +1348,16 @@ def test_call_get_callees_unsupported_type(fortran_reader): call = psyir.walk(Call)[0] with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ( - "RoutineSymbol 'bottom' exists in Container 'my_mod' but is of " - "UnsupportedFortranType" - in str(err.value) - ) + assert ("RoutineSymbol 'bottom' exists in Container 'my_mod' but is of " + "UnsupportedFortranType" in str(err.value)) def test_call_get_callees_file_container(fortran_reader): - """ + ''' Check that get_callees works if the called routine happens to be in file scope, even when there's no Container. - """ - code = """ + ''' + code = ''' subroutine top() integer :: luggage luggage = 0 @@ -1429,7 +1368,7 @@ def test_call_get_callees_file_container(fortran_reader): integer :: luggage luggage = luggage + 1 end subroutine bottom -""" +''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] result = call.get_callees() @@ -1439,13 +1378,13 @@ def test_call_get_callees_file_container(fortran_reader): def test_call_get_callees_no_container(fortran_reader): - """ + ''' Check that get_callees() raises the expected error when the Call is not within a Container and the target routine cannot be found. - """ + ''' # To avoid having the routine symbol immediately dismissed as # unresolved, the code that we initially process *does* have a Container. - code = """ + code = ''' module my_mod contains @@ -1460,7 +1399,7 @@ def test_call_get_callees_no_container(fortran_reader): luggage = luggage + 1 end subroutine bottom end module my_mod -""" +''' psyir = fortran_reader.psyir_from_source(code) top_routine = psyir.walk(Routine)[0] # Deliberately make the Routine node an orphan so there's no Container. @@ -1468,18 +1407,16 @@ def test_call_get_callees_no_container(fortran_reader): call = top_routine.walk(Call)[0] with pytest.raises(SymbolError) as err: _ = call.get_callees() - assert ( - "Failed to find a Routine named 'bottom' in code:\n'subroutine top()" - in str(err.value) - ) + assert ("Failed to find a Routine named 'bottom' in code:\n'subroutine " + "top()" in str(err.value)) def test_call_get_callees_wildcard_import_local_container(fortran_reader): - """ + ''' Check that get_callees() works successfully for a routine accessed via a wildcard import from another module in the same file. - """ - code = """ + ''' + code = ''' module some_mod contains subroutine just_do_it() @@ -1493,7 +1430,7 @@ def test_call_get_callees_wildcard_import_local_container(fortran_reader): call just_do_it() end subroutine run_it end module other_mod -""" +''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] routines = call.get_callees() @@ -1503,11 +1440,11 @@ def test_call_get_callees_wildcard_import_local_container(fortran_reader): def test_call_get_callees_import_local_container(fortran_reader): - """ + ''' Check that get_callees() works successfully for a routine accessed via a specific import from another module in the same file. - """ - code = """ + ''' + code = ''' module some_mod contains subroutine just_do_it() @@ -1521,7 +1458,7 @@ def test_call_get_callees_import_local_container(fortran_reader): call just_do_it() end subroutine run_it end module other_mod -""" +''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] routines = call.get_callees() @@ -1531,14 +1468,13 @@ def test_call_get_callees_import_local_container(fortran_reader): @pytest.mark.usefixtures("clear_module_manager_instance") -def test_call_get_callees_wildcard_import_container( - fortran_reader, tmpdir, monkeypatch -): - """ +def test_call_get_callees_wildcard_import_container(fortran_reader, + tmpdir, monkeypatch): + ''' Check that get_callees() works successfully for a routine accessed via a wildcard import from a module in another file. - """ - code = """ + ''' + code = ''' module other_mod use some_mod contains @@ -1546,34 +1482,29 @@ def test_call_get_callees_wildcard_import_container( call just_do_it() end subroutine run_it end module other_mod -""" +''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] # This should fail as it can't find the module. with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ( - "Failed to find the source code of the unresolved routine " - "'just_do_it' - looked at any routines in the same source file" - in str(err.value) - ) + assert ("Failed to find the source code of the unresolved routine " + "'just_do_it' - looked at any routines in the same source file" + in str(err.value)) # Create the module containing the subroutine definition, # write it to file and set the search path so that PSyclone can find it. path = str(tmpdir) - monkeypatch.setattr(Config.get(), "_include_paths", [path]) + monkeypatch.setattr(Config.get(), '_include_paths', [path]) - with open( - os.path.join(path, "some_mod.f90"), "w", encoding="utf-8" - ) as mfile: - mfile.write( - """\ + with open(os.path.join(path, "some_mod.f90"), + "w", encoding="utf-8") as mfile: + mfile.write('''\ module some_mod contains subroutine just_do_it() write(*,*) "hello" end subroutine just_do_it -end module some_mod""" - ) +end module some_mod''') routines = call.get_callees() assert len(routines) == 1 assert isinstance(routines[0], Routine) @@ -1581,10 +1512,10 @@ def test_call_get_callees_wildcard_import_container( def test_fn_call_get_callees(fortran_reader): - """ + ''' Test that get_callees() works for a function call. - """ - code = """ + ''' + code = ''' module some_mod implicit none integer :: luggage @@ -1599,7 +1530,7 @@ def test_fn_call_get_callees(fortran_reader): integer :: my_func my_func = 1 + val end function my_func -end module some_mod""" +end module some_mod''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] result = call.get_callees() @@ -1607,9 +1538,9 @@ def test_fn_call_get_callees(fortran_reader): def test_get_callees_code_block(fortran_reader): - """Test that get_callees() raises the expected error when the called - routine is in a CodeBlock.""" - code = """ + '''Test that get_callees() raises the expected error when the called + routine is in a CodeBlock.''' + code = ''' module some_mod implicit none integer :: luggage @@ -1623,24 +1554,22 @@ def test_get_callees_code_block(fortran_reader): integer, intent(in) :: val my_func = CMPLX(1 + val, 1.0) end function my_func -end module some_mod""" +end module some_mod''' psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[1] with pytest.raises(SymbolError) as err: _ = call.get_callees() - assert ( - "Failed to find a Routine named 'my_func' in Container 'some_mod'" - in str(err.value) - ) + assert ("Failed to find a Routine named 'my_func' in Container " + "'some_mod'" in str(err.value)) @pytest.mark.usefixtures("clear_module_manager_instance") def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): - """ + ''' Test that get_callees() follows imports to find the definition of the called routine. - """ - code = """ + ''' + code = ''' module some_mod use other_mod, only: pack_it implicit none @@ -1649,29 +1578,24 @@ def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): integer :: luggage = 0 call pack_it(luggage) end subroutine top -end module some_mod""" +end module some_mod''' # Create the module containing an import of the subroutine definition, # write it to file and set the search path so that PSyclone can find it. path = str(tmpdir) - monkeypatch.setattr(Config.get(), "_include_paths", [path]) + monkeypatch.setattr(Config.get(), '_include_paths', [path]) - with open( - os.path.join(path, "other_mod.f90"), "w", encoding="utf-8" - ) as mfile: - mfile.write( - """\ + with open(os.path.join(path, "other_mod.f90"), + "w", encoding="utf-8") as mfile: + mfile.write('''\ module other_mod use another_mod, only: pack_it contains end module other_mod - """ - ) + ''') # Finally, create the module containing the routine definition. - with open( - os.path.join(path, "another_mod.f90"), "w", encoding="utf-8" - ) as mfile: - mfile.write( - """\ + with open(os.path.join(path, "another_mod.f90"), + "w", encoding="utf-8") as mfile: + mfile.write('''\ module another_mod contains subroutine pack_it(arg) @@ -1679,8 +1603,7 @@ def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): arg = arg + 2 end subroutine pack_it end module another_mod - """ - ) + ''') psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] result = call.get_callees() @@ -1691,12 +1614,12 @@ def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): @pytest.mark.usefixtures("clear_module_manager_instance") def test_get_callees_import_private_clash(fortran_reader, tmpdir, monkeypatch): - """ + ''' Test that get_callees() raises the expected error if a module from which a routine is imported has a private shadow of that routine (and thus we don't know where to look for the target routine). - """ - code = """ + ''' + code = ''' module some_mod use other_mod, only: pack_it implicit none @@ -1705,18 +1628,16 @@ def test_get_callees_import_private_clash(fortran_reader, tmpdir, monkeypatch): integer :: luggage = 0 call pack_it(luggage) end subroutine top -end module some_mod""" +end module some_mod''' # Create the module containing a private routine with the name we are # searching for, write it to file and set the search path so that PSyclone # can find it. path = str(tmpdir) - monkeypatch.setattr(Config.get(), "_include_paths", [path]) + monkeypatch.setattr(Config.get(), '_include_paths', [path]) - with open( - os.path.join(path, "other_mod.f90"), "w", encoding="utf-8" - ) as mfile: - mfile.write( - """\ + with open(os.path.join(path, "other_mod.f90"), + "w", encoding="utf-8") as mfile: + mfile.write('''\ module other_mod use another_mod private pack_it @@ -1726,16 +1647,12 @@ def test_get_callees_import_private_clash(fortran_reader, tmpdir, monkeypatch): integer :: pack_it end function pack_it end module other_mod - """ - ) + ''') psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] with pytest.raises(NotImplementedError) as err: _ = call.get_callees() - assert ( - "RoutineSymbol 'pack_it' is imported from Container 'other_mod' " - "but that Container defines a private Symbol of the same name. " - "Searching for the Container that defines a public Routine with " - "that name is not yet supported - TODO #924" - in str(err.value) - ) + assert ("RoutineSymbol 'pack_it' is imported from Container 'other_mod' " + "but that Container defines a private Symbol of the same name. " + "Searching for the Container that defines a public Routine with " + "that name is not yet supported - TODO #924" in str(err.value)) diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index 3401ad13b8..c359c79433 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -34,7 +34,7 @@ # Author: A. R. Porter, STFC Daresbury Lab # Modified: R. W. Ford and S. Siso, STFC Daresbury Lab -"""This module tests the inlining transformation.""" +'''This module tests the inlining transformation.''' import os import pytest @@ -49,48 +49,42 @@ Routine, ) from psyclone.psyir.symbols import ( - AutomaticInterface, - DataSymbol, - UnresolvedType, -) -from psyclone.psyir.transformations import InlineTrans, TransformationError + AutomaticInterface, DataSymbol, UnresolvedType) +from psyclone.psyir.transformations import ( + InlineTrans, TransformationError) from psyclone.tests.utilities import Compile -MY_TYPE = ( - " integer, parameter :: ngrids = 10\n" - " type other_type\n" - " real, dimension(10) :: data\n" - " integer :: nx\n" - " end type other_type\n" - " type my_type\n" - " integer :: idx\n" - " real, dimension(10) :: data\n" - " real, dimension(5,10) :: data2d\n" - " type(other_type) :: local\n" - " end type my_type\n" - " type big_type\n" - " type(my_type) :: region\n" - " end type big_type\n" - " type vbig_type\n" - " type(big_type), dimension(ngrids) :: grids\n" - " end type vbig_type\n" -) +MY_TYPE = (" integer, parameter :: ngrids = 10\n" + " type other_type\n" + " real, dimension(10) :: data\n" + " integer :: nx\n" + " end type other_type\n" + " type my_type\n" + " integer :: idx\n" + " real, dimension(10) :: data\n" + " real, dimension(5,10) :: data2d\n" + " type(other_type) :: local\n" + " end type my_type\n" + " type big_type\n" + " type(my_type) :: region\n" + " end type big_type\n" + " type vbig_type\n" + " type(big_type), dimension(ngrids) :: grids\n" + " end type vbig_type\n") # init - def test_init(): - """Test an InlineTrans transformation can be successfully created.""" + '''Test an InlineTrans transformation can be successfully created.''' inline_trans = InlineTrans() assert isinstance(inline_trans, InlineTrans) # apply - def test_apply_empty_routine(fortran_reader, fortran_writer, tmpdir): - """Check that a call to an empty routine is simply removed.""" + '''Check that a call to an empty routine is simply removed.''' code = ( "module test_mod\n" "contains\n" @@ -102,20 +96,20 @@ def test_apply_empty_routine(fortran_reader, fortran_writer, tmpdir): " subroutine sub(idx)\n" " integer :: idx\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert " i = 10\n\n end subroutine run_it\n" in output + assert (" i = 10\n\n" + " end subroutine run_it\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_single_return(fortran_reader, fortran_writer, tmpdir): - """Check that a call to a routine containing only a return statement - is removed.""" + '''Check that a call to a routine containing only a return statement + is removed. ''' code = ( "module test_mod\n" "contains\n" @@ -128,20 +122,20 @@ def test_apply_single_return(fortran_reader, fortran_writer, tmpdir): " integer :: idx\n" " return\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert " i = 10\n\n end subroutine run_it\n" in output + assert (" i = 10\n\n" + " end subroutine run_it\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_return_then_cb(fortran_reader, fortran_writer, tmpdir): - """Check that a call to a routine containing a return statement followed - by a CodeBlock is removed.""" + '''Check that a call to a routine containing a return statement followed + by a CodeBlock is removed.''' code = ( "module test_mod\n" "contains\n" @@ -155,20 +149,20 @@ def test_apply_return_then_cb(fortran_reader, fortran_writer, tmpdir): " return\n" " write(*,*) idx\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert " i = 10\n\n end subroutine run_it\n" in output + assert (" i = 10\n\n" + " end subroutine run_it\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_array_arg(fortran_reader, fortran_writer, tmpdir): - """Check that the apply() method works correctly for a very simple - call to a routine with an array reference as argument.""" + ''' Check that the apply() method works correctly for a very simple + call to a routine with an array reference as argument. ''' code = ( "module test_mod\n" "contains\n" @@ -184,29 +178,25 @@ def test_apply_array_arg(fortran_reader, fortran_writer, tmpdir): " real, intent(inout) :: x\n" " x = 2.0*x\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert ( - " do i = 1, 10, 1\n" - " a(i) = 1.0\n" - " a(i) = 2.0 * a(i)\n" - " enddo\n" - in output - ) + assert (" do i = 1, 10, 1\n" + " a(i) = 1.0\n" + " a(i) = 2.0 * a(i)\n" + " enddo\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_array_access(fortran_reader, fortran_writer, tmpdir): - """ + ''' Check that the apply method works correctly when an array is passed into the routine and then indexed within it. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -225,31 +215,27 @@ def test_apply_array_access(fortran_reader, fortran_writer, tmpdir): " x(i) = 2.0*ivar\n" " end do\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert ( - " do i = 1, 10, 1\n" - " do i_1 = 1, 10, 1\n" - " a(i_1) = 2.0 * i\n" - " enddo\n" - in output - ) + assert (" do i = 1, 10, 1\n" + " do i_1 = 1, 10, 1\n" + " a(i_1) = 2.0 * i\n" + " enddo\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_gocean_kern(fortran_reader, fortran_writer, monkeypatch): - """ + ''' Test the apply method with a typical GOcean kernel. TODO #924 - currently this xfails because we don't resolve the type of the actual argument. - """ + ''' code = ( "module psy_single_invoke_test\n" " use field_mod, only: r2d_field\n" @@ -278,75 +264,66 @@ def test_apply_gocean_kern(fortran_reader, fortran_writer, monkeypatch): # Set up include_path to import the proper module src_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), - "../../../external/dl_esm_inf/finite_difference/src", - ) - monkeypatch.setattr(Config.get(), "_include_paths", [str(src_dir)]) + "../../../external/dl_esm_inf/finite_difference/src") + monkeypatch.setattr(Config.get(), '_include_paths', [str(src_dir)]) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) with pytest.raises(TransformationError) as err: inline_trans.apply(psyir.walk(Call)[0]) - if ( - "actual argument 'cu_fld' corresponding to an array formal " - "argument ('cu') is unknown" - in str(err.value) - ): + if ("actual argument 'cu_fld' corresponding to an array formal " + "argument ('cu') is unknown" in str(err.value)): pytest.xfail( "TODO #924 - extend validation to attempt to resolve type of " - "actual argument." - ) + "actual argument.") output = fortran_writer(psyir) - assert ( - " do j = cu_fld%internal%ystart, cu_fld%internal%ystop, 1\n" - " do i = cu_fld%internal%xstart, cu_fld%internal%xstop, 1\n" - " cu_fld%data(i,j) = 0.5d0 * (pf%data(i,j) + " - "pf%data(i - 1,j)) * u_fld%data(i,j)\n" - " enddo\n" - " enddo\n" - in output - ) + assert (" do j = cu_fld%internal%ystart, cu_fld%internal%ystop, 1\n" + " do i = cu_fld%internal%xstart, cu_fld%internal%xstop, 1\n" + " cu_fld%data(i,j) = 0.5d0 * (pf%data(i,j) + " + "pf%data(i - 1,j)) * u_fld%data(i,j)\n" + " enddo\n" + " enddo\n" in output) def test_apply_struct_arg(fortran_reader, fortran_writer, tmpdir): - """ + ''' Check that the apply() method works correctly when the routine argument is a StructureReference containing an ArrayMember which is accessed inside the routine. - """ + ''' code = ( - "module test_mod\n" + f"module test_mod\n" f"{MY_TYPE}" - "contains\n" - " subroutine run_it()\n" - " integer :: i\n" - " type(my_type) :: var\n" - " type(my_type) :: var_list(10)\n" - " type(big_type) :: var2(5)\n" - " do i=1,5\n" - " call sub(var, i)\n" - " call sub(var_list(i), i)\n" - " call sub(var2(i)%region, i)\n" - " call sub2(var2)\n" - " end do\n" - " end subroutine run_it\n" - " subroutine sub(x, ivar)\n" - " type(my_type), intent(inout) :: x\n" - " integer, intent(in) :: ivar\n" - " integer :: i\n" - " do i = 1, 10\n" - " x%data(i) = 2.0*ivar\n" - " end do\n" - " x%data(:) = -1.0\n" - " x%data = -5.0\n" - " x%data(1:2) = 0.0\n" - " end subroutine sub\n" - " subroutine sub2(x)\n" - " type(big_type), dimension(:), intent(inout) :: x\n" - " x(:)%region%local%nx = 0\n" - " end subroutine sub2\n" - "end module test_mod\n" - ) + f"contains\n" + f" subroutine run_it()\n" + f" integer :: i\n" + f" type(my_type) :: var\n" + f" type(my_type) :: var_list(10)\n" + f" type(big_type) :: var2(5)\n" + f" do i=1,5\n" + f" call sub(var, i)\n" + f" call sub(var_list(i), i)\n" + f" call sub(var2(i)%region, i)\n" + f" call sub2(var2)\n" + f" end do\n" + f" end subroutine run_it\n" + f" subroutine sub(x, ivar)\n" + f" type(my_type), intent(inout) :: x\n" + f" integer, intent(in) :: ivar\n" + f" integer :: i\n" + f" do i = 1, 10\n" + f" x%data(i) = 2.0*ivar\n" + f" end do\n" + f" x%data(:) = -1.0\n" + f" x%data = -5.0\n" + f" x%data(1:2) = 0.0\n" + f" end subroutine sub\n" + f" subroutine sub2(x)\n" + f" type(big_type), dimension(:), intent(inout) :: x\n" + f" x(:)%region%local%nx = 0\n" + f" end subroutine sub2\n" + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) @@ -354,40 +331,37 @@ def test_apply_struct_arg(fortran_reader, fortran_writer, tmpdir): inline_trans.apply(routine) output = fortran_writer(psyir) - assert ( - " do i = 1, 5, 1\n" - " do i_1 = 1, 10, 1\n" - " var%data(i_1) = 2.0 * i\n" - " enddo\n" - " var%data(:) = -1.0\n" - " var%data = -5.0\n" - " var%data(1:2) = 0.0\n" - " do i_2 = 1, 10, 1\n" - " var_list(i)%data(i_2) = 2.0 * i\n" - " enddo\n" - " var_list(i)%data(:) = -1.0\n" - " var_list(i)%data = -5.0\n" - " var_list(i)%data(1:2) = 0.0\n" - " do i_3 = 1, 10, 1\n" - " var2(i)%region%data(i_3) = 2.0 * i\n" - " enddo\n" - " var2(i)%region%data(:) = -1.0\n" - " var2(i)%region%data = -5.0\n" - " var2(i)%region%data(1:2) = 0.0\n" - " var2(1:5)%region%local%nx = 0\n" - " enddo\n" - in output - ) + assert (" do i = 1, 5, 1\n" + " do i_1 = 1, 10, 1\n" + " var%data(i_1) = 2.0 * i\n" + " enddo\n" + " var%data(:) = -1.0\n" + " var%data = -5.0\n" + " var%data(1:2) = 0.0\n" + " do i_2 = 1, 10, 1\n" + " var_list(i)%data(i_2) = 2.0 * i\n" + " enddo\n" + " var_list(i)%data(:) = -1.0\n" + " var_list(i)%data = -5.0\n" + " var_list(i)%data(1:2) = 0.0\n" + " do i_3 = 1, 10, 1\n" + " var2(i)%region%data(i_3) = 2.0 * i\n" + " enddo\n" + " var2(i)%region%data(:) = -1.0\n" + " var2(i)%region%data = -5.0\n" + " var2(i)%region%data(1:2) = 0.0\n" + " var2(1:5)%region%local%nx = 0\n" + " enddo\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_unresolved_struct_arg(fortran_reader, fortran_writer): - """ + ''' Check that we handle acceptable cases of the type of an argument being unresolved but that we reject the case where we can't be sure of the array indexing. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -422,8 +396,7 @@ def test_apply_unresolved_struct_arg(fortran_reader, fortran_writer): " type(mystery_type), dimension(3:5), intent(inout) :: x\n" " x(:)%region%local%nx = 0\n" " end subroutine sub4\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) @@ -433,70 +406,60 @@ def test_apply_unresolved_struct_arg(fortran_reader, fortran_writer): # Second one should fail. with pytest.raises(TransformationError) as err: inline_trans.apply(calls[1]) - assert ( - "Routine 'sub3' cannot be inlined because the type of the actual " - "argument 'mystery' corresponding to an array formal argument " - "('x') is unknown" - in str(err.value) - ) + assert ("Routine 'sub3' cannot be inlined because the type of the actual " + "argument 'mystery' corresponding to an array formal argument " + "('x') is unknown" in str(err.value)) # Third one should be fine because it is a scalar argument. inline_trans.apply(calls[2]) # We can't do the fourth one. with pytest.raises(TransformationError) as err: inline_trans.apply(calls[3]) - assert ( - "Routine 'sub4' cannot be inlined because the type of the actual " - "argument 'mystery' corresponding to an array formal argument " - "('x') is unknown." - in str(err.value) - ) + assert ("Routine 'sub4' cannot be inlined because the type of the actual " + "argument 'mystery' corresponding to an array formal argument " + "('x') is unknown." in str(err.value)) output = fortran_writer(psyir) - assert ( - " varr(1:5)%region%local%nx = 0\n" - " call sub3(mystery)\n" - " mystery%flag = 1\n" - " call sub4(mystery)\n" - in output - ) + assert (" varr(1:5)%region%local%nx = 0\n" + " call sub3(mystery)\n" + " mystery%flag = 1\n" + " call sub4(mystery)\n" in output) def test_apply_struct_slice_arg(fortran_reader, fortran_writer, tmpdir): - """ + ''' Check that the apply() method works correctly when there are slices in structure accesses in both the actual and formal arguments. - """ + ''' code = ( - "module test_mod\n" + f"module test_mod\n" f"{MY_TYPE}" - "contains\n" - " subroutine run_it()\n" - " integer :: i\n" - " type(my_type) :: var_list(10)\n" - " type(vbig_type), dimension(5) :: cvar\n" - " call sub(var_list(:)%local%nx, i)\n" - " call sub2(var_list(:), 1, 1)\n" - " call sub2(var_list(:), i, i+2)\n" - " call sub3(cvar)\n" - " end subroutine run_it\n" - " subroutine sub(ix, indx)\n" - " integer, dimension(:) :: ix\n" - " integer, intent(in) :: indx\n" - " ix(:) = ix(:) + 1\n" - " end subroutine sub\n" - " subroutine sub2(x, start, stop)\n" - " type(my_type), dimension(:) :: x\n" - " integer :: start, stop\n" - " x(:)%data(2) = 0.0\n" - " x(:)%local%nx = 4\n" - " x(start:stop+1)%local%nx = -2\n" - " end subroutine sub2\n" - " subroutine sub3(y)\n" - " type(vbig_type), dimension(:) :: y\n" - " y(2)%grids(2)%region%data(:) = 0.0\n" - " end subroutine sub3\n" - "end module test_mod\n" - ) + f"contains\n" + f" subroutine run_it()\n" + f" integer :: i\n" + f" type(my_type) :: var_list(10)\n" + f" type(vbig_type), dimension(5) :: cvar\n" + f" call sub(var_list(:)%local%nx, i)\n" + f" call sub2(var_list(:), 1, 1)\n" + f" call sub2(var_list(:), i, i+2)\n" + f" call sub3(cvar)\n" + f" end subroutine run_it\n" + f" subroutine sub(ix, indx)\n" + f" integer, dimension(:) :: ix\n" + f" integer, intent(in) :: indx\n" + f" ix(:) = ix(:) + 1\n" + f" end subroutine sub\n" + f" subroutine sub2(x, start, stop)\n" + f" type(my_type), dimension(:) :: x\n" + f" integer :: start, stop\n" + f" x(:)%data(2) = 0.0\n" + f" x(:)%local%nx = 4\n" + f" x(start:stop+1)%local%nx = -2\n" + f" end subroutine sub2\n" + f" subroutine sub3(y)\n" + f" type(vbig_type), dimension(:) :: y\n" + f" y(2)%grids(2)%region%data(:) = 0.0\n" + f" end subroutine sub3\n" + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) @@ -511,32 +474,30 @@ def test_apply_struct_slice_arg(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) -def test_apply_struct_local_limits_caller( - fortran_reader, fortran_writer, tmpdir -): - """ +def test_apply_struct_local_limits_caller(fortran_reader, fortran_writer, + tmpdir): + ''' Test the apply() method when there are array bounds specified in the caller. - """ + ''' code = ( - "module test_mod\n" + f"module test_mod\n" f"{MY_TYPE}" - "contains\n" - " subroutine run_it()\n" - " integer :: i\n" - " type(my_type) :: var_list(10)\n" - " call sub2(var_list(3:7), 5, 6)\n" - " end subroutine run_it\n" - " subroutine sub2(x, start, stop)\n" - " type(my_type), dimension(:) :: x\n" - " integer :: start, stop\n" - " x(:)%data(2) = 1.0\n" - " x(:)%local%nx = 3\n" - " x(start:stop+1)%local%nx = -2\n" - " end subroutine sub2\n" - "end module test_mod\n" - ) + f"contains\n" + f" subroutine run_it()\n" + f" integer :: i\n" + f" type(my_type) :: var_list(10)\n" + f" call sub2(var_list(3:7), 5, 6)\n" + f" end subroutine run_it\n" + f" subroutine sub2(x, start, stop)\n" + f" type(my_type), dimension(:) :: x\n" + f" integer :: start, stop\n" + f" x(:)%data(2) = 1.0\n" + f" x(:)%local%nx = 3\n" + f" x(start:stop+1)%local%nx = -2\n" + f" end subroutine sub2\n" + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) @@ -550,40 +511,38 @@ def test_apply_struct_local_limits_caller( assert Compile(tmpdir).string_compiles(output) -def test_apply_struct_local_limits_caller_decln( - fortran_reader, fortran_writer, tmpdir -): - """ +def test_apply_struct_local_limits_caller_decln(fortran_reader, fortran_writer, + tmpdir): + ''' Test the apply() method when there are non-default array bounds specified in the declaration at the call site. - """ + ''' code = ( - "module test_mod\n" + f"module test_mod\n" f"{MY_TYPE}" - "contains\n" - " subroutine run_it()\n" - " integer :: i\n" - " type(my_type), dimension(2:9) :: varat2\n" - " real, dimension(4:8) :: varat3\n" - " call sub2(varat2(:), 5, 6)\n" - " call sub2(varat2(3:8), 5, 6)\n" - " call sub3(varat3(5:6))\n" - " call sub3(varat3)\n" - " end subroutine run_it\n" - " subroutine sub2(x, start, stop)\n" - " type(my_type), dimension(:) :: x\n" - " integer :: start, stop\n" - " x(:)%data(2) = 1.0\n" - " x(:)%local%nx = 3\n" - " x(start:stop+1)%local%nx = -2\n" - " end subroutine sub2\n" - " subroutine sub3(x)\n" - " real, dimension(:) :: x\n" - " x(1:2) = 4.0\n" - " end subroutine sub3\n" - "end module test_mod\n" - ) + f"contains\n" + f" subroutine run_it()\n" + f" integer :: i\n" + f" type(my_type), dimension(2:9) :: varat2\n" + f" real, dimension(4:8) :: varat3\n" + f" call sub2(varat2(:), 5, 6)\n" + f" call sub2(varat2(3:8), 5, 6)\n" + f" call sub3(varat3(5:6))\n" + f" call sub3(varat3)\n" + f" end subroutine run_it\n" + f" subroutine sub2(x, start, stop)\n" + f" type(my_type), dimension(:) :: x\n" + f" integer :: start, stop\n" + f" x(:)%data(2) = 1.0\n" + f" x(:)%local%nx = 3\n" + f" x(start:stop+1)%local%nx = -2\n" + f" end subroutine sub2\n" + f" subroutine sub3(x)\n" + f" real, dimension(:) :: x\n" + f" x(1:2) = 4.0\n" + f" end subroutine sub3\n" + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) @@ -606,14 +565,13 @@ def test_apply_struct_local_limits_caller_decln( assert Compile(tmpdir).string_compiles(output) -def test_apply_struct_local_limits_routine( - fortran_reader, fortran_writer, tmpdir -): - """ +def test_apply_struct_local_limits_routine(fortran_reader, fortran_writer, + tmpdir): + ''' Test the apply() method when there are non-default array bounds specified in the declaration within the called routine. - """ + ''' code = ( f"module test_mod\n" f"{MY_TYPE}" @@ -638,8 +596,7 @@ def test_apply_struct_local_limits_routine( f" y(start:stop+1)%local%nx = -3\n" f" z(start+1) = 8.0\n" f" end subroutine sub3\n" - f"end module test_mod\n" - ) + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) @@ -672,13 +629,13 @@ def test_apply_struct_local_limits_routine( def test_apply_array_limits_are_formal_args(fortran_reader, fortran_writer): - """ + ''' Check that apply() correctly handles the case where the start/stop values of an array formal argument are given in terms of other formal arguments. - """ - code = """ + ''' + code = ''' module test_mod implicit none contains @@ -694,7 +651,7 @@ def test_apply_array_limits_are_formal_args(fortran_reader, fortran_writer): var(start+1) = 5.0 end subroutine end module test_mod -""" +''' psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) @@ -705,12 +662,12 @@ def test_apply_array_limits_are_formal_args(fortran_reader, fortran_writer): def test_apply_allocatable_array_arg(fortran_reader, fortran_writer): - """ + ''' Check that apply() works correctly when a formal argument is given the ALLOCATABLE attribute (meaning that the bounds of the formal argument are those of the actual argument). - """ + ''' code = ( "module test_mod\n" " type my_type\n" @@ -741,7 +698,7 @@ def test_apply_allocatable_array_arg(fortran_reader, fortran_writer): " x(ji+2,jj+1) = -1.0\n" " end subroutine sub1\n" "end module test_mod\n" - ) + ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) @@ -760,11 +717,11 @@ def test_apply_allocatable_array_arg(fortran_reader, fortran_writer): def test_apply_array_slice_arg(fortran_reader, fortran_writer, tmpdir): - """ + ''' Check that the apply() method works correctly when an array slice is passed to a routine and then accessed within it. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -802,62 +759,57 @@ def test_apply_array_slice_arg(fortran_reader, fortran_writer, tmpdir): " x(i,:) = 2.0 * x(i,:)\n" " end do\n" " end subroutine sub2a\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) for call in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(call) output = fortran_writer(psyir) - assert ( - " do i = 1, 10, 1\n" - " do i_1 = 1, 10, 1\n" - " a(1,i_1,i) = 2.0 * i_1\n" - " enddo\n" - " enddo\n" - " a(1,1,:) = 3.0 * a(1,1,:)\n" - " a(:,1,:) = 2.0 * a(:,1,:)\n" - " b(:,:) = 2.0 * b(:,:)\n" - " do i_4 = 1, 10, 1\n" - " b(i_4,:5) = 2.0 * b(i_4,:5)\n" - in output - ) + assert (" do i = 1, 10, 1\n" + " do i_1 = 1, 10, 1\n" + " a(1,i_1,i) = 2.0 * i_1\n" + " enddo\n" + " enddo\n" + " a(1,1,:) = 3.0 * a(1,1,:)\n" + " a(:,1,:) = 2.0 * a(:,1,:)\n" + " b(:,:) = 2.0 * b(:,:)\n" + " do i_4 = 1, 10, 1\n" + " b(i_4,:5) = 2.0 * b(i_4,:5)\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_struct_array_arg(fortran_reader, fortran_writer, tmpdir): - """Check that apply works correctly when the actual argument is an - array element within a structure.""" + '''Check that apply works correctly when the actual argument is an + array element within a structure.''' code = ( - "module test_mod\n" + f"module test_mod\n" f"{MY_TYPE}" - "contains\n" - " subroutine run_it()\n" - " integer :: i, ig\n" - " real :: a(10)\n" - " type(my_type) :: grid\n" - " type(my_type), dimension(5) :: grid_list\n" - " grid%data(:) = 1.0\n" - " do i=1,10\n" - " a(i) = 1.0\n" - " call sub(grid%data(i))\n" - " end do\n" - " do i=1,10\n" - " ig = min(i, 5)\n" - " call sub(grid_list(ig)%data(i))\n" - " end do\n" - " do i=1,10\n" - " ig = min(i, 5)\n" - " call sub(grid_list(ig)%local%data(i))\n" - " end do\n" - " end subroutine run_it\n" - " subroutine sub(x)\n" - " real, intent(inout) :: x\n" - " x = 2.0*x\n" - " end subroutine sub\n" - "end module test_mod\n" - ) + f"contains\n" + f" subroutine run_it()\n" + f" integer :: i, ig\n" + f" real :: a(10)\n" + f" type(my_type) :: grid\n" + f" type(my_type), dimension(5) :: grid_list\n" + f" grid%data(:) = 1.0\n" + f" do i=1,10\n" + f" a(i) = 1.0\n" + f" call sub(grid%data(i))\n" + f" end do\n" + f" do i=1,10\n" + f" ig = min(i, 5)\n" + f" call sub(grid_list(ig)%data(i))\n" + f" end do\n" + f" do i=1,10\n" + f" ig = min(i, 5)\n" + f" call sub(grid_list(ig)%local%data(i))\n" + f" end do\n" + f" end subroutine run_it\n" + f" subroutine sub(x)\n" + f" real, intent(inout) :: x\n" + f" x = 2.0*x\n" + f" end subroutine sub\n" + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) loops = psyir.walk(Loop) inline_trans = InlineTrans() @@ -866,69 +818,58 @@ def test_apply_struct_array_arg(fortran_reader, fortran_writer, tmpdir): inline_trans.apply(loops[1].loop_body.children[1]) inline_trans.apply(loops[2].loop_body.children[1]) output = fortran_writer(psyir).lower() - assert ( - " do i = 1, 10, 1\n" - " a(i) = 1.0\n" - " grid%data(i) = 2.0 * grid%data(i)\n" - " enddo\n" - in output - ) - assert ( - " do i = 1, 10, 1\n" - " ig = min(i, 5)\n" - " grid_list(ig)%data(i) = 2.0 * grid_list(ig)%data(i)\n" - " enddo\n" - in output - ) - assert ( - " do i = 1, 10, 1\n" - " ig = min(i, 5)\n" - " grid_list(ig)%local%data(i) = 2.0 * " - "grid_list(ig)%local%data(i)\n" - " enddo\n" - in output - ) + assert (" do i = 1, 10, 1\n" + " a(i) = 1.0\n" + " grid%data(i) = 2.0 * grid%data(i)\n" + " enddo\n" in output) + assert (" do i = 1, 10, 1\n" + " ig = min(i, 5)\n" + " grid_list(ig)%data(i) = 2.0 * grid_list(ig)%data(i)\n" + " enddo\n" in output) + assert (" do i = 1, 10, 1\n" + " ig = min(i, 5)\n" + " grid_list(ig)%local%data(i) = 2.0 * " + "grid_list(ig)%local%data(i)\n" + " enddo\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_struct_array_slice_arg(fortran_reader, fortran_writer, tmpdir): - """Check that apply works correctly when the actual argument is an - array slice within a structure.""" + '''Check that apply works correctly when the actual argument is an + array slice within a structure.''' code = ( - "module test_mod\n" + f"module test_mod\n" f"{MY_TYPE}" - "contains\n" - " subroutine run_it()\n" - " integer :: i\n" - " real :: a(10)\n" - " type(my_type) :: grid\n" - " type(vbig_type) :: micah\n" - " grid%data(:) = 1.0\n" - " grid%data2d(:,:) = 1.0\n" - " do i=1,10\n" - " a(i) = 1.0\n" - " call sub(micah%grids(3)%region%data(:))\n" - " call sub(grid%data2d(:,i))\n" - " call sub(grid%data2d(1:5,i))\n" - " call sub(grid%local%data)\n" - " end do\n" - " end subroutine run_it\n" - " subroutine sub(x)\n" - " real, dimension(:), intent(inout) :: x\n" - " integer ji\n" - " do ji = 1, 5\n" - " x(ji) = 2.0*x(ji)\n" - " end do\n" - " x(1:2) = 0.0\n" - " x(:) = 3.0\n" - " x = 5.0\n" - " end subroutine sub\n" - "end module test_mod\n" - ) + f"contains\n" + f" subroutine run_it()\n" + f" integer :: i\n" + f" real :: a(10)\n" + f" type(my_type) :: grid\n" + f" type(vbig_type) :: micah\n" + f" grid%data(:) = 1.0\n" + f" grid%data2d(:,:) = 1.0\n" + f" do i=1,10\n" + f" a(i) = 1.0\n" + f" call sub(micah%grids(3)%region%data(:))\n" + f" call sub(grid%data2d(:,i))\n" + f" call sub(grid%data2d(1:5,i))\n" + f" call sub(grid%local%data)\n" + f" end do\n" + f" end subroutine run_it\n" + f" subroutine sub(x)\n" + f" real, dimension(:), intent(inout) :: x\n" + f" integer ji\n" + f" do ji = 1, 5\n" + f" x(ji) = 2.0*x(ji)\n" + f" end do\n" + f" x(1:2) = 0.0\n" + f" x(:) = 3.0\n" + f" x = 5.0\n" + f" end subroutine sub\n" + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) - inline_trans.set_option(check_matching_arguments_of_callee=False) for call in psyir.walk(Call): if not isinstance(call, IntrinsicCall): if call.arguments[0].debug_string() == "grid%local%data": @@ -937,102 +878,92 @@ def test_apply_struct_array_slice_arg(fortran_reader, fortran_writer, tmpdir): continue inline_trans.apply(call) output = fortran_writer(psyir) - assert ( - " do i = 1, 10, 1\n" - " a(i) = 1.0\n" - " do ji = 1, 5, 1\n" - " micah%grids(3)%region%data(ji) = 2.0 * " - "micah%grids(3)%region%data(ji)\n" - " enddo\n" - " micah%grids(3)%region%data(1:2) = 0.0\n" - " micah%grids(3)%region%data(:) = 3.0\n" - " micah%grids(3)%region%data(:) = 5.0\n" - " do ji_1 = 1, 5, 1\n" - " grid%data2d(ji_1,i) = 2.0 * grid%data2d(ji_1,i)\n" - " enddo\n" - " grid%data2d(1:2,i) = 0.0\n" - " grid%data2d(:,i) = 3.0\n" - " grid%data2d(:,i) = 5.0\n" - " do ji_2 = 1, 5, 1\n" - " grid%data2d(ji_2,i) = 2.0 * grid%data2d(ji_2,i)\n" - " enddo\n" - " grid%data2d(1:2,i) = 0.0\n" - " grid%data2d(1:5,i) = 3.0\n" - " grid%data2d(1:5,i) = 5.0\n" - # TODO #1858: replace the following line with the commented-out - # lines below. - " call sub(grid%local%data)\n" - # " do ji_3 = 1, 5, 1\n" - # " grid%local%data(ji_3) = 2.0 * grid%local%data(ji_3)\n" - # " enddo\n" - # " grid%local%data(1:2) = 0.0\n" - # " grid%local%data(:) = 3.0\n" - # " grid%local%data = 5.0\n" - " enddo\n" in output - ) + assert (" do i = 1, 10, 1\n" + " a(i) = 1.0\n" + " do ji = 1, 5, 1\n" + " micah%grids(3)%region%data(ji) = 2.0 * " + "micah%grids(3)%region%data(ji)\n" + " enddo\n" + " micah%grids(3)%region%data(1:2) = 0.0\n" + " micah%grids(3)%region%data(:) = 3.0\n" + " micah%grids(3)%region%data(:) = 5.0\n" + " do ji_1 = 1, 5, 1\n" + " grid%data2d(ji_1,i) = 2.0 * grid%data2d(ji_1,i)\n" + " enddo\n" + " grid%data2d(1:2,i) = 0.0\n" + " grid%data2d(:,i) = 3.0\n" + " grid%data2d(:,i) = 5.0\n" + " do ji_2 = 1, 5, 1\n" + " grid%data2d(ji_2,i) = 2.0 * grid%data2d(ji_2,i)\n" + " enddo\n" + " grid%data2d(1:2,i) = 0.0\n" + " grid%data2d(1:5,i) = 3.0\n" + " grid%data2d(1:5,i) = 5.0\n" + # TODO #1858: replace the following line with the commented-out + # lines below. + " call sub(grid%local%data)\n" + # " do ji_3 = 1, 5, 1\n" + # " grid%local%data(ji_3) = 2.0 * grid%local%data(ji_3)\n" + # " enddo\n" + # " grid%local%data(1:2) = 0.0\n" + # " grid%local%data(:) = 3.0\n" + # " grid%local%data = 5.0\n" + " enddo\n" in output) assert Compile(tmpdir).string_compiles(output) @pytest.mark.parametrize("type_decln", [MY_TYPE, " use some_mod\n"]) -def test_apply_struct_array( - fortran_reader, fortran_writer, tmpdir, type_decln -): - """Test that apply works correctly when the formal argument is an +def test_apply_struct_array(fortran_reader, fortran_writer, tmpdir, + type_decln): + '''Test that apply works correctly when the formal argument is an array of structures. We test both when the type of the structure is resolved and when it isn't. In the latter case we cannot perform inlining because we don't know the array bounds at the call site. - """ + ''' code = ( - "module test_mod\n" + f"module test_mod\n" f"{type_decln}" - "contains\n" - " subroutine run_it()\n" - " integer :: i\n" - " real :: a(10)\n" - " type(my_type) :: grid\n" - " type(vbig_type) :: micah\n" - " call sub(micah%grids(:))\n" - " end subroutine run_it\n" - " subroutine sub(x)\n" - " type(big_type), dimension(2:4) :: x\n" - " integer ji\n" - " ji = 2\n" - " x(:)%region%idx = 3.0\n" - " x(ji)%region%idx = 2.0\n" - " end subroutine sub\n" - "end module test_mod\n" - ) + f"contains\n" + f" subroutine run_it()\n" + f" integer :: i\n" + f" real :: a(10)\n" + f" type(my_type) :: grid\n" + f" type(vbig_type) :: micah\n" + f" call sub(micah%grids(:))\n" + f" end subroutine run_it\n" + f" subroutine sub(x)\n" + f" type(big_type), dimension(2:4) :: x\n" + f" integer ji\n" + f" ji = 2\n" + f" x(:)%region%idx = 3.0\n" + f" x(ji)%region%idx = 2.0\n" + f" end subroutine sub\n" + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) if "use some_mod" in type_decln: with pytest.raises(TransformationError) as err: inline_trans.apply(psyir.walk(Call)[0]) - assert ( - "Routine 'sub' cannot be inlined because the type of the " - "actual argument 'micah' corresponding to an array formal " - "argument ('x') is unknown." - in str(err.value) - ) + assert ("Routine 'sub' cannot be inlined because the type of the " + "actual argument 'micah' corresponding to an array formal " + "argument ('x') is unknown." in str(err.value)) else: inline_trans.apply(psyir.walk(Call)[0]) output = fortran_writer(psyir) - assert ( - " ji = 2\n" - " micah%grids(2 - 2 + 1:4 - 2 + 1)%region%idx = 3.0\n" - " micah%grids(ji - 2 + 1)%region%idx = 2.0\n" - in output - ) + assert (" ji = 2\n" + " micah%grids(2 - 2 + 1:4 - 2 + 1)%region%idx = 3.0\n" + " micah%grids(ji - 2 + 1)%region%idx = 2.0\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_repeated_module_use(fortran_reader, fortran_writer): - """ + ''' Check that any module use statements are not duplicated when multiple calls are inlined. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -1054,8 +985,7 @@ def test_apply_repeated_module_use(fortran_reader, fortran_writer): " real, intent(inout), dimension(10) :: x\n" " x(:) = 4*radius\n" " end subroutine sub2\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) @@ -1064,24 +994,18 @@ def test_apply_repeated_module_use(fortran_reader, fortran_writer): output = fortran_writer(psyir) # Check container symbol has not been renamed. assert "use model_mod_1" not in output - assert ( - " subroutine run_it()\n" - " use model_mod, only : radius\n" - " integer :: i\n" - in output - ) - assert ( - " do i = 1, 10, 1\n" - " a(:,i) = 4 * radius\n" - " enddo\n" - " b(:,2) = radius\n" - in output - ) + assert (" subroutine run_it()\n" + " use model_mod, only : radius\n" + " integer :: i\n" in output) + assert (" do i = 1, 10, 1\n" + " a(:,i) = 4 * radius\n" + " enddo\n" + " b(:,2) = radius\n" in output) def test_apply_name_clash(fortran_reader, fortran_writer, tmpdir): - """Check that apply() correctly handles the case where a symbol - in the routine to be in-lined clashes with an existing symbol.""" + ''' Check that apply() correctly handles the case where a symbol + in the routine to be in-lined clashes with an existing symbol. ''' code = ( "module test_mod\n" "contains\n" @@ -1098,23 +1022,22 @@ def test_apply_name_clash(fortran_reader, fortran_writer, tmpdir): " i = 3.0\n" " x = 2.0*x + i\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert ( - " i = 10\n y = 1.0\n i_1 = 3.0\n y = 2.0 * y + i_1\n" - in output - ) + assert (" i = 10\n" + " y = 1.0\n" + " i_1 = 3.0\n" + " y = 2.0 * y + i_1\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_imported_symbols(fortran_reader, fortran_writer): - """Test that the apply method correctly handles imported symbols in the - routine being inlined.""" + '''Test that the apply method correctly handles imported symbols in the + routine being inlined. ''' code = ( "module test_mod\n" "contains\n" @@ -1128,27 +1051,23 @@ def test_apply_imported_symbols(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = 3*var2\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert ( - " subroutine run_it()\n" - " use some_mod, only : var2\n" - " integer :: i\n\n" - " i = 10\n" - " i = 3 * var2\n" - in output - ) + assert (" subroutine run_it()\n" + " use some_mod, only : var2\n" + " integer :: i\n\n" + " i = 10\n" + " i = 3 * var2\n" in output) # We can't check this with compilation because of the import of some_mod. def test_apply_last_stmt_is_return(fortran_reader, fortran_writer, tmpdir): - """Test that the apply method correctly omits any final 'return' - statement that may be present in the routine to be inlined.""" + '''Test that the apply method correctly omits any final 'return' + statement that may be present in the routine to be inlined.''' code = ( "module test_mod\n" "contains\n" @@ -1162,20 +1081,21 @@ def test_apply_last_stmt_is_return(fortran_reader, fortran_writer, tmpdir): " idx = idx + 3\n" " return\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) - assert " i = 10\n i = i + 3\n\n end subroutine run_it\n" in output + assert (" i = 10\n" + " i = i + 3\n\n" + " end subroutine run_it\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_call_args(fortran_reader, fortran_writer): - """Check that apply works correctly if any of the actual - arguments are not simple references.""" + '''Check that apply works correctly if any of the actual + arguments are not simple references.''' code = ( "module test_mod\n" " use kinds_mod, only: i_def\n" @@ -1191,24 +1111,22 @@ def test_apply_call_args(fortran_reader, fortran_writer): " integer(kind=i_def), intent(in) :: incr2\n" " idx = idx + incr1 * incr2\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert ( - " i = 10\n i = i + 2 * i * 5_i_def\n\n end subroutine run_it\n" - in output - ) + assert (" i = 10\n" + " i = i + 2 * i * 5_i_def\n\n" + " end subroutine run_it\n" in output) # Cannot test for compilation because of 'kinds_mod'. def test_apply_duplicate_imports(fortran_reader, fortran_writer): - """Check that apply works correctly when the routine to be inlined + '''Check that apply works correctly when the routine to be inlined imports symbols from a container that is also accessed in the - calling routine.""" + calling routine.''' code = ( "module test_mod\n" "contains\n" @@ -1223,29 +1141,24 @@ def test_apply_duplicate_imports(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = idx + 5_i_def\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert ( - " subroutine run_it()\n" - " use kinds_mod, only : i_def\n" - " integer :: i\n\n" - in output - ) - assert ( - " i = 10_i_def\n i = i + 5_i_def\n\n end subroutine run_it\n" - in output - ) + assert (" subroutine run_it()\n" + " use kinds_mod, only : i_def\n" + " integer :: i\n\n" in output) + assert (" i = 10_i_def\n" + " i = i + 5_i_def\n\n" + " end subroutine run_it\n" in output) # Cannot test for compilation because of 'kinds_mod'. def test_apply_wildcard_import(fortran_reader, fortran_writer): - """Check that apply works correctly when a wildcard import is present - in the routine to be inlined.""" + '''Check that apply works correctly when a wildcard import is present + in the routine to be inlined.''' code = ( "module test_mod\n" "contains\n" @@ -1260,24 +1173,22 @@ def test_apply_wildcard_import(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = idx + 5\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert ( - " subroutine run_it()\n use kinds_mod\n integer :: i\n\n" - in output - ) + assert (" subroutine run_it()\n" + " use kinds_mod\n" + " integer :: i\n\n" in output) # Cannot test for compilation because of 'kinds_mod'. def test_apply_import_union(fortran_reader, fortran_writer): - """Test that the apply method works correctly when the set of symbols + '''Test that the apply method works correctly when the set of symbols imported from a given container is not the same as that imported into - the scope of the call site.""" + the scope of the call site.''' code = ( "module test_mod\n" "contains\n" @@ -1292,26 +1203,23 @@ def test_apply_import_union(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = idx + 5_i_def\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert ( - " subroutine run_it()\n" - " use kinds_mod, only : i_def, r_def\n" - " integer :: i\n\n" - in output - ) - assert " i = 10.0_r_def\n i = i + 5_i_def\n" in output + assert (" subroutine run_it()\n" + " use kinds_mod, only : i_def, r_def\n" + " integer :: i\n\n" in output) + assert (" i = 10.0_r_def\n" + " i = i + 5_i_def\n" in output) # Cannot test for compilation because of 'kinds_mod'. def test_apply_callsite_rename(fortran_reader, fortran_writer): - """Check that a symbol import in the routine causes a - rename of a symbol that is local to the *calling* scope.""" + '''Check that a symbol import in the routine causes a + rename of a symbol that is local to the *calling* scope.''' code = ( "module test_mod\n" "contains\n" @@ -1329,30 +1237,26 @@ def test_apply_callsite_rename(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = idx + 5_i_def + a_clash\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert ( - " subroutine run_it()\n" - " use kinds_mod, only : i_def, r_def\n" - " use a_mod, only : a_clash\n" - " integer :: i\n" - " integer :: a_clash_1\n\n" - " a_clash_1 = 2\n" - " i = 10.0_r_def\n" - " i = i + 5_i_def + a_clash\n" - " i = i * a_clash_1\n" - in output - ) + assert (" subroutine run_it()\n" + " use kinds_mod, only : i_def, r_def\n" + " use a_mod, only : a_clash\n" + " integer :: i\n" + " integer :: a_clash_1\n\n" + " a_clash_1 = 2\n" + " i = 10.0_r_def\n" + " i = i + 5_i_def + a_clash\n" + " i = i * a_clash_1\n" in output) def test_apply_callsite_rename_container(fortran_reader, fortran_writer): - """Check that an import from a container in the routine causes a - rename of a symbol that is local to the *calling* scope.""" + '''Check that an import from a container in the routine causes a + rename of a symbol that is local to the *calling* scope.''' code = ( "module test_mod\n" "contains\n" @@ -1370,30 +1274,26 @@ def test_apply_callsite_rename_container(fortran_reader, fortran_writer): " integer, intent(inout) :: idx\n" " idx = idx + 5_i_def + a_clash\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) output = fortran_writer(psyir) - assert ( - " subroutine run_it()\n" - " use kinds_mod, only : i_def, r_def\n" - " use a_mod, only : a_clash\n" - " integer :: i\n" - " integer :: a_mod_1\n\n" - " a_mod_1 = 2\n" - " i = 10.0_r_def\n" - " i = i + 5_i_def + a_clash\n" - " i = i * a_mod_1\n" - in output - ) + assert (" subroutine run_it()\n" + " use kinds_mod, only : i_def, r_def\n" + " use a_mod, only : a_clash\n" + " integer :: i\n" + " integer :: a_mod_1\n\n" + " a_mod_1 = 2\n" + " i = 10.0_r_def\n" + " i = i + 5_i_def + a_clash\n" + " i = i * a_mod_1\n" in output) def test_validate_non_local_import(fortran_reader): - """Test that we reject the case where the routine to be - inlined accesses a symbol from an import in its parent container.""" + '''Test that we reject the case where the routine to be + inlined accesses a symbol from an import in its parent container.''' code = ( "module test_mod\n" " use some_mod, only: trouble\n" @@ -1407,26 +1307,22 @@ def test_validate_non_local_import(fortran_reader): " integer :: idx\n" " idx = idx + trouble\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ( - "Routine 'sub' cannot be inlined because it accesses variable " - "'trouble' from its parent container." - in str(err.value) - ) + assert ("Routine 'sub' cannot be inlined because it accesses variable " + "'trouble' from its parent container." in str(err.value)) def test_apply_shared_routine_call(fortran_reader): - """ + ''' Test the inlining of a routine that itself calls another routine that is also called from within the scope of the call site. - """ - code = """\ + ''' + code = '''\ module my_mod implicit none contains @@ -1439,7 +1335,7 @@ def test_apply_shared_routine_call(fortran_reader): use slartibartfast, only: norway call norway() end subroutine fijord - end module my_mod""" + end module my_mod''' psyir = fortran_reader.psyir_from_source(code) calls = psyir.walk(Call) inline_trans = InlineTrans() @@ -1453,16 +1349,15 @@ def test_apply_shared_routine_call(fortran_reader): nsym = routines[0].symbol_table.lookup("norway") for call in calls: if call.routine is not nsym: - pytest.xfail( - "#924 cannot reliably update references in inlined code." - ) + pytest.xfail("#924 cannot reliably update references in inlined " + "code.") def test_apply_function(fortran_reader, fortran_writer, tmpdir): - """Check that the apply() method works correctly for a simple call to + '''Check that the apply() method works correctly for a simple call to a function. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -1474,8 +1369,7 @@ def test_apply_function(fortran_reader, fortran_writer, tmpdir): " real :: b\n" " func = 2.0\n" " end function\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1487,41 +1381,34 @@ def test_apply_function(fortran_reader, fortran_writer, tmpdir): " real :: b\n" " real :: inlined_func\n\n" " inlined_func = 2.0\n" - " a = inlined_func" - ) + " a = inlined_func") assert expected in output assert Compile(tmpdir).string_compiles(output) # Try two different forms of function declaration. -@pytest.mark.parametrize( - "function_header", - [ - " function func(b) result(x)\n real :: x\n", - " real function func(b) result(x)\n", - ], -) +@pytest.mark.parametrize("function_header", [ + " function func(b) result(x)\n real :: x\n", + " real function func(b) result(x)\n"]) def test_apply_function_declare_name( - fortran_reader, fortran_writer, tmpdir, function_header -): - """Check that the apply() method works correctly for a simple call to + fortran_reader, fortran_writer, tmpdir, function_header): + '''Check that the apply() method works correctly for a simple call to a function where the name of the return name differs from the function name. - """ + ''' code = ( - "module test_mod\n" - "contains\n" - " subroutine run_it()\n" - " real :: a,b\n" - " a = func(b)\n" - " end subroutine run_it\n" + f"module test_mod\n" + f"contains\n" + f" subroutine run_it()\n" + f" real :: a,b\n" + f" a = func(b)\n" + f" end subroutine run_it\n" f"{function_header}" - " real :: b\n" - " x = 2.0\n" - " end function\n" - "end module test_mod\n" - ) + f" real :: b\n" + f" x = 2.0\n" + f" end function\n" + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1534,17 +1421,16 @@ def test_apply_function_declare_name( " real :: b\n" " real :: inlined_x\n\n" " inlined_x = 2.0\n" - " a = inlined_x" - ) + " a = inlined_x") assert expected in output assert Compile(tmpdir).string_compiles(output) def test_apply_function_expression(fortran_reader, fortran_writer, tmpdir): - """Check that the apply() method works correctly for a call to a + '''Check that the apply() method works correctly for a call to a function that is within an expression. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -1557,8 +1443,7 @@ def test_apply_function_expression(fortran_reader, fortran_writer, tmpdir): " b = b + 3.0\n" " x = b * 2.0\n" " end function\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1569,17 +1454,15 @@ def test_apply_function_expression(fortran_reader, fortran_writer, tmpdir): " real :: inlined_x\n\n" " b = b + 3.0\n" " inlined_x = b * 2.0\n" - " a = (a * inlined_x + 2.0) / a\n" - in output - ) + " a = (a * inlined_x + 2.0) / a\n" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_multi_function(fortran_reader, fortran_writer, tmpdir): - """Check that the apply() method works correctly when a function is + '''Check that the apply() method works correctly when a function is called twice but only one of these function calls is inlined. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -1592,8 +1475,7 @@ def test_apply_multi_function(fortran_reader, fortran_writer, tmpdir): " real :: b\n" " func = 2.0\n" " end function\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1607,8 +1489,7 @@ def test_apply_multi_function(fortran_reader, fortran_writer, tmpdir): " real :: inlined_func\n\n" " inlined_func = 2.0\n" " a = inlined_func\n" - " c = func(a)" - ) + " c = func(a)") assert expected in output assert Compile(tmpdir).string_compiles(output) @@ -1622,45 +1503,35 @@ def test_apply_multi_function(fortran_reader, fortran_writer, tmpdir): " inlined_func = 2.0\n" " a = inlined_func\n" " inlined_func_1 = 2.0\n" - " c = inlined_func_1" - ) + " c = inlined_func_1") assert expected in output -@pytest.mark.parametrize( - "start, end, indent", - [ - ("", "", ""), - ("module test_mod\ncontains\n", "end module test_mod\n", " "), - ( - "module test_mod\nuse formal\ncontains\n", - "end module test_mod\n", - " ", - ), - ], -) +@pytest.mark.parametrize("start, end, indent", [ + ("", "", ""), + ("module test_mod\ncontains\n", "end module test_mod\n", " "), + ("module test_mod\nuse formal\ncontains\n", "end module test_mod\n", + " ")]) def test_apply_raw_subroutine( - fortran_reader, fortran_writer, tmpdir, start, end, indent -): - """Test the apply method works correctly when the routine to be + fortran_reader, fortran_writer, tmpdir, start, end, indent): + '''Test the apply method works correctly when the routine to be inlined is a raw subroutine and is called directly from another raw subroutine, a subroutine within a module but without a use statement and a subroutine within a module with a wildcard use statement. - """ + ''' code = ( f"{start}" - " subroutine run_it()\n" - " real :: a\n" - " call sub(a)\n" - " end subroutine run_it\n" + f" subroutine run_it()\n" + f" real :: a\n" + f" call sub(a)\n" + f" end subroutine run_it\n" f"{end}" - "subroutine sub(x)\n" - " real, intent(inout) :: x\n" - " x = 2.0*x\n" - "end subroutine sub\n" - ) + f"subroutine sub(x)\n" + f" real, intent(inout) :: x\n" + f" x = 2.0*x\n" + f"end subroutine sub\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1670,8 +1541,7 @@ def test_apply_raw_subroutine( f"{indent}subroutine run_it()\n" f"{indent} real :: a\n\n" f"{indent} a = 2.0 * a\n\n" - f"{indent}end subroutine run_it\n" - ) + f"{indent}end subroutine run_it\n") assert expected in output if "use formal" not in output: # Compilation will not work with "use formal" as there is no @@ -1679,93 +1549,78 @@ def test_apply_raw_subroutine( assert Compile(tmpdir).string_compiles(output) -@pytest.mark.parametrize( - "use1, use2", - [ - ("use inline_mod, only : sub\n", ""), - ("use inline_mod\n", ""), - ("", "use inline_mod, only : sub\n"), - ("", "use inline_mod\n"), - ], -) +@pytest.mark.parametrize("use1, use2", [ + ("use inline_mod, only : sub\n", ""), ("use inline_mod\n", ""), + ("", "use inline_mod, only : sub\n"), ("", "use inline_mod\n")]) def test_apply_container_subroutine( - fortran_reader, fortran_writer, tmpdir, use1, use2 -): - """Test the apply method works correctly when the routine to be + fortran_reader, fortran_writer, tmpdir, use1, use2): + '''Test the apply method works correctly when the routine to be inlined is in a different container and is within a module (so a use statement is required). - """ + ''' code = ( - "module inline_mod\n" - "contains\n" - " subroutine sub(x)\n" - " real, intent(inout) :: x\n" - " x = 2.0*x\n" - " end subroutine sub\n" - "end module inline_mod\n" - "module test_mod\n" + f"module inline_mod\n" + f"contains\n" + f" subroutine sub(x)\n" + f" real, intent(inout) :: x\n" + f" x = 2.0*x\n" + f" end subroutine sub\n" + f"end module inline_mod\n" + f"module test_mod\n" f"{use1}" - "contains\n" - " subroutine run_it()\n" + f"contains\n" + f" subroutine run_it()\n" f" {use2}" - " real :: a\n" - " call sub(a)\n" - " end subroutine run_it\n" - "end module test_mod\n" - ) + f" real :: a\n" + f" call sub(a)\n" + f" end subroutine run_it\n" + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) output = fortran_writer(psyir) assert ( - " real :: a\n\n a = 2.0 * a\n\n end subroutine run_it" in output - ) + " real :: a\n\n" + " a = 2.0 * a\n\n" + " end subroutine run_it" in output) assert Compile(tmpdir).string_compiles(output) def test_apply_validate(): - """Test the apply method calls the validate method.""" + '''Test the apply method calls the validate method.''' inline_trans = InlineTrans() with pytest.raises(TransformationError) as info: inline_trans.apply(None) - assert ( - "The target of the InlineTrans transformation should be " - "a Call but found 'NoneType'." - in str(info.value) - ) + assert ("The target of the InlineTrans transformation should be " + "a Call but found 'NoneType'." in str(info.value)) # validate - def test_validate_node(): - """Test the expected exception is raised if an invalid node is - supplied to the transformation.""" + ''' Test the expected exception is raised if an invalid node is + supplied to the transformation. ''' inline_trans = InlineTrans() with pytest.raises(TransformationError) as info: inline_trans.validate(None) - assert ( - "The target of the InlineTrans transformation should be " - "a Call but found 'NoneType'." - in str(info.value) - ) - call = IntrinsicCall.create( - IntrinsicCall.Intrinsic.ALLOCATE, - [Reference(DataSymbol("array", UnresolvedType()))], - ) + assert ("The target of the InlineTrans transformation should be " + "a Call but found 'NoneType'." in str(info.value)) + call = IntrinsicCall.create(IntrinsicCall.Intrinsic.ALLOCATE, + [Reference(DataSymbol("array", + UnresolvedType()))]) with pytest.raises(TransformationError) as info: inline_trans.validate(call) assert "Cannot inline an IntrinsicCall ('ALLOCATE')" in str(info.value) def test_validate_calls_find_routine(fortran_reader): - """Test that validate() calls the _find_routine method. Use an example + '''Test that validate() calls the _find_routine method. Use an example where an exception is raised as the source of the routine to be inlined cannot be found. - """ + ''' code = ( "module test_mod\n" " use some_mod\n" @@ -1775,14 +1630,12 @@ def test_validate_calls_find_routine(fortran_reader): " i = 10\n" " call sub(i)\n" " end subroutine run_it\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - print assert ( "Cannot inline routine 'sub' because its source cannot be found:\n" "Failed to find the source code of the unresolved routine 'sub' - " @@ -1794,9 +1647,9 @@ def test_validate_calls_find_routine(fortran_reader): def test_validate_return_stmt(fortran_reader): - """Test that validate() raises the expected error if the target routine + '''Test that validate() raises the expected error if the target routine contains one or more Returns which that aren't either the very first - statement or very last statement.""" + statement or very last statement.''' code = ( "module test_mod\n" "contains\n" @@ -1811,24 +1664,20 @@ def test_validate_return_stmt(fortran_reader): " return\n" " idx = idx + 3\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ( - "Routine 'sub' contains one or more Return statements and " - "therefore cannot be inlined" - in str(err.value) - ) + assert ("Routine 'sub' contains one or more Return statements and " + "therefore cannot be inlined" in str(err.value)) def test_validate_codeblock(fortran_reader): - """Test that validate() raises the expected error for a routine that + '''Test that validate() raises the expected error for a routine that contains a CodeBlock. Also test that using the "force" option overrides - this check.""" + this check.''' code = ( "module test_mod\n" "contains\n" @@ -1841,8 +1690,7 @@ def test_validate_codeblock(fortran_reader): " integer :: idx\n" " write(*,*) idx\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1858,10 +1706,10 @@ def test_validate_codeblock(fortran_reader): def test_validate_unsupportedtype_argument(fortran_reader): - """ + ''' Test that validate rejects a subroutine with arguments of UnsupportedType. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -1882,7 +1730,6 @@ def test_validate_unsupportedtype_argument(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(routine) - print(err.value) assert ( "Found routines, but no routine with matching arguments found for" " 'sub'" @@ -1898,11 +1745,11 @@ def test_validate_unsupportedtype_argument(fortran_reader): def test_validate_unknowninterface(fortran_reader, fortran_writer, tmpdir): - """ + ''' Test that validate rejects a subroutine containing variables with UnknownInterface. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -1930,9 +1777,7 @@ def test_validate_unknowninterface(fortran_reader, fortran_writer, tmpdir): xvar = psyir.walk(Routine)[1].symbol_table.lookup("x") xvar.interface = AutomaticInterface() inline_trans.apply(routine) - assert ( - fortran_writer(psyir.walk(Routine)[0]) - == """\ + assert fortran_writer(psyir.walk(Routine)[0]) == """\ subroutine main() REAL, POINTER :: x @@ -1940,15 +1785,14 @@ def test_validate_unknowninterface(fortran_reader, fortran_writer, tmpdir): end subroutine main """ - ) assert Compile(tmpdir).string_compiles(fortran_writer(psyir)) def test_validate_static_var(fortran_reader): - """ + ''' Test that validate rejects a subroutine with StaticInterface variables. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -1962,8 +1806,7 @@ def test_validate_static_var(fortran_reader): " state = state + x\n" " x = 2.0*x + state\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -1976,29 +1819,27 @@ def test_validate_static_var(fortran_reader): ) -@pytest.mark.parametrize( - "code_body", - ["idx = idx + 5_i_def", "real, parameter :: pi = 3_wp\nidx = idx + 1\n"], -) +@pytest.mark.parametrize("code_body", ["idx = idx + 5_i_def", + "real, parameter :: pi = 3_wp\n" + "idx = idx + 1\n"]) def test_validate_unresolved_precision_sym(fortran_reader, code_body): - """Test that a routine that uses an unresolved precision symbol is + '''Test that a routine that uses an unresolved precision symbol is rejected. We test when the precision symbol appears in an executable - statement and when it appears in a constant initialisation.""" + statement and when it appears in a constant initialisation.''' code = ( - "module test_mod\n" - " use kinds_mod\n" - "contains\n" - " subroutine run_it()\n" - " integer :: i\n" - " i = 10_i_def\n" - " call sub(i)\n" - " end subroutine run_it\n" - " subroutine sub(idx)\n" - " integer, intent(inout) :: idx\n" + f"module test_mod\n" + f" use kinds_mod\n" + f"contains\n" + f" subroutine run_it()\n" + f" integer :: i\n" + f" i = 10_i_def\n" + f" call sub(i)\n" + f" end subroutine run_it\n" + f" subroutine sub(idx)\n" + f" integer, intent(inout) :: idx\n" f" {code_body}\n" - " end subroutine sub\n" - "end module test_mod\n" - ) + f" end subroutine sub\n" + f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() call = psyir.walk(Call)[0] @@ -2016,9 +1857,10 @@ def test_validate_unresolved_precision_sym(fortran_reader, code_body): ) -def test_validate_resolved_precision_sym(fortran_reader, monkeypatch, tmpdir): - """Test that a routine that uses a resolved precision symbol from its - parent Container is rejected.""" +def test_validate_resolved_precision_sym(fortran_reader, monkeypatch, + tmpdir): + '''Test that a routine that uses a resolved precision symbol from its + parent Container is rejected.''' code = ( "module test_mod\n" " use kinds_mod\n" @@ -2038,40 +1880,34 @@ def test_validate_resolved_precision_sym(fortran_reader, monkeypatch, tmpdir): " integer, intent(inout) :: idx\n" " idx = idx + 5_i_def\n" " end subroutine sub2\n" - "end module test_mod\n" - ) + "end module test_mod\n") # Set up include_path to import the proper module - monkeypatch.setattr(Config.get(), "_include_paths", [str(tmpdir)]) + monkeypatch.setattr(Config.get(), '_include_paths', [str(tmpdir)]) filename = os.path.join(str(tmpdir), "kinds_mod.f90") - with open(filename, "w", encoding="UTF-8") as module: - module.write( - """ + with open(filename, "w", encoding='UTF-8') as module: + module.write(''' module kinds_mod integer, parameter :: i_def = kind(1) end module kinds_mod - """ - ) + ''') psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() # First subroutine accesses i_def from parent Container. calls = psyir.walk(Call) with pytest.raises(TransformationError) as err: inline_trans.validate(calls[0]) - assert ( - "Routine 'sub' cannot be inlined because it accesses variable " - "'i_def' and this cannot be found in any of the containers " - "directly imported into its symbol table." - in str(err.value) - ) + assert ("Routine 'sub' cannot be inlined because it accesses variable " + "'i_def' and this cannot be found in any of the containers " + "directly imported into its symbol table." in str(err.value)) # Second subroutine imports i_def directly into its own SymbolTable and # so is OK to inline. inline_trans.validate(calls[1]) def test_validate_import_clash(fortran_reader): - """Test that validate() raises the expected error when two symbols of the + '''Test that validate() raises the expected error when two symbols of the same name are imported from different containers at the call site and - within the routine.""" + within the routine.''' code = ( "module test_mod\n" "contains\n" @@ -2086,23 +1922,19 @@ def test_validate_import_clash(fortran_reader): " integer :: idx\n" " idx = idx + trouble\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ( - "One or more symbols from routine 'sub' cannot be added to the " - "table at the call site." - in str(err.value) - ) + assert ("One or more symbols from routine 'sub' cannot be added to the " + "table at the call site." in str(err.value)) def test_validate_non_local_symbol(fortran_reader): - """Test that validate() raises the expected error when the routine to be - inlined accesses a symbol from its parent container.""" + '''Test that validate() raises the expected error when the routine to be + inlined accesses a symbol from its parent container.''' code = ( "module test_mod\n" " integer :: trouble\n" @@ -2116,24 +1948,20 @@ def test_validate_non_local_symbol(fortran_reader): " integer :: idx\n" " idx = idx + trouble\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ( - "Routine 'sub' cannot be inlined because it accesses variable " - "'trouble' from its parent container" - in str(err.value) - ) + assert ("Routine 'sub' cannot be inlined because it accesses variable " + "'trouble' from its parent container" in str(err.value)) def test_validate_wrong_number_args(fortran_reader): - """Test that validate rejects inlining routines with different number + ''' Test that validate rejects inlining routines with different number of arguments. - """ + ''' code = ( "module test_mod\n" " integer :: trouble\n" @@ -2147,8 +1975,7 @@ def test_validate_wrong_number_args(fortran_reader): " integer :: idx\n" " idx = idx + 1\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() @@ -2168,8 +1995,8 @@ def test_validate_wrong_number_args(fortran_reader): def test_validate_unresolved_import(fortran_reader): - """Test that validate rejects a routine that accesses a symbol which - is unresolved.""" + '''Test that validate rejects a routine that accesses a symbol which + is unresolved.''' code = ( "module test_mod\n" " use some_mod\n" @@ -2183,27 +2010,23 @@ def test_validate_unresolved_import(fortran_reader): " integer :: idx\n" " idx = idx + trouble\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ( - "Routine 'sub' cannot be inlined because it accesses variable " - "'trouble' and this cannot be found in any of the containers " - "directly imported into its symbol table." - in str(err.value) - ) + assert ("Routine 'sub' cannot be inlined because it accesses variable " + "'trouble' and this cannot be found in any of the containers " + "directly imported into its symbol table." in str(err.value)) def test_validate_unresolved_array_dim(fortran_reader): - """ + ''' Check that validate rejects a routine if it uses an unresolved Symbol when defining an array dimension. - """ + ''' code = ( "module test_mod\n" " use some_mod\n" @@ -2218,25 +2041,21 @@ def test_validate_unresolved_array_dim(fortran_reader): " integer, dimension(some_size) :: var\n" " idx = idx + 2\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ( - "Routine 'sub' cannot be inlined because it accesses variable " - "'some_size' and this cannot be found in any of the containers " - "directly imported into its symbol table" - in str(err.value) - ) + assert ("Routine 'sub' cannot be inlined because it accesses variable " + "'some_size' and this cannot be found in any of the containers " + "directly imported into its symbol table" in str(err.value)) def test_validate_array_reshape(fortran_reader): - """Test that the validate method rejects an attempt to inline a routine + '''Test that the validate method rejects an attempt to inline a routine if any of its formal arguments are declared to be a different shape from - those at the call site.""" + those at the call site.''' code = ( "module test_mod\n" "contains\n" @@ -2252,28 +2071,24 @@ def test_validate_array_reshape(fortran_reader): " x(i) = x(i) + m\n" " enddo\n" "end subroutine\n" - "end module\n" - ) + "end module\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ( - "Cannot inline routine 's' because it reshapes an argument: actual" - " argument 'a(:,:)' has rank 2 but the corresponding formal " - "argument, 'x', has rank 1" - in str(err.value) - ) + assert ("Cannot inline routine 's' because it reshapes an argument: actual" + " argument 'a(:,:)' has rank 2 but the corresponding formal " + "argument, 'x', has rank 1" in str(err.value)) def test_validate_array_arg_expression(fortran_reader): - """ + ''' Check that validate rejects a call if an argument corresponding to a formal array argument is not a simple Reference or Literal. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -2289,25 +2104,21 @@ def test_validate_array_arg_expression(fortran_reader): " x(i) = x(i) + m\n" " enddo\n" "end subroutine\n" - "end module\n" - ) + "end module\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ( - "The call 'call s(a + b, 10)\n' cannot be inlined because actual " - "argument 'a + b' corresponds to a formal argument with array " - "type but is not a Reference or a Literal" - in str(err.value) - ) + assert ("The call 'call s(a + b, 10)\n' cannot be inlined because actual " + "argument 'a + b' corresponds to a formal argument with array " + "type but is not a Reference or a Literal" in str(err.value)) def test_validate_indirect_range(fortran_reader): - """Test that validate rejects an attempt to inline a call to a routine - with an argument constructed using an indirect slice.""" + '''Test that validate rejects an attempt to inline a call to a routine + with an argument constructed using an indirect slice.''' code = ( "module test_mod\n" " integer, dimension(10) :: indices\n" @@ -2320,24 +2131,20 @@ def test_validate_indirect_range(fortran_reader): " real, dimension(:), intent(inout) :: x\n" " x(:) = 0.0\n" "end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ( - "Cannot inline routine 'sub' because argument 'var(indices(:))' " - "has an array range in an indirect access" - in str(err.value) - ) + assert ("Cannot inline routine 'sub' because argument 'var(indices(:))' " + "has an array range in an indirect access" in str(err.value)) def test_validate_non_unit_stride_slice(fortran_reader): - """Test that validate rejects an attempt to inline a call to a routine - with an argument constructed using an array slice with non-unit stride.""" + '''Test that validate rejects an attempt to inline a call to a routine + with an argument constructed using an array slice with non-unit stride.''' code = ( "module test_mod\n" "contains\n" @@ -2349,25 +2156,21 @@ def test_validate_non_unit_stride_slice(fortran_reader): " real, dimension(:), intent(inout) :: x\n" " x(:) = 0.0\n" "end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.set_option(check_matching_arguments_of_callee=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) - print(err.value) - assert ( - "Cannot inline routine 'sub' because one of its arguments is an " - "array slice with a non-unit stride: 'var(::2)' (TODO #1646)" - in str(err.value) - ) + assert ("Cannot inline routine 'sub' because one of its arguments is an " + "array slice with a non-unit stride: 'var(::2)' (TODO #1646)" in + str(err.value)) def test_apply_named_arg(fortran_reader): - """Test that the validate method inlines a routine that has a named - argument.""" + '''Test that the validate method inlines a routine that has a named + argument.''' code = ( "module test_mod\n" @@ -2391,8 +2194,8 @@ def test_apply_named_arg(fortran_reader): def test_validate_optional_arg(fortran_reader): - """Test that the validate method inlines a routine - that has an optional argument.""" + '''Test that the validate method inlines a routine + that has an optional argument.''' code = ( "module test_mod\n" @@ -2418,8 +2221,8 @@ def test_validate_optional_arg(fortran_reader): def test_validate_optional_and_named_arg(fortran_reader): - """Test that the validate method inlines a routine - that has an optional argument.""" + '''Test that the validate method inlines a routine + that has an optional argument.''' code = ( "module test_mod\n" "contains\n" @@ -2458,16 +2261,16 @@ def test_validate_optional_and_named_arg(fortran_reader): inline_trans.apply(call) assert ( - """var = var + 1.0 + 1.0 + '''var = var + 1.0 + 1.0 var = var + 2.0 - var = var + 1.0 + 1.0""" + var = var + 1.0 + 1.0''' in routine_main.debug_string() ) def test_validate_optional_and_named_arg_2(fortran_reader): - """Test that the validate method inlines a routine - that has an optional argument.""" + '''Test that the validate method inlines a routine + that has an optional argument.''' code = ( "module test_mod\n" "contains\n" @@ -2517,12 +2320,12 @@ def test_validate_optional_and_named_arg_2(fortran_reader): print(routine_main.debug_string()) assert ( - """var = var + 2.0 + 1.0 + '''var = var + 2.0 + 1.0 var = var + 4.0 + 1.0 var = var + 5.0 + 1.0 var = var + 3.0 var = var + 6.0 - var = var + 7.0""" + var = var + 7.0''' in routine_main.debug_string() ) @@ -2532,18 +2335,26 @@ def test_validate_optional_and_named_arg_2(fortran_reader): " use inline_mod, only : sub\n" " real :: a\n" " call sub(a)\n" - "end subroutine run_it\n" -) -CALL_IN_SUB = CALL_IN_SUB_USE.replace(" use inline_mod, only : sub\n", "") -SUB = "subroutine sub(x)\n real :: x\n x = 1.0\nend subroutine sub\n" -SUB_IN_MODULE = f"module inline_mod\ncontains\n{SUB}end module inline_mod\n" + "end subroutine run_it\n") +CALL_IN_SUB = CALL_IN_SUB_USE.replace( + " use inline_mod, only : sub\n", "") +SUB = ( + "subroutine sub(x)\n" + " real :: x\n" + " x = 1.0\n" + "end subroutine sub\n") +SUB_IN_MODULE = ( + f"module inline_mod\n" + f"contains\n" + f"{SUB}" + f"end module inline_mod\n") def test_apply_merges_symbol_table_with_routine(fortran_reader): - """ + ''' Check that the apply method merges the inlined function's symbol table to the containing Routine when the call node is inside a child ScopingNode. - """ + ''' code = ( "module test_mod\n" "contains\n" @@ -2562,22 +2373,21 @@ def test_apply_merges_symbol_table_with_routine(fortran_reader): " x(i) = 2.0*ivar\n" " end do\n" " end subroutine sub\n" - "end module test_mod\n" - ) + "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(routine) # The i_1 symbol is the renamed i from the inlined call. - assert psyir.walk(Routine)[0].symbol_table.get_symbols()["i_1"] is not None + assert psyir.walk(Routine)[0].symbol_table.get_symbols()['i_1'] is not None def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): - """ + ''' Check that the formal arguments to the inlined routine are not included when checking for clashes (since they will be replaced by the actual arguments to the call). - """ + ''' code_clash = """ subroutine sub(Istr) @@ -2599,7 +2409,7 @@ def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): call = psyir.walk(Call)[0] inline_trans = InlineTrans() inline_trans.apply(call) - expected = """\ + expected = '''\ subroutine sub(istr) integer :: istr real :: x @@ -2610,7 +2420,7 @@ def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): b(istr:) = 1.0 end subroutine sub -""" +''' output = fortran_writer(psyir) assert expected in output assert Compile(tmpdir).string_compiles(output) From 08cb731fb04527d309e5aac3c274648813a5a29b Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sun, 24 Nov 2024 13:52:04 +0100 Subject: [PATCH 07/20] Further cleanups --- .coverage.mardinateur.37407.XOevDdvx.wgw0 | 0 .coverage.mardinateur.37413.XOWfjlOx.wgw2 | 0 src/psyclone/psyir/nodes/__init__.py | 3 +- src/psyclone/psyir/nodes/call.py | 606 ++------------- src/psyclone/psyir/tools/__init__.py | 6 +- .../psyir/tools/call_routine_matcher.py | 688 ++++++++++++++++++ src/psyclone/psyir/tools/call_tree_utils.py | 2 +- .../psyir/transformations/inline_trans.py | 508 +++++++------ .../psyir/transformations/omp_task_trans.py | 4 +- src/psyclone/tests/psyir/nodes/call_test.py | 2 +- .../transformations/inline_trans_test.py | 38 +- 11 files changed, 1047 insertions(+), 810 deletions(-) create mode 100644 .coverage.mardinateur.37407.XOevDdvx.wgw0 create mode 100644 .coverage.mardinateur.37413.XOWfjlOx.wgw2 create mode 100644 src/psyclone/psyir/tools/call_routine_matcher.py diff --git a/.coverage.mardinateur.37407.XOevDdvx.wgw0 b/.coverage.mardinateur.37407.XOevDdvx.wgw0 new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.coverage.mardinateur.37413.XOWfjlOx.wgw2 b/.coverage.mardinateur.37413.XOWfjlOx.wgw2 new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/psyclone/psyir/nodes/__init__.py b/src/psyclone/psyir/nodes/__init__.py index 7efe06b811..69f066cc9a 100644 --- a/src/psyclone/psyir/nodes/__init__.py +++ b/src/psyclone/psyir/nodes/__init__.py @@ -74,7 +74,7 @@ from psyclone.psyir.nodes.statement import Statement from psyclone.psyir.nodes.structure_reference import StructureReference from psyclone.psyir.nodes.structure_member import StructureMember -from psyclone.psyir.nodes.call import Call, CallMatchingArgumentsNotFoundError +from psyclone.psyir.nodes.call import Call from psyclone.psyir.nodes.file_container import FileContainer from psyclone.psyir.nodes.directive import ( Directive, StandaloneDirective, RegionDirective) @@ -112,7 +112,6 @@ "Assignment", "BinaryOperation", "Call", - "CallMatchingArgumentsNotFoundError", "Clause", "CodeBlock", "Container", diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index bb61beac91..d1f7ff93fc 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -56,17 +56,6 @@ ContainerSymbol, ) from typing import List, Union -from psyclone.errors import PSycloneError -from psyclone.psyir.symbols.datatypes import ArrayType - - -class CallMatchingArgumentsNotFoundError(PSycloneError): - '''Exception to signal that matching arguments have not been found - for this routine - ''' - def __init__(self, value): - PSycloneError.__init__(self, value) - self.value = "CallMatchingArgumentsNotFound: " + str(value) class Call(Statement, DataNode): @@ -133,14 +122,16 @@ def create(cls, routine, arguments=()): ''' if not isinstance(routine, (Reference, RoutineSymbol)): raise TypeError( - f"The Call routine argument should be a Reference to a " - f"RoutineSymbol or a RoutineSymbol, but " - f"found '{type(routine).__name__}'.") + "The Call routine argument should be a Reference to a " + "RoutineSymbol or a RoutineSymbol, but " + f"found '{type(routine).__name__}'." + ) if not isinstance(arguments, Iterable): raise GenerationError( - f"Call.create 'arguments' argument should be an Iterable but " - f"found '{type(arguments).__name__}'.") + "Call.create 'arguments' argument should be an Iterable but " + f"found '{type(arguments).__name__}'." + ) call = cls() if isinstance(routine, Reference): @@ -174,15 +165,17 @@ def _add_args(call, arguments): if isinstance(arg, tuple): if not len(arg) == 2: raise GenerationError( - f"If a child of the children argument in create " - f"method of Call class is a tuple, it's " - f"length should be 2, but found {len(arg)}.") + "If a child of the children argument in create " + "method of Call class is a tuple, it's " + f"length should be 2, but found {len(arg)}." + ) if not isinstance(arg[0], str): raise GenerationError( - f"If a child of the children argument in create " - f"method of Call class is a tuple, its first " - f"argument should be a str, but found " - f"{type(arg[0]).__name__}.") + "If a child of the children argument in create " + "method of Call class is a tuple, its first " + "argument should be a str, but found " + f"{type(arg[0]).__name__}." + ) name, arg = arg call.append_named_arg(name, arg) @@ -202,13 +195,15 @@ def append_named_arg(self, name, arg): # Avoid circular import. # pylint: disable=import-outside-toplevel from psyclone.psyir.frontend.fortran import FortranReader + FortranReader.validate_name(name) for check_name in self.argument_names: if check_name and check_name.lower() == name.lower(): raise ValueError( f"The value of the name argument ({name}) in " - f"'append_named_arg' in the 'Call' node is " - f"already used for a named argument.") + "'append_named_arg' in the 'Call' node is " + "already used for a named argument." + ) self._argument_names.append((id(arg), name)) self.children.append(arg) @@ -231,18 +226,21 @@ def insert_named_arg(self, name, arg, index): # Avoid circular import. # pylint: disable=import-outside-toplevel from psyclone.psyir.frontend.fortran import FortranReader + FortranReader.validate_name(name) for check_name in self.argument_names: if check_name and check_name.lower() == name.lower(): raise ValueError( f"The value of the name argument ({name}) in " - f"'insert_named_arg' in the 'Call' node is " - f"already used for a named argument.") + "'insert_named_arg' in the 'Call' node is " + "already used for a named argument." + ) if not isinstance(index, int): raise TypeError( - f"The 'index' argument in 'insert_named_arg' in the " - f"'Call' node should be an int but found " - f"{type(index).__name__}.") + "The 'index' argument in 'insert_named_arg' in the " + "'Call' node should be an int but found " + f"{type(index).__name__}." + ) self._argument_names.insert(index, (id(arg), name)) # The n'th argument is placed at the n'th+1 children position # because the 1st child is the routine reference @@ -264,9 +262,10 @@ def replace_named_arg(self, existing_name, arg): ''' if not isinstance(existing_name, str): raise TypeError( - f"The 'name' argument in 'replace_named_arg' in the " - f"'Call' node should be a string, but found " - f"{type(existing_name).__name__}.") + "The 'name' argument in 'replace_named_arg' in the " + "'Call' node should be a string, but found " + f"{type(existing_name).__name__}." + ) index = 0 for _, name in self._argument_names: if name is not None and name.lower() == existing_name: @@ -432,8 +431,10 @@ def node_str(self, colour=True): :rtype: str ''' - return (f"{self.coloured_name(colour)}" - f"[name='{self.routine.debug_string()}']") + return ( + f"{self.coloured_name(colour)}" + f"[name='{self.routine.debug_string()}']" + ) def __str__(self): return self.node_str(False) @@ -469,17 +470,17 @@ def _get_container_symbols_rec( _stack_container_name_list: List[str] = [], _depth: int = 0, ): - """Return a list of all container symbols that can be found + '''Return a list of all container symbols that can be found recursively :param container_symbols: List of starting set of container symbols :type container_symbols: List[ContainerSymbol] :param _stack_container_list: Stack with already visited Containers - to avoid circular searches, defaults to [] + to avoid circular searches, defaults to [] :type _stack_container_list: List[Container], optional :param _depth: Depth of recursive search :type _depth: int - """ + ''' # # TODO: # - This function seems to be extremely slow: @@ -536,7 +537,7 @@ def _get_container_symbols_rec( return ret_container_symbol_list def get_callees(self, ignore_missing_modules: bool = False): - """ + ''' Searches for the implementation(s) of all potential target routines for this Call without any arguments check. @@ -550,485 +551,17 @@ def get_callees(self, ignore_missing_modules: bool = False): :raises NotImplementedError: if the routine is not local and not found in any containers in scope at the call site. - """ - def _location_txt(node): - ''' - Utility to generate meaningful location text. - - :param node: a PSyIR node. - :type node: :py:class:`psyclone.psyir.nodes.Node` - - :returns: description of location of node. - :rtype: str - ''' - if isinstance(node, Container): - return f"Container '{node.name}'" - out_lines = node.debug_string().split("\n") - idx = -1 - while not out_lines[idx]: - idx -= 1 - last_line = out_lines[idx] - return f"code:\n'{out_lines[0]}\n...\n{last_line}'" - - rsym = self.routine.symbol - if rsym.is_unresolved: - - # Check for any "raw" Routines, i.e. ones that are not - # in a Container. Such Routines would exist in the PSyIR - # as a child of a FileContainer (if the PSyIR contains a - # FileContainer). Note, if the PSyIR does contain a - # FileContainer, it will be the root node of the PSyIR. - for routine in self.root.children: - if (isinstance(routine, Routine) and - routine.name.lower() == rsym.name.lower()): - return [routine] - - # Now check for any wildcard imports and see if they can - # be used to resolve the symbol. - wildcard_names = [] - containers_not_found = [] - current_table: SymbolTable = self.scope.symbol_table - while current_table: - # TODO: Obtaining all container symbols in this way - # breaks some tests. - # It would be better using the ModuleManager to resolve - # (and cache) all containers to look up for this. - # - # current_containersymbols = self._get_container_symbols_rec( - # current_table.containersymbols, - # ignore_missing_modules=ignore_missing_modules, - # ) - # for container_symbol in current_containersymbols: - for container_symbol in current_table.containersymbols: - container_symbol: ContainerSymbol - if container_symbol.wildcard_import: - wildcard_names.append(container_symbol.name) - - try: - container: Container = ( - container_symbol.find_container_psyir( - local_node=self, - ignore_missing_modules=( - ignore_missing_modules - ), - ) - ) - except SymbolError: - container = None - if not container: - # Failed to find/process this Container. - containers_not_found.append(container_symbol.name) - continue - routines = [] - for name in container.resolve_routine(rsym.name): - # Allow private imports if an 'interface' - # was used. Here, we assume the name of the routine - # is different to the call. - allow_private = name != rsym.name - psyir = container.find_routine_psyir( - name, allow_private=allow_private - ) - - if psyir: - routines.append(psyir) - - if routines: - return routines - current_table = current_table.parent_symbol_table() - - if not wildcard_names: - wc_text = "there are no wildcard imports" - else: - if containers_not_found: - wc_text = ( - f"attempted to resolve the wildcard imports from" - f" {wildcard_names}. However, failed to find the " - f"source for {containers_not_found}. The module search" - f" path is set to {Config.get().include_paths}") - else: - wc_text = (f"wildcard imports from {wildcard_names}") - raise NotImplementedError( - f"Failed to find the source code of the unresolved routine " - f"'{rsym.name}' - looked at any routines in the same source " - f"file and {wc_text}. Searching for external routines " - f"that are only resolved at link time is not supported.") - - root_node = self.ancestor(Container) - if not root_node: - root_node = self.root - container = root_node - can_be_private = True - - if rsym.is_import: - cursor = rsym - # A Routine imported from another Container must be public in that - # Container. - can_be_private = False - while cursor.is_import: - csym = cursor.interface.container_symbol - try: - container = csym.find_container_psyir(local_node=self) - except SymbolError: - raise NotImplementedError( - f"RoutineSymbol '{rsym.name}' is imported from " - f"Container '{csym.name}' but the source defining " - f"that container could not be found. The module search" - f" path is set to {Config.get().include_paths}") - imported_sym = container.symbol_table.lookup(cursor.name) - if imported_sym.visibility != Symbol.Visibility.PUBLIC: - # The required Symbol must be shadowed with a PRIVATE - # Symbol in this Container. This means that the one we - # actually want is brought into scope via a wildcard - # import. - # TODO #924 - Use ModuleManager to search? - raise NotImplementedError( - f"RoutineSymbol '{rsym.name}' is imported from " - f"Container '{csym.name}' but that Container defines " - f"a private Symbol of the same name. Searching for the" - f" Container that defines a public Routine with that " - f"name is not yet supported - TODO #924") - if not isinstance(imported_sym, RoutineSymbol): - # We now know that this is a RoutineSymbol so specialise it - # in place. - imported_sym.specialise(RoutineSymbol) - cursor = imported_sym - rsym = cursor - root_node = container - - if isinstance(rsym.datatype, UnsupportedFortranType): - # TODO #924 - an UnsupportedFortranType here typically indicates - # that the target is actually an interface. - raise NotImplementedError( - f"RoutineSymbol '{rsym.name}' exists in " - f"{_location_txt(root_node)} but is of " - f"UnsupportedFortranType:\n{rsym.datatype.declaration}\n" - f"Cannot get the PSyIR of such a routine.") - - if isinstance(container, Container): - routines = [] - for name in container.resolve_routine(rsym.name): - psyir = container.find_routine_psyir( - name, allow_private=can_be_private) - if psyir: - routines.append(psyir) - if routines: - return routines - - raise SymbolError( - f"Failed to find a Routine named '{rsym.name}' in " - f"{_location_txt(root_node)}. This is normally because the routine" - f" is within a CodeBlock.") - - def _check_inline_types( - self, - call_arg: DataSymbol, - routine_arg: DataSymbol, - check_array_type: bool = True, - ): - """This function performs tests to see whether the - inlining can cope with it. - - :param call_arg: The argument of a call - :type call_arg: DataSymbol - :param routine_arg: The argument of a routine - :type routine_arg: DataSymbol - :param check_array_type: Perform strong checks on array types, - defaults to `True` - :type check_array_type: bool, optional - - :raises TransformationError: Raised if transformation can't be done - - :return: 'True' if checks are successful - :rtype: bool - """ - from psyclone.psyir.transformations.transformation_error import ( - TransformationError, - ) - from psyclone.errors import LazyString - from psyclone.psyir.nodes import Literal, Range - from psyclone.psyir.symbols import ( - UnresolvedType, - UnsupportedType, - INTEGER_TYPE, + ''' + from psyclone.psyir.tools import ( + CallRoutineMatcher ) - _ONE = Literal("1", INTEGER_TYPE) - - # If the formal argument is an array with non-default bounds then - # we also need to know the bounds of that array at the call site. - if not isinstance(routine_arg.datatype, ArrayType): - # Formal argument is not an array so we don't need to do any - # further checks. - return True - - if not isinstance(call_arg, (Reference, Literal)): - # TODO #1799 this really needs the `datatype` method to be - # extended to support all nodes. For now we have to abort - # if we encounter an argument that is not a scalar (according - # to the corresponding formal argument) but is not a - # Reference or a Literal as we don't know whether the result - # of any general expression is or is not an array. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - f"The call '{self.debug_string()}' " - "cannot be inlined because actual argument " - f"'{call_arg.debug_string()}' corresponds to a " - "formal argument with array type but is not a " - "Reference or a Literal." - ) - ) - ) - - # We have an array argument. We are only able to check that the - # argument is not re-shaped in the called routine if we have full - # type information on the actual argument. - # TODO #924. It would be useful if the `datatype` property was - # a method that took an optional 'resolve' argument to indicate - # that it should attempt to resolve any UnresolvedTypes. - if check_array_type: - if isinstance( - call_arg.datatype, (UnresolvedType, UnsupportedType) - ) or ( - isinstance(call_arg.datatype, ArrayType) - and isinstance( - call_arg.datatype.intrinsic, - (UnresolvedType, UnsupportedType), - ) - ): - raise TransformationError( - f"Routine '{self.routine.name}' cannot be " - "inlined because the type of the actual argument " - f"'{call_arg.symbol.name}' corresponding to an array" - f" formal argument ('{routine_arg.name}') is unknown." - ) - - formal_rank = 0 - actual_rank = 0 - if isinstance(routine_arg.datatype, ArrayType): - formal_rank = len(routine_arg.datatype.shape) - if isinstance(call_arg.datatype, ArrayType): - actual_rank = len(call_arg.datatype.shape) - if formal_rank != actual_rank: - # It's OK to use the loop variable in the lambda definition - # because if we get to this point then we're going to quit - # the loop. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{self.routine.name}' because it" - " reshapes an argument: actual argument" - f" '{call_arg.debug_string()}' has rank" - f" {actual_rank} but the corresponding formal" - f" argument, '{routine_arg.name}', has rank" - f" {formal_rank}" - ) - ) - ) - if actual_rank: - ranges = call_arg.walk(Range) - for rge in ranges: - ancestor_ref = rge.ancestor(Reference) - if ancestor_ref is not call_arg: - # Have a range in an indirect access. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{self.routine.name}' because" - " argument" - f" '{call_arg.debug_string()}' has" - " an array range in an indirect" - " access #(TODO 924)." - ) - ) - ) - if rge.step != _ONE: - # TODO #1646. We could resolve this problem by - # making a new array and copying the necessary - # values into it. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{self.routine.name}' because" - " one of its arguments is an array" - " slice with a non-unit stride:" - f" '{call_arg.debug_string()}' (TODO" - " #1646)" - ) - ) - ) - - def _check_argument_type_matches( - self, - call_arg: DataSymbol, - routine_arg: DataSymbol, - check_strict_array_datatype: bool = True, - ) -> bool: - """Return information whether argument types are matching. - This also supports 'optional' arguments by using - partial types. - - :param call_arg: Argument from the call - :type call_arg: DataSymbol - :param routine_arg: Argument from the routine - :type routine_arg: DataSymbol - :param check_strict_array_datatype: Check strictly for matching - array types. If `False`, only checks for ArrayType itself are done. - :type check_strict_array_datatype: bool - :returns: True if arguments match, False otherwise - :rtype: bool - :raises CallMatchingArgumentsNotFound: Raised if no matching arguments - were found. - """ - - self._check_inline_types(call_arg, routine_arg) - - type_matches = False - if not check_strict_array_datatype: - # No strict array checks have to be performed, just accept it - if isinstance(call_arg.datatype, ArrayType) and isinstance( - routine_arg.datatype, ArrayType - ): - type_matches = True - - if not type_matches: - if isinstance(routine_arg.datatype, UnsupportedFortranType): - # This could be an 'optional' argument. - # This has at least a partial data type - if call_arg.datatype != routine_arg.datatype.partial_datatype: - raise CallMatchingArgumentsNotFoundError( - "Argument partial type mismatch of call " - f"argument '{call_arg}' and routine argument " - f"'{routine_arg}'" - ) - else: - if call_arg.datatype != routine_arg.datatype: - raise CallMatchingArgumentsNotFoundError( - "Argument type mismatch of call argument " - f"'{call_arg}' with type '{call_arg.datatype} " - "and routine argument " - f"'{routine_arg}' with type '{routine_arg.datatype}." - ) - - return True - - def _get_argument_routine_match( - self, - routine: Routine, - check_strict_array_datatype: bool = True, - ) -> Union[None, List[int]]: - """Return a list of integers giving for each argument of the call - the index of the argument in argument_list (typically of a routine) - - :param check_strict_array_datatype: Strict datatype check for - array types - :type check_strict_array_datatype: bool - - :param check_matching_arguments: If no match is possible, - return the first routine in the list of potential candidates. - :type check_matching_arguments: bool - - :return: None if no match was found, otherwise list of integers - referring to matching arguments. - :rtype: None|List[int] - :raises CallMatchingArgumentsNotFound: If there was some problem in - finding matching arguments. - """ - - # Create a copy of the list of actual arguments to the routine. - # Once an argument has been successfully matched, set it to 'None' - routine_argument_list: List[DataSymbol] = ( - routine.symbol_table.argument_list[:] + call_routine_matcher: CallRoutineMatcher = CallRoutineMatcher(self) + call_routine_matcher.set_option( + ignore_missing_modules=ignore_missing_modules, ) - if len(self.arguments) > len(routine.symbol_table.argument_list): - call_str = self.debug_string().replace("\n", "") - raise CallMatchingArgumentsNotFoundError( - f"More arguments in call ('{call_str}')" - f" than callee (routine '{routine.name}')" - ) - - # Iterate over all arguments to the call - ret_arg_idx_list = [] - for call_arg_idx, call_arg in enumerate(self.arguments): - call_arg_idx: int - call_arg: DataSymbol - - # If the associated name is None, it's a positional argument - # => Just return the index if the types match - if self.argument_names[call_arg_idx] is None: - routine_arg = routine_argument_list[call_arg_idx] - routine_arg: DataSymbol - - self._check_argument_type_matches( - call_arg, routine_arg, check_strict_array_datatype - ) - - ret_arg_idx_list.append(call_arg_idx) - routine_argument_list[call_arg_idx] = None - continue - - # - # Next, we handle all named arguments - # - arg_name = self.argument_names[call_arg_idx] - routine_arg_idx = None - - for routine_arg_idx, routine_arg in enumerate( - routine_argument_list - ): - routine_arg: DataSymbol - - # Check if argument was already processed - if routine_arg is None: - continue - - if arg_name == routine_arg.name: - self._check_argument_type_matches( - call_arg, - routine_arg, - check_strict_array_datatype=( - check_strict_array_datatype - ), - ) - ret_arg_idx_list.append(routine_arg_idx) - break - - else: - # It doesn't match => Raise exception - raise CallMatchingArgumentsNotFoundError( - f"Named argument '{arg_name}' not found" - ) - - routine_argument_list[routine_arg_idx] = None - - # - # Finally, we check if all left-over arguments are optional arguments - # - for routine_arg in routine_argument_list: - routine_arg: DataSymbol - - if routine_arg is None: - continue - - # TODO #759: Optional keyword is not yet supported in psyir. - # Hence, we use a simple string match. - if ", OPTIONAL" in str(routine_arg.datatype): - continue - - raise CallMatchingArgumentsNotFoundError( - f"Argument '{routine_arg.name}' in subroutine" - f" '{routine.name}' not handled" - ) - - return ret_arg_idx_list + return call_routine_matcher.get_callee_candidates() def get_callee( self, @@ -1057,44 +590,17 @@ def get_callee( in any containers in scope at the call site. ''' - routine_list = self.get_callees( - ignore_missing_modules=ignore_missing_modules + from psyclone.psyir.tools import ( + CallRoutineMatcher ) - if len(routine_list) == 0: - raise NotImplementedError( - f"No routine or interface found for name '{self.routine.name}'" + call_routine_matcher: CallRoutineMatcher = CallRoutineMatcher(self) + call_routine_matcher.set_option( + check_matching_arguments=check_matching_arguments, + check_argument_strict_array_datatype=check_strict_array_datatype, + ignore_missing_modules=ignore_missing_modules, + ignore_unresolved_symbol=ignore_unresolved_symbol, ) - err_info_list = [] - - # Search for the routine matching the right arguments - for routine_node in routine_list: - routine_node: Routine - - try: - arg_match_list = self._get_argument_routine_match( - routine_node, - check_strict_array_datatype=check_strict_array_datatype, - ) - except CallMatchingArgumentsNotFoundError as err: - err_info_list.append(err.value) - continue - - return (routine_node, arg_match_list) - - # If we didn't find any routine, return some routine if no matching - # arguments have been found. - # This is handy for the transition phase until optional argument - # matching is supported. - if not check_matching_arguments: - # Also return a list of dummy argument indices - return (routine_list[0], [i for i in range(len(self.arguments))]) + return call_routine_matcher.get_callee() - error_msg = "\n".join(err_info_list) - - raise CallMatchingArgumentsNotFoundError( - "Found routines, but no routine with matching arguments found " - f"for '{self.routine.name}':\n" - + error_msg - ) diff --git a/src/psyclone/psyir/tools/__init__.py b/src/psyclone/psyir/tools/__init__.py index 8ccd12f9c2..6f8f3fd719 100644 --- a/src/psyclone/psyir/tools/__init__.py +++ b/src/psyclone/psyir/tools/__init__.py @@ -40,10 +40,14 @@ from psyclone.psyir.tools.dependency_tools import DTCode, DependencyTools from psyclone.psyir.tools.read_write_info import ReadWriteInfo from psyclone.psyir.tools.definition_use_chains import DefinitionUseChain +from psyclone.psyir.tools.call_routine_matcher import CallRoutineMatcher, CallMatchingArgumentsNotFoundError # For AutoAPI documentation generation. __all__ = ['CallTreeUtils', 'DTCode', 'DependencyTools', 'DefinitionUseChain', - 'ReadWriteInfo'] + 'ReadWriteInfo', + 'CallRoutineMatcher', + 'CallMatchingArgumentsNotFoundError', + ] diff --git a/src/psyclone/psyir/tools/call_routine_matcher.py b/src/psyclone/psyir/tools/call_routine_matcher.py new file mode 100644 index 0000000000..780e3300c4 --- /dev/null +++ b/src/psyclone/psyir/tools/call_routine_matcher.py @@ -0,0 +1,688 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2024, Science and Technology Facilities Council and +# University Grenoble Alpes +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# This file is based on gathering various components related to +# calls and routines from across psyclone. Hence, there's no clear author. +# Initial author of this file: M. Schreiber, University Grenoble Alpes +# Further authors: R. W. Ford, A. R. Porter and S. Siso, STFC Daresbury Lab +# ----------------------------------------------------------------------------- + +from typing import List, Union +from psyclone.psyir.symbols.datatypes import ArrayType +from psyclone.psyir.nodes import Call, Reference, Routine +from psyclone.errors import PSycloneError +from psyclone.configuration import Config +from psyclone.psyir.nodes.container import Container +from psyclone.psyir.nodes.reference import Reference +from psyclone.psyir.nodes.routine import Routine +from psyclone.psyir.symbols import ( + RoutineSymbol, + Symbol, + SymbolError, + UnsupportedFortranType, + DataSymbol, + SymbolTable, + ContainerSymbol, +) + +class CallMatchingArgumentsNotFoundError(PSycloneError): + """Exception to signal that matching arguments have not been found + for this routine + """ + + def __init__(self, value): + PSycloneError.__init__(self, value) + self.value = "CallMatchingArgumentsNotFound: " + str(value) + + +class CallRoutineMatcher: + """Helper routines to help matching 'Call' and 'Routines'. + This includes, e.g., + - searching for matching 'Routines', + - argument matching + """ + + def __init__(self, call_node: Call = None, routine_node: Routine = None): + + # Psyir node of Call + self._call_node: Call = call_node + + # Psyir node of Routine + self._routine_node: Call = routine_node + + # List of indices relating each argument of call to one argument + # of routine. This is required to support optional arguments. + self._arg_match_list: List[int] = None + + # Also check argument types to match. + # If set to `False` and in case it doesn't find matching arguments, + # the very first implementation of the matching routine will be + # returned (even if the argument type check failed). The argument + # types and number of arguments might therefore mismatch! + self._option_check_matching_arguments: bool = True + + # Use strict array datatype checks for matching + self._option_check_strict_array_datatype: bool = True + + # If 'True', missing modules don't raise an Exception + self._option_ignore_missing_modules: bool = False + + # If 'True', unresolved symbols don't raise an Exception + self._option_ignore_unresolved_symbol: bool = False + + def set_call_node(self, call_node: Call): + self._call_node = call_node + + def set_routine_node(self, routine_node: Routine): + self._routine_node = routine_node + + def set_option(self, + check_matching_arguments: bool = None, + check_argument_strict_array_datatype: bool = None, + ignore_missing_modules: bool = None, + ignore_unresolved_symbol: bool = None, + ): + + if check_matching_arguments is not None: + self._option_check_matching_arguments = check_matching_arguments + + if check_argument_strict_array_datatype is not None: + self._option_check_strict_array_datatype = check_argument_strict_array_datatype + + if ignore_missing_modules is not None: + self._option_ignore_missing_modules = ignore_missing_modules + + if ignore_unresolved_symbol is not None: + self._option_ignore_unresolved_symbol = ignore_unresolved_symbol + + def _check_inline_types( + self, + call_arg: DataSymbol, + routine_arg: DataSymbol, + check_array_type: bool = True, + ): + """This function performs tests to see whether the + inlining can cope with it. + + :param call_arg: The argument of a call + :type call_arg: DataSymbol + :param routine_arg: The argument of a routine + :type routine_arg: DataSymbol + :param check_array_type: Perform strong checks on array types, + defaults to `True` + :type check_array_type: bool, optional + + :raises TransformationError: Raised if transformation can't be done + + :return: 'True' if checks are successful + :rtype: bool + """ + from psyclone.psyir.transformations.transformation_error import ( + TransformationError, + ) + from psyclone.errors import LazyString + from psyclone.psyir.nodes import Literal, Range + from psyclone.psyir.symbols import ( + UnresolvedType, + UnsupportedType, + INTEGER_TYPE, + ) + + _ONE = Literal("1", INTEGER_TYPE) + + # If the formal argument is an array with non-default bounds then + # we also need to know the bounds of that array at the call site. + if not isinstance(routine_arg.datatype, ArrayType): + # Formal argument is not an array so we don't need to do any + # further checks. + return True + + if not isinstance(call_arg, (Reference, Literal)): + # TODO #1799 this really needs the `datatype` method to be + # extended to support all nodes. For now we have to abort + # if we encounter an argument that is not a scalar (according + # to the corresponding formal argument) but is not a + # Reference or a Literal as we don't know whether the result + # of any general expression is or is not an array. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + f"The call '{self._call_node.debug_string()}' " + "cannot be inlined because actual argument " + f"'{call_arg.debug_string()}' corresponds to a " + "formal argument with array type but is not a " + "Reference or a Literal." + ) + ) + ) + + # We have an array argument. We are only able to check that the + # argument is not re-shaped in the called routine if we have full + # type information on the actual argument. + # TODO #924. It would be useful if the `datatype` property was + # a method that took an optional 'resolve' argument to indicate + # that it should attempt to resolve any UnresolvedTypes. + if check_array_type: + if isinstance( + call_arg.datatype, (UnresolvedType, UnsupportedType) + ) or ( + isinstance(call_arg.datatype, ArrayType) + and isinstance( + call_arg.datatype.intrinsic, + (UnresolvedType, UnsupportedType), + ) + ): + raise TransformationError( + f"Routine '{self._routine_node.name}' cannot be " + "inlined because the type of the actual argument " + f"'{call_arg.symbol.name}' corresponding to an array" + f" formal argument ('{routine_arg.name}') is unknown." + ) + + formal_rank = 0 + actual_rank = 0 + if isinstance(routine_arg.datatype, ArrayType): + formal_rank = len(routine_arg.datatype.shape) + if isinstance(call_arg.datatype, ArrayType): + actual_rank = len(call_arg.datatype.shape) + if formal_rank != actual_rank: + # It's OK to use the loop variable in the lambda definition + # because if we get to this point then we're going to quit + # the loop. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self._routine_node.name}' because it" + " reshapes an argument: actual argument" + f" '{call_arg.debug_string()}' has rank" + f" {actual_rank} but the corresponding formal" + f" argument, '{routine_arg.name}', has rank" + f" {formal_rank}" + ) + ) + ) + if actual_rank: + ranges = call_arg.walk(Range) + for rge in ranges: + ancestor_ref = rge.ancestor(Reference) + if ancestor_ref is not call_arg: + # Have a range in an indirect access. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self._routine_node.name}' because" + " argument" + f" '{call_arg.debug_string()}' has" + " an array range in an indirect" + " access #(TODO 924)." + ) + ) + ) + if rge.step != _ONE: + # TODO #1646. We could resolve this problem by + # making a new array and copying the necessary + # values into it. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self._routine_node.name}' because" + " one of its arguments is an array" + " slice with a non-unit stride:" + f" '{call_arg.debug_string()}' (TODO" + " #1646)" + ) + ) + ) + + def _check_argument_type_matches( + self, + call_arg: DataSymbol, + routine_arg: DataSymbol, + check_strict_array_datatype: bool = True, + ) -> bool: + """Return information whether argument types are matching. + This also supports 'optional' arguments by using + partial types. + + :param call_arg: Argument from the call + :type call_arg: DataSymbol + :param routine_arg: Argument from the routine + :type routine_arg: DataSymbol + :param check_strict_array_datatype: Check strictly for matching + array types. If `False`, only checks for ArrayType itself are done. + :type check_strict_array_datatype: bool + :returns: True if arguments match, False otherwise + :rtype: bool + :raises CallMatchingArgumentsNotFound: Raised if no matching arguments + were found. + """ + + # self._check_inline_types(call_arg, routine_arg) + + type_matches = False + if not check_strict_array_datatype: + # No strict array checks have to be performed, just accept it + if isinstance(call_arg.datatype, ArrayType) and isinstance( + routine_arg.datatype, ArrayType + ): + type_matches = True + + if not type_matches: + if isinstance(routine_arg.datatype, UnsupportedFortranType): + # This could be an 'optional' argument. + # This has at least a partial data type + if call_arg.datatype != routine_arg.datatype.partial_datatype: + raise CallMatchingArgumentsNotFoundError( + "Argument partial type mismatch of call " + f"argument '{call_arg}' and routine argument " + f"'{routine_arg}'" + ) + else: + if call_arg.datatype != routine_arg.datatype: + raise CallMatchingArgumentsNotFoundError( + "Argument type mismatch of call argument " + f"'{call_arg}' with type '{call_arg.datatype} " + "and routine argument " + f"'{routine_arg}' with type '{routine_arg.datatype}." + ) + + return True + + def get_argument_routine_match_list( + self, + ) -> Union[None, List[int]]: + '''Return a list of integers giving for each argument of the call + the index of the argument in argument_list (typically of a routine) + + :return: None if no match was found, otherwise list of integers + referring to matching arguments. + :rtype: None|List[int] + :raises CallMatchingArgumentsNotFound: If there was some problem in + finding matching arguments. + ''' + + # Create a copy of the list of actual arguments to the routine. + # Once an argument has been successfully matched, set it to 'None' + routine_argument_list: List[DataSymbol] = ( + self._routine_node.symbol_table.argument_list[:] + ) + + if len(self._call_node.arguments) > len( + self._routine_node.symbol_table.argument_list): + call_str = self._call_node.debug_string().replace("\n", "") + raise CallMatchingArgumentsNotFoundError( + f"More arguments in call ('{call_str}')" + f" than callee (routine '{self._routine_node.name}')" + ) + + # Iterate over all arguments to the call + ret_arg_idx_list = [] + for call_arg_idx, call_arg in enumerate(self._call_node.arguments): + call_arg_idx: int + call_arg: DataSymbol + + # If the associated name is None, it's a positional argument + # => Just return the index if the types match + if self._call_node.argument_names[call_arg_idx] is None: + routine_arg = routine_argument_list[call_arg_idx] + routine_arg: DataSymbol + + self._check_argument_type_matches( + call_arg, routine_arg, + self._option_check_strict_array_datatype + ) + + ret_arg_idx_list.append(call_arg_idx) + routine_argument_list[call_arg_idx] = None + continue + + # + # Next, we handle all named arguments + # + arg_name = self._call_node.argument_names[call_arg_idx] + routine_arg_idx = None + + for routine_arg_idx, routine_arg in enumerate( + routine_argument_list + ): + routine_arg: DataSymbol + + # Check if argument was already processed + if routine_arg is None: + continue + + if arg_name == routine_arg.name: + self._check_argument_type_matches( + call_arg, + routine_arg, + check_strict_array_datatype=( + self._option_check_strict_array_datatype + ), + ) + ret_arg_idx_list.append(routine_arg_idx) + break + + else: + # It doesn't match => Raise exception + raise CallMatchingArgumentsNotFoundError( + f"Named argument '{arg_name}' not found" + ) + + routine_argument_list[routine_arg_idx] = None + + # + # Finally, we check if all left-over arguments are optional arguments + # + for routine_arg in routine_argument_list: + routine_arg: DataSymbol + + if routine_arg is None: + continue + + # TODO #759: Optional keyword is not yet supported in psyir. + # Hence, we use a simple string match. + if ", OPTIONAL" in str(routine_arg.datatype): + continue + + raise CallMatchingArgumentsNotFoundError( + f"Argument '{routine_arg.name}' in subroutine" + f" '{self._routine_node.name}' not handled" + ) + + return ret_arg_idx_list + + def get_callee_candidates(self, ignore_missing_modules: bool = False): + ''' + Searches for the implementation(s) of all potential target routines + for this Call without any arguments check. + + :param ignore_missing_modules: If a module wasn't found, return 'None' + instead of throwing an exception 'ModuleNotFound'. + :type ignore_missing_modules: bool + + :returns: the Routine(s) that this call targets. + :rtype: list[:py:class:`psyclone.psyir.nodes.Routine`] + + :raises NotImplementedError: if the routine is not local and not found + in any containers in scope at the call site. + + ''' + + def _location_txt(node): + ''' + Utility to generate meaningful location text. + + :param node: a PSyIR node. + :type node: :py:class:`psyclone.psyir.nodes.Node` + + :returns: description of location of node. + :rtype: str + ''' + if isinstance(node, Container): + return f"Container '{node.name}'" + out_lines = node.debug_string().split("\n") + idx = -1 + while not out_lines[idx]: + idx -= 1 + last_line = out_lines[idx] + return f"code:\n'{out_lines[0]}\n...\n{last_line}'" + + rsym = self._call_node.routine.symbol + if rsym.is_unresolved: + + # Check for any "raw" Routines, i.e. ones that are not + # in a Container. Such Routines would exist in the PSyIR + # as a child of a FileContainer (if the PSyIR contains a + # FileContainer). Note, if the PSyIR does contain a + # FileContainer, it will be the root node of the PSyIR. + for routine in self._call_node.root.children: + if ( + isinstance(routine, Routine) + and routine.name.lower() == rsym.name.lower() + ): + return [routine] + + # Now check for any wildcard imports and see if they can + # be used to resolve the symbol. + wildcard_names = [] + containers_not_found = [] + current_table: SymbolTable = self._call_node.scope.symbol_table + while current_table: + # TODO: Obtaining all container symbols in this way + # breaks some tests. + # It would be better using the ModuleManager to resolve + # (and cache) all containers to look up for this. + # + # current_containersymbols = self._call_node._get_container_symbols_rec( + # current_table.containersymbols, + # ignore_missing_modules=ignore_missing_modules, + # ) + # for container_symbol in current_containersymbols: + for container_symbol in current_table.containersymbols: + container_symbol: ContainerSymbol + if container_symbol.wildcard_import: + wildcard_names.append(container_symbol.name) + + try: + container: Container = ( + container_symbol.find_container_psyir( + local_node=self._call_node, + ignore_missing_modules=( + ignore_missing_modules + ), + ) + ) + except SymbolError: + container = None + if not container: + # Failed to find/process this Container. + containers_not_found.append(container_symbol.name) + continue + routines = [] + for name in container.resolve_routine(rsym.name): + # Allow private imports if an 'interface' + # was used. Here, we assume the name of the routine + # is different to the call. + allow_private = name != rsym.name + psyir = container.find_routine_psyir( + name, allow_private=allow_private + ) + + if psyir: + routines.append(psyir) + + if routines: + return routines + current_table = current_table.parent_symbol_table() + + if not wildcard_names: + wc_text = "there are no wildcard imports" + else: + if containers_not_found: + wc_text = ( + "attempted to resolve the wildcard imports from" + f" {wildcard_names}. However, failed to find the " + f"source for {containers_not_found}. The module search" + f" path is set to {Config.get().include_paths}" + ) + else: + wc_text = f"wildcard imports from {wildcard_names}" + raise NotImplementedError( + "Failed to find the source code of the unresolved routine " + f"'{rsym.name}' - looked at any routines in the same source " + f"file and {wc_text}. Searching for external routines " + "that are only resolved at link time is not supported." + ) + + root_node = self._call_node.ancestor(Container) + if not root_node: + root_node = self._call_node.root + container = root_node + can_be_private = True + + if rsym.is_import: + cursor = rsym + # A Routine imported from another Container must be public in that + # Container. + can_be_private = False + while cursor.is_import: + csym = cursor.interface.container_symbol + try: + container = csym.find_container_psyir( + local_node=self._call_node) + except SymbolError: + raise NotImplementedError( + f"RoutineSymbol '{rsym.name}' is imported from " + f"Container '{csym.name}' but the source defining " + "that container could not be found. The module search" + f" path is set to {Config.get().include_paths}" + ) + imported_sym = container.symbol_table.lookup(cursor.name) + if imported_sym.visibility != Symbol.Visibility.PUBLIC: + # The required Symbol must be shadowed with a PRIVATE + # Symbol in this Container. This means that the one we + # actually want is brought into scope via a wildcard + # import. + # TODO #924 - Use ModuleManager to search? + raise NotImplementedError( + f"RoutineSymbol '{rsym.name}' is imported from " + f"Container '{csym.name}' but that Container defines " + "a private Symbol of the same name. Searching for the" + " Container that defines a public Routine with that " + "name is not yet supported - TODO #924" + ) + if not isinstance(imported_sym, RoutineSymbol): + # We now know that this is a RoutineSymbol so specialise it + # in place. + imported_sym.specialise(RoutineSymbol) + cursor = imported_sym + rsym = cursor + root_node = container + + if isinstance(rsym.datatype, UnsupportedFortranType): + # TODO #924 - an UnsupportedFortranType here typically indicates + # that the target is actually an interface. + raise NotImplementedError( + f"RoutineSymbol '{rsym.name}' exists in " + f"{_location_txt(root_node)} but is of " + f"UnsupportedFortranType:\n{rsym.datatype.declaration}\n" + "Cannot get the PSyIR of such a routine." + ) + + if isinstance(container, Container): + routines = [] + for name in container.resolve_routine(rsym.name): + psyir = container.find_routine_psyir( + name, allow_private=can_be_private + ) + if psyir: + routines.append(psyir) + if routines: + return routines + + raise SymbolError( + f"Failed to find a Routine named '{rsym.name}' in " + f"{_location_txt(root_node)}. This is normally because the routine" + " is within a CodeBlock." + ) + + def get_callee(self): + ''' + Searches for the implementation(s) of the target routine for this Call + including argument checks. + + :param check_matching_arguments: Also check argument types to match. + If set to `False` and in case it doesn't find matching arguments, + the very first implementation of the matching routine will be + returned (even if the argument type check failed). The argument + types and number of arguments might therefore mismatch! + :type ret_arg_match_list: bool + + :returns: A tuple of two elements. The first element is the routine + that this call targets. The second one a list of arguments + providing the information on matching argument indices. + :rtype: Set[psyclone.psyir.nodes.Routine, List[int]] + + :raises NotImplementedError: if the routine is not local and not found + in any containers in scope at the call site. + ''' + + routine_list = self.get_callee_candidates() + + call_name = self._call_node.routine.name + + if len(routine_list) == 0: + raise NotImplementedError( + f"No routine or interface found for name '{call_name}'" + ) + + err_info_list = [] + + # Search for the routine matching the right arguments + for routine_node in routine_list: + routine_node: Routine + self._routine_node = routine_node + + try: + self._arg_match_list = self.get_argument_routine_match_list() + except CallMatchingArgumentsNotFoundError as err: + err_info_list.append(err.value) + continue + + return (self._routine_node, self._arg_match_list) + + # If we didn't find any routine, return some routine if no matching + # arguments have been found. + # This is handy for the transition phase until optional argument + # matching is supported. + if not self._option_check_matching_arguments: + # Also return a list of dummy argument indices + self._routine_node = routine_list[0] + self._arg_match_list = list(range(len(self._call_node.arguments))) + return (self._routine_node, self._arg_match_list) + + error_msg = "\n".join(err_info_list) + + raise CallMatchingArgumentsNotFoundError( + "Found routines, but no routine with matching arguments found " + f"for '{self._call_node.routine.name}':\n" + + error_msg + ) diff --git a/src/psyclone/psyir/tools/call_tree_utils.py b/src/psyclone/psyir/tools/call_tree_utils.py index a5aca9ebda..02f1d0eb95 100644 --- a/src/psyclone/psyir/tools/call_tree_utils.py +++ b/src/psyclone/psyir/tools/call_tree_utils.py @@ -38,7 +38,7 @@ across different subroutines and modules.''' from psyclone.core import Signature, VariablesAccessInfo -from psyclone.parse import ModuleManager +from psyclone.parse.module_manager import ModuleManager from psyclone.psyGen import BuiltIn, Kern from psyclone.psyir.nodes import Container, Reference from psyclone.psyir.symbols import ( diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 183675aaad..7de074efdb 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -62,7 +62,6 @@ from psyclone.psyir.transformations.transformation_error import ( TransformationError) -from psyclone.psyir.nodes import CallMatchingArgumentsNotFoundError from typing import Dict, List from psyclone.psyir.symbols import BOOLEAN_TYPE @@ -142,8 +141,15 @@ def __init__(self): # List of call-to-subroutine argument indices self._ret_arg_match_list: List[int] = None - # Routine to be inlines - self.node_routine: Routine = None + # Call to routine + self._call_node: Call = None + + # Routine to be inlined for call + self._routine_node: Routine = None + + from psyclone.psyir.tools import CallRoutineMatcher + + self._call_routine_matcher: CallRoutineMatcher = CallRoutineMatcher() # If 'True', make strict checks for matching arguments of # array data types. @@ -151,52 +157,42 @@ def __init__(self): # Then, no further checks are performed self._option_check_argument_strict_array_datatype: bool = True - # If searching for modules, don't trigger Exceptions if module - # wasn't found. - self._option_ignore_missing_modules: bool = False - # If 'True', don't inline if a code block is used within the # Routine. self._option_check_codeblocks: bool = True - # If 'True', the callee must have matching arguments. - # The 'matching' criteria can be weakened by other options. - # If 'False', in case no match was found, the first callee is taken. - self._option_check_matching_arguments_of_callee: bool = True - - # check_diff_container_clashes: bool = True, - # check_diff_container_clashes_unresolved_types: bool = True, - # check_resolve_imports: bool = True, - # check_static_interface: bool = True, - # check_array_type: bool = True, - # check_argument_of_unsupported_type: bool = True, - # check_argument_unresolved_symbols: bool = True, + self._option_check_diff_container_clashes: bool = True + self._option_check_diff_container_clashes_unresolved_types: bool = True + self._option_check_resolve_imports: bool = True + self._option_check_static_interface: bool = True + self._option_check_array_type: bool = True + self._option_check_argument_of_unsupported_type: bool = True + self._option_check_argument_unresolved_symbols: bool = True def set_option( self, ignore_missing_modules: bool = None, check_argument_strict_array_datatype: bool = None, - check_codeblocks: bool = None, - check_matching_arguments_of_callee: bool = None, + check_argument_matching: bool = None, + check_inline_codeblocks: bool = None, ): if check_argument_strict_array_datatype is not None: self._option_check_argument_strict_array_datatype = ( check_argument_strict_array_datatype ) + self._call_routine_matcher.set_option( + ignore_missing_modules=ignore_missing_modules) + self._call_routine_matcher.set_option( + check_argument_strict_array_datatype=( + check_argument_strict_array_datatype)) + self._call_routine_matcher.set_option( + check_matching_arguments=check_argument_matching) - if ignore_missing_modules is not None: - self._option_ignore_missing_modules = ignore_missing_modules - - if check_codeblocks is not None: - self._option_check_codeblocks = check_codeblocks - - if check_matching_arguments_of_callee is not None: - self._option_check_matching_arguments_of_callee = ( - check_matching_arguments_of_callee - ) + if check_inline_codeblocks is not None: + self._option_check_codeblocks = check_inline_codeblocks def apply( - self, node_call: Call, node_routine: Routine = None, options=None + self, call_node: Call, routine_node: Routine = None, options=None ): """ Takes the body of the routine that is the target of the supplied @@ -218,28 +214,29 @@ def apply( # self.node_routine and self._ret_arg_match_list # with the routine to be inlined and the relation between the # arguments and to which routine arguments they are matched to. - self.validate(node_call, node_routine=node_routine, options=options) + self.validate(call_node, routine_node=routine_node, options=options) + # The table associated with the scoping region holding the Call. - table = node_call.scope.symbol_table + table = call_node.scope.symbol_table - if not self.node_routine.children or isinstance( - self.node_routine.children[0], Return + if not self._routine_node.children or isinstance( + self._routine_node.children[0], Return ): # Called routine is empty so just remove the call. - node_call.detach() + call_node.detach() return # Ensure we don't modify the original Routine by working with a # copy of it. - self.node_routine = self.node_routine.copy() - routine_table = self.node_routine.symbol_table - self._remove_unused_optional_arguments() + self._routine_node = self._routine_node.copy() + routine_table = self._routine_node.symbol_table + self._optional_arg_remove_unused_optional_arguments() # Construct lists of the nodes that will be inserted and all of the # References that they contain. new_stmts = [] refs = [] - for child in self.node_routine.children: + for child in self._routine_node.children: child: Node new_stmts.append(child.copy()) refs.extend(new_stmts[-1].walk(Reference)) @@ -259,7 +256,7 @@ def apply( # as a Reference. ref2arraytrans = Reference2ArrayRangeTrans() - for child in node_call.arguments: + for child in call_node.arguments: try: # TODO #1858, this won't yet work for arrays inside structures. ref2arraytrans.apply(child) @@ -270,12 +267,12 @@ def apply( # actual arguments. formal_args = routine_table.argument_list for ref in refs[:]: - self._replace_formal_arg(ref, node_call, formal_args) + self._replace_formal_arg(ref, call_node, formal_args) # Store the Routine level symbol table and node's current scope # so we can merge symbol tables later if required. - ancestor_table = node_call.ancestor(Routine).scope.symbol_table - scope = node_call.scope + ancestor_table = call_node.ancestor(Routine).scope.symbol_table + scope = call_node.scope # Copy the nodes from the Routine into the call site. # TODO #924 - while doing this we should ensure that any References @@ -286,9 +283,9 @@ def apply( # remove it from the list. del new_stmts[-1] - if self.node_routine.return_symbol: + if self._routine_node.return_symbol: # This is a function - assignment = node_call.ancestor(Statement) + assignment = call_node.ancestor(Statement) parent = assignment.parent idx = assignment.position-1 for child in new_stmts: @@ -297,17 +294,17 @@ def apply( table = parent.scope.symbol_table # Avoid a potential name clash with the original function table.rename_symbol( - self.node_routine.return_symbol, + self._routine_node.return_symbol, table.next_available_name( - f"inlined_{self.node_routine.return_symbol.name}" + f"inlined_{self._routine_node.return_symbol.name}" ), ) - node_call.replace_with(Reference(self.node_routine.return_symbol)) + call_node.replace_with(Reference(self._routine_node.return_symbol)) else: # This is a call - parent = node_call.parent - idx = node_call.position - node_call.replace_with(new_stmts[0]) + parent = call_node.parent + idx = call_node.position + call_node.replace_with(new_stmts[0]) for child in new_stmts[1:]: idx += 1 parent.addchild(child, idx) @@ -322,7 +319,56 @@ def apply( scope.symbol_table.detach() replacement.attach(scope) - def _remove_ifblock_if_const_args(self, node: Node): + def _optional_arg_resolve_present_intrinsics(self): + """Replace PRESENT() intrinsics with `True` or `False` + + :rtype: None + """ + # We first build a lookup table of all optional arguments + # to see whether it's present or not. + optional_sym_present_dict: Dict[str, bool] = dict() + for optional_arg_idx, datasymbol in enumerate( + self._routine_node.symbol_table.datasymbols + ): + if not isinstance(datasymbol.datatype, UnsupportedFortranType): + continue + + if ", OPTIONAL" not in str(datasymbol.datatype): + continue + + sym_name = datasymbol.name.lower() + + if optional_arg_idx not in self._ret_arg_match_list: + optional_sym_present_dict[sym_name] = False + else: + optional_sym_present_dict[sym_name] = True + + # Check if we have any optional arguments at all and if not, return + if len(optional_sym_present_dict) == 0: + return + + # Find all "PRESENT()" calls + for intrinsic_call in self._routine_node.walk(IntrinsicCall): + intrinsic_call: IntrinsicCall + if intrinsic_call.routine.name.lower() == "present": + + # The argument is in the 2nd child + present_arg: Reference = intrinsic_call.children[1] + present_arg_name = present_arg.name.lower() + + assert present_arg_name in optional_sym_present_dict + + if optional_sym_present_dict[present_arg_name]: + # The argument is present. + intrinsic_call.replace_with(Literal("true", BOOLEAN_TYPE)) + else: + intrinsic_call.replace_with(Literal("false", BOOLEAN_TYPE)) + + def _optional_arg_specialize_ifblock_if_const_condition(self): + """Specialize if-block if conditions are constant booleans + + :rtype: None + """ def if_else_replace(main_schedule, if_block, if_body_schedule): """Little helper routine to eliminate one branch of an IfBlock @@ -353,7 +399,7 @@ def if_else_replace(main_schedule, if_block, if_body_schedule): from psyclone.psyir.nodes import IfBlock - for if_block in node.walk(IfBlock): + for if_block in self._routine_node.walk(IfBlock): if_block: IfBlock condition = if_block.condition @@ -383,49 +429,30 @@ def if_else_replace(main_schedule, if_block, if_body_schedule): if_else_replace(if_block.parent, if_block, if_block.else_body) - def _remove_unused_optional_arguments(self): - # We first build a lookup table of all optional arguments - # to see whether it's present or not. - optional_sym_present_dict: Dict[str, bool] = dict() - for optional_arg_idx, datasymbol in enumerate( - self.node_routine.symbol_table.datasymbols - ): - if not isinstance(datasymbol.datatype, UnsupportedFortranType): - continue + def _optional_arg_remove_unused_optional_arguments(self): + """Remove all optional arguments which are not used. - if ", OPTIONAL" not in str(datasymbol.datatype): - continue + Steps: - sym_name = datasymbol.name.lower() + - Build lookup dictionary for all optional arguments: - if optional_arg_idx not in self._ret_arg_match_list: - optional_sym_present_dict[sym_name] = False - else: - optional_sym_present_dict[sym_name] = True - - # Check if we have any optional arguments at all and if not, return - if len(optional_sym_present_dict) == 0: - return + - For all `PRESENT(...)`: + - Lookup variable in dictionary + - Replace with `True` or `False`, depending on whether + it's provided or not. - # Find all "PRESENT()" calls - for intrinsic_call in self.node_routine.walk(IntrinsicCall): - intrinsic_call: IntrinsicCall - if intrinsic_call.routine.name.lower() == "present": + - For all If-Statements, handle constant conditions: + - `True`: Replace If-Block with If-Body + - `False`: Replace If-Block with Else-Body. If it doesn't exist + just delete the if statement. - # The argument is in the 2nd child - present_arg: Reference = intrinsic_call.children[1] - present_arg_name = present_arg.name.lower() + :rtype: None + """ - assert present_arg_name in optional_sym_present_dict - - if optional_sym_present_dict[present_arg_name]: - # The argument is present. - intrinsic_call.replace_with(Literal("true", BOOLEAN_TYPE)) - else: - intrinsic_call.replace_with(Literal("false", BOOLEAN_TYPE)) + self._optional_arg_resolve_present_intrinsics() # Evaluate all if-blocks with constant booleans - self._remove_ifblock_if_const_args(self.node_routine) + self._optional_arg_specialize_ifblock_if_const_condition() def _replace_formal_arg(self, ref, call_node, formal_args): ''' @@ -783,128 +810,37 @@ def _replace_formal_struc_arg(self, actual_arg, ref, call_node, # Just an array reference. return ArrayReference.create(actual_arg.symbol, members[0][1]) - def validate( - self, - node_call: Call, - node_routine: Routine = None, - options: Dict[str, str] = None, - ): - """ - Checks that the supplied node is a valid target for inlining. - - :param call_node: target PSyIR node. - :type call_node: subclass of :py:class:`psyclone.psyir.nodes.Call` - :param routine_node: Routine to inline. - Default is to search for it. - :type routine_node: subclass of :py:class:`Routine` - :param options: a dictionary with options for transformations. - :type options: Optional[Dict[str, Any]] - :param bool options["force"]: whether or not to ignore any CodeBlocks - in the candidate routine. Default is False. - - :raises TransformationError: if the supplied node is not a Call or is - an IntrinsicCall or call to a PSyclone-generated routine. - :raises TransformationError: if the routine has a return value. - :raises TransformationError: if the routine body contains a Return - that is not the first or last statement. - :raises TransformationError: if the routine body contains a CodeBlock - and the 'force' option is not True. - :raises TransformationError: if the called routine has a named - argument. - :raises TransformationError: if any of the variables declared within - the called routine are of UnknownInterface. - :raises TransformationError: if any of the variables declared within - the called routine have a StaticInterface. - :raises TransformationError: if any of the subroutine arguments is of - UnsupportedType. - :raises TransformationError: if a symbol of a given name is imported - from different containers at the call site and within the routine. - :raises TransformationError: if the routine accesses an un-resolved - symbol. - :raises TransformationError: if the number of arguments in the call - does not match the number of formal arguments of the routine. - :raises TransformationError: if a symbol declared in the parent - container is accessed in the target routine. - :raises TransformationError: if the shape of an array formal argument - does not match that of the corresponding actual argument. - + def _validate_inline_arguments_of_call_and_routine( + self, + call_node: Call, + routine_node: Routine, + arg_index_list: List[int] + ): + """Performs various checks that the inlining is supported for the + combination of the call's and routine's arguments. + + :param call_node: Call to be replaced by the inlined Routine + :type call_node: Call + :param routine_node: Routine to be inlined + :type routine_node: Routine + :param arg_index_list: Argument index list to match the arguments of + the call to those of the routine in case of optional arguments. + :type arg_index_list: List[int] + :raises TransformationError: Arguments are not in a form to be inlined """ - super().validate(node_call, options=options) - - self.node_routine: Routine = node_routine - - # The node should be a Call. - if not isinstance(node_call, Call): - raise TransformationError( - "The target of the InlineTrans transformation " - f"should be a Call but found '{type(node_call).__name__}'." - ) - - if isinstance(node_call, IntrinsicCall): - raise TransformationError( - f"Cannot inline an IntrinsicCall ('{node_call.routine.name}')" - ) - name = node_call.routine.name - - # List of indices relating the call's arguments to the subroutine - # arguments. This can be different due to - # - optional arguments - # - named arguments - - if self.node_routine is None: - # Check that we can find the source of the routine being inlined. - # TODO #924 allow for multiple routines (interfaces). - try: - (self.node_routine, self._ret_arg_match_list) = ( - node_call.get_callee( - check_matching_arguments=( - self._option_check_matching_arguments_of_callee - ), - check_strict_array_datatype=( - self._option_check_argument_strict_array_datatype - ), - ignore_missing_modules=( - self._option_ignore_missing_modules - ), - ) - ) - except ( - CallMatchingArgumentsNotFoundError, - NotImplementedError, - FileNotFoundError, - SymbolError, - TransformationError, - ) as err: - raise TransformationError( - f"Cannot inline routine '{name}' because its source cannot" - f" be found:\n{str(err)}" - ) from err + + name = call_node.routine.name - else: - # A routine has been provided. - # We'll now determine the matching argument list - try: - self._ret_arg_match_list = ( - node_call._get_argument_routine_match( - self.node_routine, - check_strict_array_datatype=False, - ) - ) - except CallMatchingArgumentsNotFoundError as err: - raise TransformationError( - "Routine's arguments doesn't match subroutine" - ) from err - - if not self.node_routine.children or isinstance( - self.node_routine.children[0], Return + if not routine_node.children or isinstance( + routine_node.children[0], Return ): # An empty routine is fine. return - return_stmts = self.node_routine.walk(Return) + return_stmts = routine_node.walk(Return) if return_stmts: if len(return_stmts) > 1 or not isinstance( - self.node_routine.children[-1], Return + routine_node.children[-1], Return ): # Either there is more than one Return statement or there is # just one but it isn't the last statement of the Routine. @@ -913,7 +849,7 @@ def validate( f"Return statements and therefore cannot be inlined.") if self._option_check_codeblocks: - if self.node_routine.walk(CodeBlock): + if routine_node.walk(CodeBlock): # N.B. we permit the user to specify the "force" option to # allow CodeBlocks to be included. raise TransformationError( @@ -923,27 +859,20 @@ def validate( "`check_codeblocks=False` to override.)" ) - table = node_call.scope.symbol_table - routine_table = self.node_routine.symbol_table + table = call_node.scope.symbol_table + routine_table = routine_node.symbol_table - # TODO: Maybe move me to options - check_argument_unsupported_type = True - check_static_interface = True - check_diff_container_clashes = True - check_diff_container_clashes_unresolved_types = True - check_resolve_imports = True - check_array_type = True for sym in routine_table.datasymbols: # We don't inline symbols that have an UnsupportedType and are # arguments since we don't know if a simple assignment if # enough (e.g. pointers) - if check_argument_unsupported_type: + if self._option_check_argument_unsupported_type: if isinstance(sym.interface, ArgumentInterface): if isinstance(sym.datatype, UnsupportedType): if ", OPTIONAL" not in sym.datatype.declaration: raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be" + f"Routine '{routine_node.name}' cannot be" " inlined because it contains a Symbol" f" '{sym.name}' which is an Argument of" " UnsupportedType:" @@ -953,13 +882,13 @@ def validate( # don't know how they are brought into this scope. if isinstance(sym.interface, UnknownInterface): raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be " + f"Routine '{routine_node.name}' cannot be " "inlined because it contains a Symbol " f"'{sym.name}' with an UnknownInterface: " f"'{sym.datatype.declaration}'" ) - if check_static_interface: + if self._option_check_static_interface: # Check that there are no static variables in the routine # (because we don't know whether the routine is called from # other places). @@ -968,12 +897,12 @@ def validate( and not sym.is_constant ): raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be " + f"Routine '{routine_node.name}' cannot be " "inlined because it has a static (Fortran SAVE) " f"interface for Symbol '{sym.name}'." ) - if check_diff_container_clashes: + if self._option_check_diff_container_clashes: # We can't handle a clash between (apparently) different symbols # that share a name but are imported from different containers. try: @@ -981,13 +910,13 @@ def validate( routine_table, symbols_to_skip=routine_table.argument_list[:], check_unresolved_symbols=( - check_diff_container_clashes_unresolved_types + self._option_check_diff_container_clashes_unresolved_types ), ) except SymbolError as err: raise TransformationError( "One or more symbols from routine " - f"'{self.node_routine.name}' cannot be added to the " + f"'{routine_node.name}' cannot be added to the " "table at the call site." ) from err @@ -999,7 +928,7 @@ def validate( # that are used to define the precision of other Symbols in the same # table. If a precision symbol is only used within Statements then we # don't currently capture the fact that it is a precision symbol. - ref_or_lits = self.node_routine.walk((Reference, Literal)) + ref_or_lits = routine_node.walk((Reference, Literal)) # Check for symbols in any initial-value expressions # (including Fortran parameters) or array dimensions. for sym in routine_table.datasymbols: @@ -1035,7 +964,7 @@ def validate( if isinstance(sym, IntrinsicSymbol): continue - if check_resolve_imports: + if self._option_check_resolve_imports: # We haven't seen this Symbol before. if sym.is_unresolved: try: @@ -1045,7 +974,7 @@ def validate( # table local to the routine. # pylint: disable=raise-missing-from raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be " + f"Routine '{routine_node.name}' cannot be " "inlined because it accesses variable " f"'{sym.name}' and this cannot be found in any " "of the containers directly imported into its " @@ -1054,18 +983,18 @@ def validate( else: if sym.name not in routine_table: raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be " + f"Routine '{routine_node.name}' cannot be " "inlined because it accesses variable " f"'{sym.name}' from its parent container." ) # Create a list of routine arguments that is actually used routine_arg_list = [ - routine_table.argument_list[i] for i in self._ret_arg_match_list + routine_table.argument_list[i] for i in arg_index_list ] for routine_arg, call_arg in zip( - routine_arg_list, node_call.arguments + routine_arg_list, call_node.arguments ): # If the formal argument is an array with non-default bounds then # we also need to know the bounds of that array at the call site. @@ -1085,7 +1014,7 @@ def validate( raise TransformationError( LazyString( lambda: ( - f"The call '{node_call.debug_string()}' " + f"The call '{call_node.debug_string()}' " "cannot be inlined because actual argument " f"'{call_arg.debug_string()}' corresponds to a " "formal argument with array type but is not a " @@ -1100,7 +1029,7 @@ def validate( # TODO #924. It would be useful if the `datatype` property was # a method that took an optional 'resolve' argument to indicate # that it should attempt to resolve any UnresolvedTypes. - if check_array_type: + if self._option_check_array_type: if isinstance( call_arg.datatype, (UnresolvedType, UnsupportedType) ) or ( @@ -1111,7 +1040,7 @@ def validate( ) ): raise TransformationError( - f"Routine '{self.node_routine.name}' cannot be " + f"Routine '{routine_node.name}' cannot be " "inlined because the type of the actual argument " f"'{call_arg.symbol.name}' corresponding to an array" f" formal argument ('{routine_arg.name}') is unknown." @@ -1132,7 +1061,7 @@ def validate( LazyString( lambda: ( "Cannot inline routine" - f" '{self.node_routine.name}' because it" + f" '{routine_node.name}' because it" " reshapes an argument: actual argument" f" '{call_arg.debug_string()}' has rank" f" {actual_rank} but the corresponding formal" @@ -1152,7 +1081,7 @@ def validate( LazyString( lambda: ( "Cannot inline routine" - f" '{self.node_routine.name}' because" + f" '{routine_node.name}' because" " argument" f" '{call_arg.debug_string()}' has" " an array range in an indirect" @@ -1169,7 +1098,7 @@ def validate( LazyString( lambda: ( "Cannot inline routine" - f" '{self.node_routine.name}' because" + f" '{routine_node.name}' because" " one of its arguments is an array" " slice with a non-unit stride:" f" '{call_arg.debug_string()}' (TODO" @@ -1178,6 +1107,117 @@ def validate( ) ) + def validate( + self, + call_node: Call, + routine_node: Routine = None, + options: Dict[str, str] = None, + ): + """ + Checks that the supplied node is a valid target for inlining. + + :param call_node: target PSyIR node. + :type call_node: subclass of :py:class:`psyclone.psyir.nodes.Call` + :param routine_node: Routine to inline. + Default is to search for it. + :type routine_node: subclass of :py:class:`Routine` + :param options: a dictionary with options for transformations. + :type options: Optional[Dict[str, Any]] + :param bool options["force"]: whether or not to ignore any CodeBlocks + in the candidate routine. Default is False. + + :raises TransformationError: if the supplied node is not a Call or is + an IntrinsicCall or call to a PSyclone-generated routine. + :raises TransformationError: if the routine has a return value. + :raises TransformationError: if the routine body contains a Return + that is not the first or last statement. + :raises TransformationError: if the routine body contains a CodeBlock + and the 'force' option is not True. + :raises TransformationError: if the called routine has a named + argument. + :raises TransformationError: if any of the variables declared within + the called routine are of UnknownInterface. + :raises TransformationError: if any of the variables declared within + the called routine have a StaticInterface. + :raises TransformationError: if any of the subroutine arguments is of + UnsupportedType. + :raises TransformationError: if a symbol of a given name is imported + from different containers at the call site and within the routine. + :raises TransformationError: if the routine accesses an un-resolved + symbol. + :raises TransformationError: if the number of arguments in the call + does not match the number of formal arguments of the routine. + :raises TransformationError: if a symbol declared in the parent + container is accessed in the target routine. + :raises TransformationError: if the shape of an array formal argument + does not match that of the corresponding actual argument. + + """ + super().validate(call_node, options=options) + + self._call_node = call_node + self._routine_node = routine_node + + # The node should be a Call. + if not isinstance(self._call_node, Call): + raise TransformationError( + "The target of the InlineTrans transformation should" + f" be a Call but found '{type(self._call_node).__name__}'." + ) + + call_name = self._call_node.routine.name + if isinstance(self._call_node, IntrinsicCall): + raise TransformationError( + f"Cannot inline an IntrinsicCall ('{call_name}')" + ) + + # List of indices relating the call's arguments to the subroutine + # arguments. This can be different due to + # - optional arguments + # - named arguments + + from psyclone.psyir.tools import CallMatchingArgumentsNotFoundError + + self._call_routine_matcher.set_call_node(self._call_node) + + if self._routine_node is None: + # Check that we can find the source of the routine being inlined. + # TODO #924 allow for multiple routines (interfaces). + try: + (self._routine_node, self._ret_arg_match_list) = \ + self._call_routine_matcher.get_callee() + except ( + CallMatchingArgumentsNotFoundError, + NotImplementedError, + FileNotFoundError, + SymbolError, + TransformationError, + ) as err: + raise TransformationError( + f"Cannot inline routine '{call_name}' because its source" + f" cannot be found:\n{str(err)}" + ) from err + + else: + # A routine has been provided. + # Therefore, we just determine the matching argument list + # if it matches. + try: + rm = self._call_routine_matcher + rm.set_routine_node(self._routine_node) + rm.set_option( + check_argument_strict_array_datatype=False) + self._ret_arg_match_list = ( + rm.get_argument_routine_match_list() + ) + except CallMatchingArgumentsNotFoundError as err: + raise TransformationError( + "Routine's arguments doesn't match subroutine" + ) from err + + self._validate_inline_arguments_of_call_and_routine( + call_node, self._routine_node, self._ret_arg_match_list) + # For AutoAPI auto-documentation generation. __all__ = ["InlineTrans"] diff --git a/src/psyclone/psyir/transformations/omp_task_trans.py b/src/psyclone/psyir/transformations/omp_task_trans.py index bd096a9c5c..c1d341b59f 100644 --- a/src/psyclone/psyir/transformations/omp_task_trans.py +++ b/src/psyclone/psyir/transformations/omp_task_trans.py @@ -116,7 +116,7 @@ def validate(self, node, options=None): cond_trans = FoldConditionalReturnExpressionsTrans() intrans = InlineTrans() intrans.set_option( - check_matching_arguments_of_callee=( + check_argument_matching=( self._option_check_matching_arguments_of_callee ) ) @@ -180,7 +180,7 @@ def _inline_kernels(self, node): cond_trans = FoldConditionalReturnExpressionsTrans() intrans = InlineTrans() intrans.set_option( - check_matching_arguments_of_callee=( + check_argument_matching=( self._option_check_matching_arguments_of_callee ) ) diff --git a/src/psyclone/tests/psyir/nodes/call_test.py b/src/psyclone/tests/psyir/nodes/call_test.py index 4fc30e443a..1e498df5ad 100644 --- a/src/psyclone/tests/psyir/nodes/call_test.py +++ b/src/psyclone/tests/psyir/nodes/call_test.py @@ -66,7 +66,7 @@ ) from psyclone.errors import GenerationError -from psyclone.psyir.nodes.call import CallMatchingArgumentsNotFoundError +from psyclone.psyir.tools.call_routine_matcher import CallMatchingArgumentsNotFoundError class SpecialCall(Call): diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index c359c79433..2fde975a34 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -268,7 +268,7 @@ def test_apply_gocean_kern(fortran_reader, fortran_writer, monkeypatch): monkeypatch.setattr(Config.get(), '_include_paths', [str(src_dir)]) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) with pytest.raises(TransformationError) as err: inline_trans.apply(psyir.walk(Call)[0]) if ("actual argument 'cu_fld' corresponding to an array formal " @@ -326,7 +326,7 @@ def test_apply_struct_arg(fortran_reader, fortran_writer, tmpdir): f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) @@ -399,7 +399,7 @@ def test_apply_unresolved_struct_arg(fortran_reader, fortran_writer): "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) calls = psyir.walk(Call) # First one should be fine. inline_trans.apply(calls[0]) @@ -462,7 +462,7 @@ def test_apply_struct_slice_arg(fortran_reader, fortran_writer, tmpdir): f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -500,7 +500,7 @@ def test_apply_struct_local_limits_caller(fortran_reader, fortran_writer, f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -545,7 +545,7 @@ def test_apply_struct_local_limits_caller_decln(fortran_reader, fortran_writer, f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -599,7 +599,7 @@ def test_apply_struct_local_limits_routine(fortran_reader, fortran_writer, f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(routine) output = fortran_writer(psyir) @@ -654,7 +654,7 @@ def test_apply_array_limits_are_formal_args(fortran_reader, fortran_writer): ''' psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) acall = psyir.walk(Call, stop_type=Call)[0] inline_trans.apply(acall) output = fortran_writer(psyir) @@ -701,7 +701,7 @@ def test_apply_allocatable_array_arg(fortran_reader, fortran_writer): ) psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) for routine in psyir.walk(Routine)[0].walk(Call, stop_type=Call): if not isinstance(routine, IntrinsicCall): inline_trans.apply(routine) @@ -762,7 +762,7 @@ def test_apply_array_slice_arg(fortran_reader, fortran_writer, tmpdir): "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) for call in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(call) output = fortran_writer(psyir) @@ -813,7 +813,7 @@ def test_apply_struct_array_arg(fortran_reader, fortran_writer, tmpdir): psyir = fortran_reader.psyir_from_source(code) loops = psyir.walk(Loop) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) inline_trans.apply(loops[0].loop_body.children[1]) inline_trans.apply(loops[1].loop_body.children[1]) inline_trans.apply(loops[2].loop_body.children[1]) @@ -869,7 +869,7 @@ def test_apply_struct_array_slice_arg(fortran_reader, fortran_writer, tmpdir): f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) for call in psyir.walk(Call): if not isinstance(call, IntrinsicCall): if call.arguments[0].debug_string() == "grid%local%data": @@ -942,7 +942,7 @@ def test_apply_struct_array(fortran_reader, fortran_writer, tmpdir, f"end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) if "use some_mod" in type_decln: with pytest.raises(TransformationError) as err: inline_trans.apply(psyir.walk(Call)[0]) @@ -988,7 +988,7 @@ def test_apply_repeated_module_use(fortran_reader, fortran_writer): "end module test_mod\n") psyir = fortran_reader.psyir_from_source(code) inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) for call in psyir.walk(Routine)[0].walk(Call, stop_type=Call): inline_trans.apply(call) output = fortran_writer(psyir) @@ -1701,7 +1701,7 @@ def test_validate_codeblock(fortran_reader): "cannot be inlined. (If you are confident " in str(err.value) ) - inline_trans.set_option(check_codeblocks=False) + inline_trans.set_option(check_inline_codeblocks=False) inline_trans.validate(call) @@ -2075,7 +2075,7 @@ def test_validate_array_reshape(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ("Cannot inline routine 's' because it reshapes an argument: actual" @@ -2108,7 +2108,7 @@ def test_validate_array_arg_expression(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ("The call 'call s(a + b, 10)\n' cannot be inlined because actual " @@ -2135,7 +2135,7 @@ def test_validate_indirect_range(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ("Cannot inline routine 'sub' because argument 'var(indices(:))' " @@ -2160,7 +2160,7 @@ def test_validate_non_unit_stride_slice(fortran_reader): psyir = fortran_reader.psyir_from_source(code) call = psyir.walk(Call)[0] inline_trans = InlineTrans() - inline_trans.set_option(check_matching_arguments_of_callee=False) + inline_trans.set_option(check_argument_matching=False) with pytest.raises(TransformationError) as err: inline_trans.validate(call) assert ("Cannot inline routine 'sub' because one of its arguments is an " From c842289a721d54fbf916fbad5b1777df60ec62c8 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sun, 24 Nov 2024 17:00:13 +0100 Subject: [PATCH 08/20] further cleanups + working on coverage test --- .coverage.mardinateur.37407.XOevDdvx.wgw0 | 0 .coverage.mardinateur.37413.XOWfjlOx.wgw2 | 0 src/psyclone/psyir/nodes/__init__.py | 184 ++++---- src/psyclone/psyir/nodes/call.py | 163 ++++---- .../psyir/tools/call_routine_matcher.py | 162 +------- .../psyir/transformations/inline_trans.py | 393 +++++++++++------- src/psyclone/tests/psyir/nodes/call_test.py | 3 +- .../transformations/inline_trans_test.py | 124 +++++- 8 files changed, 534 insertions(+), 495 deletions(-) delete mode 100644 .coverage.mardinateur.37407.XOevDdvx.wgw0 delete mode 100644 .coverage.mardinateur.37413.XOWfjlOx.wgw2 diff --git a/.coverage.mardinateur.37407.XOevDdvx.wgw0 b/.coverage.mardinateur.37407.XOevDdvx.wgw0 deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/.coverage.mardinateur.37413.XOWfjlOx.wgw2 b/.coverage.mardinateur.37413.XOWfjlOx.wgw2 deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/psyclone/psyir/nodes/__init__.py b/src/psyclone/psyir/nodes/__init__.py index 69f066cc9a..ba2c6dd1e7 100644 --- a/src/psyclone/psyir/nodes/__init__.py +++ b/src/psyclone/psyir/nodes/__init__.py @@ -104,96 +104,96 @@ # The entities in the __all__ list are made available to import directly from # this package e.g. 'from psyclone.psyir.nodes import Literal' __all__ = [ - "colored", - "ArrayMember", - "ArrayReference", - "ArrayOfStructuresMember", - "ArrayOfStructuresReference", - "Assignment", - "BinaryOperation", - "Call", - "Clause", - "CodeBlock", - "Container", - "DataNode", - "FileContainer", - "IfBlock", - "IntrinsicCall", - "Literal", - "Loop", - "Member", - "Node", - "OperandClause", - "Operation", - "Range", - "Reference", - "Return", - "Routine", - "Schedule", - "Statement", - "StructureMember", - "StructureReference", - "UnaryOperation", - "ScopingNode", - "WhileLoop", - # PSyclone-specific nodes - "KernelSchedule", - # PSyData Nodes - "PSyDataNode", - "ExtractNode", - "ProfileNode", - "ReadOnlyVerifyNode", - "ValueRangeCheckNode", - # Directive Nodes - "Directive", - "RegionDirective", - "StandaloneDirective", - # OpenACC Directive Nodes - "ACCAtomicDirective", - "ACCDirective", - "ACCRegionDirective", - "ACCStandaloneDirective", - "ACCDataDirective", - "ACCEnterDataDirective", - "ACCParallelDirective", - "ACCLoopDirective", - "ACCKernelsDirective", - "ACCUpdateDirective", - "ACCRoutineDirective", - # OpenACC Clause Nodes - "ACCCopyClause", - "ACCCopyInClause", - "ACCCopyOutClause", - # OpenMP Directive Nodes - "OMPAtomicDirective", - "OMPDirective", - "OMPRegionDirective", - "OMPStandaloneDirective", - "OMPParallelDirective", - "OMPSerialDirective", - "OMPSingleDirective", - "OMPMasterDirective", - "OMPTaskloopDirective", - "OMPTaskDirective", - "DynamicOMPTaskDirective", - "OMPDoDirective", - "OMPParallelDoDirective", - "OMPTaskwaitDirective", - "OMPTargetDirective", - "OMPLoopDirective", - "OMPDeclareTargetDirective", - "OMPSimdDirective", - "OMPTeamsDistributeParallelDoDirective", - # OMP Clause Nodes - "OMPGrainsizeClause", - "OMPNogroupClause", - "OMPNowaitClause", - "OMPNumTasksClause", - "OMPPrivateClause", - "OMPDefaultClause", - "OMPReductionClause", - "OMPScheduleClause", - "OMPFirstprivateClause", - "OMPSharedClause", - "OMPDependClause", + 'colored', + 'ArrayMember', + 'ArrayReference', + 'ArrayOfStructuresMember', + 'ArrayOfStructuresReference', + 'Assignment', + 'BinaryOperation', + 'Call', + 'Clause', + 'CodeBlock', + 'Container', + 'DataNode', + 'FileContainer', + 'IfBlock', + 'IntrinsicCall', + 'Literal', + 'Loop', + 'Member', + 'Node', + 'OperandClause', + 'Operation', + 'Range', + 'Reference', + 'Return', + 'Routine', + 'Schedule', + 'Statement', + 'StructureMember', + 'StructureReference', + 'UnaryOperation', + 'ScopingNode', + 'WhileLoop', + # PSyclone-specific nodes + 'KernelSchedule', + # PSyData Nodes + 'PSyDataNode', + 'ExtractNode', + 'ProfileNode', + 'ReadOnlyVerifyNode', + 'ValueRangeCheckNode', + # Directive Nodes + 'Directive', + 'RegionDirective', + 'StandaloneDirective', + # OpenACC Directive Nodes + 'ACCAtomicDirective', + 'ACCDirective', + 'ACCRegionDirective', + 'ACCStandaloneDirective', + 'ACCDataDirective', + 'ACCEnterDataDirective', + 'ACCParallelDirective', + 'ACCLoopDirective', + 'ACCKernelsDirective', + 'ACCUpdateDirective', + 'ACCRoutineDirective', + # OpenACC Clause Nodes + 'ACCCopyClause', + 'ACCCopyInClause', + 'ACCCopyOutClause', + # OpenMP Directive Nodes + 'OMPAtomicDirective', + 'OMPDirective', + 'OMPRegionDirective', + 'OMPStandaloneDirective', + 'OMPParallelDirective', + 'OMPSerialDirective', + 'OMPSingleDirective', + 'OMPMasterDirective', + 'OMPTaskloopDirective', + 'OMPTaskDirective', + 'DynamicOMPTaskDirective', + 'OMPDoDirective', + 'OMPParallelDoDirective', + 'OMPTaskwaitDirective', + 'OMPTargetDirective', + 'OMPLoopDirective', + 'OMPDeclareTargetDirective', + 'OMPSimdDirective', + 'OMPTeamsDistributeParallelDoDirective', + # OMP Clause Nodes + 'OMPGrainsizeClause', + 'OMPNogroupClause', + 'OMPNowaitClause', + 'OMPNumTasksClause', + 'OMPPrivateClause', + 'OMPDefaultClause', + 'OMPReductionClause', + 'OMPScheduleClause', + 'OMPFirstprivateClause', + 'OMPSharedClause', + 'OMPDependClause', ] diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index d1f7ff93fc..f587520430 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -38,24 +38,13 @@ from collections.abc import Iterable -from psyclone.configuration import Config from psyclone.core import AccessType from psyclone.errors import GenerationError -from psyclone.psyir.nodes.container import Container from psyclone.psyir.nodes.statement import Statement from psyclone.psyir.nodes.datanode import DataNode from psyclone.psyir.nodes.reference import Reference -from psyclone.psyir.nodes.routine import Routine -from psyclone.psyir.symbols import ( - RoutineSymbol, - Symbol, - SymbolError, - UnsupportedFortranType, - DataSymbol, - SymbolTable, - ContainerSymbol, -) -from typing import List, Union +from psyclone.psyir.symbols import RoutineSymbol +from typing import List class Call(Statement, DataNode): @@ -463,78 +452,78 @@ def copy(self): return new_copy - def _get_container_symbols_rec( - self, - container_symbols_list: List[str], - ignore_missing_modules: bool = False, - _stack_container_name_list: List[str] = [], - _depth: int = 0, - ): - '''Return a list of all container symbols that can be found - recursively - - :param container_symbols: List of starting set of container symbols - :type container_symbols: List[ContainerSymbol] - :param _stack_container_list: Stack with already visited Containers - to avoid circular searches, defaults to [] - :type _stack_container_list: List[Container], optional - :param _depth: Depth of recursive search - :type _depth: int - ''' - # - # TODO: - # - This function seems to be extremely slow: - # It takes considerable time to build this list over and over - # for each lookup. - # - This function can also be written in a non-resursive way - # - # An alternative would be to cache it, but then the cache - # needs to be invalidated once some symbols are, e.g., deleted. - # - ret_container_symbol_list = container_symbols_list[:] - - # Cache the container names from symbols - container_names = [cs.name.lower() for cs in container_symbols_list] - - from psyclone.parse import ModuleManager - - module_manager = ModuleManager.get() - - for container_name in container_names: - try: - module_info = module_manager.get_module_info( - container_name.lower() - ) - if module_info is None: - continue - - except (ModuleNotFoundError, FileNotFoundError) as err: - if ignore_missing_modules: - continue - - raise err - - container: Container = module_info.get_psyir_container_node() - - # Avoid circular connections (which shouldn't - # be allowed, but who knows...) - if container.name.lower() in _stack_container_name_list: - continue - - new_container_symbols = self._get_container_symbols_rec( - container_symbols_list=container.symbol_table.containersymbols, - ignore_missing_modules=ignore_missing_modules, - _stack_container_name_list=_stack_container_name_list - + [container.name.lower()], - _depth=_depth + 1, - ) - - # Add symbol if it's not yet in the list of symbols - for container_symbol in new_container_symbols: - if container_symbol not in ret_container_symbol_list: - ret_container_symbol_list.append(container_symbol) - - return ret_container_symbol_list + # def _get_container_symbols_rec( + # self, + # container_symbols_list: List[str], + # ignore_missing_modules: bool = False, + # _stack_container_name_list: List[str] = [], + # _depth: int = 0, + # ): + # '''Return a list of all container symbols that can be found + # recursively + + # :param container_symbols: List of starting set of container symbols + # :type container_symbols: List[ContainerSymbol] + # :param _stack_container_list: Stack with already visited Containers + # to avoid circular searches, defaults to [] + # :type _stack_container_list: List[Container], optional + # :param _depth: Depth of recursive search + # :type _depth: int + # ''' + # # + # # TODO: + # # - This function seems to be extremely slow: + # # It takes considerable time to build this list over and over + # # for each lookup. + # # - This function can also be written in a non-resursive way + # # + # # An alternative would be to cache it, but then the cache + # # needs to be invalidated once some symbols are, e.g., deleted. + # # + # ret_container_symbol_list = container_symbols_list[:] + + # # Cache the container names from symbols + # container_names = [cs.name.lower() for cs in container_symbols_list] + + # from psyclone.parse import ModuleManager + + # module_manager = ModuleManager.get() + + # for container_name in container_names: + # try: + # module_info = module_manager.get_module_info( + # container_name.lower() + # ) + # if module_info is None: + # continue + + # except (ModuleNotFoundError, FileNotFoundError) as err: + # if ignore_missing_modules: + # continue + + # raise err + + # container: Container = module_info.get_psyir_container_node() + + # # Avoid circular connections (which shouldn't + # # be allowed, but who knows...) + # if container.name.lower() in _stack_container_name_list: + # continue + + # new_container_symbols = self._get_container_symbols_rec( + # container_symbols_list=container.symbol_table.containersymbols, + # ignore_missing_modules=ignore_missing_modules, + # _stack_container_name_list=_stack_container_name_list + # + [container.name.lower()], + # _depth=_depth + 1, + # ) + + # # Add symbol if it's not yet in the list of symbols + # for container_symbol in new_container_symbols: + # if container_symbol not in ret_container_symbol_list: + # ret_container_symbol_list.append(container_symbol) + + # return ret_container_symbol_list def get_callees(self, ignore_missing_modules: bool = False): ''' @@ -597,10 +586,10 @@ def get_callee( call_routine_matcher: CallRoutineMatcher = CallRoutineMatcher(self) call_routine_matcher.set_option( check_matching_arguments=check_matching_arguments, - check_argument_strict_array_datatype=check_strict_array_datatype, + check_argument_strict_array_datatype=( + check_strict_array_datatype), ignore_missing_modules=ignore_missing_modules, ignore_unresolved_symbol=ignore_unresolved_symbol, ) return call_routine_matcher.get_callee() - diff --git a/src/psyclone/psyir/tools/call_routine_matcher.py b/src/psyclone/psyir/tools/call_routine_matcher.py index 780e3300c4..1869c064b2 100644 --- a/src/psyclone/psyir/tools/call_routine_matcher.py +++ b/src/psyclone/psyir/tools/call_routine_matcher.py @@ -40,12 +40,10 @@ from typing import List, Union from psyclone.psyir.symbols.datatypes import ArrayType -from psyclone.psyir.nodes import Call, Reference, Routine +from psyclone.psyir.nodes import Call, Routine from psyclone.errors import PSycloneError from psyclone.configuration import Config from psyclone.psyir.nodes.container import Container -from psyclone.psyir.nodes.reference import Reference -from psyclone.psyir.nodes.routine import Routine from psyclone.psyir.symbols import ( RoutineSymbol, Symbol, @@ -56,9 +54,11 @@ ContainerSymbol, ) + class CallMatchingArgumentsNotFoundError(PSycloneError): """Exception to signal that matching arguments have not been found for this routine + """ def __init__(self, value): @@ -112,13 +112,14 @@ def set_option(self, check_argument_strict_array_datatype: bool = None, ignore_missing_modules: bool = None, ignore_unresolved_symbol: bool = None, - ): + ): if check_matching_arguments is not None: self._option_check_matching_arguments = check_matching_arguments if check_argument_strict_array_datatype is not None: - self._option_check_strict_array_datatype = check_argument_strict_array_datatype + self._option_check_strict_array_datatype = ( + check_argument_strict_array_datatype) if ignore_missing_modules is not None: self._option_ignore_missing_modules = ignore_missing_modules @@ -126,152 +127,6 @@ def set_option(self, if ignore_unresolved_symbol is not None: self._option_ignore_unresolved_symbol = ignore_unresolved_symbol - def _check_inline_types( - self, - call_arg: DataSymbol, - routine_arg: DataSymbol, - check_array_type: bool = True, - ): - """This function performs tests to see whether the - inlining can cope with it. - - :param call_arg: The argument of a call - :type call_arg: DataSymbol - :param routine_arg: The argument of a routine - :type routine_arg: DataSymbol - :param check_array_type: Perform strong checks on array types, - defaults to `True` - :type check_array_type: bool, optional - - :raises TransformationError: Raised if transformation can't be done - - :return: 'True' if checks are successful - :rtype: bool - """ - from psyclone.psyir.transformations.transformation_error import ( - TransformationError, - ) - from psyclone.errors import LazyString - from psyclone.psyir.nodes import Literal, Range - from psyclone.psyir.symbols import ( - UnresolvedType, - UnsupportedType, - INTEGER_TYPE, - ) - - _ONE = Literal("1", INTEGER_TYPE) - - # If the formal argument is an array with non-default bounds then - # we also need to know the bounds of that array at the call site. - if not isinstance(routine_arg.datatype, ArrayType): - # Formal argument is not an array so we don't need to do any - # further checks. - return True - - if not isinstance(call_arg, (Reference, Literal)): - # TODO #1799 this really needs the `datatype` method to be - # extended to support all nodes. For now we have to abort - # if we encounter an argument that is not a scalar (according - # to the corresponding formal argument) but is not a - # Reference or a Literal as we don't know whether the result - # of any general expression is or is not an array. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - f"The call '{self._call_node.debug_string()}' " - "cannot be inlined because actual argument " - f"'{call_arg.debug_string()}' corresponds to a " - "formal argument with array type but is not a " - "Reference or a Literal." - ) - ) - ) - - # We have an array argument. We are only able to check that the - # argument is not re-shaped in the called routine if we have full - # type information on the actual argument. - # TODO #924. It would be useful if the `datatype` property was - # a method that took an optional 'resolve' argument to indicate - # that it should attempt to resolve any UnresolvedTypes. - if check_array_type: - if isinstance( - call_arg.datatype, (UnresolvedType, UnsupportedType) - ) or ( - isinstance(call_arg.datatype, ArrayType) - and isinstance( - call_arg.datatype.intrinsic, - (UnresolvedType, UnsupportedType), - ) - ): - raise TransformationError( - f"Routine '{self._routine_node.name}' cannot be " - "inlined because the type of the actual argument " - f"'{call_arg.symbol.name}' corresponding to an array" - f" formal argument ('{routine_arg.name}') is unknown." - ) - - formal_rank = 0 - actual_rank = 0 - if isinstance(routine_arg.datatype, ArrayType): - formal_rank = len(routine_arg.datatype.shape) - if isinstance(call_arg.datatype, ArrayType): - actual_rank = len(call_arg.datatype.shape) - if formal_rank != actual_rank: - # It's OK to use the loop variable in the lambda definition - # because if we get to this point then we're going to quit - # the loop. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{self._routine_node.name}' because it" - " reshapes an argument: actual argument" - f" '{call_arg.debug_string()}' has rank" - f" {actual_rank} but the corresponding formal" - f" argument, '{routine_arg.name}', has rank" - f" {formal_rank}" - ) - ) - ) - if actual_rank: - ranges = call_arg.walk(Range) - for rge in ranges: - ancestor_ref = rge.ancestor(Reference) - if ancestor_ref is not call_arg: - # Have a range in an indirect access. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{self._routine_node.name}' because" - " argument" - f" '{call_arg.debug_string()}' has" - " an array range in an indirect" - " access #(TODO 924)." - ) - ) - ) - if rge.step != _ONE: - # TODO #1646. We could resolve this problem by - # making a new array and copying the necessary - # values into it. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{self._routine_node.name}' because" - " one of its arguments is an array" - " slice with a non-unit stride:" - f" '{call_arg.debug_string()}' (TODO" - " #1646)" - ) - ) - ) - def _check_argument_type_matches( self, call_arg: DataSymbol, @@ -295,8 +150,6 @@ def _check_argument_type_matches( were found. """ - # self._check_inline_types(call_arg, routine_arg) - type_matches = False if not check_strict_array_datatype: # No strict array checks have to be performed, just accept it @@ -491,7 +344,8 @@ def _location_txt(node): # It would be better using the ModuleManager to resolve # (and cache) all containers to look up for this. # - # current_containersymbols = self._call_node._get_container_symbols_rec( + # current_containersymbols = + # self._call_node._get_container_symbols_rec( # current_table.containersymbols, # ignore_missing_modules=ignore_missing_modules, # ) diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 7de074efdb..5c940656dd 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -37,7 +37,6 @@ This module contains the InlineTrans transformation. ''' -from psyclone.errors import LazyString from psyclone.psyGen import Transformation from psyclone.psyir.nodes import ( ArrayReference, ArrayOfStructuresReference, BinaryOperation, Call, @@ -48,7 +47,6 @@ ArgumentInterface, ArrayType, DataSymbol, - UnresolvedType, INTEGER_TYPE, StaticInterface, SymbolError, @@ -162,7 +160,7 @@ def __init__(self): self._option_check_codeblocks: bool = True self._option_check_diff_container_clashes: bool = True - self._option_check_diff_container_clashes_unresolved_types: bool = True + self._option_check_diff_container_clashes_unres_types: bool = True self._option_check_resolve_imports: bool = True self._option_check_static_interface: bool = True self._option_check_array_type: bool = True @@ -174,7 +172,15 @@ def set_option( ignore_missing_modules: bool = None, check_argument_strict_array_datatype: bool = None, check_argument_matching: bool = None, + check_inline_codeblocks: bool = None, + check_diff_container_clashes: bool = None, + check_diff_container_clashes_unres_types: bool = None, + check_resolve_imports: bool = None, + check_static_interface: bool = None, + check_array_type: bool = None, + check_argument_of_unsupported_type: bool = None, + check_argument_unresolved_symbols: bool = None, ): if check_argument_strict_array_datatype is not None: self._option_check_argument_strict_array_datatype = ( @@ -191,6 +197,34 @@ def set_option( if check_inline_codeblocks is not None: self._option_check_codeblocks = check_inline_codeblocks + if check_diff_container_clashes is not None: + self._option_check_diff_container_clashes = ( + check_diff_container_clashes) + + if check_diff_container_clashes_unres_types is not None: + self._option_check_diff_container_clashes_unres_types = ( + check_diff_container_clashes_unres_types + ) + + if check_resolve_imports is not None: + self._option_check_resolve_imports = check_resolve_imports + + if check_static_interface is not None: + self._option_check_static_interface = check_static_interface + + if check_array_type is not None: + self._option_check_array_type = check_array_type + + if check_argument_of_unsupported_type is not None: + self._option_check_argument_of_unsupported_type = ( + check_argument_of_unsupported_type + ) + + if check_argument_unresolved_symbols is not None: + self._option_check_argument_unresolved_symbols = ( + check_argument_unresolved_symbols + ) + def apply( self, call_node: Call, routine_node: Routine = None, options=None ): @@ -230,7 +264,23 @@ def apply( # copy of it. self._routine_node = self._routine_node.copy() routine_table = self._routine_node.symbol_table - self._optional_arg_remove_unused_optional_arguments() + + # Next, we remove all optional arguments which are not used. + # Step 1) + # - Build lookup dictionary for all optional arguments: + + # - For all `PRESENT(...)`: + # - Lookup variable in dictionary + # - Replace with `True` or `False`, depending on whether + # it's provided or not. + self._optional_arg_resolve_present_intrinsics() + + # Step 2) + # - For all If-Statements, handle constant conditions: + # - `True`: Replace If-Block with If-Body + # - `False`: Replace If-Block with Else-Body. If it doesn't exist + # just delete the if statement. + self._optional_arg_eliminate_ifblock_if_const_condition() # Construct lists of the nodes that will be inserted and all of the # References that they contain. @@ -320,7 +370,9 @@ def apply( replacement.attach(scope) def _optional_arg_resolve_present_intrinsics(self): - """Replace PRESENT() intrinsics with `True` or `False` + """Replace PRESENT(some_argument) intrinsics in routine with constant + booleans depending on whether `some_argument` has been provided + (`True`) or not (`False`). :rtype: None """ @@ -364,8 +416,12 @@ def _optional_arg_resolve_present_intrinsics(self): else: intrinsic_call.replace_with(Literal("false", BOOLEAN_TYPE)) - def _optional_arg_specialize_ifblock_if_const_condition(self): - """Specialize if-block if conditions are constant booleans + def _optional_arg_eliminate_ifblock_if_const_condition(self): + """Eliminate if-block if conditions are constant booleans. + + TODO: This also requires support of conditions containing logical + expressions such as `(.true. .or. .false.)` + TODO: This could also become a Psyclone transformation. :rtype: None """ @@ -404,20 +460,19 @@ def if_else_replace(main_schedule, if_block, if_body_schedule): condition = if_block.condition - # Check if the condition is a BooleanLiteral + # Make sure we only handle a BooleanLiteral as a condition + # TODO #2802 if not isinstance(condition, Literal): continue - # Check for right datatype - if ( + # Check that it's a boolean Literal + assert ( condition.datatype.intrinsic - is not ScalarType.Intrinsic.BOOLEAN - ): - continue + is ScalarType.Intrinsic.BOOLEAN + ), "Found non-boolean expression in conditional of if branch" if condition.value == "true": # Only keep if_block - if_else_replace(if_block.parent, if_block, if_block.if_body) else: @@ -427,33 +482,9 @@ def if_else_replace(main_schedule, if_block, if_body_schedule): if_block.detach() continue + # Only keep else block if_else_replace(if_block.parent, if_block, if_block.else_body) - def _optional_arg_remove_unused_optional_arguments(self): - """Remove all optional arguments which are not used. - - Steps: - - - Build lookup dictionary for all optional arguments: - - - For all `PRESENT(...)`: - - Lookup variable in dictionary - - Replace with `True` or `False`, depending on whether - it's provided or not. - - - For all If-Statements, handle constant conditions: - - `True`: Replace If-Block with If-Body - - `False`: Replace If-Block with Else-Body. If it doesn't exist - just delete the if statement. - - :rtype: None - """ - - self._optional_arg_resolve_present_intrinsics() - - # Evaluate all if-blocks with constant booleans - self._optional_arg_specialize_ifblock_if_const_condition() - def _replace_formal_arg(self, ref, call_node, formal_args): ''' Recursively combines any References to formal arguments in the supplied @@ -489,7 +520,17 @@ def _replace_formal_arg(self, ref, call_node, formal_args): # Lookup index of actual argument # If this is an optional argument, but not used, this index lookup # shouldn't fail - actual_arg_idx = self._ret_arg_match_list.index(routine_arg_idx) + try: + actual_arg_idx = self._ret_arg_match_list.index(routine_arg_idx) + except ValueError as err: + arg_list = self._routine_node.symbol_table.argument_list + arg_name = arg_list[routine_arg_idx].name + raise TransformationError( + f"Subroutine argument '{arg_name}' is not provided by call," + f" but used in the subroutine." + f" If this is correct code, this is likely due to" + f" some non-eliminated if-branches using `PRESENT(...)` as" + f" conditional (TODO #2802).") from err # Lookup the actual argument that corresponds to this formal argument. actual_arg = call_node.arguments[actual_arg_idx] @@ -810,7 +851,149 @@ def _replace_formal_struc_arg(self, actual_arg, ref, call_node, # Just an array reference. return ArrayReference.create(actual_arg.symbol, members[0][1]) - def _validate_inline_arguments_of_call_and_routine( + def _validate_inline_of_call_and_routine_argument_pairs( + self, + call_arg: DataSymbol, + routine_arg: DataSymbol + ) -> bool: + """This function performs tests to see whether the + inlining can cope with it. + + :param call_arg: The argument of a call + :type call_arg: DataSymbol + :param routine_arg: The argument of a routine + :type routine_arg: DataSymbol + + :raises TransformationError: Raised if transformation can't be done + + :return: 'True' if checks are successful + :rtype: bool + """ + from psyclone.psyir.transformations.transformation_error import ( + TransformationError, + ) + from psyclone.errors import LazyString + from psyclone.psyir.nodes import Literal, Range + from psyclone.psyir.symbols import ( + UnresolvedType, + UnsupportedType, + INTEGER_TYPE, + ) + + _ONE = Literal("1", INTEGER_TYPE) + + # If the formal argument is an array with non-default bounds then + # we also need to know the bounds of that array at the call site. + if not isinstance(routine_arg.datatype, ArrayType): + # Formal argument is not an array so we don't need to do any + # further checks. + return True + + if not isinstance(call_arg, (Reference, Literal)): + # TODO #1799 this really needs the `datatype` method to be + # extended to support all nodes. For now we have to abort + # if we encounter an argument that is not a scalar (according + # to the corresponding formal argument) but is not a + # Reference or a Literal as we don't know whether the result + # of any general expression is or is not an array. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + f"The call '{self._call_node.debug_string()}' " + "cannot be inlined because actual argument " + f"'{call_arg.debug_string()}' corresponds to a " + "formal argument with array type but is not a " + "Reference or a Literal." + ) + ) + ) + + # We have an array argument. We are only able to check that the + # argument is not re-shaped in the called routine if we have full + # type information on the actual argument. + # TODO #924. It would be useful if the `datatype` property was + # a method that took an optional 'resolve' argument to indicate + # that it should attempt to resolve any UnresolvedTypes. + if self._option_check_array_type: + if isinstance( + call_arg.datatype, (UnresolvedType, UnsupportedType) + ) or ( + isinstance(call_arg.datatype, ArrayType) + and isinstance( + call_arg.datatype.intrinsic, + (UnresolvedType, UnsupportedType), + ) + ): + raise TransformationError( + f"Routine '{self._routine_node.name}' cannot be " + "inlined because the type of the actual argument " + f"'{call_arg.symbol.name}' corresponding to an array" + f" formal argument ('{routine_arg.name}') is unknown." + ) + + formal_rank = 0 + actual_rank = 0 + if isinstance(routine_arg.datatype, ArrayType): + formal_rank = len(routine_arg.datatype.shape) + if isinstance(call_arg.datatype, ArrayType): + actual_rank = len(call_arg.datatype.shape) + if formal_rank != actual_rank: + # It's OK to use the loop variable in the lambda definition + # because if we get to this point then we're going to quit + # the loop. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self._routine_node.name}' because it" + " reshapes an argument: actual argument" + f" '{call_arg.debug_string()}' has rank" + f" {actual_rank} but the corresponding formal" + f" argument, '{routine_arg.name}', has rank" + f" {formal_rank}" + ) + ) + ) + if actual_rank: + ranges = call_arg.walk(Range) + for rge in ranges: + ancestor_ref = rge.ancestor(Reference) + if ancestor_ref is not call_arg: + # Have a range in an indirect access. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self._routine_node.name}' because" + " argument" + f" '{call_arg.debug_string()}' has" + " an array range in an indirect" + " access #(TODO 924)." + ) + ) + ) + if rge.step != _ONE: + # TODO #1646. We could resolve this problem by + # making a new array and copying the necessary + # values into it. + # pylint: disable=cell-var-from-loop + raise TransformationError( + LazyString( + lambda: ( + "Cannot inline routine" + f" '{self._routine_node.name}' because" + " one of its arguments is an array" + " slice with a non-unit stride:" + f" '{call_arg.debug_string()}' (TODO" + " #1646)" + ) + ) + ) + + def _validate_inline_of_call_and_routine( self, call_node: Call, routine_node: Routine, @@ -827,8 +1010,9 @@ def _validate_inline_arguments_of_call_and_routine( the call to those of the routine in case of optional arguments. :type arg_index_list: List[int] :raises TransformationError: Arguments are not in a form to be inlined + """ - + name = call_node.routine.name if not routine_node.children or isinstance( @@ -862,12 +1046,11 @@ def _validate_inline_arguments_of_call_and_routine( table = call_node.scope.symbol_table routine_table = routine_node.symbol_table - for sym in routine_table.datasymbols: # We don't inline symbols that have an UnsupportedType and are # arguments since we don't know if a simple assignment if # enough (e.g. pointers) - if self._option_check_argument_unsupported_type: + if self._option_check_argument_of_unsupported_type: if isinstance(sym.interface, ArgumentInterface): if isinstance(sym.datatype, UnsupportedType): if ", OPTIONAL" not in sym.datatype.declaration: @@ -876,7 +1059,7 @@ def _validate_inline_arguments_of_call_and_routine( " inlined because it contains a Symbol" f" '{sym.name}' which is an Argument of" " UnsupportedType:" - f" '{sym.datatype.declaration}'" + f" '{sym.datatype.declaration}'." ) # We don't inline symbols that have an UnknownInterface, as we # don't know how they are brought into this scope. @@ -885,7 +1068,7 @@ def _validate_inline_arguments_of_call_and_routine( f"Routine '{routine_node.name}' cannot be " "inlined because it contains a Symbol " f"'{sym.name}' with an UnknownInterface: " - f"'{sym.datatype.declaration}'" + f"'{sym.datatype.declaration}'." ) if self._option_check_static_interface: @@ -910,7 +1093,7 @@ def _validate_inline_arguments_of_call_and_routine( routine_table, symbols_to_skip=routine_table.argument_list[:], check_unresolved_symbols=( - self._option_check_diff_container_clashes_unresolved_types + self._option_check_diff_container_clashes_unres_types ), ) except SymbolError as err: @@ -996,116 +1179,10 @@ def _validate_inline_arguments_of_call_and_routine( for routine_arg, call_arg in zip( routine_arg_list, call_node.arguments ): - # If the formal argument is an array with non-default bounds then - # we also need to know the bounds of that array at the call site. - if not isinstance(routine_arg.datatype, ArrayType): - # Formal argument is not an array so we don't need to do any - # further checks. - continue - - if not isinstance(call_arg, (Reference, Literal)): - # TODO #1799 this really needs the `datatype` method to be - # extended to support all nodes. For now we have to abort - # if we encounter an argument that is not a scalar (according - # to the corresponding formal argument) but is not a - # Reference or a Literal as we don't know whether the result - # of any general expression is or is not an array. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - f"The call '{call_node.debug_string()}' " - "cannot be inlined because actual argument " - f"'{call_arg.debug_string()}' corresponds to a " - "formal argument with array type but is not a " - "Reference or a Literal." - ) - ) - ) - - # We have an array argument. We are only able to check that the - # argument is not re-shaped in the called routine if we have full - # type information on the actual argument. - # TODO #924. It would be useful if the `datatype` property was - # a method that took an optional 'resolve' argument to indicate - # that it should attempt to resolve any UnresolvedTypes. - if self._option_check_array_type: - if isinstance( - call_arg.datatype, (UnresolvedType, UnsupportedType) - ) or ( - isinstance(call_arg.datatype, ArrayType) - and isinstance( - call_arg.datatype.intrinsic, - (UnresolvedType, UnsupportedType), - ) - ): - raise TransformationError( - f"Routine '{routine_node.name}' cannot be " - "inlined because the type of the actual argument " - f"'{call_arg.symbol.name}' corresponding to an array" - f" formal argument ('{routine_arg.name}') is unknown." - ) - - formal_rank = 0 - actual_rank = 0 - if isinstance(routine_arg.datatype, ArrayType): - formal_rank = len(routine_arg.datatype.shape) - if isinstance(call_arg.datatype, ArrayType): - actual_rank = len(call_arg.datatype.shape) - if formal_rank != actual_rank: - # It's OK to use the loop variable in the lambda definition - # because if we get to this point then we're going to quit - # the loop. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{routine_node.name}' because it" - " reshapes an argument: actual argument" - f" '{call_arg.debug_string()}' has rank" - f" {actual_rank} but the corresponding formal" - f" argument, '{routine_arg.name}', has rank" - f" {formal_rank}" - ) - ) - ) - if actual_rank: - ranges = call_arg.walk(Range) - for rge in ranges: - ancestor_ref = rge.ancestor(Reference) - if ancestor_ref is not call_arg: - # Have a range in an indirect access. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{routine_node.name}' because" - " argument" - f" '{call_arg.debug_string()}' has" - " an array range in an indirect" - " access #(TODO 924)." - ) - ) - ) - if rge.step != _ONE: - # TODO #1646. We could resolve this problem by - # making a new array and copying the necessary - # values into it. - # pylint: disable=cell-var-from-loop - raise TransformationError( - LazyString( - lambda: ( - "Cannot inline routine" - f" '{routine_node.name}' because" - " one of its arguments is an array" - " slice with a non-unit stride:" - f" '{call_arg.debug_string()}' (TODO" - " #1646)" - ) - ) - ) + self._validate_inline_of_call_and_routine_argument_pairs( + call_arg, + routine_arg + ) def validate( self, @@ -1215,7 +1292,7 @@ def validate( "Routine's arguments doesn't match subroutine" ) from err - self._validate_inline_arguments_of_call_and_routine( + self._validate_inline_of_call_and_routine( call_node, self._routine_node, self._ret_arg_match_list) diff --git a/src/psyclone/tests/psyir/nodes/call_test.py b/src/psyclone/tests/psyir/nodes/call_test.py index 1e498df5ad..00ba761c98 100644 --- a/src/psyclone/tests/psyir/nodes/call_test.py +++ b/src/psyclone/tests/psyir/nodes/call_test.py @@ -66,7 +66,8 @@ ) from psyclone.errors import GenerationError -from psyclone.psyir.tools.call_routine_matcher import CallMatchingArgumentsNotFoundError +from psyclone.psyir.tools.call_routine_matcher import ( + CallMatchingArgumentsNotFoundError) class SpecialCall(Call): diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index 2fde975a34..e6e42199bd 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -2168,6 +2168,26 @@ def test_validate_non_unit_stride_slice(fortran_reader): str(err.value)) +def test_set_options(fortran_reader): + '''Test that simply sets all options for sake of the coverage test.''' + + inline_trans = InlineTrans() + inline_trans.set_option( + ignore_missing_modules=False, + check_argument_strict_array_datatype=False, + check_argument_matching=False, + + check_inline_codeblocks=False, + check_diff_container_clashes=False, + check_diff_container_clashes_unres_types=False, + check_resolve_imports=False, + check_static_interface=False, + check_array_type=False, + check_argument_of_unsupported_type=False, + check_argument_unresolved_symbols=False, + ) + + def test_apply_named_arg(fortran_reader): '''Test that the validate method inlines a routine that has a named argument.''' @@ -2193,7 +2213,7 @@ def test_apply_named_arg(fortran_reader): inline_trans.apply(call) -def test_validate_optional_arg(fortran_reader): +def test_apply_optional_arg(fortran_reader): '''Test that the validate method inlines a routine that has an optional argument.''' @@ -2220,7 +2240,105 @@ def test_validate_optional_arg(fortran_reader): inline_trans.apply(call) -def test_validate_optional_and_named_arg(fortran_reader): +def test_apply_optional_arg_with_special_cases(fortran_reader): + '''Test that the validate method inlines a routine + that has an optional argument. + This example has an additional if-branching condition + `1.0==1.0` which is not directly of type `Literal` + ''' + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var)\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" + " real, optional :: opt\n" + " if( present(opt) )then\n" + " x = x + opt\n" + " end if\n" + " if( 1.0 == 1.0 )then\n" + " x = x\n" + " end if\n" + " x = x + 1.0\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.apply(call) + + +def test_apply_optional_arg_error(fortran_reader): + '''Test that the validate method can't inline a routine + where the optional argument is still used. + ''' + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var)\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" + " real, optional :: opt\n" + " if( present(opt) )then\n" + " x = x + opt\n" + " end if\n" + " x = x + opt\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + with pytest.raises(TransformationError) as einfo: + inline_trans.apply(call) + + assert ("Subroutine argument 'opt' is not provided by call," + " but used in the subroutine." in str(einfo.value)) + + +def test_apply_unsupported_pointer_error(fortran_reader): + '''Test that the validate method can't inline a routine + where a pointer argument is used. + This covers a special code + `if ", OPTIONAL" not in sym.datatype.declaration:` + which doesn't work that reliably and should be replaced + with something more robust. + ''' + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var)\n" + "end subroutine main\n" + "subroutine sub(x)\n" + " real, intent(inout), pointer :: x\n" + " x = 1.0\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + with pytest.raises(TransformationError) as einfo: + inline_trans.apply(call) + + assert ("Routine 'sub' cannot be inlined because it contains a Symbol 'x'" + " which is an Argument of UnsupportedType:" + " 'REAL, INTENT(INOUT), POINTER :: x'." in str(einfo.value)) + + +def test_apply_optional_and_named_arg(fortran_reader): '''Test that the validate method inlines a routine that has an optional argument.''' code = ( @@ -2268,7 +2386,7 @@ def test_validate_optional_and_named_arg(fortran_reader): ) -def test_validate_optional_and_named_arg_2(fortran_reader): +def test_apply_optional_and_named_arg_2(fortran_reader): '''Test that the validate method inlines a routine that has an optional argument.''' code = ( From 0344b47b2a2399f83e13e23d2b93f4cfee150d30 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sun, 24 Nov 2024 20:16:25 +0100 Subject: [PATCH 09/20] work on coverage --- src/psyclone/psyir/nodes/call.py | 2 +- src/psyclone/psyir/symbols/containersymbol.py | 7 +- src/psyclone/psyir/symbols/symbol_table.py | 13 ++- .../psyir/tools/call_routine_matcher.py | 34 ++---- .../psyir/transformations/inline_trans.py | 72 +++++++++--- .../transformations/inline_trans_test.py | 104 ++++++++++++++++++ 6 files changed, 184 insertions(+), 48 deletions(-) diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index f587520430..365850eb92 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -586,7 +586,7 @@ def get_callee( call_routine_matcher: CallRoutineMatcher = CallRoutineMatcher(self) call_routine_matcher.set_option( check_matching_arguments=check_matching_arguments, - check_argument_strict_array_datatype=( + check_strict_array_datatype=( check_strict_array_datatype), ignore_missing_modules=ignore_missing_modules, ignore_unresolved_symbol=ignore_unresolved_symbol, diff --git a/src/psyclone/psyir/symbols/containersymbol.py b/src/psyclone/psyir/symbols/containersymbol.py index 02d211de45..f9594e11ac 100644 --- a/src/psyclone/psyir/symbols/containersymbol.py +++ b/src/psyclone/psyir/symbols/containersymbol.py @@ -126,7 +126,7 @@ def copy(self): return new_symbol def find_container_psyir( - self, local_node=None, ignore_missing_modules: bool = False + self, local_node=None ): """Searches for the Container that this Symbol refers to. If it is not available, use the interface to import the container. If @@ -157,10 +157,7 @@ def find_container_psyir( self._reference = local return self._reference # We didn't find it so now attempt to import the container. - try: - self._reference = self._interface.get_container(self._name) - except ModuleNotFoundError: - return None + self._reference = self._interface.get_container(self._name) return self._reference def __str__(self): diff --git a/src/psyclone/psyir/symbols/symbol_table.py b/src/psyclone/psyir/symbols/symbol_table.py index 3afa5aea5f..479b2befaf 100644 --- a/src/psyclone/psyir/symbols/symbol_table.py +++ b/src/psyclone/psyir/symbols/symbol_table.py @@ -586,7 +586,8 @@ def add(self, new_symbol, tag=None): self._symbols[key] = new_symbol def check_for_clashes( - self, other_table, symbols_to_skip=(), check_unresolved_symbols=True + self, other_table, symbols_to_skip=(), + check_unresolved_symbols: bool = True ): """ Checks the symbols in the supplied table against those in @@ -831,7 +832,8 @@ def _handle_symbol_clash(self, old_sym, other_table): self.rename_symbol(self_sym, new_name) self.add(old_sym) - def merge(self, other_table, symbols_to_skip=()): + def merge(self, other_table, symbols_to_skip=(), + check_unresolved_symbols: bool = True): '''Merges all of the symbols found in `other_table` into this table. Symbol objects in *either* table may be renamed in the event of clashes. @@ -844,6 +846,9 @@ def merge(self, other_table, symbols_to_skip=()): the merge. :type symbols_to_skip: Iterable[ :py:class:`psyclone.psyir.symbols.Symbol`] + :param check_unresolved_symbols: If `True`, also check unresolved + symbols. + :type check_unresolved_symbols: bool :raises TypeError: if `other_table` is not a SymbolTable. :raises TypeError: if `symbols_to_skip` is not an Iterable. @@ -860,7 +865,9 @@ def merge(self, other_table, symbols_to_skip=()): try: self.check_for_clashes(other_table, - symbols_to_skip=symbols_to_skip) + symbols_to_skip=symbols_to_skip, + check_unresolved_symbols=( + check_unresolved_symbols)) except SymbolError as err: raise SymbolError( f"Cannot merge {other_table.view()} with {self.view()} due to " diff --git a/src/psyclone/psyir/tools/call_routine_matcher.py b/src/psyclone/psyir/tools/call_routine_matcher.py index 1869c064b2..cdddf9b2ab 100644 --- a/src/psyclone/psyir/tools/call_routine_matcher.py +++ b/src/psyclone/psyir/tools/call_routine_matcher.py @@ -40,9 +40,8 @@ from typing import List, Union from psyclone.psyir.symbols.datatypes import ArrayType -from psyclone.psyir.nodes import Call, Routine from psyclone.errors import PSycloneError -from psyclone.configuration import Config +from psyclone.psyir.nodes import Call, Routine from psyclone.psyir.nodes.container import Container from psyclone.psyir.symbols import ( RoutineSymbol, @@ -53,6 +52,7 @@ SymbolTable, ContainerSymbol, ) +from psyclone.configuration import Config class CallMatchingArgumentsNotFoundError(PSycloneError): @@ -109,7 +109,7 @@ def set_routine_node(self, routine_node: Routine): def set_option(self, check_matching_arguments: bool = None, - check_argument_strict_array_datatype: bool = None, + check_strict_array_datatype: bool = None, ignore_missing_modules: bool = None, ignore_unresolved_symbol: bool = None, ): @@ -117,9 +117,9 @@ def set_option(self, if check_matching_arguments is not None: self._option_check_matching_arguments = check_matching_arguments - if check_argument_strict_array_datatype is not None: + if check_strict_array_datatype is not None: self._option_check_strict_array_datatype = ( - check_argument_strict_array_datatype) + check_strict_array_datatype) if ignore_missing_modules is not None: self._option_ignore_missing_modules = ignore_missing_modules @@ -130,8 +130,7 @@ def set_option(self, def _check_argument_type_matches( self, call_arg: DataSymbol, - routine_arg: DataSymbol, - check_strict_array_datatype: bool = True, + routine_arg: DataSymbol ) -> bool: """Return information whether argument types are matching. This also supports 'optional' arguments by using @@ -151,7 +150,7 @@ def _check_argument_type_matches( """ type_matches = False - if not check_strict_array_datatype: + if not self._option_check_strict_array_datatype: # No strict array checks have to be performed, just accept it if isinstance(call_arg.datatype, ArrayType) and isinstance( routine_arg.datatype, ArrayType @@ -219,8 +218,7 @@ def get_argument_routine_match_list( routine_arg: DataSymbol self._check_argument_type_matches( - call_arg, routine_arg, - self._option_check_strict_array_datatype + call_arg, routine_arg ) ret_arg_idx_list.append(call_arg_idx) @@ -245,10 +243,7 @@ def get_argument_routine_match_list( if arg_name == routine_arg.name: self._check_argument_type_matches( call_arg, - routine_arg, - check_strict_array_datatype=( - self._option_check_strict_array_datatype - ), + routine_arg ) ret_arg_idx_list.append(routine_arg_idx) break @@ -359,9 +354,6 @@ def _location_txt(node): container: Container = ( container_symbol.find_container_psyir( local_node=self._call_node, - ignore_missing_modules=( - ignore_missing_modules - ), ) ) except SymbolError: @@ -500,13 +492,7 @@ def get_callee(self): ''' routine_list = self.get_callee_candidates() - - call_name = self._call_node.routine.name - - if len(routine_list) == 0: - raise NotImplementedError( - f"No routine or interface found for name '{call_name}'" - ) + assert len(routine_list) != 0 err_info_list = [] diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 5c940656dd..62a855b89d 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -149,12 +149,6 @@ def __init__(self): self._call_routine_matcher: CallRoutineMatcher = CallRoutineMatcher() - # If 'True', make strict checks for matching arguments of - # array data types. - # If disabled, it's sufficient that both arguments are of ArrayType. - # Then, no further checks are performed - self._option_check_argument_strict_array_datatype: bool = True - # If 'True', don't inline if a code block is used within the # Routine. self._option_check_codeblocks: bool = True @@ -182,14 +176,58 @@ def set_option( check_argument_of_unsupported_type: bool = None, check_argument_unresolved_symbols: bool = None, ): - if check_argument_strict_array_datatype is not None: - self._option_check_argument_strict_array_datatype = ( - check_argument_strict_array_datatype - ) + """Set special options + + :param ignore_missing_modules: If `True`, raise ModuleNotFound if + module is not available, defaults to None + :type ignore_missing_modules: bool, optional + :param check_argument_strict_array_datatype: + If `True`, make strict checks for matching arguments of + array data types. + If disabled, it's sufficient that both arguments are of ArrayType. + Then, no further checks are performed, defaults to None + :type check_argument_strict_array_datatype: bool, optional + :param check_argument_matching: If `True`, check for all arguments + to match. If `False`, if no matching argument was found, take + 1st one in list. Defaults to None + :type check_argument_matching: bool, optional + :param check_inline_codeblocks: If `True`, raise Exception + if encountering code blocks, defaults to None + :type check_inline_codeblocks: bool, optional + :param check_diff_container_clashes: + If `True` and different symbols share a name but are imported + from different containers, raise Exception. + If `True`, raise Exception if + containers are clashing, defaults to None + :type check_diff_container_clashes: bool, optional + :param check_diff_container_clashes_unres_types: If `True`, + raise Exception if unresolved types are clashing, defaults to None + :type check_diff_container_clashes_unres_types: bool, optional + :param check_resolve_imports: If `True`, also resolve imports, + defaults to None + :type check_resolve_imports: bool, optional + :param check_static_interface: + Check that there are no static variables in the routine + (because we don't know whether the routine is called from + other places). Defaults to None + :type check_static_interface: bool, optional + :param check_array_type: If `True` and argument is an array, + check that inlining is working for this array type, + defaults to None + :type check_array_type: bool, optional + :param check_argument_of_unsupported_type: If `True`, + also perform checks (fail inlining) on arguments of + unsupported type, defaults to None + :type check_argument_of_unsupported_type: bool, optional + :param check_argument_unresolved_symbols: If `True`, + stop if encountering an unresolved symbol, defaults to None + :type check_argument_unresolved_symbols: bool, optional + """ + self._call_routine_matcher.set_option( ignore_missing_modules=ignore_missing_modules) self._call_routine_matcher.set_option( - check_argument_strict_array_datatype=( + check_strict_array_datatype=( check_argument_strict_array_datatype)) self._call_routine_matcher.set_option( check_matching_arguments=check_argument_matching) @@ -296,7 +334,8 @@ def apply( table.merge( routine_table, symbols_to_skip=routine_table.argument_list[:], - # check_unresolved_symbols=self._option_check_argument_unresolved_symbols, + check_unresolved_symbols=( + self._option_check_argument_unresolved_symbols), ) # When constructing new references to replace references to formal @@ -364,7 +403,10 @@ def apply( # the ancestor Routine. This avoids issues like #2424 when # applying ParallelLoopTrans to loops containing inlined calls. if ancestor_table is not scope.symbol_table: - ancestor_table.merge(scope.symbol_table) + ancestor_table.merge( + scope.symbol_table, + check_unresolved_symbols=( + self._option_check_argument_unresolved_symbols)) replacement = type(scope.symbol_table)() scope.symbol_table.detach() replacement.attach(scope) @@ -1283,13 +1325,13 @@ def validate( rm = self._call_routine_matcher rm.set_routine_node(self._routine_node) rm.set_option( - check_argument_strict_array_datatype=False) + check_strict_array_datatype=False) self._ret_arg_match_list = ( rm.get_argument_routine_match_list() ) except CallMatchingArgumentsNotFoundError as err: raise TransformationError( - "Routine's arguments doesn't match subroutine" + "Routine's argument(s) don't match:\n"+str(err) ) from err self._validate_inline_of_call_and_routine( diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index e6e42199bd..323447ab36 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -107,6 +107,32 @@ def test_apply_empty_routine(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) +def test_apply_empty_routine_coverage_option_check_strict_array_datatype( + fortran_reader, fortran_writer, tmpdir): + '''For coverage of particular branch in `inline_trans.py`.''' + code = ( + "module test_mod\n" + "contains\n" + " subroutine run_it()\n" + " integer, dimension(6) :: i\n" + " i = 10\n" + " call sub(i)\n" + " end subroutine run_it\n" + " subroutine sub(idx)\n" + " integer, dimension(:) :: idx\n" + " end subroutine sub\n" + "end module test_mod\n") + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.set_option(check_argument_strict_array_datatype=False) + inline_trans.apply(routine) + output = fortran_writer(psyir) + assert (" i = 10\n\n" + " end subroutine run_it\n" in output) + assert Compile(tmpdir).string_compiles(output) + + def test_apply_single_return(fortran_reader, fortran_writer, tmpdir): '''Check that a call to a routine containing only a return statement is removed. ''' @@ -160,6 +186,44 @@ def test_apply_return_then_cb(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) +def test_apply_provided_routine(fortran_reader, fortran_writer, tmpdir): + ''' Check that the apply() method works also for a provided routine. ''' + code = ( + "module test_mod\n" + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " real :: a(10)\n" + " do i=1,10\n" + " a(i) = 1.0\n" + " call sub(a(i))\n" + " end do\n" + " end subroutine run_it\n" + " subroutine sub(x)\n" + " real, intent(inout) :: x\n" + " x = 2.0*x\n" + " end subroutine sub\n" + " subroutine sub2(x, y)\n" + " real, intent(inout) :: x, y\n" + " x = 2.0*x\n" + " end subroutine sub2\n" + "end module test_mod\n") + psyir = fortran_reader.psyir_from_source(code) + + call = psyir.walk(Call)[0] + + routine = psyir.walk(Routine)[1] + inline_trans = InlineTrans() + inline_trans.apply(call, routine) + + routine = psyir.walk(Routine)[2] + with pytest.raises(TransformationError) as einfo: + inline_trans.apply(call, routine) + + assert "Routine's argument(s) don't match:" in str(einfo.value) + assert "Argument 'y' in subroutine 'sub2' not handled" in str(einfo.value) + + def test_apply_array_arg(fortran_reader, fortran_writer, tmpdir): ''' Check that the apply() method works correctly for a very simple call to a routine with an array reference as argument. ''' @@ -228,6 +292,46 @@ def test_apply_array_access(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) +def test_apply_array_access_check_unresolved_symbols_error( + fortran_reader, fortran_writer, tmpdir): + ''' + This check solely exists for the coverage report to + catch the simple case `if not check_unresolved_symbols:` + in `symbol_table.py` + + ''' + code = ( + "module test_mod\n" + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " real :: a(10)\n" + " do i=1,10\n" + " call sub(a, i)\n" + " end do\n" + " end subroutine run_it\n" + " subroutine sub(x, ivar)\n" + " real, intent(inout), dimension(10) :: x\n" + " integer, intent(in) :: ivar\n" + " integer :: i\n" + " do i = 1, 10\n" + " x(i) = 2.0*ivar\n" + " end do\n" + " end subroutine sub\n" + "end module test_mod\n") + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.set_option(check_argument_unresolved_symbols=False) + inline_trans.apply(routine) + output = fortran_writer(psyir) + assert (" do i = 1, 10, 1\n" + " do i_1 = 1, 10, 1\n" + " a(i_1) = 2.0 * i\n" + " enddo\n" in output) + assert Compile(tmpdir).string_compiles(output) + + def test_apply_gocean_kern(fortran_reader, fortran_writer, monkeypatch): ''' Test the apply method with a typical GOcean kernel. From 3fd04e87c2e7f288e7b03ae49d48d7650711fa5d Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sun, 24 Nov 2024 20:30:01 +0100 Subject: [PATCH 10/20] Updates for CI tests --- doc/Makefile | 10 ++++++++-- src/psyclone/psyir/transformations/inline_trans.py | 8 ++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/doc/Makefile b/doc/Makefile index 9f3cc35e4e..b4cdc968b6 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -1,8 +1,14 @@ -all: +all: developer_guide reference_guide user_guide +.PHONY: developer_guide reference_guide user_guide + +developer_guide: make -C developer_guide html SPHINXOPTS="-W --keep-going" - make -C developer_guide linkcheck || echo "Ignoring error of link checking" + +reference_guide: make -C reference_guide html SPHINXOPTS="-W --keep-going" + +user_guide: make -C user_guide html SPHINXOPTS="-W --keep-going" clean: diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 62a855b89d..eb0e3ce744 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -197,8 +197,6 @@ def set_option( :param check_diff_container_clashes: If `True` and different symbols share a name but are imported from different containers, raise Exception. - If `True`, raise Exception if - containers are clashing, defaults to None :type check_diff_container_clashes: bool, optional :param check_diff_container_clashes_unres_types: If `True`, raise Exception if unresolved types are clashing, defaults to None @@ -228,9 +226,11 @@ def set_option( ignore_missing_modules=ignore_missing_modules) self._call_routine_matcher.set_option( check_strict_array_datatype=( - check_argument_strict_array_datatype)) + check_argument_strict_array_datatype) + ) self._call_routine_matcher.set_option( - check_matching_arguments=check_argument_matching) + check_matching_arguments=check_argument_matching + ) if check_inline_codeblocks is not None: self._option_check_codeblocks = check_inline_codeblocks From e9a56aa88f8451d788f291ed00c89e1bd1de7029 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Mon, 25 Nov 2024 21:35:34 +0100 Subject: [PATCH 11/20] intermediate backup --- utils/run_pytest_cov.sh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/utils/run_pytest_cov.sh b/utils/run_pytest_cov.sh index 7ae7bd36c4..b52a8405cc 100755 --- a/utils/run_pytest_cov.sh +++ b/utils/run_pytest_cov.sh @@ -23,10 +23,13 @@ COV_REPORT="xml:cov.xml" # Additional options # Also write to Terminal -#OPTS=" --cov-report term" +OPTS=" --cov-report term" + +if [[ -e cov.xml ]]; then + echo "Removing previoud reporting file 'cov.xml'" + rm -rf cov.xml +fi -#echo "Running 'pytest --cov $PSYCLONE_MODULE --cov-report term-missing -n $(nproc) $SRC_DIR'" -#pytest --cov $PSYCLONE_MODULE -v --cov-report term-missing -n $(nproc) $SRC_DIR echo "Running 'pytest --cov $PSYCLONE_MODULE --cov-report $COV_REPORT -n $(nproc) $SRC_DIR'" pytest --cov $PSYCLONE_MODULE -v --cov-report $COV_REPORT $OPTS -n $(nproc) $SRC_DIR From 6ebffae432179d7cd3cb8ed842299becaeb574a6 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Sat, 30 Nov 2024 21:12:11 +0100 Subject: [PATCH 12/20] cleanups --- src/psyclone/configuration.py | 9 +++ src/psyclone/psyir/nodes/call.py | 2 +- .../psyir/tools/call_routine_matcher.py | 58 ++++++++++--------- .../psyir/transformations/inline_trans.py | 53 +++++++---------- .../transformations/inline_trans_test.py | 6 +- 5 files changed, 63 insertions(+), 65 deletions(-) diff --git a/src/psyclone/configuration.py b/src/psyclone/configuration.py index 7e65304c5d..02cd272d0e 100644 --- a/src/psyclone/configuration.py +++ b/src/psyclone/configuration.py @@ -449,9 +449,18 @@ def find_file(): if not within_virtual_env(): # 4. /share/psyclone/ _file_paths.append(share_dir) + # 5. /share/psyclone/ _file_paths.extend(pkg_share_dir) + # 6. /config/ + # Search for configuration file relative to this source file + dev_path_list = os.path.split( + os.path.abspath(__file__))[:-1]+( + "..", "..", "config") + dev_path = os.path.abspath(os.path.join(*dev_path_list)) + _file_paths.append(dev_path) + for cfile in [os.path.join(cdir, _FILE_NAME) for cdir in _file_paths]: if os.path.isfile(cfile): return cfile diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index 36ff7e4eb7..a961285116 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -521,7 +521,7 @@ def get_callee( check_strict_array_datatype=( check_strict_array_datatype), ignore_missing_modules=ignore_missing_modules, - ignore_unresolved_symbol=ignore_unresolved_symbol, + ignore_unresolved_types=ignore_unresolved_symbol, ) return call_routine_matcher.get_callee() diff --git a/src/psyclone/psyir/tools/call_routine_matcher.py b/src/psyclone/psyir/tools/call_routine_matcher.py index cdddf9b2ab..389a3f66bc 100644 --- a/src/psyclone/psyir/tools/call_routine_matcher.py +++ b/src/psyclone/psyir/tools/call_routine_matcher.py @@ -39,7 +39,7 @@ # ----------------------------------------------------------------------------- from typing import List, Union -from psyclone.psyir.symbols.datatypes import ArrayType +from psyclone.psyir.symbols.datatypes import ArrayType, UnresolvedType from psyclone.errors import PSycloneError from psyclone.psyir.nodes import Call, Routine from psyclone.psyir.nodes.container import Container @@ -98,8 +98,8 @@ def __init__(self, call_node: Call = None, routine_node: Routine = None): # If 'True', missing modules don't raise an Exception self._option_ignore_missing_modules: bool = False - # If 'True', unresolved symbols don't raise an Exception - self._option_ignore_unresolved_symbol: bool = False + # If 'True', unresolved types don't raise an Exception + self._option_ignore_unresolved_types: bool = False def set_call_node(self, call_node: Call): self._call_node = call_node @@ -111,7 +111,7 @@ def set_option(self, check_matching_arguments: bool = None, check_strict_array_datatype: bool = None, ignore_missing_modules: bool = None, - ignore_unresolved_symbol: bool = None, + ignore_unresolved_types: bool = None, ): if check_matching_arguments is not None: @@ -124,8 +124,8 @@ def set_option(self, if ignore_missing_modules is not None: self._option_ignore_missing_modules = ignore_missing_modules - if ignore_unresolved_symbol is not None: - self._option_ignore_unresolved_symbol = ignore_unresolved_symbol + if ignore_unresolved_types is not None: + self._option_ignore_unresolved_types = ignore_unresolved_types def _check_argument_type_matches( self, @@ -149,32 +149,34 @@ def _check_argument_type_matches( were found. """ - type_matches = False - if not self._option_check_strict_array_datatype: + if self._option_check_strict_array_datatype: # No strict array checks have to be performed, just accept it if isinstance(call_arg.datatype, ArrayType) and isinstance( routine_arg.datatype, ArrayType ): - type_matches = True - - if not type_matches: - if isinstance(routine_arg.datatype, UnsupportedFortranType): - # This could be an 'optional' argument. - # This has at least a partial data type - if call_arg.datatype != routine_arg.datatype.partial_datatype: - raise CallMatchingArgumentsNotFoundError( - "Argument partial type mismatch of call " - f"argument '{call_arg}' and routine argument " - f"'{routine_arg}'" - ) - else: - if call_arg.datatype != routine_arg.datatype: - raise CallMatchingArgumentsNotFoundError( - "Argument type mismatch of call argument " - f"'{call_arg}' with type '{call_arg.datatype} " - "and routine argument " - f"'{routine_arg}' with type '{routine_arg.datatype}." - ) + return True + + if self._option_ignore_unresolved_types: + if isinstance(call_arg.datatype, UnresolvedType): + return True + + if isinstance(routine_arg.datatype, UnsupportedFortranType): + # This could be an 'optional' argument. + # This has at least a partial data type + if call_arg.datatype != routine_arg.datatype.partial_datatype: + raise CallMatchingArgumentsNotFoundError( + "Argument partial type mismatch of call " + f"argument '{call_arg}' and routine argument " + f"'{routine_arg}'" + ) + else: + if call_arg.datatype != routine_arg.datatype: + raise CallMatchingArgumentsNotFoundError( + "Argument type mismatch of call argument " + f"'{call_arg}' with type '{call_arg.datatype}' " + "and routine argument " + f"'{routine_arg}' with type '{routine_arg.datatype}'." + ) return True diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index eb0e3ce744..ff280c44e9 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -158,14 +158,15 @@ def __init__(self): self._option_check_resolve_imports: bool = True self._option_check_static_interface: bool = True self._option_check_array_type: bool = True - self._option_check_argument_of_unsupported_type: bool = True - self._option_check_argument_unresolved_symbols: bool = True + self._option_check_unsupported_type: bool = True + self._option_check_unresolved_symbols: bool = True def set_option( self, ignore_missing_modules: bool = None, check_argument_strict_array_datatype: bool = None, check_argument_matching: bool = None, + check_argument_ignore_unresolved_types: bool = None, check_inline_codeblocks: bool = None, check_diff_container_clashes: bool = None, @@ -173,64 +174,50 @@ def set_option( check_resolve_imports: bool = None, check_static_interface: bool = None, check_array_type: bool = None, - check_argument_of_unsupported_type: bool = None, - check_argument_unresolved_symbols: bool = None, + check_unsupported_type: bool = None, + check_unresolved_symbols: bool = None, ): """Set special options :param ignore_missing_modules: If `True`, raise ModuleNotFound if module is not available, defaults to None - :type ignore_missing_modules: bool, optional :param check_argument_strict_array_datatype: If `True`, make strict checks for matching arguments of array data types. If disabled, it's sufficient that both arguments are of ArrayType. Then, no further checks are performed, defaults to None - :type check_argument_strict_array_datatype: bool, optional :param check_argument_matching: If `True`, check for all arguments to match. If `False`, if no matching argument was found, take 1st one in list. Defaults to None - :type check_argument_matching: bool, optional :param check_inline_codeblocks: If `True`, raise Exception if encountering code blocks, defaults to None - :type check_inline_codeblocks: bool, optional :param check_diff_container_clashes: If `True` and different symbols share a name but are imported from different containers, raise Exception. - :type check_diff_container_clashes: bool, optional :param check_diff_container_clashes_unres_types: If `True`, raise Exception if unresolved types are clashing, defaults to None - :type check_diff_container_clashes_unres_types: bool, optional :param check_resolve_imports: If `True`, also resolve imports, defaults to None - :type check_resolve_imports: bool, optional :param check_static_interface: Check that there are no static variables in the routine (because we don't know whether the routine is called from other places). Defaults to None - :type check_static_interface: bool, optional :param check_array_type: If `True` and argument is an array, check that inlining is working for this array type, defaults to None - :type check_array_type: bool, optional - :param check_argument_of_unsupported_type: If `True`, + :param check_unsupported_type: If `True`, also perform checks (fail inlining) on arguments of unsupported type, defaults to None - :type check_argument_of_unsupported_type: bool, optional :param check_argument_unresolved_symbols: If `True`, stop if encountering an unresolved symbol, defaults to None - :type check_argument_unresolved_symbols: bool, optional """ self._call_routine_matcher.set_option( - ignore_missing_modules=ignore_missing_modules) - self._call_routine_matcher.set_option( - check_strict_array_datatype=( - check_argument_strict_array_datatype) - ) - self._call_routine_matcher.set_option( - check_matching_arguments=check_argument_matching - ) + ignore_missing_modules=ignore_missing_modules, + check_strict_array_datatype=check_argument_strict_array_datatype, + check_matching_arguments=check_argument_matching, + ignore_unresolved_types=check_argument_ignore_unresolved_types + ) if check_inline_codeblocks is not None: self._option_check_codeblocks = check_inline_codeblocks @@ -253,14 +240,14 @@ def set_option( if check_array_type is not None: self._option_check_array_type = check_array_type - if check_argument_of_unsupported_type is not None: - self._option_check_argument_of_unsupported_type = ( - check_argument_of_unsupported_type + if check_unsupported_type is not None: + self._option_check_unsupported_type = ( + check_unsupported_type ) - if check_argument_unresolved_symbols is not None: - self._option_check_argument_unresolved_symbols = ( - check_argument_unresolved_symbols + if check_unresolved_symbols is not None: + self._option_check_unresolved_symbols = ( + check_unresolved_symbols ) def apply( @@ -335,7 +322,7 @@ def apply( routine_table, symbols_to_skip=routine_table.argument_list[:], check_unresolved_symbols=( - self._option_check_argument_unresolved_symbols), + self._option_check_unresolved_symbols), ) # When constructing new references to replace references to formal @@ -406,7 +393,7 @@ def apply( ancestor_table.merge( scope.symbol_table, check_unresolved_symbols=( - self._option_check_argument_unresolved_symbols)) + self._option_check_unresolved_symbols)) replacement = type(scope.symbol_table)() scope.symbol_table.detach() replacement.attach(scope) @@ -1092,7 +1079,7 @@ def _validate_inline_of_call_and_routine( # We don't inline symbols that have an UnsupportedType and are # arguments since we don't know if a simple assignment if # enough (e.g. pointers) - if self._option_check_argument_of_unsupported_type: + if self._option_check_unsupported_type: if isinstance(sym.interface, ArgumentInterface): if isinstance(sym.datatype, UnsupportedType): if ", OPTIONAL" not in sym.datatype.declaration: diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index 323447ab36..ed66858522 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -322,7 +322,7 @@ def test_apply_array_access_check_unresolved_symbols_error( psyir = fortran_reader.psyir_from_source(code) routine = psyir.walk(Call)[0] inline_trans = InlineTrans() - inline_trans.set_option(check_argument_unresolved_symbols=False) + inline_trans.set_option(check_unresolved_symbols=False) inline_trans.apply(routine) output = fortran_writer(psyir) assert (" do i = 1, 10, 1\n" @@ -2287,8 +2287,8 @@ def test_set_options(fortran_reader): check_resolve_imports=False, check_static_interface=False, check_array_type=False, - check_argument_of_unsupported_type=False, - check_argument_unresolved_symbols=False, + check_unsupported_type=False, + check_unresolved_symbols=False, ) From cfa95aa84f235a4a6f5356120f5a47c3ad7ee263 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Wed, 4 Dec 2024 17:48:21 +0100 Subject: [PATCH 13/20] bugfix --- .gitignore | 1 + src/psyclone/psyir/tools/call_routine_matcher.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 0fd8c5b122..3db75bdf3f 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,4 @@ src/*.egg-info .venv cov.xml .coverage.* +__pycache__ diff --git a/src/psyclone/psyir/tools/call_routine_matcher.py b/src/psyclone/psyir/tools/call_routine_matcher.py index 389a3f66bc..7899a74839 100644 --- a/src/psyclone/psyir/tools/call_routine_matcher.py +++ b/src/psyclone/psyir/tools/call_routine_matcher.py @@ -242,7 +242,7 @@ def get_argument_routine_match_list( if routine_arg is None: continue - if arg_name == routine_arg.name: + if arg_name.lower() == routine_arg.name.lower(): self._check_argument_type_matches( call_arg, routine_arg From 3737df5deebcd40bb09c09424aeb28ef3452ca29 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Mon, 13 Jan 2025 23:46:05 +0100 Subject: [PATCH 14/20] merge with master --- .../psyir/tools/call_routine_matcher.py | 50 +++++++++---------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/src/psyclone/psyir/tools/call_routine_matcher.py b/src/psyclone/psyir/tools/call_routine_matcher.py index 7899a74839..05648efa8c 100644 --- a/src/psyclone/psyir/tools/call_routine_matcher.py +++ b/src/psyclone/psyir/tools/call_routine_matcher.py @@ -38,7 +38,7 @@ # Further authors: R. W. Ford, A. R. Porter and S. Siso, STFC Daresbury Lab # ----------------------------------------------------------------------------- -from typing import List, Union +from typing import List, Union, Set from psyclone.psyir.symbols.datatypes import ArrayType, UnresolvedType from psyclone.errors import PSycloneError from psyclone.psyir.nodes import Call, Routine @@ -169,14 +169,24 @@ def _check_argument_type_matches( f"argument '{call_arg}' and routine argument " f"'{routine_arg}'" ) - else: - if call_arg.datatype != routine_arg.datatype: - raise CallMatchingArgumentsNotFoundError( - "Argument type mismatch of call argument " - f"'{call_arg}' with type '{call_arg.datatype}' " - "and routine argument " - f"'{routine_arg}' with type '{routine_arg.datatype}'." - ) + + return True + + if (isinstance(routine_arg.datatype, ArrayType) and + isinstance(call_arg.datatype, ArrayType)): + + # If these are two arrays, only make sure that the types + # match. + if (call_arg.datatype.datatype == routine_arg.datatype.datatype): + return True + + if call_arg.datatype != routine_arg.datatype: + raise CallMatchingArgumentsNotFoundError( + "Argument type mismatch of call argument " + f"'{call_arg}' with type '{call_arg.datatype}' " + "and routine argument " + f"'{routine_arg}' with type '{routine_arg.datatype}'." + ) return True @@ -279,17 +289,12 @@ def get_argument_routine_match_list( return ret_arg_idx_list - def get_callee_candidates(self, ignore_missing_modules: bool = False): + def get_callee_candidates(self) -> List[Routine]: ''' Searches for the implementation(s) of all potential target routines for this Call without any arguments check. - :param ignore_missing_modules: If a module wasn't found, return 'None' - instead of throwing an exception 'ModuleNotFound'. - :type ignore_missing_modules: bool - :returns: the Routine(s) that this call targets. - :rtype: list[:py:class:`psyclone.psyir.nodes.Routine`] :raises NotImplementedError: if the routine is not local and not found in any containers in scope at the call site. @@ -472,25 +477,18 @@ def _location_txt(node): " is within a CodeBlock." ) - def get_callee(self): + def get_callee(self) -> Set[Union[Routine, List[int]]]: ''' Searches for the implementation(s) of the target routine for this Call including argument checks. - :param check_matching_arguments: Also check argument types to match. - If set to `False` and in case it doesn't find matching arguments, - the very first implementation of the matching routine will be - returned (even if the argument type check failed). The argument - types and number of arguments might therefore mismatch! - :type ret_arg_match_list: bool - :returns: A tuple of two elements. The first element is the routine that this call targets. The second one a list of arguments providing the information on matching argument indices. - :rtype: Set[psyclone.psyir.nodes.Routine, List[int]] - :raises NotImplementedError: if the routine is not local and not found - in any containers in scope at the call site. + :raises CallMatchingArgumentsNotFoundError: if the routine is not local + and not found in any containers in scope at the call site or if + the arguments don't match. ''' routine_list = self.get_callee_candidates() From 8f26e80833ee4b83a901638baf2271f66ce909a5 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Tue, 14 Jan 2025 00:12:10 +0100 Subject: [PATCH 15/20] cleanups --- src/psyclone/psyir/nodes/call.py | 68 ++++++++----------- src/psyclone/psyir/symbols/containersymbol.py | 8 +-- .../psyir/tools/call_routine_matcher.py | 14 ++-- src/psyclone/psyir/tools/call_tree_utils.py | 2 +- src/psyclone/tests/psyir/nodes/call_test.py | 20 +++--- utils/run_pytest_cov.sh | 2 +- 6 files changed, 54 insertions(+), 60 deletions(-) diff --git a/src/psyclone/psyir/nodes/call.py b/src/psyclone/psyir/nodes/call.py index 2c8ee0b102..75e65d057b 100644 --- a/src/psyclone/psyir/nodes/call.py +++ b/src/psyclone/psyir/nodes/call.py @@ -111,16 +111,14 @@ def create(cls, routine, arguments=()): ''' if not isinstance(routine, (Reference, RoutineSymbol)): raise TypeError( - "The Call routine argument should be a Reference to a " - "RoutineSymbol or a RoutineSymbol, but " - f"found '{type(routine).__name__}'." - ) + f"The Call routine argument should be a Reference to a " + f"RoutineSymbol or a RoutineSymbol, but " + f"found '{type(routine).__name__}'.") if not isinstance(arguments, Iterable): raise GenerationError( - "Call.create 'arguments' argument should be an Iterable but " - f"found '{type(arguments).__name__}'." - ) + f"Call.create 'arguments' argument should be an Iterable but " + f"found '{type(arguments).__name__}'.") call = cls() if isinstance(routine, Reference): @@ -154,17 +152,15 @@ def _add_args(call, arguments): if isinstance(arg, tuple): if not len(arg) == 2: raise GenerationError( - "If a child of the children argument in create " - "method of Call class is a tuple, it's " - f"length should be 2, but found {len(arg)}." - ) + f"If a child of the children argument in create " + f"method of Call class is a tuple, it's " + f"length should be 2, but found {len(arg)}.") if not isinstance(arg[0], str): raise GenerationError( - "If a child of the children argument in create " - "method of Call class is a tuple, its first " - "argument should be a str, but found " - f"{type(arg[0]).__name__}." - ) + f"If a child of the children argument in create " + f"method of Call class is a tuple, its first " + f"argument should be a str, but found " + f"{type(arg[0]).__name__}.") name, arg = arg call.append_named_arg(name, arg) @@ -184,15 +180,13 @@ def append_named_arg(self, name, arg): # Avoid circular import. # pylint: disable=import-outside-toplevel from psyclone.psyir.frontend.fortran import FortranReader - FortranReader.validate_name(name) for check_name in self.argument_names: if check_name and check_name.lower() == name.lower(): raise ValueError( f"The value of the name argument ({name}) in " - "'append_named_arg' in the 'Call' node is " - "already used for a named argument." - ) + f"'append_named_arg' in the 'Call' node is " + f"already used for a named argument.") self._argument_names.append((id(arg), name)) self.children.append(arg) @@ -215,21 +209,18 @@ def insert_named_arg(self, name, arg, index): # Avoid circular import. # pylint: disable=import-outside-toplevel from psyclone.psyir.frontend.fortran import FortranReader - FortranReader.validate_name(name) for check_name in self.argument_names: if check_name and check_name.lower() == name.lower(): raise ValueError( f"The value of the name argument ({name}) in " - "'insert_named_arg' in the 'Call' node is " - "already used for a named argument." - ) + f"'insert_named_arg' in the 'Call' node is " + f"already used for a named argument.") if not isinstance(index, int): raise TypeError( - "The 'index' argument in 'insert_named_arg' in the " - "'Call' node should be an int but found " - f"{type(index).__name__}." - ) + f"The 'index' argument in 'insert_named_arg' in the " + f"'Call' node should be an int but found " + f"{type(index).__name__}.") self._argument_names.insert(index, (id(arg), name)) # The n'th argument is placed at the n'th+1 children position # because the 1st child is the routine reference @@ -249,10 +240,9 @@ def replace_named_arg(self, existing_name: str, arg: DataNode): ''' if not isinstance(existing_name, str): raise TypeError( - "The 'name' argument in 'replace_named_arg' in the " - "'Call' node should be a string, but found " - f"{type(existing_name).__name__}." - ) + f"The 'name' argument in 'replace_named_arg' in the " + f"'Call' node should be a string, but found " + f"{type(existing_name).__name__}.") index = 0 for _, name in self._argument_names: if name is not None and name.lower() == existing_name: @@ -261,9 +251,8 @@ def replace_named_arg(self, existing_name: str, arg: DataNode): else: raise ValueError( f"The value of the existing_name argument ({existing_name}) " - "in 'replace_named_arg' in the 'Call' node was not found " - "in the existing arguments." - ) + f"in 'replace_named_arg' in the 'Call' node was not found " + f"in the existing arguments.") # The n'th argument is placed at the n'th+1 children position # because the 1st child is the routine reference self.children[index + 1] = arg @@ -418,10 +407,8 @@ def node_str(self, colour=True): :rtype: str ''' - return ( - f"{self.coloured_name(colour)}" - f"[name='{self.routine.debug_string()}']" - ) + return (f"{self.coloured_name(colour)}" + f"[name='{self.routine.debug_string()}']") def __str__(self): return self.node_str(False) @@ -455,6 +442,9 @@ def get_callees(self, ignore_missing_modules: bool = False): Searches for the implementation(s) of all potential target routines for this Call without any arguments check. + Deprecation warning: This only exists for backwards compatibility + reason. It's recommende to directly use `CallRoutineMatcher`. + :param ignore_missing_modules: If a module wasn't found, return 'None' instead of throwing an exception 'ModuleNotFound'. :type ignore_missing_modules: bool diff --git a/src/psyclone/psyir/symbols/containersymbol.py b/src/psyclone/psyir/symbols/containersymbol.py index 008f808a2e..43122c2f71 100644 --- a/src/psyclone/psyir/symbols/containersymbol.py +++ b/src/psyclone/psyir/symbols/containersymbol.py @@ -125,10 +125,8 @@ def copy(self): new_symbol.is_intrinsic = self.is_intrinsic return new_symbol - def find_container_psyir( - self, local_node=None - ): - """Searches for the Container that this Symbol refers to. If it is + def find_container_psyir(self, local_node=None): + ''' Searches for the Container that this Symbol refers to. If it is not available, use the interface to import the container. If `local_node` is supplied then the PSyIR tree below it is searched for the container first. @@ -144,7 +142,7 @@ def find_container_psyir( :returns: referenced container. :rtype: :py:class:`psyclone.psyir.nodes.Container` - """ + ''' if not self._reference: # First check in the current PSyIR tree (if supplied). if local_node: diff --git a/src/psyclone/psyir/tools/call_routine_matcher.py b/src/psyclone/psyir/tools/call_routine_matcher.py index 05648efa8c..d344278eb9 100644 --- a/src/psyclone/psyir/tools/call_routine_matcher.py +++ b/src/psyclone/psyir/tools/call_routine_matcher.py @@ -34,8 +34,9 @@ # ----------------------------------------------------------------------------- # This file is based on gathering various components related to # calls and routines from across psyclone. Hence, there's no clear author. -# Initial author of this file: M. Schreiber, University Grenoble Alpes -# Further authors: R. W. Ford, A. R. Porter and S. Siso, STFC Daresbury Lab +# Authors of gathered files: R. W. Ford, A. R. Porter and +# S. Siso, STFC Daresbury Lab +# Creator/partial author of this file: M. Schreiber, University Grenoble Alpes # ----------------------------------------------------------------------------- from typing import List, Union, Set @@ -174,7 +175,7 @@ def _check_argument_type_matches( if (isinstance(routine_arg.datatype, ArrayType) and isinstance(call_arg.datatype, ArrayType)): - + # If these are two arrays, only make sure that the types # match. if (call_arg.datatype.datatype == routine_arg.datatype.datatype): @@ -263,7 +264,7 @@ def get_argument_routine_match_list( else: # It doesn't match => Raise exception raise CallMatchingArgumentsNotFoundError( - f"Named argument '{arg_name}' not found" + f"Named argument '{arg_name}' not found." ) routine_argument_list[routine_arg_idx] = None @@ -284,7 +285,7 @@ def get_argument_routine_match_list( raise CallMatchingArgumentsNotFoundError( f"Argument '{routine_arg.name}' in subroutine" - f" '{self._routine_node.name}' not handled" + f" '{self._routine_node.name}' not handled." ) return ret_arg_idx_list @@ -521,8 +522,9 @@ def get_callee(self) -> Set[Union[Routine, List[int]]]: error_msg = "\n".join(err_info_list) + s = str(self._call_node.debug_string()).replace("\n", "") raise CallMatchingArgumentsNotFoundError( "Found routines, but no routine with matching arguments found " - f"for '{self._call_node.routine.name}':\n" + f"for '{s}':\n" + error_msg ) diff --git a/src/psyclone/psyir/tools/call_tree_utils.py b/src/psyclone/psyir/tools/call_tree_utils.py index 643426abec..3d11f7020c 100644 --- a/src/psyclone/psyir/tools/call_tree_utils.py +++ b/src/psyclone/psyir/tools/call_tree_utils.py @@ -38,7 +38,7 @@ across different subroutines and modules.''' from psyclone.core import Signature, VariablesAccessInfo -from psyclone.parse.module_manager import ModuleManager +from psyclone.parse import ModuleManager from psyclone.psyGen import BuiltIn, Kern from psyclone.psyir.nodes import Container, Reference from psyclone.psyir.symbols import ( diff --git a/src/psyclone/tests/psyir/nodes/call_test.py b/src/psyclone/tests/psyir/nodes/call_test.py index dd8c4160c3..c04730afab 100644 --- a/src/psyclone/tests/psyir/nodes/call_test.py +++ b/src/psyclone/tests/psyir/nodes/call_test.py @@ -1245,10 +1245,12 @@ def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): with pytest.raises(CallMatchingArgumentsNotFoundError) as err: call_foo.get_callee() - assert ( - "Found routines, but no routine with matching arguments found" - in str(err.value) - ) + assert ("Found routines, but no routine with matching arguments" + " found for 'call foo(e, f, d=g)':" in str(err.value)) + + print(str(err.value)) + assert ("CallMatchingArgumentsNotFound: Named argument" + " 'd' not found." in str(err.value)) def test_call_get_callee_8_arguments_not_handled(fortran_reader): @@ -1285,10 +1287,12 @@ def test_call_get_callee_8_arguments_not_handled(fortran_reader): with pytest.raises(CallMatchingArgumentsNotFoundError) as err: call_foo.get_callee() - assert ( - "Found routines, but no routine with matching arguments found" - in str(err.value) - ) + assert ("CallMatchingArgumentsNotFound: Found routines, but" + " no routine with matching arguments found for 'call" + " foo(e, f)':" in str(err.value)) + + assert ("CallMatchingArgumentsNotFound: Argument 'c' in" + " subroutine 'foo' not handled." in str(err.value)) @pytest.mark.usefixtures("clear_module_manager_instance") diff --git a/utils/run_pytest_cov.sh b/utils/run_pytest_cov.sh index b9d9ee3a12..f8f52e98d3 100755 --- a/utils/run_pytest_cov.sh +++ b/utils/run_pytest_cov.sh @@ -26,7 +26,7 @@ COV_REPORT="xml:cov.xml" OPTS=" --cov-report term" if [[ -e cov.xml ]]; then - echo "Removing previoud reporting file 'cov.xml'" + echo "Removing previous reporting file 'cov.xml'" rm -rf cov.xml fi From 237f56d9a93add2182bc33e36427095891010d2e Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Tue, 14 Jan 2025 00:37:49 +0100 Subject: [PATCH 16/20] u --- .../transformations/inline_trans_test.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index 9c9aa86f46..dc59e924f2 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -1834,9 +1834,10 @@ def test_validate_unsupportedtype_argument(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(routine) + assert ( - "Found routines, but no routine with matching arguments found for" - " 'sub'" + "Transformation Error: Cannot inline routine 'sub'" + " because its source cannot be found:" in str(err.value) ) assert ( @@ -2085,15 +2086,16 @@ def test_validate_wrong_number_args(fortran_reader): inline_trans = InlineTrans() with pytest.raises(TransformationError) as err: inline_trans.validate(call) - assert ( - "Found routines, but no routine with matching arguments found for" - " 'sub':" - in str(err.value) - ) assert ( - "More arguments in call ('call sub(i, trouble)') than callee (routine" - " 'sub')" + "Transformation Error: Cannot inline routine 'sub'" + " because its source cannot be found:\n" + "CallMatchingArgumentsNotFound: Found routines," + " but no routine with matching arguments found" + " for 'call sub(i, trouble)':\n" + "CallMatchingArgumentsNotFound: More arguments" + " in call ('call sub(i, trouble)') than callee" + " (routine 'sub')" in str(err.value) ) @@ -2540,7 +2542,6 @@ def test_apply_optional_and_named_arg_2(fortran_reader): inline_trans.apply(call) - print(routine_main.debug_string()) assert ( '''var = var + 2.0 + 1.0 var = var + 4.0 + 1.0 From cbb65f0a506d96bd85ac928287a878471f32c449 Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Tue, 14 Jan 2025 13:09:15 +0100 Subject: [PATCH 17/20] intermediate commit --- .../psyir/tools/call_routine_matcher_test.py | 455 ++++++++++++++++++ .../transformations/inline_trans_test.py | 134 ++---- 2 files changed, 499 insertions(+), 90 deletions(-) create mode 100644 src/psyclone/tests/psyir/tools/call_routine_matcher_test.py diff --git a/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py b/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py new file mode 100644 index 0000000000..64e2d4ec94 --- /dev/null +++ b/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py @@ -0,0 +1,455 @@ +# ----------------------------------------------------------------------------- +# BSD 3-Clause License +# +# Copyright (c) 2020-2025, Science and Technology Facilities Council. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# ----------------------------------------------------------------------------- +# Author: Martin Schreiber, Univ. Grenoble Alpes / LJK / Inria +# ----------------------------------------------------------------------------- + + +import pytest +from psyclone.psyir.tools.call_routine_matcher import ( + CallRoutineMatcher, + CallMatchingArgumentsNotFoundError, +) +from psyclone.psyir.nodes import Call, Node, Routine +from psyclone.psyir.transformations import InlineTrans + + +def test_apply_optional_and_named_arg(fortran_reader): + """Test that the validate method inlines a routine + that has an optional argument.""" + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var, named=1.0)\n" + " ! Result:\n" + " ! var = var + 1.0 + 1.0\n" + " call sub(var, 2.0, named=1.0)\n" + " ! Result:\n" + " ! var = var + 2.0\n" + " ! var = var + 1.0 + 1.0\n" + "end subroutine main\n" + "subroutine sub(x, opt, named)\n" + " real, intent(inout) :: x\n" + " real, optional :: opt\n" + " real :: named\n" + " if( present(opt) )then\n" + " x = x + opt\n" + " end if\n" + " x = x + 1.0 + named\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + psyir: Node = fortran_reader.psyir_from_source(code) + + inline_trans = InlineTrans() + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + for call in psyir.walk(Call, stop_type=Call): + call: Call + if call.routine.name != "sub": + continue + + inline_trans.apply(call) + + assert ( + """var = var + 1.0 + 1.0 + var = var + 2.0 + var = var + 1.0 + 1.0""" + in routine_main.debug_string() + ) + + +def test_unresolved_types(fortran_reader): + """Test that the validate method inlines a routine that has a named + argument.""" + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var, opt=1.0)\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" + " real :: opt\n" + " x = x + 1.0\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + + crm = CallRoutineMatcher(call) + + crm.set_option(ignore_unresolved_types=True) + crm.get_callee_candidates() + + +def test_call_get_callee_1_simple_match(fortran_reader): + """ + Check that the right routine has been found for a single routine + implementation. + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(e, f, g) + end subroutine + + ! Matching routine + subroutine foo(a, b, c) + integer :: a, b, c + end subroutine + +end module some_mod""" + + psyir = fortran_reader.psyir_from_source(code) + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + + call_foo: Call = routine_main.walk(Call)[0] + + (result, _) = call_foo.get_callee() + + routine_match: Routine = psyir.walk(Routine)[1] + assert result is routine_match + + +def test_call_get_callee_2_optional_args(fortran_reader): + """ + Check that optional arguments have been correlated correctly. + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(e, f) + end subroutine + + ! Matching routine + subroutine foo(a, b, c) + integer :: a, b + integer, optional :: c + end subroutine + +end module some_mod""" + + root_node: Node = fortran_reader.psyir_from_source(code) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_match: Routine = root_node.walk(Routine)[1] + assert routine_match.name == "foo" + + call_foo: Call = routine_main.walk(Call)[0] + assert call_foo.routine.name == "foo" + + (result, arg_idx_list) = call_foo.get_callee() + result: Routine + + assert len(arg_idx_list) == 2 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + + assert result is routine_match + + +def test_call_get_callee_3a_trigger_error(fortran_reader): + """ + Test which is supposed to trigger an error when no matching routine + is found + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(e, f, g) + end subroutine + + ! Matching routine + subroutine foo(a, b) + integer :: a, b + end subroutine + +end module some_mod""" + + root_node: Node = fortran_reader.psyir_from_source(code) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + call_foo: Call = routine_main.walk(Call)[0] + assert call_foo.routine.name == "foo" + + with pytest.raises(CallMatchingArgumentsNotFoundError) as err: + call_foo.get_callee() + + assert ( + "Found routines, but no routine with matching arguments found" + in str(err.value) + ) + + +def test_call_get_callee_3c_trigger_error(fortran_reader): + """ + Test which is supposed to trigger an error when no matching routine + is found, but we use the special option check_matching_arguments=False + to find one. + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(e, f, g) + end subroutine + + ! Matching routine + subroutine foo(a, b) + integer :: a, b + end subroutine + +end module some_mod""" + + root_node: Node = fortran_reader.psyir_from_source(code) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + call_foo: Call = routine_main.walk(Call)[0] + assert call_foo.routine.name == "foo" + + call_foo.get_callee(check_matching_arguments=False) + + +def test_call_get_callee_4_named_arguments(fortran_reader): + """ + Check that named arguments have been correlated correctly + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(c=e, a=f, b=g) + end subroutine + + ! Matching routine + subroutine foo(a, b, c) + integer :: a, b, c + end subroutine + +end module some_mod""" + + root_node: Node = fortran_reader.psyir_from_source(code) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_match: Routine = root_node.walk(Routine)[1] + assert routine_match.name == "foo" + + call_foo: Call = routine_main.walk(Call)[0] + assert call_foo.routine.name == "foo" + + (result, arg_idx_list) = call_foo.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 2 + assert arg_idx_list[1] == 0 + assert arg_idx_list[2] == 1 + + assert result is routine_match + + +def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): + """ + Check that optional and named arguments have been correlated correctly + when the call is to a generic interface. + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + call foo(b=e, a=f) + end subroutine + + ! Matching routine + subroutine foo(a, b, c) + integer :: a, b + integer, optional :: c + end subroutine + +end module some_mod""" + + root_node: Node = fortran_reader.psyir_from_source(code) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_match: Routine = root_node.walk(Routine)[1] + assert routine_match.name == "foo" + + call_foo: Call = routine_main.walk(Call)[0] + assert call_foo.routine.name == "foo" + + (result, arg_idx_list) = call_foo.get_callee() + result: Routine + + assert len(arg_idx_list) == 2 + assert arg_idx_list[0] == 1 + assert arg_idx_list[1] == 0 + + assert result is routine_match + + +_code_test_get_callee_6 = """ +module some_mod + implicit none + + interface foo + procedure foo_a, foo_b, foo_c, foo_optional + end interface +contains + + subroutine main() + integer :: e_int, f_int, g_int + real :: e_real, f_real, g_real + + ! Should match foo_a, test_call_get_callee_6_interfaces_0_0 + call foo(e_int, f_int) + + ! Should match foo_a, test_call_get_callee_6_interfaces_0_1 + call foo(e_int, f_int, g_int) + + ! Should match foo_b, test_call_get_callee_6_interfaces_1_0 + call foo(e_real, f_int) + + ! Should match foo_b, test_call_get_callee_6_interfaces_1_1 + call foo(e_real, f_int, g_int) + + ! Should match foo_b, test_call_get_callee_6_interfaces_1_2 + call foo(e_real, c=f_int, b=g_int) + + ! Should match foo_c, test_call_get_callee_6_interfaces_2_0 + call foo(e_int, f_real, g_int) + + ! Should match foo_c, test_call_get_callee_6_interfaces_2_1 + call foo(b=e_real, a=f_int) + + ! Should match foo_c, test_call_get_callee_6_interfaces_2_2 + call foo(b=e_real, a=f_int, g_int) + + ! Should not match foo_optional because of invalid type, + ! test_call_get_callee_6_interfaces_3_0_mismatch + call foo(f_int, e_real, g_int, g_int) + end subroutine + + subroutine foo_a(a, b, c) + integer :: a, b + integer, optional :: c + end subroutine + + subroutine foo_b(a, b, c) + real :: a + integer :: b + integer, optional :: c + end subroutine + + subroutine foo_c(a, b, c) + integer :: a + real :: b + integer, optional :: c + end subroutine + + subroutine foo_optional(a, b, c, d) + integer :: a + real :: b + integer :: c + real, optional :: d ! real vs. int + end subroutine + + +end module some_mod""" + + +def test_set_routine(fortran_reader): + """Test the routine setter (not in the constructor).""" + + code = ( + "module test_mod\n" + "contains\n" + "subroutine main\n" + " real :: var = 0.0\n" + " call sub(var, opt=1.0)\n" + "end subroutine main\n" + "subroutine sub(x, opt)\n" + " real, intent(inout) :: x\n" + " real :: opt\n" + " x = x + 1.0\n" + "end subroutine sub\n" + "end module test_mod\n" + ) + + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + routine = psyir.walk(Routine)[0] + + crm = CallRoutineMatcher() + crm.set_call_node(call) + crm.set_routine_node(routine) diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index dc59e924f2..7226d71065 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -107,6 +107,50 @@ def test_apply_empty_routine(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) +def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): + ''' + Check that the formal arguments to the inlined routine are not included + when checking for clashes (since they will be replaced by the actual + arguments to the call). + ''' + + code_clash = """ + subroutine sub(Istr) + integer :: Istr + real :: x + x = 2.0*x + call sub_sub(Istr) + end subroutine sub + + subroutine sub_sub(Istr) + integer :: i + integer :: Istr + real :: b(10) + + b(Istr:10) = 1.0 + end subroutine sub_sub""" + + psyir = fortran_reader.psyir_from_source(code_clash) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.apply(call) + expected = '''\ +subroutine sub(istr) + integer :: istr + real :: x + integer :: i + real, dimension(10) :: b + + x = 2.0 * x + b(istr:) = 1.0 + +end subroutine sub +''' + output = fortran_writer(psyir) + assert expected in output + assert Compile(tmpdir).string_compiles(output) + + def test_apply_empty_routine_coverage_option_check_strict_array_datatype( fortran_reader, fortran_writer, tmpdir): '''For coverage of particular branch in `inline_trans.py`.''' @@ -2444,53 +2488,6 @@ def test_apply_unsupported_pointer_error(fortran_reader): " 'REAL, INTENT(INOUT), POINTER :: x'." in str(einfo.value)) -def test_apply_optional_and_named_arg(fortran_reader): - '''Test that the validate method inlines a routine - that has an optional argument.''' - code = ( - "module test_mod\n" - "contains\n" - "subroutine main\n" - " real :: var = 0.0\n" - " call sub(var, named=1.0)\n" - " ! Result:\n" - " ! var = var + 1.0 + 1.0\n" - " call sub(var, 2.0, named=1.0)\n" - " ! Result:\n" - " ! var = var + 2.0\n" - " ! var = var + 1.0 + 1.0\n" - "end subroutine main\n" - "subroutine sub(x, opt, named)\n" - " real, intent(inout) :: x\n" - " real, optional :: opt\n" - " real :: named\n" - " if( present(opt) )then\n" - " x = x + opt\n" - " end if\n" - " x = x + 1.0 + named\n" - "end subroutine sub\n" - "end module test_mod\n" - ) - psyir: Node = fortran_reader.psyir_from_source(code) - - inline_trans = InlineTrans() - - routine_main: Routine = psyir.walk(Routine)[0] - assert routine_main.name == "main" - for call in psyir.walk(Call, stop_type=Call): - call: Call - if call.routine.name != "sub": - continue - - inline_trans.apply(call) - - assert ( - '''var = var + 1.0 + 1.0 - var = var + 2.0 - var = var + 1.0 + 1.0''' - in routine_main.debug_string() - ) - def test_apply_optional_and_named_arg_2(fortran_reader): '''Test that the validate method inlines a routine @@ -2604,46 +2601,3 @@ def test_apply_merges_symbol_table_with_routine(fortran_reader): # The i_1 symbol is the renamed i from the inlined call. assert psyir.walk(Routine)[0].symbol_table.get_symbols()['i_1'] is not None - -def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): - ''' - Check that the formal arguments to the inlined routine are not included - when checking for clashes (since they will be replaced by the actual - arguments to the call). - ''' - - code_clash = """ - subroutine sub(Istr) - integer :: Istr - real :: x - x = 2.0*x - call sub_sub(Istr) - end subroutine sub - - subroutine sub_sub(Istr) - integer :: i - integer :: Istr - real :: b(10) - - b(Istr:10) = 1.0 - end subroutine sub_sub""" - - psyir = fortran_reader.psyir_from_source(code_clash) - call = psyir.walk(Call)[0] - inline_trans = InlineTrans() - inline_trans.apply(call) - expected = '''\ -subroutine sub(istr) - integer :: istr - real :: x - integer :: i - real, dimension(10) :: b - - x = 2.0 * x - b(istr:) = 1.0 - -end subroutine sub -''' - output = fortran_writer(psyir) - assert expected in output - assert Compile(tmpdir).string_compiles(output) From f77715ae34861e5cc431052b129504fcbcd9027f Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Tue, 14 Jan 2025 20:32:10 +0100 Subject: [PATCH 18/20] coverage test cases --- src/psyclone/tests/psyir/nodes/call_test.py | 1045 +---------------- .../psyir/tools/call_routine_matcher_test.py | 969 ++++++++++++++- .../transformations/inline_trans_test.py | 64 - 3 files changed, 970 insertions(+), 1108 deletions(-) diff --git a/src/psyclone/tests/psyir/nodes/call_test.py b/src/psyclone/tests/psyir/nodes/call_test.py index c04730afab..342d4dbc82 100644 --- a/src/psyclone/tests/psyir/nodes/call_test.py +++ b/src/psyclone/tests/psyir/nodes/call_test.py @@ -41,15 +41,11 @@ from psyclone.configuration import Config from psyclone.core import Signature, VariablesAccessInfo from psyclone.errors import GenerationError -from psyclone.parse import ModuleManager from psyclone.psyir.nodes import ( ArrayReference, - Assignment, BinaryOperation, Call, - CodeBlock, Literal, - Node, Reference, Routine, Schedule, @@ -62,8 +58,7 @@ NoType, RoutineSymbol, REAL_TYPE, - SymbolError, - UnsupportedFortranType, + SymbolError ) from psyclone.psyir.tools.call_routine_matcher import ( @@ -622,638 +617,7 @@ def test_copy(): assert call._argument_names != call2._argument_names -def test_call_get_callees_local(fortran_reader): - ''' - Check that get_callees() works as expected when the target of the Call - exists in the same Container as the call site. - ''' - code = ''' -module some_mod - implicit none - integer :: luggage -contains - subroutine top() - luggage = 0 - call bottom() - end subroutine top - - subroutine bottom() - luggage = luggage + 1 - end subroutine bottom -end module some_mod''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - result = call.get_callees() - assert result == [psyir.walk(Routine)[1]] - - -def test_call_get_callee_1_simple_match(fortran_reader): - ''' - Check that the right routine has been found for a single routine - implementation. - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(e, f, g) - end subroutine - - ! Matching routine - subroutine foo(a, b, c) - integer :: a, b, c - end subroutine - -end module some_mod''' - - psyir = fortran_reader.psyir_from_source(code) - - routine_main: Routine = psyir.walk(Routine)[0] - assert routine_main.name == "main" - - call_foo: Call = routine_main.walk(Call)[0] - - (result, _) = call_foo.get_callee() - - routine_match: Routine = psyir.walk(Routine)[1] - assert result is routine_match - - -def test_call_get_callee_2_optional_args(fortran_reader): - ''' - Check that optional arguments have been correlated correctly. - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(e, f) - end subroutine - - ! Matching routine - subroutine foo(a, b, c) - integer :: a, b - integer, optional :: c - end subroutine - -end module some_mod''' - - root_node: Node = fortran_reader.psyir_from_source(code) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_match: Routine = root_node.walk(Routine)[1] - assert routine_match.name == "foo" - - call_foo: Call = routine_main.walk(Call)[0] - assert call_foo.routine.name == "foo" - - (result, arg_idx_list) = call_foo.get_callee() - result: Routine - - assert len(arg_idx_list) == 2 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - - assert result is routine_match - - -def test_call_get_callee_3a_trigger_error(fortran_reader): - ''' - Test which is supposed to trigger an error when no matching routine - is found - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(e, f, g) - end subroutine - - ! Matching routine - subroutine foo(a, b) - integer :: a, b - end subroutine - -end module some_mod''' - - root_node: Node = fortran_reader.psyir_from_source(code) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - call_foo: Call = routine_main.walk(Call)[0] - assert call_foo.routine.name == "foo" - - with pytest.raises(CallMatchingArgumentsNotFoundError) as err: - call_foo.get_callee() - - assert ( - "Found routines, but no routine with matching arguments found" - in str(err.value) - ) - - -def test_call_get_callee_3c_trigger_error(fortran_reader): - ''' - Test which is supposed to trigger an error when no matching routine - is found, but we use the special option check_matching_arguments=False - to find one. - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(e, f, g) - end subroutine - - ! Matching routine - subroutine foo(a, b) - integer :: a, b - end subroutine - -end module some_mod''' - - root_node: Node = fortran_reader.psyir_from_source(code) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - call_foo: Call = routine_main.walk(Call)[0] - assert call_foo.routine.name == "foo" - - call_foo.get_callee(check_matching_arguments=False) - - -def test_call_get_callee_4_named_arguments(fortran_reader): - ''' - Check that named arguments have been correlated correctly - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(c=e, a=f, b=g) - end subroutine - - ! Matching routine - subroutine foo(a, b, c) - integer :: a, b, c - end subroutine - -end module some_mod''' - - root_node: Node = fortran_reader.psyir_from_source(code) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_match: Routine = root_node.walk(Routine)[1] - assert routine_match.name == "foo" - - call_foo: Call = routine_main.walk(Call)[0] - assert call_foo.routine.name == "foo" - - (result, arg_idx_list) = call_foo.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 2 - assert arg_idx_list[1] == 0 - assert arg_idx_list[2] == 1 - - assert result is routine_match - - -def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): - ''' - Check that optional and named arguments have been correlated correctly - when the call is to a generic interface. - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - call foo(b=e, a=f) - end subroutine - - ! Matching routine - subroutine foo(a, b, c) - integer :: a, b - integer, optional :: c - end subroutine - -end module some_mod''' - - root_node: Node = fortran_reader.psyir_from_source(code) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_match: Routine = root_node.walk(Routine)[1] - assert routine_match.name == "foo" - - call_foo: Call = routine_main.walk(Call)[0] - assert call_foo.routine.name == "foo" - - (result, arg_idx_list) = call_foo.get_callee() - result: Routine - - assert len(arg_idx_list) == 2 - assert arg_idx_list[0] == 1 - assert arg_idx_list[1] == 0 - - assert result is routine_match - - -_code_test_get_callee_6 = ''' -module some_mod - implicit none - - interface foo - procedure foo_a, foo_b, foo_c, foo_optional - end interface -contains - - subroutine main() - integer :: e_int, f_int, g_int - real :: e_real, f_real, g_real - - ! Should match foo_a, test_call_get_callee_6_interfaces_0_0 - call foo(e_int, f_int) - - ! Should match foo_a, test_call_get_callee_6_interfaces_0_1 - call foo(e_int, f_int, g_int) - - ! Should match foo_b, test_call_get_callee_6_interfaces_1_0 - call foo(e_real, f_int) - - ! Should match foo_b, test_call_get_callee_6_interfaces_1_1 - call foo(e_real, f_int, g_int) - - ! Should match foo_b, test_call_get_callee_6_interfaces_1_2 - call foo(e_real, c=f_int, b=g_int) - - ! Should match foo_c, test_call_get_callee_6_interfaces_2_0 - call foo(e_int, f_real, g_int) - - ! Should match foo_c, test_call_get_callee_6_interfaces_2_1 - call foo(b=e_real, a=f_int) - - ! Should match foo_c, test_call_get_callee_6_interfaces_2_2 - call foo(b=e_real, a=f_int, g_int) - - ! Should not match foo_optional because of invalid type, - ! test_call_get_callee_6_interfaces_3_0_mismatch - call foo(f_int, e_real, g_int, g_int) - end subroutine - - subroutine foo_a(a, b, c) - integer :: a, b - integer, optional :: c - end subroutine - - subroutine foo_b(a, b, c) - real :: a - integer :: b - integer, optional :: c - end subroutine - - subroutine foo_c(a, b, c) - integer :: a - real :: b - integer, optional :: c - end subroutine - - subroutine foo_optional(a, b, c, d) - integer :: a - real :: b - integer :: c - real, optional :: d ! real vs. int - end subroutine - - -end module some_mod''' - - -def test_call_get_callee_6_interfaces_0_0(fortran_reader): - ''' - Check that a non-existing optional argument at the end of the list - has been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_a: Routine = root_node.walk(Routine)[1] - assert routine_foo_a.name == "foo_a" - - call_foo_a: Call = routine_main.walk(Call)[0] - assert call_foo_a.routine.name == "foo" - - (result, arg_idx_list) = call_foo_a.get_callee() - result: Routine - - assert len(arg_idx_list) == 2 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - - assert result is routine_foo_a - - -def test_call_get_callee_6_interfaces_0_1(fortran_reader): - ''' - Check that an existing optional argument at the end of the list - has been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_a: Routine = root_node.walk(Routine)[1] - assert routine_foo_a.name == "foo_a" - - call_foo_a: Call = routine_main.walk(Call)[1] - assert call_foo_a.routine.name == "foo" - - (result, arg_idx_list) = call_foo_a.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - assert arg_idx_list[2] == 2 - - assert result is routine_foo_a - - -def test_call_get_callee_6_interfaces_1_0(fortran_reader): - ''' - Check that - - different argument types and - - non-existing optional argument at the end of the list - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_b: Routine = root_node.walk(Routine)[2] - assert routine_foo_b.name == "foo_b" - - call_foo_b: Call = routine_main.walk(Call)[2] - assert call_foo_b.routine.name == "foo" - - (result, arg_idx_list) = call_foo_b.get_callee() - result: Routine - - assert len(arg_idx_list) == 2 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - - assert result is routine_foo_b - - -def test_call_get_callee_6_interfaces_1_1(fortran_reader): - ''' - Check that - - different argument types and - - existing optional argument at the end of the list - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_b: Routine = root_node.walk(Routine)[2] - assert routine_foo_b.name == "foo_b" - - call_foo_b: Call = routine_main.walk(Call)[3] - assert call_foo_b.routine.name == "foo" - - (result, arg_idx_list) = call_foo_b.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - assert arg_idx_list[2] == 2 - - assert result is routine_foo_b - - -def test_call_get_callee_6_interfaces_1_2(fortran_reader): - ''' - Check that - - different argument types and - - naming arguments resulting in a different order - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_b: Routine = root_node.walk(Routine)[2] - assert routine_foo_b.name == "foo_b" - - call_foo_b: Call = routine_main.walk(Call)[4] - assert call_foo_b.routine.name == "foo" - - (result, arg_idx_list) = call_foo_b.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 2 - assert arg_idx_list[2] == 1 - - assert result is routine_foo_b - - -def test_call_get_callee_6_interfaces_2_0(fortran_reader): - ''' - Check that - - different argument types (different order than in tests before) - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_c: Routine = root_node.walk(Routine)[3] - assert routine_foo_c.name == "foo_c" - - call_foo_c: Call = routine_main.walk(Call)[5] - assert call_foo_c.routine.name == "foo" - - (result, arg_idx_list) = call_foo_c.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 0 - assert arg_idx_list[1] == 1 - assert arg_idx_list[2] == 2 - - assert result is routine_foo_c - - -def test_call_get_callee_6_interfaces_2_1(fortran_reader): - ''' - Check that - - different argument types (different order than in tests before) and - - naming arguments resulting in a different order and - - optional argument - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_c: Routine = root_node.walk(Routine)[3] - assert routine_foo_c.name == "foo_c" - - call_foo_c: Call = routine_main.walk(Call)[6] - assert call_foo_c.routine.name == "foo" - - (result, arg_idx_list) = call_foo_c.get_callee() - result: Routine - - assert len(arg_idx_list) == 2 - assert arg_idx_list[0] == 1 - assert arg_idx_list[1] == 0 - - assert result is routine_foo_c - - -def test_call_get_callee_6_interfaces_2_2(fortran_reader): - ''' - Check that - - different argument types (different order than in tests before) and - - naming arguments resulting in a different order and - - last call argument without naming - have been correctly determined. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_c: Routine = root_node.walk(Routine)[3] - assert routine_foo_c.name == "foo_c" - - call_foo_c: Call = routine_main.walk(Call)[7] - assert call_foo_c.routine.name == "foo" - - (result, arg_idx_list) = call_foo_c.get_callee() - result: Routine - - assert len(arg_idx_list) == 3 - assert arg_idx_list[0] == 1 - assert arg_idx_list[1] == 0 - assert arg_idx_list[2] == 2 - - assert result is routine_foo_c - - -def test_call_get_callee_6_interfaces_3_0_mismatch(fortran_reader): - ''' - Check that matching a partial data type can also go wrong. - ''' - - root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) - - routine_main: Routine = root_node.walk(Routine)[0] - assert routine_main.name == "main" - - routine_foo_optional: Routine = root_node.walk(Routine)[4] - assert routine_foo_optional.name == "foo_optional" - - call_foo_optional: Call = routine_main.walk(Call)[8] - assert call_foo_optional.routine.name == "foo" - - with pytest.raises(CallMatchingArgumentsNotFoundError) as einfo: - call_foo_optional.get_callee() - - assert "Argument partial type mismatch of call argument" in ( - str(einfo.value)) - - -def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): - ''' - Trigger error that matching arguments were not found - ''' - code = ''' -module some_mod - implicit none -contains - - subroutine main() - integer :: e, f, g - ! Use named argument 'd', which doesn't exist - ! to trigger an error when searching for the matching routine. - call foo(e, f, d=g) - end subroutine - - ! Matching routine - subroutine foo(a, b, c) - integer :: a, b, c - end subroutine - -end module some_mod''' - - psyir = fortran_reader.psyir_from_source(code) - - routine_main: Routine = psyir.walk(Routine)[0] - assert routine_main.name == "main" - - call_foo: Call = routine_main.walk(Call)[0] - - with pytest.raises(CallMatchingArgumentsNotFoundError) as err: - call_foo.get_callee() - - assert ("Found routines, but no routine with matching arguments" - " found for 'call foo(e, f, d=g)':" in str(err.value)) - - print(str(err.value)) - assert ("CallMatchingArgumentsNotFound: Named argument" - " 'd' not found." in str(err.value)) - - -def test_call_get_callee_8_arguments_not_handled(fortran_reader): +def test_call_get_callee_arguments_not_handled(fortran_reader): ''' Trigger error that matching arguments were not found. In this test, this is caused by omitting the required third non-optional @@ -1295,411 +659,6 @@ def test_call_get_callee_8_arguments_not_handled(fortran_reader): " subroutine 'foo' not handled." in str(err.value)) -@pytest.mark.usefixtures("clear_module_manager_instance") -def test_call_get_callees_unresolved(fortran_reader, tmpdir, monkeypatch): - ''' - Test that get_callees() raises the expected error if the called routine - is unresolved. - ''' - code = ''' -subroutine top() - call bottom() -end subroutine top''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and there are " - "no wildcard imports." in str(err.value)) - # Repeat but in the presence of a wildcard import. - code = ''' -subroutine top() - use some_mod_somewhere - call bottom() -end subroutine top''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and attempted " - "to resolve the wildcard imports from ['some_mod_somewhere']. " - "However, failed to find the source for ['some_mod_somewhere']. " - "The module search path is set to []" in str(err.value)) - # Repeat but when some_mod_somewhere *is* resolved but doesn't help us - # find the routine we're looking for. - mod_manager = ModuleManager.get() - monkeypatch.setattr(mod_manager, "_instance", None) - path = str(tmpdir) - monkeypatch.setattr(Config.get(), '_include_paths', [path]) - with open(os.path.join(path, "some_mod_somewhere.f90"), "w", - encoding="utf-8") as ofile: - ofile.write('''\ -module some_mod_somewhere -end module some_mod_somewhere -''') - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine 'bottom'" - " - looked at any routines in the same source file and wildcard " - "imports from ['some_mod_somewhere']." in str(err.value)) - mod_manager = ModuleManager.get() - monkeypatch.setattr(mod_manager, "_instance", None) - code = ''' -subroutine top() - use another_mod, only: this_one - call this_one() -end subroutine top''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("RoutineSymbol 'this_one' is imported from Container 'another_mod'" - " but the source defining that container could not be found. The " - "module search path is set to [" in str(err.value)) - - -def test_call_get_callees_interface(fortran_reader): - ''' - Check that get_callees() works correctly when the target of a call is - actually a generic interface. - ''' - code = ''' -module my_mod - - interface bottom - module procedure :: rbottom, ibottom - end interface bottom -contains - subroutine top() - integer :: luggage - luggage = 0 - call bottom(luggage) - end subroutine top - - subroutine ibottom(luggage) - integer :: luggage - luggage = luggage + 1 - end subroutine ibottom - - subroutine rbottom(luggage) - real :: luggage - luggage = luggage + 1.0 - end subroutine rbottom -end module my_mod -''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - callees = call.get_callees() - assert len(callees) == 2 - assert isinstance(callees[0], Routine) - assert callees[0].name == "rbottom" - assert isinstance(callees[1], Routine) - assert callees[1].name == "ibottom" - - -def test_call_get_callees_unsupported_type(fortran_reader): - ''' - Check that get_callees() raises the expected error when the called routine - is of UnsupportedFortranType. This is hard to achieve so we have to - manually construct some aspects of the test case. - - ''' - code = ''' -module my_mod - integer, target :: value -contains - subroutine top() - integer :: luggage - luggage = bottom() - end subroutine top - function bottom() result(fval) - integer, pointer :: fval - fval => value - end function bottom -end module my_mod -''' - psyir = fortran_reader.psyir_from_source(code) - container = psyir.children[0] - routine = container.find_routine_psyir("bottom") - rsym = container.symbol_table.lookup(routine.name) - # Ensure the type of this RoutineSymbol is UnsupportedFortranType. - rsym.datatype = UnsupportedFortranType("integer, pointer :: fval") - assign = container.walk(Assignment)[0] - # Currently `bottom()` gets matched by fparser2 as a structure constructor - # and the fparser2 frontend leaves this as a CodeBlock (TODO #2429) so - # replace it with a Call. Once #2429 is fixed the next two lines can be - # removed. - assert isinstance(assign.rhs, CodeBlock) - assign.rhs.replace_with(Call.create(rsym)) - call = psyir.walk(Call)[0] - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("RoutineSymbol 'bottom' exists in Container 'my_mod' but is of " - "UnsupportedFortranType" in str(err.value)) - - -def test_call_get_callees_file_container(fortran_reader): - ''' - Check that get_callees works if the called routine happens to be in file - scope, even when there's no Container. - ''' - code = ''' - subroutine top() - integer :: luggage - luggage = 0 - call bottom(luggage) - end subroutine top - - subroutine bottom(luggage) - integer :: luggage - luggage = luggage + 1 - end subroutine bottom -''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - result = call.get_callees() - assert len(result) == 1 - assert isinstance(result[0], Routine) - assert result[0].name == "bottom" - - -def test_call_get_callees_no_container(fortran_reader): - ''' - Check that get_callees() raises the expected error when the Call is not - within a Container and the target routine cannot be found. - ''' - # To avoid having the routine symbol immediately dismissed as - # unresolved, the code that we initially process *does* have a Container. - code = ''' -module my_mod - -contains - subroutine top() - integer :: luggage - luggage = 0 - call bottom(luggage) - end subroutine top - - subroutine bottom(luggage) - integer :: luggage - luggage = luggage + 1 - end subroutine bottom -end module my_mod -''' - psyir = fortran_reader.psyir_from_source(code) - top_routine = psyir.walk(Routine)[0] - # Deliberately make the Routine node an orphan so there's no Container. - top_routine.detach() - call = top_routine.walk(Call)[0] - with pytest.raises(SymbolError) as err: - _ = call.get_callees() - assert ("Failed to find a Routine named 'bottom' in code:\n'subroutine " - "top()" in str(err.value)) - - -def test_call_get_callees_wildcard_import_local_container(fortran_reader): - ''' - Check that get_callees() works successfully for a routine accessed via - a wildcard import from another module in the same file. - ''' - code = ''' -module some_mod -contains - subroutine just_do_it() - write(*,*) "hello" - end subroutine just_do_it -end module some_mod -module other_mod - use some_mod -contains - subroutine run_it() - call just_do_it() - end subroutine run_it -end module other_mod -''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - routines = call.get_callees() - assert len(routines) == 1 - assert isinstance(routines[0], Routine) - assert routines[0].name == "just_do_it" - - -def test_call_get_callees_import_local_container(fortran_reader): - ''' - Check that get_callees() works successfully for a routine accessed via - a specific import from another module in the same file. - ''' - code = ''' -module some_mod -contains - subroutine just_do_it() - write(*,*) "hello" - end subroutine just_do_it -end module some_mod -module other_mod - use some_mod, only: just_do_it -contains - subroutine run_it() - call just_do_it() - end subroutine run_it -end module other_mod -''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - routines = call.get_callees() - assert len(routines) == 1 - assert isinstance(routines[0], Routine) - assert routines[0].name == "just_do_it" - - -@pytest.mark.usefixtures("clear_module_manager_instance") -def test_call_get_callees_wildcard_import_container(fortran_reader, - tmpdir, monkeypatch): - ''' - Check that get_callees() works successfully for a routine accessed via - a wildcard import from a module in another file. - ''' - code = ''' -module other_mod - use some_mod -contains - subroutine run_it() - call just_do_it() - end subroutine run_it -end module other_mod -''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - # This should fail as it can't find the module. - with pytest.raises(NotImplementedError) as err: - _ = call.get_callees() - assert ("Failed to find the source code of the unresolved routine " - "'just_do_it' - looked at any routines in the same source file" - in str(err.value)) - # Create the module containing the subroutine definition, - # write it to file and set the search path so that PSyclone can find it. - path = str(tmpdir) - monkeypatch.setattr(Config.get(), '_include_paths', [path]) - - with open(os.path.join(path, "some_mod.f90"), - "w", encoding="utf-8") as mfile: - mfile.write('''\ -module some_mod -contains - subroutine just_do_it() - write(*,*) "hello" - end subroutine just_do_it -end module some_mod''') - routines = call.get_callees() - assert len(routines) == 1 - assert isinstance(routines[0], Routine) - assert routines[0].name == "just_do_it" - - -def test_fn_call_get_callees(fortran_reader): - ''' - Test that get_callees() works for a function call. - ''' - code = ''' -module some_mod - implicit none - integer :: luggage -contains - subroutine top() - luggage = 0 - luggage = luggage + my_func(1) - end subroutine top - - function my_func(val) - integer, intent(in) :: val - integer :: my_func - my_func = 1 + val - end function my_func -end module some_mod''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - result = call.get_callees() - assert result == [psyir.walk(Routine)[1]] - - -def test_get_callees_code_block(fortran_reader): - '''Test that get_callees() raises the expected error when the called - routine is in a CodeBlock.''' - code = ''' -module some_mod - implicit none - integer :: luggage -contains - subroutine top() - luggage = 0 - luggage = luggage + real(my_func(1)) - end subroutine top - - complex function my_func(val) - integer, intent(in) :: val - my_func = CMPLX(1 + val, 1.0) - end function my_func -end module some_mod''' - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[1] - with pytest.raises(SymbolError) as err: - _ = call.get_callees() - assert ("Failed to find a Routine named 'my_func' in Container " - "'some_mod'" in str(err.value)) - - -@pytest.mark.usefixtures("clear_module_manager_instance") -def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): - ''' - Test that get_callees() follows imports to find the definition of the - called routine. - ''' - code = ''' -module some_mod - use other_mod, only: pack_it - implicit none -contains - subroutine top() - integer :: luggage = 0 - call pack_it(luggage) - end subroutine top -end module some_mod''' - # Create the module containing an import of the subroutine definition, - # write it to file and set the search path so that PSyclone can find it. - path = str(tmpdir) - monkeypatch.setattr(Config.get(), '_include_paths', [path]) - - with open(os.path.join(path, "other_mod.f90"), - "w", encoding="utf-8") as mfile: - mfile.write('''\ - module other_mod - use another_mod, only: pack_it - contains - end module other_mod - ''') - # Finally, create the module containing the routine definition. - with open(os.path.join(path, "another_mod.f90"), - "w", encoding="utf-8") as mfile: - mfile.write('''\ - module another_mod - contains - subroutine pack_it(arg) - integer, intent(inout) :: arg - arg = arg + 2 - end subroutine pack_it - end module another_mod - ''') - psyir = fortran_reader.psyir_from_source(code) - call = psyir.walk(Call)[0] - result = call.get_callees() - assert len(result) == 1 - assert isinstance(result[0], Routine) - assert result[0].name == "pack_it" - - @pytest.mark.usefixtures("clear_module_manager_instance") def test_get_callees_import_private_clash(fortran_reader, tmpdir, monkeypatch): ''' diff --git a/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py b/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py index 64e2d4ec94..06647cf83c 100644 --- a/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py +++ b/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py @@ -35,13 +35,19 @@ # ----------------------------------------------------------------------------- +from psyclone.psyir.symbols.datatypes import ArrayType, UnresolvedType +import os import pytest +from psyclone.configuration import Config +from psyclone.parse import ModuleManager from psyclone.psyir.tools.call_routine_matcher import ( CallRoutineMatcher, CallMatchingArgumentsNotFoundError, ) -from psyclone.psyir.nodes import Call, Node, Routine +from psyclone.psyir.symbols import UnsupportedFortranType, SymbolError +from psyclone.psyir.nodes import Call, Node, Routine, Assignment, CodeBlock from psyclone.psyir.transformations import InlineTrans +from psyclone.tests.utilities import Compile def test_apply_optional_and_named_arg(fortran_reader): @@ -428,6 +434,309 @@ def test_call_get_callee_5_optional_and_named_arguments(fortran_reader): end module some_mod""" +def test_call_get_callee_6_interfaces_0_0(fortran_reader): + """ + Check that a non-existing optional argument at the end of the list + has been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_a: Routine = root_node.walk(Routine)[1] + assert routine_foo_a.name == "foo_a" + + call_foo_a: Call = routine_main.walk(Call)[0] + assert call_foo_a.routine.name == "foo" + + (result, arg_idx_list) = call_foo_a.get_callee() + result: Routine + + assert len(arg_idx_list) == 2 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + + assert result is routine_foo_a + + +def test_call_get_callee_6_interfaces_0_1(fortran_reader): + """ + Check that an existing optional argument at the end of the list + has been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_a: Routine = root_node.walk(Routine)[1] + assert routine_foo_a.name == "foo_a" + + call_foo_a: Call = routine_main.walk(Call)[1] + assert call_foo_a.routine.name == "foo" + + (result, arg_idx_list) = call_foo_a.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + assert arg_idx_list[2] == 2 + + assert result is routine_foo_a + + +def test_call_get_callee_6_interfaces_1_0(fortran_reader): + """ + Check that + - different argument types and + - non-existing optional argument at the end of the list + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_b: Routine = root_node.walk(Routine)[2] + assert routine_foo_b.name == "foo_b" + + call_foo_b: Call = routine_main.walk(Call)[2] + assert call_foo_b.routine.name == "foo" + + (result, arg_idx_list) = call_foo_b.get_callee() + result: Routine + + assert len(arg_idx_list) == 2 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + + assert result is routine_foo_b + + +def test_call_get_callee_6_interfaces_1_1(fortran_reader): + """ + Check that + - different argument types and + - existing optional argument at the end of the list + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_b: Routine = root_node.walk(Routine)[2] + assert routine_foo_b.name == "foo_b" + + call_foo_b: Call = routine_main.walk(Call)[3] + assert call_foo_b.routine.name == "foo" + + (result, arg_idx_list) = call_foo_b.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + assert arg_idx_list[2] == 2 + + assert result is routine_foo_b + + +def test_call_get_callee_6_interfaces_1_2(fortran_reader): + """ + Check that + - different argument types and + - naming arguments resulting in a different order + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_b: Routine = root_node.walk(Routine)[2] + assert routine_foo_b.name == "foo_b" + + call_foo_b: Call = routine_main.walk(Call)[4] + assert call_foo_b.routine.name == "foo" + + (result, arg_idx_list) = call_foo_b.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 2 + assert arg_idx_list[2] == 1 + + assert result is routine_foo_b + + +def test_call_get_callee_6_interfaces_2_0(fortran_reader): + """ + Check that + - different argument types (different order than in tests before) + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_c: Routine = root_node.walk(Routine)[3] + assert routine_foo_c.name == "foo_c" + + call_foo_c: Call = routine_main.walk(Call)[5] + assert call_foo_c.routine.name == "foo" + + (result, arg_idx_list) = call_foo_c.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 0 + assert arg_idx_list[1] == 1 + assert arg_idx_list[2] == 2 + + assert result is routine_foo_c + + +def test_call_get_callee_6_interfaces_2_1(fortran_reader): + """ + Check that + - different argument types (different order than in tests before) and + - naming arguments resulting in a different order and + - optional argument + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_c: Routine = root_node.walk(Routine)[3] + assert routine_foo_c.name == "foo_c" + + call_foo_c: Call = routine_main.walk(Call)[6] + assert call_foo_c.routine.name == "foo" + + (result, arg_idx_list) = call_foo_c.get_callee() + result: Routine + + assert len(arg_idx_list) == 2 + assert arg_idx_list[0] == 1 + assert arg_idx_list[1] == 0 + + assert result is routine_foo_c + + +def test_call_get_callee_6_interfaces_2_2(fortran_reader): + """ + Check that + - different argument types (different order than in tests before) and + - naming arguments resulting in a different order and + - last call argument without naming + have been correctly determined. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_c: Routine = root_node.walk(Routine)[3] + assert routine_foo_c.name == "foo_c" + + call_foo_c: Call = routine_main.walk(Call)[7] + assert call_foo_c.routine.name == "foo" + + (result, arg_idx_list) = call_foo_c.get_callee() + result: Routine + + assert len(arg_idx_list) == 3 + assert arg_idx_list[0] == 1 + assert arg_idx_list[1] == 0 + assert arg_idx_list[2] == 2 + + assert result is routine_foo_c + + +def test_call_get_callee_6_interfaces_3_0_mismatch(fortran_reader): + """ + Check that matching a partial data type can also go wrong. + """ + + root_node: Node = fortran_reader.psyir_from_source(_code_test_get_callee_6) + + routine_main: Routine = root_node.walk(Routine)[0] + assert routine_main.name == "main" + + routine_foo_optional: Routine = root_node.walk(Routine)[4] + assert routine_foo_optional.name == "foo_optional" + + call_foo_optional: Call = routine_main.walk(Call)[8] + assert call_foo_optional.routine.name == "foo" + + with pytest.raises(CallMatchingArgumentsNotFoundError) as einfo: + call_foo_optional.get_callee() + + assert "Argument partial type mismatch of call argument" in ( + str(einfo.value) + ) + + +def test_call_get_callee_7_matching_arguments_not_found(fortran_reader): + """ + Trigger error that matching arguments were not found + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f, g + ! Use named argument 'd', which doesn't exist + ! to trigger an error when searching for the matching routine. + call foo(e, f, d=g) + end subroutine + + ! Matching routine + subroutine foo(a, b, c) + integer :: a, b, c + end subroutine + +end module some_mod""" + + psyir = fortran_reader.psyir_from_source(code) + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + + call_foo: Call = routine_main.walk(Call)[0] + + with pytest.raises(CallMatchingArgumentsNotFoundError) as err: + call_foo.get_callee() + + assert ( + "Found routines, but no routine with matching arguments" + " found for 'call foo(e, f, d=g)':" in str(err.value) + ) + + print(str(err.value)) + assert ( + "CallMatchingArgumentsNotFound: Named argument" + " 'd' not found." in str(err.value) + ) + + def test_set_routine(fortran_reader): """Test the routine setter (not in the constructor).""" @@ -453,3 +762,661 @@ def test_set_routine(fortran_reader): crm = CallRoutineMatcher() crm.set_call_node(call) crm.set_routine_node(routine) + + +def test_fn_call_get_callees(fortran_reader): + """ + Test that get_callees() works for a function call. + """ + code = """ +module some_mod + implicit none + integer :: luggage +contains + subroutine top() + luggage = 0 + luggage = luggage + my_func(1) + end subroutine top + + function my_func(val) + integer, intent(in) :: val + integer :: my_func + my_func = 1 + val + end function my_func +end module some_mod""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + result = call.get_callees() + assert result == [psyir.walk(Routine)[1]] + + +@pytest.mark.usefixtures("clear_module_manager_instance") +def test_call_get_callees_wildcard_import_container( + fortran_reader, tmpdir, monkeypatch +): + """ + Check that get_callees() works successfully for a routine accessed via + a wildcard import from a module in another file. + """ + code = """ +module other_mod + use some_mod +contains + subroutine run_it() + call just_do_it() + end subroutine run_it +end module other_mod +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + # This should fail as it can't find the module. + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "Failed to find the source code of the unresolved routine " + "'just_do_it' - looked at any routines in the same source file" + in str(err.value) + ) + # Create the module containing the subroutine definition, + # write it to file and set the search path so that PSyclone can find it. + path = str(tmpdir) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) + + with open( + os.path.join(path, "some_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ +module some_mod +contains + subroutine just_do_it() + write(*,*) "hello" + end subroutine just_do_it +end module some_mod""" + ) + routines = call.get_callees() + assert len(routines) == 1 + assert isinstance(routines[0], Routine) + assert routines[0].name == "just_do_it" + + +@pytest.mark.usefixtures("clear_module_manager_instance") +def test_call_get_callees_unresolved(fortran_reader, tmpdir, monkeypatch): + """ + Test that get_callees() raises the expected error if the called routine + is unresolved. + """ + code = """ +subroutine top() + call bottom() +end subroutine top""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and there are " + "no wildcard imports." in str(err.value) + ) + # Repeat but in the presence of a wildcard import. + code = """ +subroutine top() + use some_mod_somewhere + call bottom() +end subroutine top""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and attempted " + "to resolve the wildcard imports from ['some_mod_somewhere']. " + "However, failed to find the source for ['some_mod_somewhere']. " + "The module search path is set to []" in str(err.value) + ) + # Repeat but when some_mod_somewhere *is* resolved but doesn't help us + # find the routine we're looking for. + mod_manager = ModuleManager.get() + monkeypatch.setattr(mod_manager, "_instance", None) + path = str(tmpdir) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) + with open( + os.path.join(path, "some_mod_somewhere.f90"), "w", encoding="utf-8" + ) as ofile: + ofile.write( + """\ +module some_mod_somewhere +end module some_mod_somewhere +""" + ) + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "Failed to find the source code of the unresolved routine 'bottom'" + " - looked at any routines in the same source file and wildcard " + "imports from ['some_mod_somewhere']." in str(err.value) + ) + mod_manager = ModuleManager.get() + monkeypatch.setattr(mod_manager, "_instance", None) + code = """ +subroutine top() + use another_mod, only: this_one + call this_one() +end subroutine top""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "RoutineSymbol 'this_one' is imported from Container 'another_mod'" + " but the source defining that container could not be found. The " + "module search path is set to [" in str(err.value) + ) + + +def test_call_get_callees_interface(fortran_reader): + """ + Check that get_callees() works correctly when the target of a call is + actually a generic interface. + """ + code = """ +module my_mod + + interface bottom + module procedure :: rbottom, ibottom + end interface bottom +contains + subroutine top() + integer :: luggage + luggage = 0 + call bottom(luggage) + end subroutine top + + subroutine ibottom(luggage) + integer :: luggage + luggage = luggage + 1 + end subroutine ibottom + + subroutine rbottom(luggage) + real :: luggage + luggage = luggage + 1.0 + end subroutine rbottom +end module my_mod +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + callees = call.get_callees() + assert len(callees) == 2 + assert isinstance(callees[0], Routine) + assert callees[0].name == "rbottom" + assert isinstance(callees[1], Routine) + assert callees[1].name == "ibottom" + + +def test_call_get_callees_unsupported_type(fortran_reader): + """ + Check that get_callees() raises the expected error when the called routine + is of UnsupportedFortranType. This is hard to achieve so we have to + manually construct some aspects of the test case. + + """ + code = """ +module my_mod + integer, target :: value +contains + subroutine top() + integer :: luggage + luggage = bottom() + end subroutine top + function bottom() result(fval) + integer, pointer :: fval + fval => value + end function bottom +end module my_mod +""" + psyir = fortran_reader.psyir_from_source(code) + container = psyir.children[0] + routine = container.find_routine_psyir("bottom") + rsym = container.symbol_table.lookup(routine.name) + # Ensure the type of this RoutineSymbol is UnsupportedFortranType. + rsym.datatype = UnsupportedFortranType("integer, pointer :: fval") + assign = container.walk(Assignment)[0] + # Currently `bottom()` gets matched by fparser2 as a structure constructor + # and the fparser2 frontend leaves this as a CodeBlock (TODO #2429) so + # replace it with a Call. Once #2429 is fixed the next two lines can be + # removed. + assert isinstance(assign.rhs, CodeBlock) + assign.rhs.replace_with(Call.create(rsym)) + call = psyir.walk(Call)[0] + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "RoutineSymbol 'bottom' exists in Container 'my_mod' but is of " + "UnsupportedFortranType" in str(err.value) + ) + + +def test_call_get_callees_local(fortran_reader): + """ + Check that get_callees() works as expected when the target of the Call + exists in the same Container as the call site. + """ + code = """ +module some_mod + implicit none + integer :: luggage +contains + subroutine top() + luggage = 0 + call bottom() + end subroutine top + + subroutine bottom() + luggage = luggage + 1 + end subroutine bottom +end module some_mod""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + result = call.get_callees() + assert result == [psyir.walk(Routine)[1]] + + +def test_call_get_callee_matching_arguments_not_found(fortran_reader): + """ + Trigger error that matching arguments were not found. + In this test, this is caused by omitting the required third non-optional + argument. + """ + code = """ +module some_mod + implicit none +contains + + subroutine main() + integer :: e, f + ! Omit the 3rd required argument + call foo(e, f) + end subroutine + + ! Routine matching by 'name', but not by argument matching + subroutine foo(a, b, c) + integer :: a, b, c + end subroutine + +end module some_mod""" + + psyir = fortran_reader.psyir_from_source(code) + + routine_main: Routine = psyir.walk(Routine)[0] + assert routine_main.name == "main" + + call_foo: Call = routine_main.walk(Call)[0] + + with pytest.raises(CallMatchingArgumentsNotFoundError) as err: + call_foo.get_callee() + + assert ( + "CallMatchingArgumentsNotFound: Found routines, but" + " no routine with matching arguments found for 'call" + " foo(e, f)':" in str(err.value) + ) + + assert ( + "CallMatchingArgumentsNotFound: Argument 'c' in" + " subroutine 'foo' not handled." in str(err.value) + ) + + +def test_call_get_callees_file_container(fortran_reader): + """ + Check that get_callees works if the called routine happens to be in file + scope, even when there's no Container. + """ + code = """ + subroutine top() + integer :: luggage + luggage = 0 + call bottom(luggage) + end subroutine top + + subroutine bottom(luggage) + integer :: luggage + luggage = luggage + 1 + end subroutine bottom +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + result = call.get_callees() + assert len(result) == 1 + assert isinstance(result[0], Routine) + assert result[0].name == "bottom" + + +def test_call_get_callees_no_container(fortran_reader): + """ + Check that get_callees() raises the expected error when the Call is not + within a Container and the target routine cannot be found. + """ + # To avoid having the routine symbol immediately dismissed as + # unresolved, the code that we initially process *does* have a Container. + code = """ +module my_mod + +contains + subroutine top() + integer :: luggage + luggage = 0 + call bottom(luggage) + end subroutine top + + subroutine bottom(luggage) + integer :: luggage + luggage = luggage + 1 + end subroutine bottom +end module my_mod +""" + psyir = fortran_reader.psyir_from_source(code) + top_routine = psyir.walk(Routine)[0] + # Deliberately make the Routine node an orphan so there's no Container. + top_routine.detach() + call = top_routine.walk(Call)[0] + with pytest.raises(SymbolError) as err: + _ = call.get_callees() + assert ( + "Failed to find a Routine named 'bottom' in code:\n'subroutine " + "top()" in str(err.value) + ) + + +def test_call_get_callees_wildcard_import_local_container(fortran_reader): + """ + Check that get_callees() works successfully for a routine accessed via + a wildcard import from another module in the same file. + """ + code = """ +module some_mod +contains + subroutine just_do_it() + write(*,*) "hello" + end subroutine just_do_it +end module some_mod +module other_mod + use some_mod +contains + subroutine run_it() + call just_do_it() + end subroutine run_it +end module other_mod +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + routines = call.get_callees() + assert len(routines) == 1 + assert isinstance(routines[0], Routine) + assert routines[0].name == "just_do_it" + + +def test_call_get_callees_import_local_container(fortran_reader): + """ + Check that get_callees() works successfully for a routine accessed via + a specific import from another module in the same file. + """ + code = """ +module some_mod +contains + subroutine just_do_it() + write(*,*) "hello" + end subroutine just_do_it +end module some_mod +module other_mod + use some_mod, only: just_do_it +contains + subroutine run_it() + call just_do_it() + end subroutine run_it +end module other_mod +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + routines = call.get_callees() + assert len(routines) == 1 + assert isinstance(routines[0], Routine) + assert routines[0].name == "just_do_it" + + +def test_get_callees_code_block(fortran_reader): + """Test that get_callees() raises the expected error when the called + routine is in a CodeBlock.""" + code = """ +module some_mod + implicit none + integer :: luggage +contains + subroutine top() + luggage = 0 + luggage = luggage + real(my_func(1)) + end subroutine top + + complex function my_func(val) + integer, intent(in) :: val + my_func = CMPLX(1 + val, 1.0) + end function my_func +end module some_mod""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[1] + with pytest.raises(SymbolError) as err: + _ = call.get_callees() + assert ( + "Failed to find a Routine named 'my_func' in Container " + "'some_mod'" in str(err.value) + ) + + +@pytest.mark.usefixtures("clear_module_manager_instance") +def test_get_callees_follow_imports(fortran_reader, tmpdir, monkeypatch): + """ + Test that get_callees() follows imports to find the definition of the + called routine. + """ + code = """ +module some_mod + use other_mod, only: pack_it + implicit none +contains + subroutine top() + integer :: luggage = 0 + call pack_it(luggage) + end subroutine top +end module some_mod""" + # Create the module containing an import of the subroutine definition, + # write it to file and set the search path so that PSyclone can find it. + path = str(tmpdir) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) + + with open( + os.path.join(path, "other_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ + module other_mod + use another_mod, only: pack_it + contains + end module other_mod + """ + ) + # Finally, create the module containing the routine definition. + with open( + os.path.join(path, "another_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ + module another_mod + contains + subroutine pack_it(arg) + integer, intent(inout) :: arg + arg = arg + 2 + end subroutine pack_it + end module another_mod + """ + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + result = call.get_callees() + assert len(result) == 1 + assert isinstance(result[0], Routine) + assert result[0].name == "pack_it" + + +@pytest.mark.usefixtures("clear_module_manager_instance") +def test_get_callees_import_private_clash(fortran_reader, tmpdir, monkeypatch): + """ + Test that get_callees() raises the expected error if a module from which + a routine is imported has a private shadow of that routine (and thus we + don't know where to look for the target routine). + """ + code = """ +module some_mod + use other_mod, only: pack_it + implicit none +contains + subroutine top() + integer :: luggage = 0 + call pack_it(luggage) + end subroutine top +end module some_mod""" + # Create the module containing a private routine with the name we are + # searching for, write it to file and set the search path so that PSyclone + # can find it. + path = str(tmpdir) + monkeypatch.setattr(Config.get(), "_include_paths", [path]) + + with open( + os.path.join(path, "other_mod.f90"), "w", encoding="utf-8" + ) as mfile: + mfile.write( + """\ + module other_mod + use another_mod + private pack_it + contains + function pack_it(arg) + integer :: arg + integer :: pack_it + end function pack_it + end module other_mod + """ + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + with pytest.raises(NotImplementedError) as err: + _ = call.get_callees() + assert ( + "RoutineSymbol 'pack_it' is imported from Container 'other_mod' " + "but that Container defines a private Symbol of the same name. " + "Searching for the Container that defines a public Routine with " + "that name is not yet supported - TODO #924" in str(err.value) + ) + + +def test_apply_empty_routine_coverage_option_check_strict_array_datatype( + fortran_reader, fortran_writer, tmpdir +): + """For coverage of particular branch in `inline_trans.py`.""" + code = ( + "module test_mod\n" + "contains\n" + " subroutine run_it()\n" + " integer, dimension(6) :: i\n" + " i = 10\n" + " call sub(i)\n" + " end subroutine run_it\n" + " subroutine sub(idx)\n" + " integer, dimension(:) :: idx\n" + " end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.set_option(check_argument_strict_array_datatype=False) + inline_trans.apply(routine) + output = fortran_writer(psyir) + assert " i = 10\n\n" " end subroutine run_it\n" in output + assert Compile(tmpdir).string_compiles(output) + + +def test_apply_array_access_check_unresolved_symbols_error( + fortran_reader, fortran_writer, tmpdir +): + """ + This check solely exists for the coverage report to + catch the simple case `if not check_unresolved_symbols:` + in `symbol_table.py` + + """ + code = ( + "module test_mod\n" + "contains\n" + " subroutine run_it()\n" + " integer :: i\n" + " real :: a(10)\n" + " do i=1,10\n" + " call sub(a, i)\n" + " end do\n" + " end subroutine run_it\n" + " subroutine sub(x, ivar)\n" + " real, intent(inout), dimension(10) :: x\n" + " integer, intent(in) :: ivar\n" + " integer :: i\n" + " do i = 1, 10\n" + " x(i) = 2.0*ivar\n" + " end do\n" + " end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + routine = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.set_option(check_unresolved_symbols=False) + inline_trans.apply(routine) + output = fortran_writer(psyir) + assert ( + " do i = 1, 10, 1\n" + " do i_1 = 1, 10, 1\n" + " a(i_1) = 2.0 * i\n" + " enddo\n" in output + ) + assert Compile(tmpdir).string_compiles(output) + + +def test_apply_array_access_check_unresolved_override_option( + fortran_reader, fortran_writer, tmpdir +): + """ + This check solely exists for the coverage report to catch + the case where the override option to ignore unresolved + types is used. + + """ + code = ( + "module test_mod\n" + "use does_not_exist\n" + "contains\n" + " subroutine run_it()\n" + " type(unknown_type) :: a\n" + " call sub(a%unresolved_type)\n" + " end subroutine run_it\n" + " subroutine sub(a)\n" + " type(unresolved) :: a\n" + " end subroutine sub\n" + "end module test_mod\n" + ) + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + inline_trans = InlineTrans() + inline_trans.set_option( + check_argument_ignore_unresolved_types=True + ) + inline_trans.apply(call) diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index 7226d71065..83ed070d43 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -151,31 +151,6 @@ def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) -def test_apply_empty_routine_coverage_option_check_strict_array_datatype( - fortran_reader, fortran_writer, tmpdir): - '''For coverage of particular branch in `inline_trans.py`.''' - code = ( - "module test_mod\n" - "contains\n" - " subroutine run_it()\n" - " integer, dimension(6) :: i\n" - " i = 10\n" - " call sub(i)\n" - " end subroutine run_it\n" - " subroutine sub(idx)\n" - " integer, dimension(:) :: idx\n" - " end subroutine sub\n" - "end module test_mod\n") - psyir = fortran_reader.psyir_from_source(code) - routine = psyir.walk(Call)[0] - inline_trans = InlineTrans() - inline_trans.set_option(check_argument_strict_array_datatype=False) - inline_trans.apply(routine) - output = fortran_writer(psyir) - assert (" i = 10\n\n" - " end subroutine run_it\n" in output) - assert Compile(tmpdir).string_compiles(output) - def test_apply_single_return(fortran_reader, fortran_writer, tmpdir): '''Check that a call to a routine containing only a return statement @@ -336,45 +311,6 @@ def test_apply_array_access(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) -def test_apply_array_access_check_unresolved_symbols_error( - fortran_reader, fortran_writer, tmpdir): - ''' - This check solely exists for the coverage report to - catch the simple case `if not check_unresolved_symbols:` - in `symbol_table.py` - - ''' - code = ( - "module test_mod\n" - "contains\n" - " subroutine run_it()\n" - " integer :: i\n" - " real :: a(10)\n" - " do i=1,10\n" - " call sub(a, i)\n" - " end do\n" - " end subroutine run_it\n" - " subroutine sub(x, ivar)\n" - " real, intent(inout), dimension(10) :: x\n" - " integer, intent(in) :: ivar\n" - " integer :: i\n" - " do i = 1, 10\n" - " x(i) = 2.0*ivar\n" - " end do\n" - " end subroutine sub\n" - "end module test_mod\n") - psyir = fortran_reader.psyir_from_source(code) - routine = psyir.walk(Call)[0] - inline_trans = InlineTrans() - inline_trans.set_option(check_unresolved_symbols=False) - inline_trans.apply(routine) - output = fortran_writer(psyir) - assert (" do i = 1, 10, 1\n" - " do i_1 = 1, 10, 1\n" - " a(i_1) = 2.0 * i\n" - " enddo\n" in output) - assert Compile(tmpdir).string_compiles(output) - def test_apply_gocean_kern(fortran_reader, fortran_writer, monkeypatch): ''' From 8651a19b55a56c7b801ac5b9bdd1825d32ecb04a Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Tue, 14 Jan 2025 20:33:16 +0100 Subject: [PATCH 19/20] u --- src/psyclone/tests/psyir/nodes/call_test.py | 3 +-- src/psyclone/tests/psyir/tools/call_routine_matcher_test.py | 1 - src/psyclone/tests/psyir/transformations/inline_trans_test.py | 4 ---- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/psyclone/tests/psyir/nodes/call_test.py b/src/psyclone/tests/psyir/nodes/call_test.py index 342d4dbc82..d07cb73aea 100644 --- a/src/psyclone/tests/psyir/nodes/call_test.py +++ b/src/psyclone/tests/psyir/nodes/call_test.py @@ -57,8 +57,7 @@ DataSymbol, NoType, RoutineSymbol, - REAL_TYPE, - SymbolError + REAL_TYPE ) from psyclone.psyir.tools.call_routine_matcher import ( diff --git a/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py b/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py index 06647cf83c..38cae8b290 100644 --- a/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py +++ b/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py @@ -35,7 +35,6 @@ # ----------------------------------------------------------------------------- -from psyclone.psyir.symbols.datatypes import ArrayType, UnresolvedType import os import pytest from psyclone.configuration import Config diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index 83ed070d43..2ad16f0c6c 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -151,7 +151,6 @@ def test_apply_argument_clash(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) - def test_apply_single_return(fortran_reader, fortran_writer, tmpdir): '''Check that a call to a routine containing only a return statement is removed. ''' @@ -311,7 +310,6 @@ def test_apply_array_access(fortran_reader, fortran_writer, tmpdir): assert Compile(tmpdir).string_compiles(output) - def test_apply_gocean_kern(fortran_reader, fortran_writer, monkeypatch): ''' Test the apply method with a typical GOcean kernel. @@ -2424,7 +2422,6 @@ def test_apply_unsupported_pointer_error(fortran_reader): " 'REAL, INTENT(INOUT), POINTER :: x'." in str(einfo.value)) - def test_apply_optional_and_named_arg_2(fortran_reader): '''Test that the validate method inlines a routine that has an optional argument.''' @@ -2536,4 +2533,3 @@ def test_apply_merges_symbol_table_with_routine(fortran_reader): inline_trans.apply(routine) # The i_1 symbol is the renamed i from the inlined call. assert psyir.walk(Routine)[0].symbol_table.get_symbols()['i_1'] is not None - From 447b147967cea50aa8daee360eca4136532835ce Mon Sep 17 00:00:00 2001 From: SCHREIBER Martin Date: Tue, 14 Jan 2025 20:54:38 +0100 Subject: [PATCH 20/20] u --- src/psyclone/tests/psyir/tools/call_routine_matcher_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py b/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py index 38cae8b290..7173622194 100644 --- a/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py +++ b/src/psyclone/tests/psyir/tools/call_routine_matcher_test.py @@ -31,7 +31,11 @@ # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # ----------------------------------------------------------------------------- -# Author: Martin Schreiber, Univ. Grenoble Alpes / LJK / Inria +# This file is based on gathering various components related to +# calls and routines from across psyclone. Hence, there's no clear author. +# Authors of gathered files: R. W. Ford, A. R. Porter and +# S. Siso, STFC Daresbury Lab +# Author: M. Schreiber, Univ. Grenoble Alpes / LJK / Inria # -----------------------------------------------------------------------------