|
1 | 1 | import type { ContractDefinition, FunctionDefinition } from 'solidity-ast'; |
2 | 2 | import { ASTDereferencer, findAll } from 'solidity-ast/utils'; |
3 | 3 | import { SrcDecoder } from '../../src-decoder'; |
4 | | -import { ValidationExceptionInitializer, skipCheck } from '../run'; |
| 4 | +import { ValidationExceptionInitializer, skipCheck, tryDerefFunction } from '../run'; |
5 | 5 |
|
6 | 6 | /** |
7 | 7 | * Reports if this contract is non-abstract and any of the following are true: |
@@ -141,6 +141,27 @@ function getParentsNotInitializedByOtherParents( |
141 | 141 | return remainingParents; |
142 | 142 | } |
143 | 143 |
|
| 144 | +/** |
| 145 | + * Calls the callback if the referenced function definition is found in the AST. |
| 146 | + * Otherwise, does nothing. |
| 147 | + * |
| 148 | + * @param deref AST dereferencer |
| 149 | + * @param referencedDeclaration ID of the referenced function |
| 150 | + * @param callback Function to call if the referenced function definition is found |
| 151 | + */ |
| 152 | +function doIfReferencedFunctionFound( |
| 153 | + deref: ASTDereferencer, |
| 154 | + referencedDeclaration: number | null | undefined, |
| 155 | + callback: (functionDef: FunctionDefinition) => void, |
| 156 | +) { |
| 157 | + if (referencedDeclaration && referencedDeclaration > 0) { |
| 158 | + const functionDef = tryDerefFunction(deref, referencedDeclaration); |
| 159 | + if (functionDef !== undefined) { |
| 160 | + callback(functionDef); |
| 161 | + } |
| 162 | + } |
| 163 | +} |
| 164 | + |
144 | 165 | /** |
145 | 166 | * Reports exceptions for missing initializer calls, duplicate initializer calls, and incorrect initializer order. |
146 | 167 | * |
@@ -176,10 +197,9 @@ function* getInitializerCallExceptions( |
176 | 197 | (fnCall.expression.nodeType === 'Identifier' || fnCall.expression.nodeType === 'MemberAccess') |
177 | 198 | ) { |
178 | 199 | let recursiveFunctionIds: number[] = []; |
179 | | - const referencedFn = fnCall.expression.referencedDeclaration; |
180 | | - if (referencedFn && referencedFn > 0) { |
181 | | - recursiveFunctionIds = getRecursiveFunctionIds(referencedFn, deref); |
182 | | - } |
| 200 | + doIfReferencedFunctionFound(deref, fnCall.expression.referencedDeclaration, (functionDef: FunctionDefinition) => { |
| 201 | + recursiveFunctionIds = getRecursiveFunctionIds(deref, functionDef); |
| 202 | + }); |
183 | 203 |
|
184 | 204 | // For each recursively called function, if it is a parent initializer, then: |
185 | 205 | // - Check if it was already called (duplicate call) |
@@ -258,38 +278,41 @@ function* getInitializerCallExceptions( |
258 | 278 | /** |
259 | 279 | * Gets the IDs of all functions that are recursively called by the given function, including the given function itself at the end of the list. |
260 | 280 | * |
261 | | - * @param referencedFn The ID of the function to start from |
262 | 281 | * @param deref AST dereferencer |
| 282 | + * @param functionDef The node of the function definition to start from |
263 | 283 | * @param visited Set of function IDs that have already been visited |
264 | 284 | * @returns The IDs of all functions that are recursively called by the given function, including the given function itself at the end of the list. |
265 | 285 | */ |
266 | | -function getRecursiveFunctionIds(referencedFn: number, deref: ASTDereferencer, visited?: Set<number>): number[] { |
| 286 | +function getRecursiveFunctionIds( |
| 287 | + deref: ASTDereferencer, |
| 288 | + functionDef: FunctionDefinition, |
| 289 | + visited?: Set<number>, |
| 290 | +): number[] { |
267 | 291 | const result: number[] = []; |
268 | 292 |
|
269 | 293 | if (visited === undefined) { |
270 | 294 | visited = new Set(); |
271 | 295 | } |
272 | | - if (visited.has(referencedFn)) { |
| 296 | + if (visited.has(functionDef.id)) { |
273 | 297 | return result; |
274 | 298 | } else { |
275 | | - visited.add(referencedFn); |
| 299 | + visited.add(functionDef.id); |
276 | 300 | } |
277 | 301 |
|
278 | | - const fn = deref('FunctionDefinition', referencedFn); |
279 | | - const expressionStatements = fn.body?.statements?.filter(stmt => stmt.nodeType === 'ExpressionStatement') ?? []; |
| 302 | + const expressionStatements = |
| 303 | + functionDef.body?.statements?.filter(stmt => stmt.nodeType === 'ExpressionStatement') ?? []; |
280 | 304 | for (const stmt of expressionStatements) { |
281 | 305 | const fnCall = stmt.expression; |
282 | 306 | if ( |
283 | 307 | fnCall.nodeType === 'FunctionCall' && |
284 | 308 | (fnCall.expression.nodeType === 'Identifier' || fnCall.expression.nodeType === 'MemberAccess') |
285 | 309 | ) { |
286 | | - const referencedId = fnCall.expression.referencedDeclaration; |
287 | | - if (referencedId && referencedId > 0) { |
288 | | - result.push(...getRecursiveFunctionIds(referencedId, deref, visited)); |
289 | | - } |
| 310 | + doIfReferencedFunctionFound(deref, fnCall.expression.referencedDeclaration, (functionDef: FunctionDefinition) => { |
| 311 | + result.push(...getRecursiveFunctionIds(deref, functionDef, visited)); |
| 312 | + }); |
290 | 313 | } |
291 | 314 | } |
292 | | - result.push(referencedFn); |
| 315 | + result.push(functionDef.id); |
293 | 316 |
|
294 | 317 | return result; |
295 | 318 | } |
|
0 commit comments