Skip to content

Commit 0ff4be2

Browse files
committed
Add BINARY_OP and CALL adaptive specialization
BINARY_OP: Specialize int add/subtract/multiply and float add/subtract/multiply with type guards and deoptimization. CALL: Add func_version to PyFunction, specialize simple function calls (CallPyExactArgs, CallBoundMethodExactArgs) with invoke_exact_args fast path that skips FuncArgs allocation and fill_locals_from_args.
1 parent 111acb5 commit 0ff4be2

File tree

2 files changed

+413
-12
lines changed

2 files changed

+413
-12
lines changed

crates/vm/src/builtins/function.rs

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use super::{
99
use crate::common::lock::OnceCell;
1010
use crate::common::lock::PyMutex;
1111
use crate::function::ArgMapping;
12+
use core::sync::atomic::{AtomicU32, Ordering::Relaxed};
1213
use crate::object::{Traverse, TraverseFn};
1314
use crate::{
1415
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
@@ -72,10 +73,13 @@ pub struct PyFunction {
7273
annotate: PyMutex<Option<PyObjectRef>>,
7374
module: PyMutex<PyObjectRef>,
7475
doc: PyMutex<PyObjectRef>,
76+
func_version: AtomicU32,
7577
#[cfg(feature = "jit")]
7678
jitted_code: OnceCell<CompiledCode>,
7779
}
7880

81+
static FUNC_VERSION_COUNTER: AtomicU32 = AtomicU32::new(1);
82+
7983
unsafe impl Traverse for PyFunction {
8084
fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
8185
self.globals.traverse(tracer_fn);
@@ -200,6 +204,7 @@ impl PyFunction {
200204
annotate: PyMutex::new(None),
201205
module: PyMutex::new(module),
202206
doc: PyMutex::new(doc),
207+
func_version: AtomicU32::new(FUNC_VERSION_COUNTER.fetch_add(1, Relaxed)),
203208
#[cfg(feature = "jit")]
204209
jitted_code: OnceCell::new(),
205210
};
@@ -592,6 +597,66 @@ impl Py<PyFunction> {
592597
pub fn invoke(&self, func_args: FuncArgs, vm: &VirtualMachine) -> PyResult {
593598
self.invoke_with_locals(func_args, None, vm)
594599
}
600+
601+
/// Returns the function version, or 0 if invalidated.
602+
#[inline]
603+
pub fn func_version(&self) -> u32 {
604+
self.func_version.load(Relaxed)
605+
}
606+
607+
/// Check if this function is eligible for exact-args call specialization.
608+
/// Returns true if: no VARARGS, no VARKEYWORDS, no kwonly args, not generator/coroutine,
609+
/// and effective_nargs matches co_argcount.
610+
pub(crate) fn can_specialize_call(&self, effective_nargs: u32) -> bool {
611+
let code = self.code.lock();
612+
let flags = code.flags;
613+
!flags.intersects(
614+
bytecode::CodeFlags::VARARGS
615+
| bytecode::CodeFlags::VARKEYWORDS
616+
| bytecode::CodeFlags::GENERATOR
617+
| bytecode::CodeFlags::COROUTINE,
618+
) && code.kwonlyarg_count == 0
619+
&& code.arg_count == effective_nargs
620+
}
621+
622+
/// Fast path for calling a simple function with exact positional args.
623+
/// Skips FuncArgs allocation, prepend_arg, and fill_locals_from_args.
624+
/// Only valid when: no VARARGS, no VARKEYWORDS, no kwonlyargs, not generator/coroutine,
625+
/// and nargs == co_argcount.
626+
pub fn invoke_exact_args(&self, args: &[PyObjectRef], vm: &VirtualMachine) -> PyResult {
627+
let code = self.code.lock().clone();
628+
629+
let locals = ArgMapping::from_dict_exact(vm.ctx.new_dict());
630+
631+
let frame = Frame::new(
632+
code.clone(),
633+
Scope::new(Some(locals), self.globals.clone()),
634+
self.builtins.clone(),
635+
self.closure.as_ref().map_or(&[], |c| c.as_slice()),
636+
Some(self.to_owned().into()),
637+
vm,
638+
)
639+
.into_ref(&vm.ctx);
640+
641+
// Copy args directly into fastlocals
642+
{
643+
let fastlocals = unsafe { frame.fastlocals.borrow_mut() };
644+
for (i, arg) in args.iter().enumerate() {
645+
fastlocals[i] = Some(arg.clone());
646+
}
647+
}
648+
649+
// Handle cell2arg
650+
if let Some(cell2arg) = code.cell2arg.as_deref() {
651+
let fastlocals = unsafe { frame.fastlocals.borrow_mut() };
652+
for (cell_idx, arg_idx) in cell2arg.iter().enumerate().filter(|(_, i)| **i != -1) {
653+
let x = fastlocals[*arg_idx as usize].take();
654+
frame.set_cell_contents(cell_idx, x);
655+
}
656+
}
657+
658+
vm.run_frame(frame)
659+
}
595660
}
596661

597662
impl PyPayload for PyFunction {
@@ -614,12 +679,7 @@ impl PyFunction {
614679
#[pygetset(setter)]
615680
fn set___code__(&self, code: PyRef<PyCode>) {
616681
*self.code.lock() = code;
617-
// TODO: jit support
618-
// #[cfg(feature = "jit")]
619-
// {
620-
// // If available, clear cached compiled code.
621-
// let _ = self.jitted_code.take();
622-
// }
682+
self.func_version.store(0, Relaxed);
623683
}
624684

625685
#[pygetset]
@@ -628,7 +688,8 @@ impl PyFunction {
628688
}
629689
#[pygetset(setter)]
630690
fn set___defaults__(&self, defaults: Option<PyTupleRef>) {
631-
self.defaults_and_kwdefaults.lock().0 = defaults
691+
self.defaults_and_kwdefaults.lock().0 = defaults;
692+
self.func_version.store(0, Relaxed);
632693
}
633694

634695
#[pygetset]
@@ -637,7 +698,8 @@ impl PyFunction {
637698
}
638699
#[pygetset(setter)]
639700
fn set___kwdefaults__(&self, kwdefaults: Option<PyDictRef>) {
640-
self.defaults_and_kwdefaults.lock().1 = kwdefaults
701+
self.defaults_and_kwdefaults.lock().1 = kwdefaults;
702+
self.func_version.store(0, Relaxed);
641703
}
642704

643705
// {"__closure__", T_OBJECT, OFF(func_closure), READONLY},

0 commit comments

Comments
 (0)