diff --git a/src/psyclone/psyir/symbols/symbol_table.py b/src/psyclone/psyir/symbols/symbol_table.py index 82dbc6b8da..80a52e40e9 100644 --- a/src/psyclone/psyir/symbols/symbol_table.py +++ b/src/psyclone/psyir/symbols/symbol_table.py @@ -48,6 +48,7 @@ import inspect import copy import logging +import re from typing import Any, List, Optional, Union, TYPE_CHECKING from psyclone.configuration import Config @@ -57,6 +58,7 @@ ImportInterface, RoutineSymbol, Symbol, SymbolError, UnresolvedInterface) from psyclone.psyir.symbols.intrinsic_symbol import IntrinsicSymbol from psyclone.psyir.symbols.typed_symbol import TypedSymbol +from psyclone.psyir.symbols.datatypes import UnsupportedFortranType if TYPE_CHECKING: from psyclone.psyir.nodes.scoping_node import ScopingNode @@ -608,6 +610,8 @@ def add(self, new_symbol: Symbol, tag: Optional[str] = None): :raises InternalError: if the new_symbol argument is not a symbol. :raises KeyError: if the symbol name is already in use. :raises KeyError: if a tag is supplied and it is already in use. + :raises KeyError: if the symbol is a COMMON-block marker and an + identical declaration is already present under another marker name. :raises SymbolError: if the supplied symbol has an ImportInterface that refers to a ContainerSymbol that is not in scope. @@ -621,6 +625,22 @@ def add(self, new_symbol: Symbol, tag: Optional[str] = None): raise KeyError(f"Symbol table already contains a symbol with " f"name '{new_symbol.name}'.") + # Treat a COMMON-block marker whose declaration exactly matches one + # already present (possibly under a different name) as a duplicate. + if (self._normalize(new_symbol.name).startswith( + "_psyclone_internal_commonblock") + and isinstance(new_symbol.datatype, UnsupportedFortranType)): + if any( + sym.datatype.declaration == new_symbol.datatype.declaration + for sym in self.symbols + if (self._normalize(sym.name).startswith( + "_psyclone_internal_commonblock") + and isinstance(sym.datatype, UnsupportedFortranType)) + ): + raise KeyError( + f"Symbol table already contains a COMMON-block marker " + f"with the same declaration as '{new_symbol.name}'.") + if tag: if tag in self.get_tags(): raise KeyError( @@ -704,6 +724,12 @@ def check_for_clashes(self, other_table, symbols_to_skip=()): isinstance(other_sym, IntrinsicSymbol)): continue + # If both symbols have CommonBlockInterface, they represent the + # same shared COMMON-block data. They cannot (and do not need to) + # be renamed, so treat this as a benign clash. + if this_sym.is_commonblock and other_sym.is_commonblock: + continue + if other_sym.is_import and this_sym.is_import: # Both symbols are imported. That's fine as long as they have # the same import interface (are imported from the same @@ -945,6 +971,7 @@ def _add_symbols_from_table(self, other_table, symbols_to_skip=()): already been updated to refer to a Container in this table. ''' + for old_sym in other_table.symbols: if old_sym in symbols_to_skip or isinstance(old_sym, @@ -959,11 +986,73 @@ def _add_symbols_from_table(self, other_table, symbols_to_skip=()): # We have a clash with a symbol in this table. self._handle_symbol_clash(old_sym, other_table) + def _handle_symbol_clash_common_block(self, old_sym: Symbol) -> bool: + ''' + Handles a name clash for COMMON-block related symbols. Called from + :py:meth:`_handle_symbol_clash` as soon as a COMMON-block symbol is + detected. Returns ``True`` if the clash has been fully resolved + (nothing more to do) or ``False`` if the generic rename-and-add path + should be followed instead. + + Two kinds of COMMON-block symbol are handled: + + * Variables with a + :py:class:`~psyclone.psyir.symbols.CommonBlockInterface` + (``is_commonblock``): the clash has already been approved by + ``check_for_clashes``; nothing to do. + * Internal marker symbols (``_PSYCLONE_INTERNAL_COMMONBLOCK_N``): + if the incoming marker's COMMON-block name(s) overlap with any + marker already in ``self``, the block is already declared and a + second declaration would produce a compile error — skip it. + Otherwise fall through to the rename-and-add path. + + :param old_sym: the Symbol being added to self. + + :returns: ``True`` if the clash is resolved; ``False`` if the + generic rename-and-add path should be followed. + + ''' + try: + self_sym = self.lookup(old_sym.name) + except KeyError: + # old_sym.name is not in this table: add() must have raised + # because an identical declaration is already present under a + # different marker name. The COMMON block is already declared + # so the incoming marker should simply be skipped. + self_sym = None + + if self_sym is None: + # Name absent means same-declaration / different-name duplicate. + return True + + if old_sym.is_commonblock and self_sym.is_commonblock: + # check_for_clashes has already approved this; nothing to do. + return True + + if (isinstance(old_sym.datatype, UnsupportedFortranType) + and isinstance(self_sym.datatype, UnsupportedFortranType) + and self._normalize(old_sym.name).startswith( + "_psyclone_internal_commonblock")): + # Marker with different declaration but possibly overlapping + # block name(s). Skip if the block is already declared. + _blk_re = re.compile(r"/\s*(\w*)\s*/", re.IGNORECASE) + old_blocks = set(_blk_re.findall(old_sym.datatype.declaration)) + for sym in self.symbols: + if (self._normalize(sym.name).startswith( + "_psyclone_internal_commonblock") + and isinstance(sym.datatype, UnsupportedFortranType)): + self_blocks = set(_blk_re.findall( + sym.datatype.declaration)) + if old_blocks & self_blocks: + return True + + return False + def _handle_symbol_clash(self, old_sym, other_table): ''' - Adds the supplied Symbol to the current table in the presence - of a name clash. `check_for_clashes` MUST have been called - prior to this method in order to check for any unresolvable cases. + Adds the supplied Symbol to the current table in the presence of a + name clash. ``check_for_clashes`` MUST have been called prior to this + method in order to check for any unresolvable cases. :param old_sym: the Symbol to be added to self. :type old_sym: :py:class:`psyclone.psyir.symbols.Symbol` @@ -976,6 +1065,15 @@ def _handle_symbol_clash(self, old_sym, other_table): check_for_clashes()). ''' + # Check for COMMON-block markers first, before any lookup, because + # add() may have rejected old_sym because an identical declaration + # already exists under a *different* name. In that case old_sym.name + # is not in this table at all, and lookup() would raise a KeyError. + if old_sym.is_commonblock or self._normalize(old_sym.name).startswith( + "_psyclone_internal_commonblock"): + if self._handle_symbol_clash_common_block(old_sym): + return + self_sym = self.lookup(old_sym.name) if old_sym.is_import: # The clashing symbol is imported from a Container and the table diff --git a/src/psyclone/psyir/transformations/inline_trans.py b/src/psyclone/psyir/transformations/inline_trans.py index 5cd4f3e544..b0fc2eaba2 100644 --- a/src/psyclone/psyir/transformations/inline_trans.py +++ b/src/psyclone/psyir/transformations/inline_trans.py @@ -38,7 +38,7 @@ ''' -from typing import Dict, List, Optional +from typing import Dict, Optional from psyclone.core import SymbolicMaths from psyclone.errors import LazyString, InternalError @@ -56,6 +56,7 @@ INTEGER_TYPE, StructureType, SymbolError, + SymbolTable, UnresolvedType, UnsupportedType, UnsupportedFortranType, @@ -147,6 +148,7 @@ def apply(self, use_first_callee_and_no_arg_check: bool = False, permit_codeblocks: bool = False, permit_unsupported_type_args: bool = False, + parameter_cloning: bool = True, **kwargs ): ''' @@ -163,6 +165,13 @@ def apply(self, if the target routine contains a CodeBlock. :param permit_unsupported_type_args: If `True` then the target routine is permitted to have arguments of UnsupportedType. + :param parameter_cloning: if `True` (the default), constant + (PARAMETER) symbols from the routine being inlined are always + copied into the call-site symbol table, potentially being renamed + to avoid clashes. If `False`, a constant from the routine is + skipped when an identical constant (same name, same type, and same + value) already exists at the call site, so no duplicate is + created. :raises InternalError: if the merge of the symbol tables fails. In theory this should never happen because validate() should @@ -219,6 +228,20 @@ def apply(self, # just delete the if statement. self._optional_arg_eliminate_ifblock_if_const_condition(routine) + # If parameter_cloning is disabled, identify duplicate constant + # (PARAMETER) symbols and redirect their references *before* the + # routine body is extracted, so that the extracted statements already + # carry references to the call-site symbols. + extra_skip: list[DataSymbol] = [] + if not parameter_cloning: + extra_skip = self._redirect_duplicate_parameters( + table, routine) + + # Redirect references to COMMON-block variables that are aliased + # (same block position, different name) to the caller's symbol, + # and exclude the now-unreferenced callee symbols from the merge. + extra_skip += self._redirect_common_block_aliases(table, routine) + # Construct lists of the nodes that will be inserted and all of the # References that they contain. new_stmts = [] @@ -231,7 +254,8 @@ def apply(self, # call site. This preserves any references to them. try: table.merge(routine_table, - symbols_to_skip=routine_table.argument_list[:]) + symbols_to_skip=routine_table.argument_list[:] + + extra_skip) except SymbolError as err: raise InternalError( f"Error copying routine symbols to call site. This should " @@ -329,9 +353,193 @@ def apply(self, idx += 1 parent.addchild(child, idx) + def _redirect_duplicate_parameters( + self, + table, + routine: Routine, + ) -> list[DataSymbol]: + ''' + Identifies constant (PARAMETER) symbols in ``routine_table`` that + are identical to constants already present in ``table`` (same name, + same type, and same initial value). For each such symbol, every + :py:class:`~psyclone.psyir.nodes.Reference` to it inside ``routine`` + and inside the datatypes / initial-value expressions of other symbols + in ``routine_table`` is redirected to point to the corresponding + symbol in ``table``. + + Only constants whose initial value is represented as a PSyIR node + (i.e. ``initial_value is not None``) are considered; constants of + ``UnsupportedFortranType`` with an embedded value string are left + unchanged. + + A constant is only considered a duplicate when every routine-local + symbol referenced inside its initial-value expression is itself a + confirmed duplicate. This prevents false positives for expressions + like ``negflag = .NOT. flag`` when ``flag`` has different values in + the caller and the callee (the names would match but the semantics + would differ). + + :param table: the call-site symbol table. + :param routine: the (copy of the) routine being inlined. + + :returns: the list of symbols that are duplicates of + call-site constants and should be excluded from the subsequent + table merge. + + ''' + routine_table: SymbolTable = routine.symbol_table + # The names of all local data symbols in the routine table (used to + # identify references that point to routine-local constants). + routine_local_names = { + s.name.lower() for s in routine_table.datasymbols + if not s.is_automatic + } + + # First pass: collect all constants from the routine whose name, + # datatype, and initial-value tree match a constant in the call-site + # table. The structural comparison uses __eq__, which compares + # Reference nodes by symbol name. This is correct for leaf constants + # (Literals) and is refined for dependent constants in the second + # pass below. + candidates: dict = {} + for rsym in routine_table.datasymbols: + if not rsym.is_constant or rsym.initial_value is None: + # Skip constants whose value is not represented as a PSyIR + # node (e.g. UnsupportedFortranType with embedded value). + continue + tsym = table.lookup(rsym.name, otherwise=None) + if not isinstance(tsym, DataSymbol): + continue + if not tsym.is_constant or tsym.initial_value is None: + continue + if rsym.datatype != tsym.datatype: + continue + if rsym.initial_value != tsym.initial_value: + continue + candidates[rsym.name.lower()] = rsym + + # Second pass: iteratively remove candidates whose initial-value + # expression references a routine-local symbol that is NOT itself + # a confirmed duplicate. Without this step, an expression like + # ``negflag = .NOT. flag`` would compare as equal by name even when + # ``flag`` has different values in the two routines. + changed = True + while changed: + changed = False + to_remove = [ + name for name, rsym in candidates.items() + if any( + dep.name.lower() in routine_local_names + and dep.name.lower() not in candidates + for dep in rsym.initial_value.get_all_accessed_symbols() + ) + ] + for name in to_remove: + del candidates[name] + if to_remove: + changed = True + + duplicates: list[DataSymbol] = list(candidates.values()) + + # Redirect all references from duplicate symbols in the routine to + # their call-site counterparts. + for rsym in duplicates: + tsym = table.lookup(rsym.name) + # Update all References in the routine body. + routine.replace_symbols_using(tsym) + # Update any references to rsym embedded in the datatypes or + # initial-value expressions of other symbols in routine_table. + for sym in routine_table.symbols: + if sym is rsym: + continue + sym.replace_symbols_using(tsym) + + return duplicates + + def _redirect_common_block_aliases( + self, + table: SymbolTable, + routine: Routine, + ) -> list[DataSymbol]: + '''Redirect references to COMMON-block variables in *routine* that are + aliased to differently-named variables in the caller *table* (same + block, same position). + + For each such pair the caller's symbol is substituted for the + callee's symbol in every :py:class:`~psyclone.psyir.nodes.Reference` + inside *routine*. The callee symbols that have been redirected are + returned so they can be excluded from the subsequent symbol-table + merge (they no longer have any live references). + + The types of each aliased pair must already have been verified to be + compatible by :py:meth:`validate`. + + :param table: the call-site symbol table. + :param routine: the (copy of the) routine being inlined. + + :returns: callee symbols whose references have been redirected and + that should therefore be skipped during the table merge. + ''' + routine_table = routine.symbol_table + caller_blocks = self._common_block_vars(table) + callee_blocks = self._common_block_vars(routine_table) + + symbols_to_skip = [] + for block_name, callee_vars in callee_blocks.items(): + if block_name not in caller_blocks: + continue + caller_vars = caller_blocks[block_name] + for caller_var_name, callee_var_name in zip( + caller_vars, callee_vars): + if caller_var_name.lower() == callee_var_name.lower(): + continue + # Replace all References to the callee's alias with the + # corresponding caller's symbol. + caller_sym = table.lookup(caller_var_name) + callee_sym = routine_table.lookup(callee_var_name) + for ref in routine.walk(Reference): + if ref.symbol is callee_sym: + ref.symbol = caller_sym + symbols_to_skip.append(callee_sym) + + return symbols_to_skip + + @staticmethod + def _common_block_vars(table: SymbolTable) -> dict[str, list[str]]: + '''Return a dict mapping lower-cased COMMON-block name to the + lower-cased list of variable names for every COMMON-block marker + symbol found in *table*. + + Each marker symbol is named ``_PSYCLONE_INTERNAL_COMMONBLOCK_N`` and + carries a declaration such as ``COMMON /name/ var1, var2``. A single + declaration may contain several block groups + (``COMMON /a/ x /b/ y, z``), all of which are extracted. + + :param table: the symbol table to inspect. + :returns: mapping of block name to ordered list of variable names. + + ''' + result = {} + # Match pattern /[common_block_name]/ comma_separated_variables + import re + _group_re = re.compile(r"/\s*(\w*)\s*/\s*([\w\s,]+)", re.IGNORECASE) + for sym in table.symbols: + if (sym.name.lower().startswith( + "_psyclone_internal_commonblock") + and isinstance(sym.datatype, UnsupportedFortranType)): + for m in _group_re.finditer(sym.datatype.declaration): + block_name = m.group(1).strip().lower() + var_names = [ + v.strip().lower() + for v in m.group(2).split(",") + if v.strip() + ] + result[block_name] = var_names + return result + def _optional_arg_resolve_present_intrinsics(self, routine_node: Routine, - arg_match_list: List = []): + arg_match_list: list = []): """Replace PRESENT(some_argument) intrinsics in routine with constant booleans depending on whether `some_argument` has been provided (`True`) or not (`False`). @@ -435,7 +643,7 @@ def _replace_formal_args_in_expr( self, expression: Node, call_node: Call, - formal_args: List[DataSymbol], + formal_args: list[DataSymbol], routine_node: Routine, use_first_callee_and_no_arg_check: bool = False, ) -> Reference: @@ -513,7 +721,7 @@ def _replace_formal_args_in_expr( def _create_inlined_idx( self, call_node: Call, - formal_args: List[DataSymbol], + formal_args: list[DataSymbol], local_idx: DataNode, decln_start: DataNode, actual_start: DataNode, @@ -613,10 +821,10 @@ def _update_actual_indices( actual_arg: ArrayMixin, local_ref: Reference, call_node: Call, - formal_args: List[DataSymbol], + formal_args: list[DataSymbol], routine_node: Routine, use_first_callee_and_no_arg_check: bool = False, - ) -> List[Node]: + ) -> list[Node]: ''' Create a new list of indices for the supplied actual argument (ArrayMixin) by replacing any Ranges with the appropriate expressions @@ -729,7 +937,7 @@ def _generate_formal_arg_replacement( actual_arg: Reference, ref: Reference, call_node: Call, - formal_args: List[DataSymbol], + formal_args: list[DataSymbol], routine_node: Routine, use_first_callee_and_no_arg_check: bool = False, ) -> Reference: @@ -962,6 +1170,9 @@ def validate( does not match that of the corresponding actual argument. :raises TransformationError: if one of the declarations in the routine depends on an argument that is written to prior to the call. + :raises TransformationError: if a COMMON block is declared in both + the caller and the routine being inlined with different variable + names and incompatible types. :raises InternalError: if an unhandled Node type is returned by Reference.previous_accesses(). @@ -1116,6 +1327,33 @@ def validate( f"{err.value}") from err routine_table = routine.symbol_table + # Check that COMMON blocks shared between the caller and the callee + # that use *different* variable names are still type-compatible. + # Different names at the same block position mean the two variables + # are memory aliases; that is acceptable as long as their types + # match (the actual reference-redirection happens in apply()). + caller_blocks = self._common_block_vars(parent_routine.symbol_table) + callee_blocks = self._common_block_vars(routine_table) + for block_name, callee_vars in callee_blocks.items(): + if block_name not in caller_blocks: + continue + caller_vars = caller_blocks[block_name] + for caller_var_name, callee_var_name in zip( + caller_vars, callee_vars): + if caller_var_name.lower() == callee_var_name.lower(): + continue + # Different names – check that the types are compatible. + caller_sym = parent_routine.symbol_table.lookup( + caller_var_name) + callee_sym = routine_table.lookup(callee_var_name) + if caller_sym.datatype != callee_sym.datatype: + raise TransformationError( + f"Cannot inline '{routine.name}' because COMMON " + f"block '/{block_name}/' maps '{caller_var_name}' " + f"(type '{caller_sym.datatype}') in the caller to " + f"'{callee_var_name}' (type '{callee_sym.datatype}')" + f" in the routine being inlined - the types are " + f"incompatible.") # Create a list of routine arguments that is actually used routine_arg_list = [ routine_table.argument_list[i] for i in arg_match_list @@ -1125,7 +1363,6 @@ def validate( routine_arg_list, node.arguments ): self._validate_inline_of_call_and_routine_argument_pairs( - call_node=node, call_arg=actual_arg, routine_node=routine, routine_arg=routine_arg @@ -1182,7 +1419,6 @@ def validate( def _validate_inline_of_call_and_routine_argument_pairs( self, - call_node: Call, call_arg: DataNode, routine_node: Routine, routine_arg: DataSymbol diff --git a/src/psyclone/tests/psyir/backend/fortran_common_block_test.py b/src/psyclone/tests/psyir/backend/fortran_common_block_test.py index 9a1b8fd084..919f9093ee 100644 --- a/src/psyclone/tests/psyir/backend/fortran_common_block_test.py +++ b/src/psyclone/tests/psyir/backend/fortran_common_block_test.py @@ -61,6 +61,8 @@ def test_fw_common_blocks(fortran_reader, fortran_writer, tmpdir): routine = psyir.walk(Routine)[0] assert routine.symbol_table.lookup("a").is_commonblock # Sanity check + assert routine.symbol_table.lookup("d").is_commonblock # Sanity check + assert routine.symbol_table.lookup("e").is_commonblock # Sanity check code = fortran_writer(routine) assert code == ( diff --git a/src/psyclone/tests/psyir/symbols/symbol_table_test.py b/src/psyclone/tests/psyir/symbols/symbol_table_test.py index 0d73fb26f9..824d78d5c9 100644 --- a/src/psyclone/tests/psyir/symbols/symbol_table_test.py +++ b/src/psyclone/tests/psyir/symbols/symbol_table_test.py @@ -1325,6 +1325,95 @@ def test_handle_symbol_clash_imported_symbols(): "of the same name imported from 'Ridcully'" in str(err.value)) +def test_handle_symbol_clash_commonblock_same_declaration(): + '''Test that _handle_symbol_clash() ignores duplicate COMMON-block + markers with identical declarations.''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + decl = "common /keep_me/ a" + marker_name = "_PSYCLONE_INTERNAL_COMMONBLOCK_1" + table1.add(symbols.DataSymbol( + marker_name, symbols.UnsupportedFortranType(decl))) + table2.add(symbols.DataSymbol( + marker_name, symbols.UnsupportedFortranType(decl))) + + old_sym = table2.lookup(marker_name) + table1._handle_symbol_clash(old_sym, table2) + + assert len(table1.symbols) == 1 + assert old_sym.name == marker_name + + +def test_add_symbols_from_table_commonblock_same_decl_different_name(): + '''Test that _add_symbols_from_table() silently skips an incoming + COMMON-block marker whose declaration is identical to one already in the + table but under a *different* marker name. + + This is the regression case for the bug where add() raises a KeyError + for the duplicate declaration, _add_symbols_from_table() forwards it to + _handle_symbol_clash(), and _handle_symbol_clash() previously crashed + because it called self.lookup(old_sym.name) before checking for + COMMON-block markers — and the name was absent from the table.''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + decl = "COMMON /ocean/ u, v" + # table1 already has the COMMON block under marker number 10. + table1.add(symbols.DataSymbol( + "_PSYCLONE_INTERNAL_COMMONBLOCK_10", + symbols.UnsupportedFortranType(decl))) + # table2 has the *same* COMMON block under marker number 33 (different + # number, as happens when two routines are independently parsed). + table2.add(symbols.DataSymbol( + "_PSYCLONE_INTERNAL_COMMONBLOCK_33", + symbols.UnsupportedFortranType(decl))) + + table1._add_symbols_from_table(table2) + + # The COMMON block must be present exactly once (no duplicate). + matching = [sym for sym in table1.symbols + if isinstance(sym.datatype, symbols.UnsupportedFortranType) + and sym.datatype.declaration == decl] + assert len(matching) == 1 + + +def test_handle_symbol_clash_commonblock_distinct_blocks_renamed(): + '''Test that _handle_symbol_clash() renames and adds an incoming + COMMON-block marker when block names do not overlap.''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + marker_name = "_PSYCLONE_INTERNAL_COMMONBLOCK_1" + table1.add(symbols.DataSymbol( + marker_name, symbols.UnsupportedFortranType("common /first/ a"))) + table2.add(symbols.DataSymbol( + marker_name, symbols.UnsupportedFortranType("common /second/ b"))) + + old_sym = table2.lookup(marker_name) + table1._handle_symbol_clash(old_sym, table2) + + assert old_sym.name != marker_name + assert any(sym.datatype.declaration == "common /second/ b" + for sym in table1.symbols) + + +def test_handle_symbol_clash_unsupported_fortran_non_commonblock_name(): + '''Test that a clash between UnsupportedFortranType symbols with names + unrelated to common-block markers takes the standard rename-and-add path. + ''' + table1 = symbols.SymbolTable() + table2 = symbols.SymbolTable() + table1.add(symbols.DataSymbol( + "clash", symbols.UnsupportedFortranType("type(t1) :: clash"))) + table2.add(symbols.DataSymbol( + "clash", symbols.UnsupportedFortranType("type(t2) :: clash"))) + + old_sym = table2.lookup("clash") + table1._handle_symbol_clash(old_sym, table2) + + assert old_sym.name != "clash" + assert any(sym.datatype.declaration == "type(t2) :: clash" + for sym in table1.symbols) + + def test_swap_symbol_properties(): ''' Test the symboltable swap_properties method ''' # pylint: disable=too-many-statements diff --git a/src/psyclone/tests/psyir/transformations/inline_trans_test.py b/src/psyclone/tests/psyir/transformations/inline_trans_test.py index cda89d5229..a1b6ee165b 100644 --- a/src/psyclone/tests/psyir/transformations/inline_trans_test.py +++ b/src/psyclone/tests/psyir/transformations/inline_trans_test.py @@ -2223,12 +2223,43 @@ def test_validate_array_reshape(fortran_reader): sub_s = psyir.walk(Routine)[1] with pytest.raises(TransformationError) as err: inline_trans._validate_inline_of_call_and_routine_argument_pairs( - call, call.arguments[0], + call.arguments[0], sub_s, sub_s.symbol_table.lookup("x")) assert ("actual argument 'a(:,:)' has rank 2 but the corresponding formal " "argument, 'x', has rank 1" in str(err.value)) +def test_validate_unknown_type_array_arg(fortran_reader): + '''Test that _validate_inline_of_call_and_routine_argument_pairs rejects + an attempt to inline a call when the actual argument has an unknown type + but the corresponding formal argument is an array.''' + code = """\ +module test_mod +contains +subroutine main + use some_mod, only: mystery + call sub(mystery) +end subroutine +subroutine sub(x) + real, dimension(10), intent(inout) :: x + x(:) = 0.0 +end subroutine +end module +""" + psyir = fortran_reader.psyir_from_source(code) + call = psyir.walk(Call)[0] + sub = psyir.walk(Routine)[1] + inline_trans = InlineTrans() + with pytest.raises(TransformationError) as err: + inline_trans._validate_inline_of_call_and_routine_argument_pairs( + call.arguments[0], sub, sub.symbol_table.lookup("x")) + assert ( + "Routine 'sub' cannot be inlined because the type of the actual " + "argument 'mystery' corresponding to an array formal argument " + "('x') is unknown." in str(err.value) + ) + + def test_validate_array_arg_expression(fortran_reader): ''' Check that validate rejects a call if an argument corresponding to @@ -2843,3 +2874,702 @@ def test_apply_array_access_check_unresolved_override_option( inline_trans.apply( call, use_first_callee_and_no_arg_check=True) # TODO check results + + +def test_apply_common_block_no_duplicate( + fortran_reader, fortran_writer, tmp_path): + '''Test that inlining two routines that share a COMMON block does not + produce duplicate COMMON declarations (which would cause a Fortran compile + error "Symbol X is already in a COMMON block").''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + call sub1() + call sub2() + end subroutine caller + subroutine sub1() + real :: volume, lmmpi + COMMON /blk/ volume, lmmpi + volume = 1.0 + end subroutine sub1 + subroutine sub2() + real :: volume, lmmpi + COMMON /blk/ volume, lmmpi + lmmpi = 2.0 + end subroutine sub2 +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(caller) + # Exactly one COMMON declaration must appear. + assert result.count("COMMON /blk/") == 1 + # Both variables must still be present. + assert "volume" in result + assert "lmmpi" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_common_block_no_duplicate_three_routines( + fortran_reader, fortran_writer, tmp_path): + '''Test that inlining three routines that all share the same COMMON block + does not produce duplicate COMMON declarations. This mirrors the real-world + case of inlining zetabc_tile, u2dbc_tile and v2dbc_tile (each of which + includes the same set of COMMON-block headers) into step2D_FB_tile.''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + call sub1() + call sub2() + call sub3() + end subroutine caller + subroutine sub1() + real :: zeta, ubar, vbar + COMMON /ocean_zeta/ zeta + COMMON /ocean_ubar/ ubar + COMMON /ocean_vbar/ vbar + zeta = 1.0 + end subroutine sub1 + subroutine sub2() + real :: zeta, ubar, vbar + COMMON /ocean_zeta/ zeta + COMMON /ocean_ubar/ ubar + COMMON /ocean_vbar/ vbar + ubar = 2.0 + end subroutine sub2 + subroutine sub3() + real :: zeta, ubar, vbar + COMMON /ocean_zeta/ zeta + COMMON /ocean_ubar/ ubar + COMMON /ocean_vbar/ vbar + vbar = 3.0 + end subroutine sub3 +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + calls = caller.walk(Call) + trans.apply(calls[0]) + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(caller) + # Each COMMON block must appear exactly once. + assert result.count("COMMON /ocean_zeta/") == 1 + assert result.count("COMMON /ocean_ubar/") == 1 + assert result.count("COMMON /ocean_vbar/") == 1 + # All three variables must still be present. + assert "zeta" in result + assert "ubar" in result + assert "vbar" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_common_block_caller_has_extra_block( + fortran_reader, fortran_writer, tmp_path): + '''Test that inlining a routine whose only COMMON block is already present + in the caller does not produce a duplicate COMMON declaration, even when + the caller also has an *additional* COMMON block that the inlined routine + does not declare. This is a regression test derived from the real-world + test.f file: the presence of the extra /comm_setup_mpi1/ block in the + caller was enough to confuse the earlier deduplication logic and caused + "Symbol 'zeta' at (1) is already in a COMMON block".''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + integer :: lmmpi + COMMON /comm_setup_mpi1/ lmmpi + real :: zeta + COMMON /ocean_zeta/ zeta + call subfoo() + end subroutine caller + subroutine subfoo() + real :: zeta + COMMON /ocean_zeta/ zeta + zeta = zeta + 1.0 + end subroutine subfoo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(caller) + # /ocean_zeta/ must appear exactly once – not duplicated. + assert result.count("COMMON /ocean_zeta/") == 1 + # The extra block from the caller must be preserved. + assert result.count("COMMON /comm_setup_mpi1/") == 1 + assert "zeta" in result + assert "lmmpi" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_common_block_accept_different_names( + fortran_reader, fortran_writer, tmp_path): + '''Test that inlining is accepted when the same COMMON block is declared + in both the caller and the callee with different variable names but the + same type. The callee's variable ('height') is an alias of the caller's + variable ('depth') at the same block position, so all references to + 'height' inside the inlined body must be replaced by 'depth'. + ''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + COMMON /ocean/ depth + real :: depth + integer :: b + call subfoo(b) + end subroutine caller + subroutine subfoo(a) + COMMON /ocean/ height + real :: height + integer :: a + + a = height + end subroutine subfoo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + trans.apply(calls[0]) + + result = fortran_writer(psyir) + # 'height' must have been replaced by 'depth' (the caller's alias). + assert """\ + subroutine caller() + real :: depth + integer :: b + COMMON /ocean/ depth + + b = depth + + end subroutine caller +""" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_common_block_reject_due_to_different_types( + fortran_reader, fortran_writer, tmp_path): + '''Test that inlining is rejected when the same COMMON block is declared + in both the caller and the callee with different variable names and + different types. + ''' + + src = """\ +module test_mod + implicit none +contains + subroutine caller() + COMMON /ocean/ depth + real(kind=4) :: depth + integer :: b + call subfoo(b) + end subroutine caller + subroutine subfoo(a) + COMMON /ocean/ height + real(kind=8) :: height + integer :: a + + a = 3 + end subroutine subfoo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(src) + caller = psyir.walk(Routine)[0] + + trans = InlineTrans() + calls = caller.walk(Call) + with pytest.raises(TransformationError) as err: + # Should raise a TransformationError because + # the types of the common-block variables differ. + trans.apply(calls[0]) + assert ("Cannot inline 'subfoo' because COMMON block '/ocean/' maps" + " 'depth' (type 'Scalar]>') in the caller to 'height'" + " (type 'Scalar]>') in the routine being inlined - the types" + " are incompatible.") in str(err.value) + + +def test_apply_parameter_cloning_default( + fortran_reader, fortran_writer, tmp_path): + '''Test that the default behaviour (parameter_cloning=True) clones a + constant from the inlined routine into the call-site table, even when + an identical constant already exists there, potentially renaming it.''' + code = """\ +module test_mod +contains + subroutine bar(b) + real, parameter :: constval = 123.0 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo) + + result = fortran_writer(bar) + # With cloning enabled the inlined constant must appear at least once; + # it may be renamed to avoid the clash. + assert "constval" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_identical(fortran_reader, + fortran_writer, tmp_path): + '''Test that parameter_cloning=False suppresses the duplicate when the + call-site already has an identical constant (same name, type, value). + This is the main use-case from the user request.''' + code = """\ +module test_mod +contains + subroutine bar(b) + real, parameter :: constval = 123.0 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # constval should be declared exactly once (no duplicate parameter). + assert result.count("parameter :: constval") == 1 + # The inlined assignment should still use constval correctly. + assert "constval" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_different_value( + fortran_reader, fortran_writer, tmp_path): + '''Test that parameter_cloning=False does NOT suppress a parameter when + the values differ between the call site and the inlined routine.''' + code = """\ +module test_mod +contains + subroutine bar(b) + real, parameter :: constval = 42.0 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # Both constant declarations must survive since they have different values. + assert result.count("constval") >= 2 + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_no_match_in_caller( + fortran_reader, fortran_writer, tmp_path): + '''Test that parameter_cloning=False still adds a constant that does not + exist at the call site.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 123.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # constval from foo must be added to bar because bar didn't have it. + assert "constval" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_used_in_array_dim( + fortran_reader, fortran_writer, tmp_path): + '''Test that parameter_cloning=False correctly handles a constant that + is used as an array-dimension bound inside the inlined routine.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: n = 5 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: n = 5 + real, dimension(n) :: tmp + integer :: a + tmp(1) = real(a) + a = int(tmp(1)) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # n should appear only once as a parameter declaration. + assert result.count(", parameter ::", result.lower().find("n =")) <= 1 \ + or result.count("n = 5") == 1 + # The inlined array tmp should still be present and use n. + assert "tmp" in result + assert "n" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_multiple_params( + fortran_reader, fortran_writer, tmp_path): + '''Test parameter_cloning=False with multiple constants, some matching + and some not.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: shared = 10 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: shared = 10 + integer, parameter :: local_only = 99 + integer :: a + a = shared + local_only + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # shared must be declared exactly once (no duplicate parameter). + assert result.count("parameter :: shared") == 1 + # local_only is unique to foo, so it must be added to bar. + assert "local_only" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_complex_rhs_identical( + fortran_reader, fortran_writer, tmp_path): + '''Test parameter_cloning=False with constants whose value is a complex + PSyIR expression (BinaryOperation) that is identical in the caller and the + routine. The duplicate should be suppressed.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 123 + base_val + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 123 + base_val + integer :: a + a = constval + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # Neither base_val nor constval should be duplicated. + assert result.count("parameter :: constval") == 1 + assert result.count("parameter :: base_val") == 1 + # The inlined body should still reference constval. + assert "constval" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_complex_rhs_different( + fortran_reader, fortran_writer, tmp_path): + '''Test parameter_cloning=False with constants that have identical names + but different complex RHS expressions. Both declarations must be kept.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 100 + base_val + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: base_val = 10 + integer, parameter :: constval = 123 + base_val + integer :: a + a = constval + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # constval has different values in bar and foo, so both must appear. + assert result.count("parameter :: constval") >= 2 or ( + "constval" in result and "constval_1" in result) + # base_val is identical and should be deduplicated. + assert result.count("parameter :: base_val") == 1 + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_unary_op_different_base( + fortran_reader, fortran_writer, tmp_path): + '''Test parameter_cloning=False where .NOT. parameters share a name but + their base parameter differs. The derived constant must NOT be deduplicated + because the structural match is only nominal (the base has different + values), and using the caller's copy would produce wrong semantics.''' + code = """\ +module test_mod +contains + subroutine bar(b) + logical, parameter :: flag = .true. + logical, parameter :: negflag = .not. flag + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + logical, parameter :: flag = .false. + logical, parameter :: negflag = .not. flag + integer :: a + if (negflag) a = 42 + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # flag has different values so both must appear (foo's renamed). + assert result.count("parameter :: flag") >= 2 or "flag_1" in result + # negflag depends on flag which differs, so foo's negflag must also + # appear (renamed), and the inlined if must use foo's (renamed) negflag. + assert "negflag_1" in result + assert "if (negflag_1)" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_type_clash( + fortran_reader, fortran_writer): + '''Test with parameter types that don't match.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: wp = kind(1.0d0) + real(kind=wp), parameter :: pi = 3.14592 + real :: tmp + integer :: b + + tmp = wp + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: wp = kind(1.0) + integer :: a + + a = 42 * wp + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + assert """\ +subroutine bar(b) + integer, parameter :: wp = KIND(1.0d0) + real(kind=wp), parameter :: pi = 3.14592 + integer, parameter :: wp_1 = KIND(1.0) + integer :: b + real :: tmp + + tmp = wp + b = 42 * wp_1 + +end subroutine bar""" in result + + +def test_apply_parameter_cloning_false_caller_has_non_constant( + fortran_reader, fortran_writer, tmp_path): + '''Test that parameter_cloning=False does NOT suppress a routine constant + when the call-site has a symbol with the same name that is not a constant + (i.e. tsym.is_constant is False). This exercises the + ``if not tsym.is_constant or tsym.initial_value is None`` branch in + _redirect_duplicate_parameters.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer :: constval + integer :: b + constval = 7 + call foo(b) + end subroutine bar + subroutine foo(a) + integer, parameter :: constval = 10 + integer :: a + a = constval + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # bar's constval is a variable; foo's is a parameter. They are not + # duplicates, so foo's parameter constant must appear (possibly renamed). + assert """\ +subroutine bar(b) + integer, parameter :: constval_1 = 10 + integer :: b + integer :: constval + + constval = 7 + b = constval_1 + +end subroutine bar""" in result + assert Compile(tmp_path).string_compiles(result) + + +def test_apply_parameter_cloning_false_different_datatype( + fortran_reader, fortran_writer, tmp_path): + '''Test that parameter_cloning=False does NOT suppress a routine constant + when the call-site has a constant with the same name but a different + datatype. This exercises the ``if rsym.datatype != tsym.datatype`` + branch in _redirect_duplicate_parameters.''' + code = """\ +module test_mod +contains + subroutine bar(b) + integer, parameter :: constval = 10 + integer :: b + call foo(b) + end subroutine bar + subroutine foo(a) + real, parameter :: constval = 10.0 + integer :: a + a = int(constval) + end subroutine foo +end module test_mod +""" + psyir = fortran_reader.psyir_from_source(code) + bar = psyir.walk(Routine)[0] + foo = psyir.walk(Routine)[1] + call = bar.walk(Call)[0] + + InlineTrans().apply(call, routine=foo, parameter_cloning=False) + + result = fortran_writer(bar) + # bar has integer constval=10, foo has real constval=10.0. Different + # types so the routine's parameter must be added (renamed) rather than + # deduplicated. + assert """\ +subroutine bar(b) + integer, parameter :: constval = 10 + real, parameter :: constval_1 = 10.0 + integer :: b + + b = INT(constval_1) + +end subroutine bar""" in result + assert Compile(tmp_path).string_compiles(result)