Skip to content

Commit 165b76e

Browse files
mesonepigrecoclaude
andcommitted
Lazy Julia loading via juliacall: import sscha.Ensemble 200s -> 6s
Companion patch of tdscha's fast_julia_startup. sscha.Ensemble used to boot the Julia runtime through PyJulia at import time; on any machine where PyCall.jl is not built for the exact running libpython this falls back to Julia(compiled_modules=False) and recompiles PyCall at every Python launch (~3 minutes per import). - New Modules/JuliaExt.py lazy bridge: juliacall (PythonCall.jl) backend by default (no libpython coupling, auto-installs Julia via the shipped juliapkg.json), legacy pyjulia reused if already booted in the process. - Ensemble.py: the eager init block is replaced by a lazy stand-in object, so all julia.Main.* call sites work unchanged and the runtime boots at the first fourier-gradient evaluation instead of at import. - __JULIA_EXT__ keeps its meaning (backend availability, no boot): SchaMinimizer.use_julia and Ensemble.fourier_gradient defaults work as before; no change needed there. - Julia keeps running single-threaded by default, as with PyJulia (JULIA_NUM_THREADS is honored). - New optional extra: pip install python-sscha[julia] -> juliacall. - SSCHA_JULIA_BACKEND env var: juliacall | pyjulia (legacy) | none. Design and measurements: julia_design.md in the tdscha repository. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
1 parent c4011c5 commit 165b76e

5 files changed

Lines changed: 293 additions & 28 deletions

File tree

Modules/Ensemble.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -76,34 +76,30 @@
7676
__EPSILON__ = 1e-6
7777
__A_TO_BOHR__ = 1.889725989
7878

79-
__JULIA_EXT__ = False
79+
# The Julia runtime is booted lazily by JuliaExt at the first actual use
80+
# (e.g. the first fourier gradient evaluation), so that importing
81+
# sscha.Ensemble stays fast.
82+
import sscha.JuliaExt as JuliaExt
83+
84+
85+
class _LazyJuliaModule(object):
86+
"""Lazy stand-in for the old "import julia" module.
87+
88+
Accessing the .Main attribute initializes the Julia runtime (and includes
89+
fourier_gradient.jl) on first use. All the julia.Main.xxx(...) call sites
90+
below keep working unchanged with both the juliacall and pyjulia backends.
91+
"""
92+
@property
93+
def Main(self):
94+
return JuliaExt.get_main()
95+
96+
97+
julia = _LazyJuliaModule()
98+
99+
# Deprecated alias kept for backward compatibility: it only tells whether a
100+
# Julia backend is installed, the runtime is not initialized at import time.
101+
__JULIA_EXT__ = JuliaExt.available()
80102
__JULIA_ERROR__ = ""
81-
try:
82-
import julia, julia.Main
83-
julia.Main.include(os.path.join(os.path.dirname(__file__),
84-
"fourier_gradient.jl"))
85-
__JULIA_EXT__ = True
86-
except:
87-
try:
88-
import julia
89-
try:
90-
from julia.api import Julia
91-
jl = Julia(compiled_modules=False)
92-
import julia.Main
93-
julia.Main.include(os.path.join(os.path.dirname(__file__),
94-
"fourier_gradient.jl"))
95-
__JULIA_EXT__ = True
96-
except:
97-
# Install the required modules
98-
julia.install()
99-
try:
100-
julia.Main.include(os.path.join(os.path.dirname(__file__),
101-
"fourier_gradient.jl"))
102-
__JULIA_EXT__ = True
103-
except Exception as e:
104-
warnings.warn("Julia extension not available.\nError: {}".format(e))
105-
except Exception as e:
106-
warnings.warn("Julia extension not available.\nError: {}".format(e))
107103

108104

109105
try:

