|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import asyncio |
4 | | -import contextlib |
5 | | -import inspect |
| 4 | +import dis |
6 | 5 | from collections.abc import Sequence |
7 | 6 | from typing import Any, Callable, Literal, cast, overload |
8 | 7 |
|
@@ -116,23 +115,37 @@ def __init__( |
116 | 115 | target: str | None = None, |
117 | 116 | ) -> None: |
118 | 117 | self.function = to_event_handler_function(function, positional_args=False) |
119 | | - |
120 | | - if not (stop_propagation and prevent_default): |
121 | | - with contextlib.suppress(Exception): |
122 | | - func_to_inspect = cast(Any, function) |
123 | | - while hasattr(func_to_inspect, "__wrapped__"): |
124 | | - func_to_inspect = func_to_inspect.__wrapped__ |
125 | | - |
126 | | - source = inspect.getsource(func_to_inspect) |
127 | | - if not stop_propagation and ".stopPropagation()" in source: |
128 | | - stop_propagation = True |
129 | | - if not prevent_default and ".preventDefault()" in source: |
130 | | - prevent_default = True |
131 | | - |
132 | 118 | self.prevent_default = prevent_default |
133 | 119 | self.stop_propagation = stop_propagation |
134 | 120 | self.target = target |
135 | 121 |
|
| 122 | + # Check if our `preventDefault` or `stopPropagation` methods were called |
| 123 | + # by inspecting the function's bytecode |
| 124 | + func_to_inspect = cast(Any, function) |
| 125 | + while hasattr(func_to_inspect, "__wrapped__"): |
| 126 | + func_to_inspect = func_to_inspect.__wrapped__ |
| 127 | + |
| 128 | + code = func_to_inspect.__code__ |
| 129 | + if code.co_argcount > 0: |
| 130 | + event_arg_name = code.co_varnames[0] |
| 131 | + last_was_event = False |
| 132 | + |
| 133 | + for instr in dis.get_instructions(func_to_inspect): |
| 134 | + if instr.opname == "LOAD_FAST" and instr.argval == event_arg_name: |
| 135 | + last_was_event = True |
| 136 | + continue |
| 137 | + |
| 138 | + if last_was_event and instr.opname in ( |
| 139 | + "LOAD_METHOD", |
| 140 | + "LOAD_ATTR", |
| 141 | + ): |
| 142 | + if instr.argval == "preventDefault": |
| 143 | + self.prevent_default = True |
| 144 | + elif instr.argval == "stopPropagation": |
| 145 | + self.stop_propagation = True |
| 146 | + |
| 147 | + last_was_event = False |
| 148 | + |
136 | 149 | __hash__ = None # type: ignore |
137 | 150 |
|
138 | 151 | def __eq__(self, other: object) -> bool: |
|
0 commit comments