diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fa8219..bdf38e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,28 @@ +- 0.4.11 + - Add pulse properties for off-resonance +- 0.4.9 & 0.4.10 + - Fix ismrmrd export +- 0.4.8 + - New `backend="pulseq_rs"` option on `Sequence.import_file` that parses + `.seq` files through [pulseq-rs](https://github.com/pulseq-frame/pulseq-rs) + via new PyO3 bindings in the `_prepass` extension; default backend + remains `pydisseqt`. Forwards `larmor_hz`, `fov_scale`, `fov_pos`, + `fov_rot`, and `soft_delays` to the pulseq-rs interpreter. + - New `Sequence.get_adc_labels(name)` returns the pulseq label + (`lin`, `par`, `seg`, `slc`, …) for every measured ADC sample as a 1-D + `int32` tensor — useful for reconstruction. + - New `Sequence.get_label_changes(name)` returns `(rep_index, value)` + pairs at each per-repetition label transition, for splitting a + sequence on label state. Raises if a single repetition has ADC + samples with multiple distinct values for the label. + - `Repetition` gained an `adc_labels: dict[str, torch.Tensor]` attribute + populated by the pulseq-rs backend; empty for the pydisseqt backend. +- 0.4.7 + - Bumped minimum Python to 3.10 + - Added `pp14` extra (`pip install mrzerocore[pp14]`) that pins + `pypulseq < 1.5` + - Bugfix: removed a double degrees→radians conversion in the + `mr0_TSE_2D_multi_shot_seq` playground notebook - 0.4.6 - Small bugfixes - restrict pypulseq dependency to < 1.5.0 diff --git a/Cargo.lock b/Cargo.lock index 99ac838..a4e86c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,6 +48,12 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + [[package]] name = "memoffset" version = "0.9.1" @@ -59,9 +65,10 @@ dependencies = [ [[package]] name = "mrzero_core" -version = "0.4.6" +version = "0.4.11" dependencies = [ "num-complex", + "pulseq-rs", "pyo3", ] @@ -120,13 +127,23 @@ checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" [[package]] name = "proc-macro2" -version = "1.0.83" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b33eb56c327dec362a9e55b3ad14f9d2f0904fb5a5b03b513ab5465399e9f43" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" dependencies = [ "unicode-ident", ] +[[package]] +name = "pulseq-rs" +version = "0.2.0" +source = "git+https://github.com/pulseq-frame/pulseq-rs.git#865a890edb1a68fe91f9581fe22539f6e53f9990" +dependencies = [ + "num-complex", + "thiserror", + "winnow", +] + [[package]] name = "pyo3" version = "0.21.2" @@ -223,9 +240,9 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "syn" -version = "2.0.65" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2863d96a84c6439701d7a38f9de935ec562c8832cc55d1dde0f513b52fad106" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", @@ -238,6 +255,26 @@ version = "0.12.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.12" @@ -313,3 +350,12 @@ name = "windows_x86_64_msvc" version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "winnow" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" +dependencies = [ + "memchr", +] diff --git a/Cargo.toml b/Cargo.toml index dade8c7..56c3824 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mrzero_core" -version = "0.4.6" +version = "0.4.11" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -10,3 +10,4 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.21.2", features = ["abi3-py37", "extension-module", "num-complex"] } num-complex = "0.4.6" +pulseq-rs = { git = "https://github.com/pulseq-frame/pulseq-rs.git" } diff --git a/pyproject.toml b/pyproject.toml index 85efeea..9468db0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "mrzerocore" -version = "0.4.6" +version = "0.4.11" description = "Core functionality of MRzero" authors = [ {name = "Jonathan Endres", email = "jonathan.endres@uk-erlangen.de"}, @@ -16,12 +16,12 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", "License :: OSI Approved :: GNU Affero General Public License v3", ] -requires-python = ">=3.9" +requires-python = ">=3.10" dependencies = [ "ismrmrd", "matplotlib>=3.5", "pydisseqt>=0.1.13", - "pypulseq<1.5.0", + "pypulseq<1.5", "requests>=2.20", "scikit-image", "scipy>=1.7", @@ -31,6 +31,11 @@ dependencies = [ "nibabel", ] +# --- OPTIONAL INSTALLS --- +[project.optional-dependencies] +# to select pypulseq 1.4 or <1.5 use pip install mrzerocore[pp14] +pp15 = ["pypulseq"] + [project.urls] Repository = "https://github.com/MRsources/MRzero-Core" Documentation = "https://mrsources.github.io/MRzero-Core/" @@ -41,3 +46,4 @@ profile = "release" strip = true module-name = "MRzeroCore._prepass" python-source = "python" + diff --git a/python/MRzeroCore/sequence.py b/python/MRzeroCore/sequence.py index 5d6078a..1a5bb4e 100644 --- a/python/MRzeroCore/sequence.py +++ b/python/MRzeroCore/sequence.py @@ -1,9 +1,10 @@ from __future__ import annotations from time import time +import warnings import torch import numpy as np from enum import Enum -from typing import Iterable, Optional +from typing import Iterable, Literal, Optional import matplotlib.pyplot as plt # TODO: if everything is working, deprecate old pulseq loader @@ -55,6 +56,16 @@ class Pulse: Flip angle in radians phase : torch.Tensor Pulse phase in radians + pulse_freq: torch.Tensor, to be removed in the future! + pulse frequency omega_1 = angle/duration + freq_offset: torch.Tensor + Frequency offset in Hz + duration: torch.Tensor + pulse duration in seconds + grad: torch.Tensor (dim=3) + gradient during the pulse in Hz/m per channel (x,y,z) + off_ress: bool + Specifies if the pulse should be simulated with the off-resonance treatment shim_array : torch.Tensor Contains B1 mag and phase, used for pTx. 2D tensor([[1, 0]]) for 1Tx. selective : bool @@ -66,6 +77,13 @@ def __init__( usage: PulseUsage, angle: torch.Tensor, phase: torch.Tensor, + + pulse_freq: torch.Tensor, # to be removed in the future + freq_offset: torch.Tensor, + duration: torch.Tensor, + grad: torch.Tensor, + off_res: bool, + shim_array: torch.Tensor, selective: bool, ): @@ -73,6 +91,13 @@ def __init__( self.usage = usage self.angle = angle self.phase = phase + + self.pulse_freq = pulse_freq # to be removed in the future + self.freq_offset = freq_offset + self.duration = duration + self.grad = grad + self.off_res = off_res + self.shim_array = shim_array self.selective = selective @@ -82,6 +107,13 @@ def cpu(self) -> Pulse: self.usage, torch.as_tensor(self.angle, dtype=torch.float32).cpu(), torch.as_tensor(self.phase, dtype=torch.float32).cpu(), + + torch.as_tensor(self.pulse_freq, dtype=torch.float32).cpu(), # to be removed in the future + torch.as_tensor(self.freq_offset, dtype=torch.float32).cpu(), + torch.as_tensor(self.duration, dtype=torch.float32).cpu(), + torch.as_tensor(self.grad, dtype=torch.float32).cpu(), + self.off_res, + torch.as_tensor(self.shim_array, dtype=torch.float32).cpu(), self.selective ) @@ -92,6 +124,13 @@ def cuda(self, device: int | None = None) -> Pulse: self.usage, torch.as_tensor(self.angle, dtype=torch.float32).cuda(device), torch.as_tensor(self.phase, dtype=torch.float32).cuda(device), + + torch.as_tensor(self.pulse_freq, dtype=torch.float32).cuda(device), # to be removed in the future + torch.as_tensor(self.freq_offset, dtype=torch.float32).cuda(device), + torch.as_tensor(self.duration, dtype=torch.float32).cuda(device), + torch.as_tensor(self.grad, dtype=torch.float32).cuda(device), + self.off_res, + torch.as_tensor(self.shim_array, dtype=torch.float32).cuda(device), self.selective ) @@ -108,6 +147,11 @@ def zero(cls): PulseUsage.UNDEF, torch.zeros(1, dtype=torch.float32), torch.zeros(1, dtype=torch.float32), + torch.zeros(1, dtype=torch.float32), # to be removed in the future + torch.zeros(1, dtype=torch.float32), + torch.zeros(1, dtype=torch.float32), + torch.zeros(3, dtype=torch.float32), + False, torch.asarray([[1, 0]], dtype=torch.float32), True ) @@ -118,6 +162,13 @@ def clone(self) -> Pulse: self.usage, self.angle.clone(), self.phase.clone(), + + self.pulse_freq.clone(), # to be removed in the future + self.freq_offset.clone(), + self.duration.clone(), + self.grad.clone(), + self.off_res, + self.shim_array.clone(), self.selective ) @@ -190,6 +241,11 @@ def __init__( self.gradm = gradm self.adc_phase = adc_phase self.adc_usage = adc_usage + # Per-event integer labels (pulseq LABELSET/LABELINC state at the time + # each ADC fires). Populated by the pulseq_rs importer; empty otherwise. + # Each tensor has shape (event_count,); values at non-ADC events are + # meaningless and should be masked by adc_usage > 0. + self.adc_labels: dict[str, torch.Tensor] = {} def cuda(self, device: int | None = None) -> Repetition: """Move this repetition to the specified CUDA device and return it.""" @@ -415,6 +471,96 @@ def get_duration(self) -> float: """Calculate the total duration of self in seconds.""" return sum(rep.event_time.sum().item() for rep in self) + def get_adc_labels(self, name: str) -> torch.Tensor: + """Return the values of a pulseq label for every measured ADC sample. + + The label is read from each event with ``adc_usage > 0`` and the + result is concatenated across all repetitions, giving a 1-D + ``int32`` tensor whose length equals the total number of measured + samples in the sequence. Useful for reconstruction code that needs + the per-sample ``lin`` / ``par`` / ``slc`` / … indices. + + Parameters + ---------- + name : str + Label name as used by pulseq: one of ``"slc"``, ``"seg"``, + ``"rep"``, ``"avg"``, ``"set"``, ``"eco"``, ``"phs"``, ``"lin"``, + ``"par"``, ``"acq"`` (counters), or ``"nav"``, ``"rev"``, + ``"sms"``, ``"ref"``, ``"ima"``, ``"off"``, ``"noise"`` (flags, + returned as ``0`` / ``1``). + + Raises + ------ + KeyError + If the sequence has no labels (i.e. it wasn't imported with the + ``pulseq_rs`` backend) or the requested name doesn't exist. + """ + chunks: list[torch.Tensor] = [] + for rep in self: + if name not in rep.adc_labels: + raise KeyError( + f"label {name!r} not available on this sequence; ensure " + f"it was imported with backend='pulseq_rs'" + ) + mask = rep.adc_usage > 0 + chunks.append(rep.adc_labels[name][mask]) + if not chunks: + return torch.zeros(0, dtype=torch.int32) + return torch.cat(chunks) + + def get_label_changes(self, name: str) -> list[tuple[int, int]]: + """Return repetition-wise changes of a pulseq label. + + Walks the sequence in order, collapses each repetition to the single + value of ``name`` across all of its ADC samples, and emits a + ``(rep_index, new_value)`` pair whenever that value differs from the + previous one. The first repetition that carries an ADC is always + emitted as the baseline. + + Repetitions without any ADC are skipped (they have no opinion on + the label state). If a single repetition contains ADCs with + different values for ``name``, this is treated as illegal — the + intended use case is splitting the sequence on label changes, which + requires one value per repetition. + + Parameters + ---------- + name : str + See :meth:`get_adc_labels`. + + Raises + ------ + ValueError + If any repetition contains multiple distinct values for ``name`` + among its ADC samples. + KeyError + If the label is not available on this sequence. + """ + changes: list[tuple[int, int]] = [] + prev: Optional[int] = None + for i, rep in enumerate(self): + if name not in rep.adc_labels: + raise KeyError( + f"label {name!r} not available on this sequence; ensure " + f"it was imported with backend='pulseq_rs'" + ) + mask = rep.adc_usage > 0 + if not mask.any(): + continue + vals = rep.adc_labels[name][mask] + uniq = torch.unique(vals) + if uniq.numel() != 1: + raise ValueError( + f"label {name!r} has multiple values " + f"{uniq.tolist()} within repetition {i}; cannot collapse " + f"to a single per-rep value" + ) + v = int(uniq.item()) + if v != prev: + changes.append((i, v)) + prev = v + return changes + @classmethod def import_file(cls, file_name: str, exact_trajectories: bool = True, @@ -422,6 +568,12 @@ def import_file(cls, file_name: str, default_shim: torch.Tensor = torch.asarray([[1, 0]], dtype=torch.float32), ref_voltage: float = 300.0, resolution: Optional[int] = None, + backend: Literal["pydisseqt", "pulseq_rs"] = "pydisseqt", + larmor_hz: Optional[float] = None, + fov_scale: Optional[float] = None, + fov_pos: Optional[tuple[float, float, float]] = None, + fov_rot: Optional[tuple[float, float, float, float]] = None, + soft_delays: Optional[dict[str, float]] = None, ) -> Sequence: """Import a pulseq .seq file or a bundle of .dsv files. @@ -446,12 +598,51 @@ def import_file(cls, file_name: str, .dsv files do not contain data for the number of ADC samples. This is used to specify the number of samples per ADC block. If false, uses the .dsv time step as ADC dwell time + backend : "pydisseqt" | "pulseq_rs" + Parser used to read the .seq file. ``"pydisseqt"`` (default) is the + legacy path. ``"pulseq_rs"`` uses the Rust pulseq-rs parser via the + bundled extension; .dsv files always fall back to pydisseqt. + larmor_hz, fov_scale, fov_pos, fov_rot, soft_delays + Forwarded to the pulseq-rs interpreter when ``backend="pulseq_rs"``; + ignored for the pydisseqt backend. Returns ------- mr0.Sequence The imported file as mr0 Sequence """ + if not file_name.endswith(".seq") and backend == "pulseq_rs": + warnings.warn( + "pulseq_rs backend only supports .seq files; falling back " + "to pydisseqt for DSV input.", + stacklevel=2, + ) + backend = "pydisseqt" + + if backend == "pulseq_rs": + return cls._import_pulseq_rs( + file_name, exact_trajectories, print_stats, default_shim, + larmor_hz=larmor_hz, fov_scale=fov_scale, + fov_pos=fov_pos, fov_rot=fov_rot, soft_delays=soft_delays, + ) + elif backend == "pydisseqt": + return cls._import_pydisseqt( + file_name, exact_trajectories, print_stats, default_shim, + ref_voltage, resolution, + ) + else: + raise ValueError( + f"unknown backend {backend!r}; expected 'pydisseqt' or 'pulseq_rs'" + ) + + @classmethod + def _import_pydisseqt(cls, file_name: str, + exact_trajectories: bool, + print_stats: bool, + default_shim: torch.Tensor, + ref_voltage: float, + resolution: Optional[int], + ) -> Sequence: start = time() if file_name.endswith(".seq"): parser = pydisseqt.load_pulseq(file_name) @@ -488,6 +679,10 @@ def pulse_usage(angle: float) -> PulseUsage: # Fetch additional data needed for building the mr0 sequence pulse = parser.integrate_one(pulses[i][0], pulses[i][1]).pulse shim = parser.sample_one(rep_start).pulse.shim + + # load pulse frequency-offset needed for potential treatment off off-resonance + frequency = parser.sample_one(rep_start).pulse.amplitude # this only works for block pulses! should be removed in the future + frequency_offset = parser.sample_one(rep_start).pulse.frequency adcs = parser.events("adc", rep_start, rep_end) @@ -549,17 +744,25 @@ def pulse_usage(angle: float) -> PulseUsage: ) # -- Now we build the mr0 Sequence repetition -- - rep = seq.new_rep(event_count) + rep.event_time[:] = torch.as_tensor(np.diff(abs_times)) + rep.pulse.angle = pulse.angle - rep.pulse.phase = pulse.phase + rep.pulse.phase = pulse.phase + + # provide frequency and frequency-offset to pulse object needed for potential treatment off off-resonance + rep.pulse.pulse_freq = 2*torch.pi * frequency # rad/s # may only work for block-pulses + rep.pulse.freq_offset = frequency_offset # Hz + if rep.pulse.freq_offset != 0: + rep.pulse.off_res = True + rep.pulse.usage = pulse_usage(pulse.angle) if shim is None: rep.pulse.shim_array = default_shim.clone() else: rep.pulse.shim_array = torch.as_tensor(shim) - rep.event_time[:] = torch.as_tensor(np.diff(abs_times)) + #rep.event_time[:] = torch.as_tensor(np.diff(abs_times)) rep.gradm[:, 0] = torch.as_tensor(moments.gradient.x) rep.gradm[:, 1] = torch.as_tensor(moments.gradient.y) @@ -574,6 +777,247 @@ def pulse_usage(angle: float) -> PulseUsage: print(f"Converting the sequence to mr0 took {time() - start} s") return seq + """credits @ Claude""" + @classmethod + def _import_pulseq_rs(cls, file_name: str, + exact_trajectories: bool, + print_stats: bool, + default_shim: torch.Tensor, + *, + larmor_hz: Optional[float] = None, + fov_scale: Optional[float] = None, + fov_pos: Optional[tuple[float, float, float]] = None, + fov_rot: Optional[tuple[float, float, float, float]] = None, + soft_delays: Optional[dict[str, float]] = None, + ) -> Sequence: + from . import _prepass # local import to avoid touching module init order + + start = time() + interp = _prepass.load_pulseq_rs( + file_name, + larmor_hz=larmor_hz, + fov_scale=fov_scale, + fov_pos=fov_pos, + fov_rot=fov_rot, + soft_delays=soft_delays, + ) + if print_stats: + print(f"Importing the .seq file took {time() - start} s") + start = time() + + blocks = list(interp.blocks) + duration = interp.duration + + # Cache shape-times arrays per block so events_axis_in() doesn't keep + # rebuilding them through the FFI. + block_starts = [b.start for b in blocks] + block_ends = [b.start + b.duration for b in blocks] + grad_breakpoints = { # absolute times of every breakpoint per (block, axis) + "x": [None] * len(blocks), + "y": [None] * len(blocks), + "z": [None] * len(blocks), + } + for j, b in enumerate(blocks): + for axis in ("x", "y", "z"): + g = getattr(b, "g" + axis) + if g is None: + continue + t_off = b.start + g.delay + grad_breakpoints[axis][j] = [t_off + t for t in g.shape_times()] + + adc_times_per_block = [None] * len(blocks) + adc_phases_per_block = [None] * len(blocks) + # Labels live at the block (ADC) level, so one dict per ADC block is + # enough — every sample inside that block shares the same snapshot. + adc_labels_per_block: list[Optional[dict[str, int]]] = [None] * len(blocks) + for j, b in enumerate(blocks): + if b.adc is None: + continue + adc_times_per_block[j] = [b.start + t for t in b.adc.sample_times()] + adc_phases_per_block[j] = list(b.adc.sample_phases()) + adc_labels_per_block[j] = b.adc.labels() + + label_names: list[str] = [] + for lbl in adc_labels_per_block: + if lbl is not None: + label_names = list(lbl.keys()) + break + + def pulse_usage(angle: float) -> PulseUsage: + if abs(angle) < 100 * np.pi / 180: + return PulseUsage.EXCIT + else: + return PulseUsage.REFOC + + def events_axis_in(t0: float, t1: float, axis: str) -> list[float]: + out: list[float] = [] + bp = grad_breakpoints[axis] + for j in range(len(blocks)): + if block_starts[j] >= t1 or block_ends[j] <= t0: + continue + times = bp[j] + if times is None: + continue + for t in times: + if t0 <= t <= t1: + out.append(t) + return out + + def adcs_in(t0: float, t1: float) -> tuple[ + list[float], list[float], list[dict[str, int]] + ]: + times_out: list[float] = [] + phases_out: list[float] = [] + labels_out: list[dict[str, int]] = [] + for j in range(len(blocks)): + if block_starts[j] >= t1 or block_ends[j] <= t0: + continue + ts = adc_times_per_block[j] + if ts is None: + continue + ph = adc_phases_per_block[j] + lbl = adc_labels_per_block[j] + assert ph is not None and lbl is not None + for k, t in enumerate(ts): + if t0 <= t <= t1: + times_out.append(t) + phases_out.append(ph[k]) + labels_out.append(lbl) + return times_out, phases_out, labels_out + + def integrate_axis(axis: str, t0: float, t1: float) -> float: + if t1 <= t0: + return 0.0 + moment = 0.0 + for j in range(len(blocks)): + if block_starts[j] >= t1 or block_ends[j] <= t0: + continue + g = getattr(blocks[j], "g" + axis) + if g is None: + continue + lo = max(t0, block_starts[j]) - block_starts[j] + hi = min(t1, block_ends[j]) - block_starts[j] + if hi > lo: + moment += g.integrate(lo, hi) + return moment + + seq = cls(normalized_grads=False) + + # Discover RF pulses: each RF-bearing block produces one (start, end). + pulses: list[tuple[float, float, int]] = [] # (pulse_start, pulse_end, block_idx) + for j, b in enumerate(blocks): + if b.rf is None: + continue + ps = b.start + b.rf.delay + pe = ps + b.rf.shape_duration + pulses.append((ps, pe, j)) + + for i in range(len(pulses)): + ps, pe, b_idx = pulses[i] + rep_start = 0.5 * (ps + pe) + if i + 1 < len(pulses): + rep_end = 0.5 * (pulses[i + 1][0] + pulses[i + 1][1]) + else: + rep_end = duration + + rf = blocks[b_idx].rf + angle, phase = rf.integrate(0.0, rf.shape_duration) + + # Shim handling: pulseq-rs always returns a list, with [(1.0, 0.0)] + # meaning "no shim" - in that case fall back to the default. + shims = rf.shims + if (len(shims) == 1 + and abs(shims[0][0] - 1.0) < 1e-12 + and abs(shims[0][1]) < 1e-12): + shim_arr = default_shim.clone() + else: + shim_arr = torch.as_tensor(shims, dtype=torch.float32) + + adcs, adc_phases, adc_label_snapshots = adcs_in(rep_start, rep_end) + + if exact_trajectories: + first = pe + last = pulses[i + 1][0] if i + 1 < len(pulses) else rep_end + eps = 1e-6 + precision = 6 + + if len(adcs) > 0: + grad_before = sorted({round(t, precision) for t in ( + events_axis_in(first + eps, adcs[0] - eps, "x") + + events_axis_in(first + eps, adcs[0] - eps, "y") + + events_axis_in(first + eps, adcs[0] - eps, "z") + )}) + grad_after = sorted({round(t, precision) for t in ( + events_axis_in(adcs[-1] + eps, last - eps, "x") + + events_axis_in(adcs[-1] + eps, last - eps, "y") + + events_axis_in(adcs[-1] + eps, last - eps, "z") + )}) + if i == len(pulses) - 1: + abs_times = [rep_start, first] + grad_before + adcs + else: + abs_times = ([rep_start, first] + grad_before + adcs + + grad_after + [last, rep_end]) + adc_start = 2 + len(grad_before) - 1 + else: + grad = sorted({round(t, precision) for t in ( + events_axis_in(first + eps, last - eps, "x") + + events_axis_in(first + eps, last - eps, "y") + + events_axis_in(first + eps, last - eps, "z") + )}) + if i == len(pulses) - 1: + abs_times = [rep_start, first] + grad + else: + abs_times = [rep_start, first] + grad + [last, rep_end] + adc_start = None + else: + abs_times = [rep_start] + adcs + [rep_end] + adc_start = 0 + + event_count = len(abs_times) - 1 + + if print_stats: + print( + f"Rep. {i + 1}: {event_count} samples, of which " + f"{len(adcs)} are ADC (starting at {adc_start})" + ) + + mom_x = np.empty(event_count, dtype=np.float64) + mom_y = np.empty(event_count, dtype=np.float64) + mom_z = np.empty(event_count, dtype=np.float64) + for k in range(event_count): + t0 = abs_times[k] + t1 = abs_times[k + 1] + mom_x[k] = integrate_axis("x", t0, t1) + mom_y[k] = integrate_axis("y", t0, t1) + mom_z[k] = integrate_axis("z", t0, t1) + + rep = seq.new_rep(event_count) + rep.pulse.angle = torch.as_tensor(angle) + rep.pulse.phase = torch.as_tensor(phase) + rep.pulse.usage = pulse_usage(angle) + rep.pulse.shim_array = shim_arr + + rep.event_time[:] = torch.as_tensor(np.diff(abs_times)) + rep.gradm[:, 0] = torch.as_tensor(mom_x) + rep.gradm[:, 1] = torch.as_tensor(mom_y) + rep.gradm[:, 2] = torch.as_tensor(mom_z) + + if adc_start is not None: + phases_t = np.pi / 2 - torch.as_tensor(adc_phases) + rep.adc_usage[adc_start:adc_start + len(adcs)] = 1 + rep.adc_phase[adc_start:adc_start + len(adcs)] = phases_t + + if label_names and adc_label_snapshots and adc_start is not None: + for name in label_names: + t = torch.zeros(event_count, dtype=torch.int32) + for k, snap in enumerate(adc_label_snapshots): + t[adc_start + k] = snap[name] + rep.adc_labels[name] = t + + if print_stats: + print(f"Converting the sequence to mr0 took {time() - start} s") + return seq + @classmethod def from_seq_file(cls, file_name: str) -> Sequence: """Import a sequence from a pulseq .seq file. diff --git a/python/MRzeroCore/simulation/main_pass.py b/python/MRzeroCore/simulation/main_pass.py index cc5e0cf..0f39ed3 100644 --- a/python/MRzeroCore/simulation/main_pass.py +++ b/python/MRzeroCore/simulation/main_pass.py @@ -5,6 +5,7 @@ from ..phantom.sim_data import SimData from .pre_pass import Graph import numpy as np +from tqdm.auto import tqdm # NOTE: return encoding and magnetization is currently missing. If we want to @@ -103,9 +104,10 @@ def execute_graph(graph: Graph, graph[0][0].kt_vec = torch.zeros(4, device=data.device) mag_adc = [] - for i, (dists, rep) in enumerate(zip(graph[1:], seq)): + pbar = tqdm(total=len(seq), desc="Calculating repetitions", disable=not print_progress) + for dists, rep in zip(graph[1:], seq): if print_progress: - print(f"\rCalculating repetition {i+1} / {len(seq)}", end='') + pbar.update(1) angle = torch.as_tensor(rep.pulse.angle) phase = torch.as_tensor(rep.pulse.phase) @@ -278,7 +280,8 @@ def calc_mag(ancestor: tuple) -> torch.Tensor: ancestor[1].mag = None if print_progress: - print(" - done") + pbar.close() + print('Done.') if return_mag_adc: return torch.cat(signal), mag_adc diff --git a/python/MRzeroCore/simulation/sig_to_mrd.py b/python/MRzeroCore/simulation/sig_to_mrd.py index 36b2aed..45aec2b 100644 --- a/python/MRzeroCore/simulation/sig_to_mrd.py +++ b/python/MRzeroCore/simulation/sig_to_mrd.py @@ -264,32 +264,33 @@ def s_to_ms(x_in_s): seq_fov = [m_to_mm(val) for val in seq.definitions.get("FOV", [1, 1, 1])] seq_labels = seq.evaluate_labels(evolution="adc") - mrd_enc_params = ismrmrd.xsd.encodingType() - - mrd_enc_params.encodedSpace = ismrmrd.xsd.encodingSpaceType( - matrixSize=ismrmrd.xsd.matrixSizeType( - x=seq_res[0], - y=seq_res[1], - z=seq_res[2], - ), - fieldOfView_mm=ismrmrd.xsd.fieldOfViewMm( - x=seq_fov[0], - y=seq_fov[1], - z=seq_fov[2], - ), - ) - - mrd_enc_params.reconSpace = ismrmrd.xsd.encodingSpaceType( - matrixSize=ismrmrd.xsd.matrixSizeType( - x=seq_res[0], - y=seq_res[1], - z=seq_res[2], + mrd_enc_params = ismrmrd.xsd.encodingType( + encodedSpace=ismrmrd.xsd.encodingSpaceType( + matrixSize=ismrmrd.xsd.matrixSizeType( + x=seq_res[0], + y=seq_res[1], + z=seq_res[2], + ), + fieldOfView_mm=ismrmrd.xsd.fieldOfViewMm( + x=seq_fov[0], + y=seq_fov[1], + z=seq_fov[2], + ), ), - fieldOfView_mm=ismrmrd.xsd.fieldOfViewMm( - x=seq_fov[0], - y=seq_fov[1], - z=seq_fov[2], + reconSpace=ismrmrd.xsd.encodingSpaceType( + matrixSize=ismrmrd.xsd.matrixSizeType( + x=seq_res[0], + y=seq_res[1], + z=seq_res[2], + ), + fieldOfView_mm=ismrmrd.xsd.fieldOfViewMm( + x=seq_fov[0], + y=seq_fov[1], + z=seq_fov[2], + ), ), + encodingLimits=_labels_to_encodinglimits(seq_labels), + trajectory=ismrmrd.xsd.trajectoryType.OTHER ) if verbose > 4: @@ -299,16 +300,17 @@ def s_to_ms(x_in_s): print( f"Wrote encode/recon RES to mrd header with {mrd_enc_params.encodedSpace.fieldOfView_mm}/{mrd_enc_params.reconSpace.fieldOfView_mm}" ) - - mrd_enc_params.encodingLimits = _labels_to_encodinglimits(seq_labels) # The Lamour frequency is a required field in the ISMRMRD header - exp = ismrmrd.xsd.experimentalConditionsType() - exp.H1resonanceFrequency_Hz = int(seq.system.B0 * 42.5764 * 1e6) + exp = ismrmrd.xsd.experimentalConditionsType( + H1resonanceFrequency_Hz=int(seq.system.B0 * 42.5764 * 1e6) + ) mrd_head = ismrmrd.xsd.ismrmrdHeader( experimentalConditions=exp, - measurementInformation=ismrmrd.xsd.measurementInformationType(), + measurementInformation=ismrmrd.xsd.measurementInformationType( + patientPosition=ismrmrd.xsd.patientPositionType.HFS + ), acquisitionSystemInformation=ismrmrd.xsd.acquisitionSystemInformationType(), sequenceParameters=mrd_seq_params, encoding=[mrd_enc_params], diff --git a/src/lib.rs b/src/lib.rs index eb63e41..e006ffd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; use pyo3::types::*; use pyo3::{wrap_pyfunction, PyTraverseError}; mod pre_pass; +mod seq_import; use pre_pass::{comp_graph, Repetition}; use std::{collections::HashMap, slice::from_raw_parts, time::Instant}; @@ -288,5 +289,12 @@ fn _prepass(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(compute_graph, m)?)?; m.add_class::()?; + m.add_function(wrap_pyfunction!(seq_import::load_pulseq_rs, m)?)?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) } diff --git a/src/seq_import.rs b/src/seq_import.rs new file mode 100644 index 0000000..6eeb972 --- /dev/null +++ b/src/seq_import.rs @@ -0,0 +1,513 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use num_complex::Complex64; +use pulseq_rs::int::{self, Quaternion, Transform}; +use pulseq_rs::raw::RfUse; +use pulseq_rs::seq; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::prelude::*; + +// 1H @ 3 T - MR-zero is usually run at 3 T. +const DEFAULT_LARMOR_HZ: f64 = 3.0 * 42_577_468.8; + +/// Interpreted pulseq sequence: a list of blocks with timing in absolute seconds. +/// Each block carries its own start time so Python doesn't have to re-scan. +#[pyclass(module = "_prepass")] +pub struct PyInterpSeq { + #[pyo3(get)] + name: Option, + #[pyo3(get)] + fov: [f64; 3], + #[pyo3(get)] + duration: f64, + #[pyo3(get)] + blocks: Py, +} + +#[pymethods] +impl PyInterpSeq { + fn __len__(&self, py: Python) -> usize { + self.blocks.as_ref(py).len() + } + + fn __repr__(&self) -> String { + format!( + "PyInterpSeq(name={:?}, fov={:?}, duration={} s, n_blocks={})", + self.name, + self.fov, + self.duration, + Python::with_gil(|py| self.blocks.as_ref(py).len()) + ) + } +} + +#[pyclass(module = "_prepass")] +pub struct PyBlock { + #[pyo3(get)] + start: f64, + #[pyo3(get)] + duration: f64, + #[pyo3(get)] + rf: Option>, + #[pyo3(get)] + gx: Option>, + #[pyo3(get)] + gy: Option>, + #[pyo3(get)] + gz: Option>, + #[pyo3(get)] + adc: Option>, +} + +#[pymethods] +impl PyBlock { + fn __repr__(&self) -> String { + let mut parts: Vec<&str> = Vec::new(); + if self.rf.is_some() { + parts.push("rf"); + } + if self.gx.is_some() { + parts.push("gx"); + } + if self.gy.is_some() { + parts.push("gy"); + } + if self.gz.is_some() { + parts.push("gz"); + } + if self.adc.is_some() { + parts.push("adc"); + } + format!( + "PyBlock(start={}, duration={}, events=[{}])", + self.start, + self.duration, + parts.join(", ") + ) + } +} + +#[pyclass(module = "_prepass")] +pub struct PyRf { + #[pyo3(get)] + amp: f64, + #[pyo3(get)] + phase: f64, + #[pyo3(get)] + freq: f64, + #[pyo3(get)] + delay: f64, + #[pyo3(get)] + center: f64, + #[pyo3(get)] + rf_use: &'static str, + #[pyo3(get)] + shape_duration: f64, + /// Per-channel shim weights, each as (magnitude, phase [rad]). A missing + /// shim is exposed as `[(1.0, 0.0)]`, matching pulseq-rs. + #[pyo3(get)] + shims: Vec<(f64, f64)>, + shape: Arc>, +} + +#[pymethods] +impl PyRf { + /// Sparse breakpoint times of the underlying complex shape, in seconds. + fn shape_times(&self) -> Vec { + self.shape.time.clone() + } + + /// Complex shape values at the breakpoints, returned as `(re, im)`. + fn shape_amp(&self) -> (Vec, Vec) { + let n = self.shape.amp.len(); + let mut re = Vec::with_capacity(n); + let mut im = Vec::with_capacity(n); + for c in &self.shape.amp { + re.push(c.re); + im.push(c.im); + } + (re, im) + } + + /// Integrate the pulse over `[t0, t1]` (block-relative seconds). + /// Returns `(flip [rad], phase [rad])`. The shape carries amp×exp(i·phase) + /// in Hz; the flip integral is `2π · amp · ∫|shape| dt` over the window, + /// and the phase is the argument of the complex moment (so this matches + /// `pydisseqt.parser.integrate_one(...).pulse`). + fn integrate(&self, t0: f64, t1: f64) -> (f64, f64) { + let (re, im) = integrate_complex_shape(self.shape.as_ref(), t0, t1); + let mom = Complex64::new(re, im); + let flip = 2.0 * std::f64::consts::PI * self.amp * mom.norm(); + // Add the constant RF phase offset and the optional in-shape phase. + let phase = if mom.norm() > 0.0 { + self.phase + mom.arg() + } else { + self.phase + }; + (flip, phase) + } + + fn __repr__(&self) -> String { + format!( + "PyRf(amp={} Hz, phase={} rad, freq={} Hz, delay={} s, dur={} s, use={}, shims={})", + self.amp, + self.phase, + self.freq, + self.delay, + self.shape_duration, + self.rf_use, + self.shims.len() + ) + } +} + +#[pyclass(module = "_prepass")] +pub struct PyGradient { + /// FOV-scaled gradient amplitude `[Hz/m]`. The shape stores normalised values + /// in `[-1, 1]`; multiply by `amp` to get instantaneous strength. + #[pyo3(get)] + amp: f64, + #[pyo3(get)] + delay: f64, + #[pyo3(get)] + shape_duration: f64, + shape: Arc>, +} + +#[pymethods] +impl PyGradient { + fn shape_times(&self) -> Vec { + self.shape.time.clone() + } + + fn shape_amp(&self) -> Vec { + self.shape.amp.clone() + } + + /// Sample the gradient strength `[Hz/m]` at a block-relative time `t`. + /// Returns 0 outside `[delay, delay + shape_duration]`. + fn sample(&self, t: f64) -> f64 { + let rel = t - self.delay; + if rel < 0.0 || rel > self.shape.duration { + return 0.0; + } + self.amp * self.shape.interpolate(rel) + } + + /// Integrate the gradient moment over `[t0, t1]` (block-relative seconds). + /// Returns moment in `[Hz·s / m] = [cycles / m]`, matching what + /// `pydisseqt.parser.integrate(...).gradient.{x,y,z}` returns. + fn integrate(&self, t0: f64, t1: f64) -> f64 { + // Shift to shape-relative time. + let s0 = t0 - self.delay; + let s1 = t1 - self.delay; + self.amp * integrate_real_shape(self.shape.as_ref(), s0, s1) + } + + fn __repr__(&self) -> String { + format!( + "PyGradient(amp={} Hz/m, delay={} s, dur={} s, n_samples={})", + self.amp, + self.delay, + self.shape_duration, + self.shape.amp.len() + ) + } +} + +#[pyclass(module = "_prepass")] +pub struct PyAdc { + #[pyo3(get)] + num: u32, + #[pyo3(get)] + dwell: f64, + #[pyo3(get)] + delay: f64, + #[pyo3(get)] + freq: f64, + #[pyo3(get)] + phase: f64, + phase_shape: Option>>, + labels: int::Labels, +} + +#[pymethods] +impl PyAdc { + /// Block-relative sample times `[s]`: `delay + (n + 0.5) * dwell`. + fn sample_times(&self) -> Vec { + (0..self.num) + .map(|n| self.delay + (n as f64 + 0.5) * self.dwell) + .collect() + } + + /// Per-sample phases `[rad]`: base `phase` plus the optional in-shape + /// modulation evaluated at each sample's relative time. + fn sample_phases(&self) -> Vec { + let n = self.num as usize; + let mut out = vec![self.phase; n]; + if let Some(ps) = self.phase_shape.as_ref() { + for (i, val) in out.iter_mut().enumerate() { + let t = (i as f64 + 0.5) * self.dwell; + *val += ps.interpolate(t); + } + } + out + } + + /// Snapshot of the pulseq label state at the time this ADC fires. + /// Counters (slc, seg, rep, …) come back as `i32`; boolean flags + /// (nav, rev, …) are returned as `0` / `1` so callers can pack them + /// into a single tensor type. + fn labels(&self) -> HashMap<&'static str, i32> { + let l = self.labels; + let mut m = HashMap::with_capacity(17); + m.insert("slc", l.slc); + m.insert("seg", l.seg); + m.insert("rep", l.rep); + m.insert("avg", l.avg); + m.insert("set", l.set); + m.insert("eco", l.eco); + m.insert("phs", l.phs); + m.insert("lin", l.lin); + m.insert("par", l.par); + m.insert("acq", l.acq); + m.insert("nav", l.nav as i32); + m.insert("rev", l.rev as i32); + m.insert("sms", l.sms as i32); + m.insert("ref", l.ref_ as i32); + m.insert("ima", l.ima as i32); + m.insert("off", l.off as i32); + m.insert("noise", l.noise as i32); + m + } + + fn __repr__(&self) -> String { + format!( + "PyAdc(num={}, dwell={} s, delay={} s, freq={} Hz, phase={} rad)", + self.num, self.dwell, self.delay, self.freq, self.phase + ) + } +} + +#[allow(clippy::too_many_arguments)] +#[pyfunction] +#[pyo3(signature = ( + path, + *, + larmor_hz = None, + fov_scale = None, + fov_pos = None, + fov_rot = None, + soft_delays = None, +))] +pub fn load_pulseq_rs( + py: Python, + path: &str, + larmor_hz: Option, + fov_scale: Option, + fov_pos: Option<[f64; 3]>, + fov_rot: Option<[f64; 4]>, + soft_delays: Option>, +) -> PyResult> { + let seq = seq::Sequence::from_file(path) + .map_err(|e| PyValueError::new_err(format!("pulseq-rs parse error: {e}")))?; + + let transform = Transform { + scale: fov_scale.unwrap_or(1.0), + rotation: Quaternion(fov_rot.unwrap_or([1.0, 0.0, 0.0, 0.0])), + position: fov_pos.unwrap_or([0.0, 0.0, 0.0]), + }; + + let (int_seq, warnings) = int::Sequence::from_seq( + &seq, + transform, + larmor_hz.unwrap_or(DEFAULT_LARMOR_HZ), + soft_delays.unwrap_or_default(), + ) + .map_err(|e| PyValueError::new_err(format!("pulseq-rs interpreter error: {e}")))?; + + let user_warning = py.get_type_bound::(); + for w in warnings { + PyErr::warn_bound(py, &user_warning, &w.to_string(), 0)?; + } + + let blocks = pyo3::types::PyList::empty_bound(py); + let mut t_start = 0.0_f64; + let mut total = 0.0_f64; + + for block in &int_seq.blocks { + let rf = block.rf.as_ref().map(|x| rf_to_py(py, x)).transpose()?; + let gx = block.gx.as_ref().map(|x| grad_to_py(py, x)).transpose()?; + let gy = block.gy.as_ref().map(|x| grad_to_py(py, x)).transpose()?; + let gz = block.gz.as_ref().map(|x| grad_to_py(py, x)).transpose()?; + let adc = block.adc.as_ref().map(|x| adc_to_py(py, x)).transpose()?; + + let py_block = PyBlock { + start: t_start, + duration: block.duration, + rf, + gx, + gy, + gz, + adc, + }; + blocks.append(Py::new(py, py_block)?)?; + t_start += block.duration; + total = t_start; + } + + let seq_obj = PyInterpSeq { + name: int_seq.name.clone(), + fov: int_seq.fov, + duration: total, + blocks: blocks.into(), + }; + Py::new(py, seq_obj) + .map_err(|e| PyRuntimeError::new_err(format!("failed to wrap PyInterpSeq: {e}"))) +} + +fn rf_to_py(py: Python, rf: &int::Rf) -> PyResult> { + let rf_use = match rf.rf_use { + RfUse::Excitation => "excitation", + RfUse::Refocusing => "refocusing", + RfUse::Inversion => "inversion", + RfUse::Saturation => "saturation", + RfUse::Preparation => "preparation", + RfUse::Other => "other", + RfUse::Undefined => "undefined", + }; + let shims: Vec<(f64, f64)> = rf.shims.iter().map(|c| (c.norm(), c.arg())).collect(); + Py::new( + py, + PyRf { + amp: rf.amp, + phase: rf.phase, + freq: rf.freq, + delay: rf.delay, + center: rf.center, + rf_use, + shape_duration: rf.shape.duration, + shims, + shape: rf.shape.clone(), + }, + ) +} + +fn grad_to_py(py: Python, g: &int::Gradient) -> PyResult> { + Py::new( + py, + PyGradient { + amp: g.amp, + delay: g.delay, + shape_duration: g.shape.duration, + shape: g.shape.clone(), + }, + ) +} + +fn adc_to_py(py: Python, adc: &int::Adc) -> PyResult> { + Py::new( + py, + PyAdc { + num: adc.num, + dwell: adc.dwell, + delay: adc.delay, + freq: adc.freq, + phase: adc.phase, + phase_shape: adc.phase_shape.clone(), + labels: adc.labels, + }, + ) +} + +/// Trapezoidal integration of a sparse piecewise-linear real shape over +/// `[t0, t1]` (shape-relative seconds). Values outside the shape's time +/// support are treated as `shape.amp[0]` and `*shape.amp.last()` (matching +/// `Shape::interpolate`). +fn integrate_real_shape(shape: &int::Shape, t0: f64, t1: f64) -> f64 { + if t1 <= t0 { + return 0.0; + } + let t0 = t0.clamp(0.0, shape.duration); + let t1 = t1.clamp(0.0, shape.duration); + if t1 <= t0 { + return 0.0; + } + let times = &shape.time; + let amps = &shape.amp; + let n = times.len(); + if n == 0 { + return 0.0; + } + if n == 1 { + return amps[0] * (t1 - t0); + } + + let v0 = shape.interpolate(t0); + let v1 = shape.interpolate(t1); + + let mut acc = 0.0; + let mut prev_t = t0; + let mut prev_v = v0; + for i in 0..n { + let ti = times[i]; + if ti <= prev_t { + continue; + } + if ti >= t1 { + break; + } + let vi = amps[i]; + acc += 0.5 * (prev_v + vi) * (ti - prev_t); + prev_t = ti; + prev_v = vi; + } + acc += 0.5 * (prev_v + v1) * (t1 - prev_t); + acc +} + +/// Same as `integrate_real_shape`, but for complex shapes. Used to integrate +/// the RF pulse to a complex moment whose magnitude → flip and arg → phase. +fn integrate_complex_shape(shape: &int::Shape, t0: f64, t1: f64) -> (f64, f64) { + if t1 <= t0 { + return (0.0, 0.0); + } + let t0 = t0.clamp(0.0, shape.duration); + let t1 = t1.clamp(0.0, shape.duration); + if t1 <= t0 { + return (0.0, 0.0); + } + let times = &shape.time; + let amps = &shape.amp; + let n = times.len(); + if n == 0 { + return (0.0, 0.0); + } + if n == 1 { + let c = amps[0] * (t1 - t0); + return (c.re, c.im); + } + + let v0 = shape.interpolate(t0); + let v1 = shape.interpolate(t1); + + let mut acc = Complex64::new(0.0, 0.0); + let mut prev_t = t0; + let mut prev_v = v0; + for i in 0..n { + let ti = times[i]; + if ti <= prev_t { + continue; + } + if ti >= t1 { + break; + } + let vi = amps[i]; + acc += (prev_v + vi) * (0.5 * (ti - prev_t)); + prev_t = ti; + prev_v = vi; + } + acc += (prev_v + v1) * (0.5 * (t1 - prev_t)); + (acc.re, acc.im) +}