|
| 1 | +""" |
| 2 | +Lazy bridge to the Julia runtime. |
| 3 | +
|
| 4 | +This module is the single entry point for all the Julia calls of python-sscha. |
| 5 | +The Julia runtime is NOT booted when this module (or sscha) is imported; |
| 6 | +it is initialized on the first call to :func:`get_main`, so that users who do |
| 7 | +not use ``fourier gradient`` never pay the startup cost. |
| 8 | +
|
| 9 | +Two backends are supported: |
| 10 | +
|
| 11 | +* ``juliacall`` (PythonCall.jl) — the default. It has no libpython coupling |
| 12 | + (works with any Python interpreter, no ``python-jl`` needed) and installs |
| 13 | + Julia automatically on a fresh machine through ``juliapkg``. |
| 14 | +* ``pyjulia`` (PyCall.jl) — legacy. It is used only if juliacall is not |
| 15 | + installed, or if another package (e.g. an old python-sscha) has already |
| 16 | + booted the PyJulia runtime in this process: only one Julia runtime can |
| 17 | + exist per process, so in that case we must reuse it. |
| 18 | +
|
| 19 | +The backend can be forced with the environment variable |
| 20 | +``SSCHA_JULIA_BACKEND`` set to ``juliacall``, ``pyjulia`` or ``none``. |
| 21 | +
|
| 22 | +The object returned by :func:`get_main` mimics ``julia.Main`` from PyJulia: |
| 23 | +attribute access gives callables, ``eval`` evaluates a code string and |
| 24 | +``include`` loads a file. Under juliacall, numpy array arguments are |
| 25 | +converted to native Julia ``Array``s (PyJulia semantics), because the |
| 26 | +sscha Julia kernels use strictly-typed signatures that do not dispatch |
| 27 | +on the no-copy ``PyArray`` wrappers, and array results are converted |
| 28 | +back to numpy. |
| 29 | +""" |
| 30 | + |
| 31 | +import importlib.util |
| 32 | +import os |
| 33 | +import sys |
| 34 | +import threading |
| 35 | + |
| 36 | +import numpy as np |
| 37 | + |
| 38 | +# The Julia source files defining the sscha kernels, included at first use. |
| 39 | +_JL_FILES = ["fourier_gradient.jl"] |
| 40 | + |
| 41 | +_BACKEND_ENV = "SSCHA_JULIA_BACKEND" |
| 42 | + |
| 43 | +_lock = threading.RLock() |
| 44 | +_main = None |
| 45 | +_init_error = None |
| 46 | + |
| 47 | + |
| 48 | +class JuliaError(ImportError): |
| 49 | + pass |
| 50 | + |
| 51 | + |
| 52 | +def _requested_backend(): |
| 53 | + backend = os.environ.get(_BACKEND_ENV, "").strip().lower() |
| 54 | + if backend in ("juliacall", "pyjulia", "none"): |
| 55 | + return backend |
| 56 | + return "" |
| 57 | + |
| 58 | + |
| 59 | +def available(): |
| 60 | + """Check if a Julia backend is installed, WITHOUT booting the runtime. |
| 61 | +
|
| 62 | + This is a cheap check (importlib.find_spec). The runtime may still fail |
| 63 | + to initialize at first use (e.g. broken installation); in that case |
| 64 | + :func:`get_main` raises a JuliaError with the details. |
| 65 | + """ |
| 66 | + if _main is not None: |
| 67 | + return True |
| 68 | + if _init_error is not None: |
| 69 | + return False |
| 70 | + |
| 71 | + backend = _requested_backend() |
| 72 | + if backend == "none": |
| 73 | + return False |
| 74 | + if backend == "juliacall": |
| 75 | + return importlib.util.find_spec("juliacall") is not None |
| 76 | + if backend == "pyjulia": |
| 77 | + return importlib.util.find_spec("julia") is not None |
| 78 | + |
| 79 | + if "julia.Main" in sys.modules: |
| 80 | + return True |
| 81 | + return (importlib.util.find_spec("juliacall") is not None |
| 82 | + or importlib.util.find_spec("julia") is not None) |
| 83 | + |
| 84 | + |
| 85 | +def get_main(): |
| 86 | + """Return the (lazily initialized) Julia Main proxy. |
| 87 | +
|
| 88 | + The first call boots the Julia runtime and includes the sscha Julia |
| 89 | + sources; subsequent calls return the cached proxy. Raises JuliaError if |
| 90 | + no working backend is available. |
| 91 | + """ |
| 92 | + global _main, _init_error |
| 93 | + |
| 94 | + if _main is not None: |
| 95 | + return _main |
| 96 | + |
| 97 | + with _lock: |
| 98 | + if _main is not None: |
| 99 | + return _main |
| 100 | + if _init_error is not None: |
| 101 | + raise JuliaError( |
| 102 | + "The Julia extension failed to initialize earlier:\n{}".format( |
| 103 | + _init_error)) |
| 104 | + try: |
| 105 | + _main = _initialize() |
| 106 | + except Exception as e: |
| 107 | + _init_error = "{}: {}".format(type(e).__name__, e) |
| 108 | + raise JuliaError( |
| 109 | + "Could not initialize the Julia extension.\n" |
| 110 | + "Install it with: pip install juliacall\n" |
| 111 | + "(Julia itself is downloaded automatically at first use.)\n" |
| 112 | + "Original error: {}".format(_init_error)) |
| 113 | + return _main |
| 114 | + |
| 115 | + |
| 116 | +def _initialize(): |
| 117 | + backend = _requested_backend() |
| 118 | + if backend == "none": |
| 119 | + raise JuliaError( |
| 120 | + "The Julia extension is disabled ({}=none).".format(_BACKEND_ENV)) |
| 121 | + |
| 122 | + if not backend: |
| 123 | + if "julia.Main" in sys.modules: |
| 124 | + # PyJulia is already running in this process (e.g. booted by |
| 125 | + # python-sscha); a second runtime cannot be created, reuse it. |
| 126 | + backend = "pyjulia" |
| 127 | + elif importlib.util.find_spec("juliacall") is not None: |
| 128 | + backend = "juliacall" |
| 129 | + elif importlib.util.find_spec("julia") is not None: |
| 130 | + backend = "pyjulia" |
| 131 | + else: |
| 132 | + raise JuliaError("Neither juliacall nor pyjulia is installed.") |
| 133 | + |
| 134 | + if backend == "juliacall": |
| 135 | + main = _init_juliacall() |
| 136 | + else: |
| 137 | + main = _init_pyjulia() |
| 138 | + |
| 139 | + dirname = os.path.dirname(os.path.abspath(__file__)) |
| 140 | + for fname in _JL_FILES: |
| 141 | + main.include(os.path.join(dirname, fname)) |
| 142 | + return main |
| 143 | + |
| 144 | + |
| 145 | +def _init_juliacall(): |
| 146 | + # These must be set before the first "import juliacall". |
| 147 | + # PyJulia honored JULIA_NUM_THREADS and defaulted to a single thread: |
| 148 | + # keep exactly that behavior (some kernels use shared buffers inside |
| 149 | + # Threads.@threads loops and are NOT safe with more than one thread). |
| 150 | + n_threads = os.environ.get("JULIA_NUM_THREADS", "1") |
| 151 | + os.environ.setdefault("PYTHON_JULIACALL_THREADS", n_threads) |
| 152 | + if os.environ.get("PYTHON_JULIACALL_THREADS", "1") != "1": |
| 153 | + # Required for safe Julia multithreading from Python. |
| 154 | + os.environ.setdefault("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes") |
| 155 | + |
| 156 | + from juliacall import Main, convert |
| 157 | + return _JuliaCallMain(Main, convert) |
| 158 | + |
| 159 | + |
| 160 | +def _init_pyjulia(): |
| 161 | + import julia |
| 162 | + |
| 163 | + if "julia.Main" not in sys.modules: |
| 164 | + try: |
| 165 | + import julia.Main |
| 166 | + except Exception: |
| 167 | + # Statically linked python or libpython mismatch: PyCall cannot |
| 168 | + # use its precompiled cache. This path is slow (it recompiles |
| 169 | + # PyCall at every launch): prefer installing juliacall. |
| 170 | + from julia.api import Julia |
| 171 | + Julia(compiled_modules=False) |
| 172 | + import julia.Main |
| 173 | + |
| 174 | + return _PyJuliaMain(sys.modules["julia.Main"]) |
| 175 | + |
| 176 | + |
| 177 | +class _PyJuliaMain(object): |
| 178 | + """PyJulia passthrough: julia.Main already speaks numpy.""" |
| 179 | + |
| 180 | + def __init__(self, main): |
| 181 | + self._main = main |
| 182 | + |
| 183 | + def include(self, path): |
| 184 | + return self._main.include(path) |
| 185 | + |
| 186 | + def eval(self, code): |
| 187 | + return self._main.eval(code) |
| 188 | + |
| 189 | + def __getattr__(self, name): |
| 190 | + return getattr(self._main, name) |
| 191 | + |
| 192 | + |
| 193 | +# numpy dtype -> Julia element type, for the nested Vector{Vector{T}} case |
| 194 | +_JL_ELTYPE = { |
| 195 | + "float32": "Float32", |
| 196 | + "float64": "Float64", |
| 197 | + "int32": "Int32", |
| 198 | + "int64": "Int64", |
| 199 | + "complex64": "ComplexF32", |
| 200 | + "complex128": "ComplexF64", |
| 201 | + "bool": "Bool", |
| 202 | +} |
| 203 | + |
| 204 | + |
| 205 | +class _JuliaCallMain(object): |
| 206 | + """juliacall proxy restoring PyJulia argument/return conventions.""" |
| 207 | + |
| 208 | + def __init__(self, main, convert): |
| 209 | + # Avoid __getattr__ recursion: set everything through __dict__. |
| 210 | + self.__dict__["_main"] = main |
| 211 | + self.__dict__["_convert"] = convert |
| 212 | + self.__dict__["_array_type"] = main.seval("Array") |
| 213 | + self.__dict__["_nested_types"] = {} |
| 214 | + |
| 215 | + def include(self, path): |
| 216 | + return self._main.include(path) |
| 217 | + |
| 218 | + def eval(self, code): |
| 219 | + return self._from_julia(self._main.seval(code)) |
| 220 | + |
| 221 | + def __getattr__(self, name): |
| 222 | + func = getattr(self._main, name) |
| 223 | + |
| 224 | + def _call(*args, **kwargs): |
| 225 | + jl_args = [self._to_julia(a) for a in args] |
| 226 | + jl_kwargs = {k: self._to_julia(v) for k, v in kwargs.items()} |
| 227 | + return self._from_julia(func(*jl_args, **jl_kwargs)) |
| 228 | + |
| 229 | + _call.__name__ = name |
| 230 | + return _call |
| 231 | + |
| 232 | + def _to_julia(self, x): |
| 233 | + if isinstance(x, np.ndarray): |
| 234 | + return self._convert(self._array_type, x) |
| 235 | + |
| 236 | + # Lists/tuples of 1d arrays with a common dtype are what the |
| 237 | + # sparse-symmetry initializers expect as Vector{Vector{T}}. |
| 238 | + if (isinstance(x, (list, tuple)) and len(x) > 0 |
| 239 | + and all(isinstance(e, np.ndarray) and e.ndim == 1 for e in x)): |
| 240 | + eltype = _JL_ELTYPE.get(x[0].dtype.name) |
| 241 | + if eltype is not None and all(e.dtype == x[0].dtype for e in x): |
| 242 | + nested = self._nested_types.get(eltype) |
| 243 | + if nested is None: |
| 244 | + nested = self._main.seval( |
| 245 | + "Vector{{Vector{{{}}}}}".format(eltype)) |
| 246 | + self._nested_types[eltype] = nested |
| 247 | + return self._convert(nested, list(x)) |
| 248 | + |
| 249 | + return x |
| 250 | + |
| 251 | + def _from_julia(self, x): |
| 252 | + if isinstance(x, tuple): |
| 253 | + return tuple(self._from_julia(e) for e in x) |
| 254 | + |
| 255 | + import juliacall |
| 256 | + if isinstance(x, juliacall.ArrayValue): |
| 257 | + # Buffer-protocol view on the Julia data; numpy keeps the |
| 258 | + # wrapper alive through ndarray.base. |
| 259 | + return np.asarray(x) |
| 260 | + return x |
0 commit comments