Skip to content

Commit 8b6642d

Browse files
authored
[mypyc] Fix free-threading race condition in argument parsing (#21613)
Global argument parser state initialization was missing synchronization. Don't use a mutex on the hot path as an optimization. Fixes #21578 (probably). I added a regression test that failed about 30% of time on master (on macOS). I used coding agent assist.
1 parent 0f9676e commit 8b6642d

2 files changed

Lines changed: 147 additions & 8 deletions

File tree

mypyc/lib-rt/getargsfast.c

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,22 @@
1818
#include <Python.h>
1919
#include "CPy.h"
2020

21-
#define PARSER_INITED(parser) ((parser)->kwtuple != NULL)
21+
// The kwtuple field doubles as the "parser has been initialized" flag: it is
22+
// written last (after all other parser fields) and read first. On free-threaded
23+
// builds the fast paths read it without holding any lock, so the write must be a
24+
// release store and the reads acquire loads. That way a thread that observes a
25+
// non-NULL kwtuple is guaranteed to also see the fully-initialized min/max/etc.
26+
// fields. On GIL builds these are plain accesses with no overhead.
27+
#ifdef Py_GIL_DISABLED
28+
#define PARSER_KWTUPLE(parser) _Py_atomic_load_ptr_acquire(&(parser)->kwtuple)
29+
#define SET_PARSER_KWTUPLE(parser, value) \
30+
_Py_atomic_store_ptr_release(&(parser)->kwtuple, (value))
31+
#else
32+
#define PARSER_KWTUPLE(parser) ((parser)->kwtuple)
33+
#define SET_PARSER_KWTUPLE(parser, value) ((parser)->kwtuple = (value))
34+
#endif
35+
36+
#define PARSER_INITED(parser) (PARSER_KWTUPLE(parser) != NULL)
2237

2338
/* Forward */
2439
static int
@@ -115,19 +130,21 @@ CPyArg_ParseStackAndKeywordsSimple(PyObject *const *args, Py_ssize_t nargs, PyOb
115130
/* List of static parsers. */
116131
static struct CPyArg_Parser *static_arg_parsers = NULL;
117132

133+
#ifdef Py_GIL_DISABLED
134+
// Serializes one-time initialization of parsers and insertion into the
135+
// static_arg_parsers list. Only contended the first time a given compiled
136+
// function is called; once a parser is initialized the fast paths never lock.
137+
static PyMutex static_arg_parsers_mutex;
138+
#endif
139+
118140
static int
119-
parser_init(CPyArg_Parser *parser)
141+
parser_init_locked(CPyArg_Parser *parser)
120142
{
121143
const char * const *keywords;
122144
const char *format;
123145
int i, len, min, max, nkw;
124146
PyObject *kwtuple;
125147

126-
assert(parser->keywords != NULL);
127-
if (PARSER_INITED(parser)) {
128-
return 1;
129-
}
130-
131148
keywords = parser->keywords;
132149
/* scan keywords and count the number of positional-only parameters */
133150
for (i = 0; keywords[i] && !*keywords[i]; i++) {
@@ -244,14 +261,53 @@ parser_init(CPyArg_Parser *parser)
244261
PyUnicode_InternInPlace(&str);
245262
PyTuple_SET_ITEM(kwtuple, i, str);
246263
}
247-
parser->kwtuple = kwtuple;
248264

249265
assert(parser->next == NULL);
250266
parser->next = static_arg_parsers;
251267
static_arg_parsers = parser;
268+
269+
// Publish the parser last: storing kwtuple marks it as initialized, so all
270+
// other fields (and the list insertion above) must already be in place. On
271+
// free-threaded builds this is a release store paired with the acquire loads
272+
// in PARSER_INITED/PARSER_KWTUPLE.
273+
SET_PARSER_KWTUPLE(parser, kwtuple);
252274
return 1;
253275
}
254276

277+
// Cold path of parser_init: perform the one-time initialization. On
278+
// free-threaded builds this is serialized so that only one thread builds the
279+
// parser and inserts it into the static_arg_parsers list.
280+
static CPy_NOINLINE int
281+
parser_init_slow(CPyArg_Parser *parser)
282+
{
283+
#ifdef Py_GIL_DISABLED
284+
PyMutex_Lock(&static_arg_parsers_mutex);
285+
// Re-check now that we hold the lock: another thread may have initialized
286+
// the parser while we were waiting.
287+
if (PARSER_INITED(parser)) {
288+
PyMutex_Unlock(&static_arg_parsers_mutex);
289+
return 1;
290+
}
291+
int retval = parser_init_locked(parser);
292+
PyMutex_Unlock(&static_arg_parsers_mutex);
293+
return retval;
294+
#else
295+
return parser_init_locked(parser);
296+
#endif
297+
}
298+
299+
// Hot path: a parser is almost always already initialized, so keep the common
300+
// case inline and branch out to parser_init_slow only on first use.
301+
static inline int
302+
parser_init(CPyArg_Parser *parser)
303+
{
304+
assert(parser->keywords != NULL);
305+
if (likely(PARSER_INITED(parser))) {
306+
return 1;
307+
}
308+
return parser_init_slow(parser);
309+
}
310+
255311
static PyObject*
256312
find_keyword(PyObject *kwnames, PyObject *const *kwstack, PyObject *key)
257313
{

mypyc/test-data/run-functions.test

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,3 +1524,86 @@ def test_multiple_params() -> None:
15241524
# Different parameter name than loop variable
15251525
funcs2 = uses_multiple_params_different_name(["a", "b"], "!")
15261526
assert [f() for f in funcs2] == ["a!", "b!"]
1527+
1528+
[case testConcurrentFirstCallWithKeywordArgs]
1529+
# Regression test for a free-threading data race in argument parser
1530+
# initialization. The first call to a compiled function with keyword arguments
1531+
# lazily initializes a static CPyArg_Parser and pushes it onto a global list.
1532+
# When many threads race that first call on a free-threaded build, the
1533+
# initialization and list insertion must be synchronized, or the runtime hits
1534+
# a failed assertion (parser->next == NULL) and aborts.
1535+
import sys
1536+
import threading
1537+
1538+
# A pool of distinct functions, none of which is called before the concurrent
1539+
# test. Each function's first call lazily initializes a static CPyArg_Parser,
1540+
# and the test races that initialization across threads. Keyword arguments
1541+
# force the call through the slow path that runs parser_init. Using many
1542+
# functions gives many independent races per run, so the test reliably triggers
1543+
# the bug on an unfixed free-threaded build instead of depending on a single
1544+
# narrow timing window.
1545+
def g0(a: int, b: int, c: int) -> int: return a + b + c
1546+
def g1(a: int, b: int, c: int) -> int: return a + b + c
1547+
def g2(a: int, b: int, c: int) -> int: return a + b + c
1548+
def g3(a: int, b: int, c: int) -> int: return a + b + c
1549+
def g4(a: int, b: int, c: int) -> int: return a + b + c
1550+
def g5(a: int, b: int, c: int) -> int: return a + b + c
1551+
def g6(a: int, b: int, c: int) -> int: return a + b + c
1552+
def g7(a: int, b: int, c: int) -> int: return a + b + c
1553+
def g8(a: int, b: int, c: int) -> int: return a + b + c
1554+
def g9(a: int, b: int, c: int) -> int: return a + b + c
1555+
def g10(a: int, b: int, c: int) -> int: return a + b + c
1556+
def g11(a: int, b: int, c: int) -> int: return a + b + c
1557+
def g12(a: int, b: int, c: int) -> int: return a + b + c
1558+
def g13(a: int, b: int, c: int) -> int: return a + b + c
1559+
def g14(a: int, b: int, c: int) -> int: return a + b + c
1560+
def g15(a: int, b: int, c: int) -> int: return a + b + c
1561+
def g16(a: int, b: int, c: int) -> int: return a + b + c
1562+
def g17(a: int, b: int, c: int) -> int: return a + b + c
1563+
def g18(a: int, b: int, c: int) -> int: return a + b + c
1564+
def g19(a: int, b: int, c: int) -> int: return a + b + c
1565+
def g20(a: int, b: int, c: int) -> int: return a + b + c
1566+
def g21(a: int, b: int, c: int) -> int: return a + b + c
1567+
def g22(a: int, b: int, c: int) -> int: return a + b + c
1568+
def g23(a: int, b: int, c: int) -> int: return a + b + c
1569+
def g24(a: int, b: int, c: int) -> int: return a + b + c
1570+
def g25(a: int, b: int, c: int) -> int: return a + b + c
1571+
def g26(a: int, b: int, c: int) -> int: return a + b + c
1572+
def g27(a: int, b: int, c: int) -> int: return a + b + c
1573+
def g28(a: int, b: int, c: int) -> int: return a + b + c
1574+
def g29(a: int, b: int, c: int) -> int: return a + b + c
1575+
1576+
FUNCS = [g0, g1, g2, g3, g4, g5, g6, g7, g8, g9, g10, g11, g12, g13, g14,
1577+
g15, g16, g17, g18, g19, g20, g21, g22, g23, g24, g25, g26, g27,
1578+
g28, g29]
1579+
1580+
def is_gil_disabled() -> bool:
1581+
return hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled()
1582+
1583+
def test_concurrent_first_call() -> None:
1584+
if not is_gil_disabled():
1585+
# The race can only happen without the GIL.
1586+
return
1587+
1588+
num_threads = 16
1589+
barrier = threading.Barrier(num_threads)
1590+
errors: list[str] = []
1591+
1592+
def run() -> None:
1593+
# Line up all threads, then let them race freely through the list. The
1594+
# first call to each function lazily initializes its parser, so every
1595+
# function is a fresh race under real parallel pressure.
1596+
barrier.wait()
1597+
try:
1598+
for fn in FUNCS:
1599+
assert fn(a=1, b=2, c=3) == 6
1600+
except BaseException as e:
1601+
errors.append(repr(e))
1602+
1603+
threads = [threading.Thread(target=run) for _ in range(num_threads)]
1604+
for t in threads:
1605+
t.start()
1606+
for t in threads:
1607+
t.join()
1608+
1609+
assert not errors, errors

0 commit comments

Comments
 (0)