Modules/JuliaExt.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
"""
2+
Lazy bridge to the Julia runtime.
3+
4+
This module is the single entry point for all the Julia calls of python-sscha.
5+
The Julia runtime is NOT booted when this module (or sscha) is imported;
6+
it is initialized on the first call to :func:`get_main`, so that users who do
7+
not use ``fourier gradient`` never pay the startup cost.
8+
9+
Two backends are supported:
10+
11+
* ``juliacall`` (PythonCall.jl) — the default. It has no libpython coupling
12+
(works with any Python interpreter, no ``python-jl`` needed) and installs
13+
Julia automatically on a fresh machine through ``juliapkg``.
14+
* ``pyjulia`` (PyCall.jl) — legacy. It is used only if juliacall is not
15+
installed, or if another package (e.g. an old python-sscha) has already
16+
booted the PyJulia runtime in this process: only one Julia runtime can
17+
exist per process, so in that case we must reuse it.
18+
19+
The backend can be forced with the environment variable
20+
``SSCHA_JULIA_BACKEND`` set to ``juliacall``, ``pyjulia`` or ``none``.
21+
22+
The object returned by :func:`get_main` mimics ``julia.Main`` from PyJulia:
23+
attribute access gives callables, ``eval`` evaluates a code string and
24+
``include`` loads a file. Under juliacall, numpy array arguments are
25+
converted to native Julia ``Array``s (PyJulia semantics), because the
26+
sscha Julia kernels use strictly-typed signatures that do not dispatch
27+
on the no-copy ``PyArray`` wrappers, and array results are converted
28+
back to numpy.
29+
"""
30+
31+
import importlib.util
32+
import os
33+
import sys
34+
import threading
35+
36+
import numpy as np
37+
38+
# The Julia source files defining the sscha kernels, included at first use.
39+
_JL_FILES = ["fourier_gradient.jl"]
40+
41+
_BACKEND_ENV = "SSCHA_JULIA_BACKEND"
42+
43+
_lock = threading.RLock()
44+
_main = None
45+
_init_error = None
46+
47+
48+
class JuliaError(ImportError):
49+
pass
50+
51+
52+
def _requested_backend():
53+
backend = os.environ.get(_BACKEND_ENV, "").strip().lower()
54+
if backend in ("juliacall", "pyjulia", "none"):
55+
return backend
56+
return ""
57+
58+
59+
def available():
60+
"""Check if a Julia backend is installed, WITHOUT booting the runtime.
61+
62+
This is a cheap check (importlib.find_spec). The runtime may still fail
63+
to initialize at first use (e.g. broken installation); in that case
64+
:func:`get_main` raises a JuliaError with the details.
65+
"""
66+
if _main is not None:
67+
return True
68+
if _init_error is not None:
69+
return False
70+
71+
backend = _requested_backend()
72+
if backend == "none":
73+
return False
74+
if backend == "juliacall":
75+
return importlib.util.find_spec("juliacall") is not None
76+
if backend == "pyjulia":
77+
return importlib.util.find_spec("julia") is not None
78+
79+
if "julia.Main" in sys.modules:
80+
return True
81+
return (importlib.util.find_spec("juliacall") is not None
82+
or importlib.util.find_spec("julia") is not None)
83+
84+
85+
def get_main():
86+
"""Return the (lazily initialized) Julia Main proxy.
87+
88+
The first call boots the Julia runtime and includes the sscha Julia
89+
sources; subsequent calls return the cached proxy. Raises JuliaError if
90+
no working backend is available.
91+
"""
92+
global _main, _init_error
93+
94+
if _main is not None:
95+
return _main
96+
97+
with _lock:
98+
if _main is not None:
99+
return _main
100+
if _init_error is not None:
101+
raise JuliaError(
102+
"The Julia extension failed to initialize earlier:\n{}".format(
103+
_init_error))
104+
try:
105+
_main = _initialize()
106+
except Exception as e:
107+
_init_error = "{}: {}".format(type(e).__name__, e)
108+
raise JuliaError(
109+
"Could not initialize the Julia extension.\n"
110+
"Install it with: pip install juliacall\n"
111+
"(Julia itself is downloaded automatically at first use.)\n"
112+
"Original error: {}".format(_init_error))
113+
return _main
114+
115+
116+
def _initialize():
117+
backend = _requested_backend()
118+
if backend == "none":
119+
raise JuliaError(
120+
"The Julia extension is disabled ({}=none).".format(_BACKEND_ENV))
121+
122+
if not backend:
123+
if "julia.Main" in sys.modules:
124+
# PyJulia is already running in this process (e.g. booted by
125+
# python-sscha); a second runtime cannot be created, reuse it.
126+
backend = "pyjulia"
127+
elif importlib.util.find_spec("juliacall") is not None:
128+
backend = "juliacall"
129+
elif importlib.util.find_spec("julia") is not None:
130+
backend = "pyjulia"
131+
else:
132+
raise JuliaError("Neither juliacall nor pyjulia is installed.")
133+
134+
if backend == "juliacall":
135+
main = _init_juliacall()
136+
else:
137+
main = _init_pyjulia()
138+
139+
dirname = os.path.dirname(os.path.abspath(__file__))
140+
for fname in _JL_FILES:
141+
main.include(os.path.join(dirname, fname))
142+
return main
143+
144+
145+
def _init_juliacall():
146+
# These must be set before the first "import juliacall".
147+
# PyJulia honored JULIA_NUM_THREADS and defaulted to a single thread:
148+
# keep exactly that behavior (some kernels use shared buffers inside
149+
# Threads.@threads loops and are NOT safe with more than one thread).
150+
n_threads = os.environ.get("JULIA_NUM_THREADS", "1")
151+
os.environ.setdefault("PYTHON_JULIACALL_THREADS", n_threads)
152+
if os.environ.get("PYTHON_JULIACALL_THREADS", "1") != "1":
153+
# Required for safe Julia multithreading from Python.
154+
os.environ.setdefault("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes")
155+
156+
from juliacall import Main, convert
157+
return _JuliaCallMain(Main, convert)
158+
159+
160+
def _init_pyjulia():
161+
import julia
162+
163+
if "julia.Main" not in sys.modules:
164+
try:
165+
import julia.Main
166+
except Exception:
167+
# Statically linked python or libpython mismatch: PyCall cannot
168+
# use its precompiled cache. This path is slow (it recompiles
169+
# PyCall at every launch): prefer installing juliacall.
170+
from julia.api import Julia
171+
Julia(compiled_modules=False)
172+
import julia.Main
173+
174+
return _PyJuliaMain(sys.modules["julia.Main"])
175+
176+
177+
class _PyJuliaMain(object):
178+
"""PyJulia passthrough: julia.Main already speaks numpy."""
179+
180+
def __init__(self, main):
181+
self._main = main
182+
183+
def include(self, path):
184+
return self._main.include(path)
185+
186+
def eval(self, code):
187+
return self._main.eval(code)
188+
189+
def __getattr__(self, name):
190+
return getattr(self._main, name)
191+
192+
193+
# numpy dtype -> Julia element type, for the nested Vector{Vector{T}} case
194+
_JL_ELTYPE = {
195+
"float32": "Float32",
196+
"float64": "Float64",
197+
"int32": "Int32",
198+
"int64": "Int64",
199+
"complex64": "ComplexF32",
200+
"complex128": "ComplexF64",
201+
"bool": "Bool",
202+
}
203+
204+
205+
class _JuliaCallMain(object):
206+
"""juliacall proxy restoring PyJulia argument/return conventions."""
207+
208+
def __init__(self, main, convert):
209+
# Avoid __getattr__ recursion: set everything through __dict__.
210+
self.__dict__["_main"] = main
211+
self.__dict__["_convert"] = convert
212+
self.__dict__["_array_type"] = main.seval("Array")
213+
self.__dict__["_nested_types"] = {}
214+
215+
def include(self, path):
216+
return self._main.include(path)
217+
218+
def eval(self, code):
219+
return self._from_julia(self._main.seval(code))
220+
221+
def __getattr__(self, name):
222+
func = getattr(self._main, name)
223+
224+
def _call(*args, **kwargs):
225+
jl_args = [self._to_julia(a) for a in args]
226+
jl_kwargs = {k: self._to_julia(v) for k, v in kwargs.items()}
227+
return self._from_julia(func(*jl_args, **jl_kwargs))
228+
229+
_call.__name__ = name
230+
return _call
231+
232+
def _to_julia(self, x):
233+
if isinstance(x, np.ndarray):
234+
return self._convert(self._array_type, x)
235+
236+
# Lists/tuples of 1d arrays with a common dtype are what the
237+
# sparse-symmetry initializers expect as Vector{Vector{T}}.
238+
if (isinstance(x, (list, tuple)) and len(x) > 0
239+
and all(isinstance(e, np.ndarray) and e.ndim == 1 for e in x)):
240+
eltype = _JL_ELTYPE.get(x[0].dtype.name)
241+
if eltype is not None and all(e.dtype == x[0].dtype for e in x):
242+
nested = self._nested_types.get(eltype)
243+
if nested is None:
244+
nested = self._main.seval(
245+
"Vector{{Vector{{{}}}}}".format(eltype))
246+
self._nested_types[eltype] = nested
247+
return self._convert(nested, list(x))
248+
249+
return x
250+
251+
def _from_julia(self, x):
252+
if isinstance(x, tuple):
253+
return tuple(self._from_julia(e) for e in x)
254+
255+
import juliacall
256+
if isinstance(x, juliacall.ArrayValue):
257+
# Buffer-protocol view on the Julia data; numpy keeps the
258+
# wrapper alive through ndarray.base.
259+
return np.asarray(x)
260+
return x

Modules/juliapkg.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"julia": "^1.10"
3+
}

meson.build

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ py.install_sources([
210210
'Modules/Dynamical.py',
211211
'Modules/Ensemble.py',
212212
# 'Modules/fourier_gradient.jl',
213+
'Modules/JuliaExt.py',
213214
'Modules/LocalCluster.py',
214215
'Modules/Minimizer.py',
215216
'Modules/Optimizer.py',
@@ -241,7 +242,7 @@ py.install_sources([
241242
# Create a 'sscha' subdirectory within the Python installation directory
242243
# and copy the .jl files there.
243244
install_data(
244-
['Modules/fourier_gradient.jl'], # List the .jl files you need
245+
['Modules/fourier_gradient.jl', 'Modules/juliapkg.json'], # List the .jl files you need
245246
install_dir : py.get_install_dir() / 'sscha'
246247
)
247248
# If there are many .jl files in multiple subdirectories,

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ dependencies = [
2323
# entries = { scripts = ["scripts/sscha", "scripts/cluster_check.x", ...] }
2424
# However, Meson is better at handling scripts like install_data
2525

26+
[project.optional-dependencies]
27+
# Fast Julia-accelerated fourier gradients. juliacall installs the Julia
28+
# runtime automatically at first use, no further setup is required.
29+
julia = ["juliacall"]
30+
2631
[project.scripts]
2732
sscha-plot-data="sscha.cli:plot_data"
2833
sscha="sscha.cli:main"

0 commit comments

Comments
 (0)