diff --git a/python/MRzeroCore/__init__.py b/python/MRzeroCore/__init__.py index 110aa1b..4383f43 100644 --- a/python/MRzeroCore/__init__.py +++ b/python/MRzeroCore/__init__.py @@ -11,7 +11,7 @@ from .phantom.custom_voxel_phantom import CustomVoxelPhantom from .phantom.sim_data import SimData from .phantom.brainweb import generate_brainweb_phantoms -from .phantom.nifti_phantom import NiftiPhantom, NiftiTissue, NiftiRef, NiftiMapping +from .phantom.nifti_phantom import NiftiPhantom, NiftiTissue, NiftiRef, NiftiMapping, ResliceConfig from .phantom.tissue_dict import TissueDict from .simulation.isochromat_sim import isochromat_sim from .simulation.pre_pass import compute_graph, compute_graph_ext, Graph diff --git a/python/MRzeroCore/phantom/custom_voxel_phantom.py b/python/MRzeroCore/phantom/custom_voxel_phantom.py index 3737f4a..65bc84a 100644 --- a/python/MRzeroCore/phantom/custom_voxel_phantom.py +++ b/python/MRzeroCore/phantom/custom_voxel_phantom.py @@ -129,6 +129,8 @@ def build(self) -> SimData: # TODO: until the dephasing func fix is here, this only works on the # device self.voxel_size happens to be on size = self.voxel_pos.max(0).values - self.voxel_pos.min(0).values + + affine = torch.eye(3,4) return SimData( self.PD, @@ -140,6 +142,7 @@ def build(self) -> SimData: self.B1[None, :], torch.ones(1, self.PD.numel()), size, + affine, self.voxel_pos, torch.tensor([float('inf'), float('inf'), float('inf')]), build_dephasing_func(self.voxel_shape, self.voxel_size), diff --git a/python/MRzeroCore/phantom/nifti_phantom.py b/python/MRzeroCore/phantom/nifti_phantom.py index 8b835b7..e80be37 100644 --- a/python/MRzeroCore/phantom/nifti_phantom.py +++ b/python/MRzeroCore/phantom/nifti_phantom.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Literal, Any from pathlib import Path @@ -65,6 +65,20 @@ def to_dict(self) -> dict[str, float]: return {"gyro": self.gyro, "B0": self.B0} +@dataclass +class ResliceConfig: + """Target resampling grid declared inside the phantom JSON.""" + resolution: list # [nx, ny, nz] — output voxel count + affine: list # 3×4 NIfTI-style affine in mm (rotation×voxel_size | origin) + + @classmethod + def from_dict(cls, config: dict): + return cls(resolution=config["resolution"], affine=config["affine"]) + + def to_dict(self) -> dict: + return {"resolution": self.resolution, "affine": self.affine} + + @dataclass class NiftiRef: file_name: Path @@ -162,6 +176,7 @@ class NiftiPhantom: units: PhantomUnits system: PhantomSystem tissues: dict[str, NiftiTissue] + reslice_to: ResliceConfig | None = field(default=None) @classmethod def default(cls, gyro=42.5764, B0=3.0): @@ -177,28 +192,37 @@ def load(cls, path: Path | str): def save(self, path: Path | str): import json + import re import os path = Path(path) os.makedirs(path.parent, exist_ok=True) + text = json.dumps(self.to_dict(), indent=2) + # Compact arrays of numbers onto a single line + text = re.sub( + r'\[(\n\s+[-\d.eE+]+,?)+\n\s*\]', + lambda m: '[' + ', '.join(re.findall(r'[-\d.eE+]+', m.group(0))) + ']', + text + ) with open(path, "w") as f: - json.dump(self.to_dict(), f, indent=2) + f.write(text) @classmethod def from_dict(cls, config: dict): - assert config["file_type"] == "nifti_phantom_v1" units = PhantomUnits.from_dict(config["units"]) system = PhantomSystem.from_dict(config["system"]) tissues = { name: NiftiTissue.from_dict(tissue) for name, tissue in config["tissues"].items() } + reslice_to = (ResliceConfig.from_dict(config["reslice_to"]) + if "reslice_to" in config else None) - return cls(units, system, tissues) + return cls(units, system, tissues, reslice_to) def to_dict(self) -> dict: - return { + d = { "file_type": self.file_type, "units": self.units.to_dict(), "system": self.system.to_dict(), @@ -206,3 +230,6 @@ def to_dict(self) -> dict: name: tissue.to_dict() for name, tissue in self.tissues.items() }, } + if self.reslice_to is not None: + d["reslice_to"] = self.reslice_to.to_dict() + return d diff --git a/python/MRzeroCore/phantom/sim_data.py b/python/MRzeroCore/phantom/sim_data.py index 8e1a38d..eb23b5e 100644 --- a/python/MRzeroCore/phantom/sim_data.py +++ b/python/MRzeroCore/phantom/sim_data.py @@ -35,6 +35,8 @@ class are nothing but the data needed for simulation, so it can describe size : torch.Tensor Physical size of the phantom. If a sequence with normalized gradients is simulated, size is used to scale them to match the phantom. + affine : torch.Tensor + Affine matrix of the phantom data, in millimeters. avg_B1_trig : torch.Tensor (361, 3) values containing the PD-weighted avg of sin/cos/sin²(B1*flip) voxel_pos : torch.Tensor @@ -61,6 +63,7 @@ def __init__( B1: torch.Tensor, coil_sens: torch.Tensor, size: torch.Tensor, + affine: torch.Tensor, voxel_pos: torch.Tensor, nyquist: torch.Tensor, dephasing_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], @@ -100,6 +103,7 @@ def __init__( self.tissue_masks = {} self.coil_sens = coil_sens.clone() self.size = size.clone() + self.affine = affine.clone() self.voxel_pos = voxel_pos.clone() self.avg_B1_trig = calc_avg_B1_trig(B1, PD) self.nyquist = nyquist.clone() @@ -125,6 +129,7 @@ def cuda(self) -> SimData: self.B1.cuda(), self.coil_sens.cuda(), self.size.cuda(), + self.affine.cuda(), self.voxel_pos.cuda(), self.nyquist.cuda(), self.dephasing_func, @@ -152,6 +157,7 @@ def cpu(self) -> SimData: self.B1.cpu(), self.coil_sens.cpu(), self.size.cpu(), + self.affine.cpu(), self.voxel_pos.cpu(), self.nyquist.cpu(), self.dephasing_func, diff --git a/python/MRzeroCore/phantom/tissue_dict.py b/python/MRzeroCore/phantom/tissue_dict.py index 0318006..214f750 100644 --- a/python/MRzeroCore/phantom/tissue_dict.py +++ b/python/MRzeroCore/phantom/tissue_dict.py @@ -1,10 +1,11 @@ from .voxel_grid_phantom import VoxelGridPhantom from .sim_data import SimData -from .nifti_phantom import NiftiPhantom, NiftiTissue, NiftiRef, NiftiMapping +from .nifti_phantom import NiftiPhantom, NiftiTissue, NiftiRef, NiftiMapping, ResliceConfig from pathlib import Path import torch import numpy as np -from typing import Literal +import matplotlib.pyplot as plt +from typing import Literal, Self from functools import lru_cache @@ -40,7 +41,7 @@ def load(cls, path: Path | str, config: NiftiPhantom | None = None): # units are supported (conversion factor 1); this might change in the future return TissueDict({ - name: load_tissue(tissue, base_dir) + name: load_tissue(tissue, base_dir, reslice=config.reslice_to) for name, tissue in config.tissues.items() }) @@ -106,20 +107,13 @@ def set(value): tissues = {tissue: save_tissue(self[tissue]) for tissue in self.keys()} # Write the NIfTIs - size = np.asarray(next(iter(self.values())).size) - vs = 1000 * size / np.asarray(density[0].shape) - affine = np.array( - [ - [+vs[0], 0, 0, -size[0] / 2 * 1000], - [0, +vs[1], 0, -size[1] / 2 * 1000], - [0, 0, +vs[2], -size[2] / 2 * 1000], - [0, 0, 0, 0], # Row ignored - ] - ) + affine = np.asarray(next(iter(self.values())).affine) + if affine.shape == (3, 4): + affine = np.vstack([affine, [0, 0, 0, 1]], dtype=np.float32) def save_nifti(prop, name): if len(prop) > 0: - ext = f"-{name}" if name != "density" else "" + ext = f"_{name}" if name != "density" else "" file_name = base_dir / f"{base_name}{ext}.nii.gz" data = np.stack(prop, -1) @@ -135,8 +129,14 @@ def save_nifti(prop, name): save_nifti(B1_tx, "B1+") save_nifti(B1_rx, "B1-") + reslice_to = { + "resolution": list(next(iter(self.values())).PD.shape), + "affine": affine[:3, :4].tolist() + } + config = NiftiPhantom.default(gyro, B0) config.tissues = tissues + config.reslice_to = ResliceConfig.from_dict(reslice_to) config.save(path_to_json) def interpolate(self, x: int, y: int, z: int): @@ -165,6 +165,7 @@ def combine(self) -> VoxelGridPhantom: combined.D = sum(seg * p.D for seg, p in zip(segmentation, phantoms)) combined.B0 = sum(seg * p.B0 for seg, p in zip(segmentation, phantoms)) combined.B1 = sum(seg[None, ...] * p.B1 for seg, p in zip(segmentation, phantoms)) + combined.coil_sens = sum(seg[None, ...] * p.coil_sens for seg, p in zip(segmentation, phantoms)) return combined @@ -184,38 +185,142 @@ def build(self, PD_threshold: float = 1e-6, "coil_sens": torch.cat([obj.coil_sens for obj in data_list], 1), "voxel_pos": torch.cat([obj.voxel_pos for obj in data_list], 0), "size": data_list[0].size, + "affine": data_list[0].affine, "nyquist": data_list[0].nyquist, "dephasing_func": data_list[0].dephasing_func, + "recover_func": lambda data: self.recover(data), + "tissue_masks": dict(zip(self.keys(), [obj.tissue_masks["combined"] for obj in data_list])) } - # Only add tissue_masks if any object has it non-empty - if any(obj.tissue_masks for obj in data_list): - kwargs["tissue_masks"] = torch.stack([obj.tissue_masks for obj in data_list]) - return SimData(**kwargs) + def recover(self, sim_data: SimData) -> Self: + """Provided to :class:`SimData` to reverse the ``build()``""" + + assert sim_data.tissue_masks is not None + + tissues = list(sim_data.tissue_masks.keys()) + tissue_begin = 0 # first tissue starts at index 0 in sparse tensors + + def to_full(sparse): + assert sparse.ndim < 3 + + if sparse.ndim == 2: + full = torch.zeros( + [sparse.shape[0], *mask.shape], dtype=sparse.dtype, device=mask.device) + full[:, mask] = sparse + else: + full = torch.zeros(mask.shape, device=mask.device) + full[mask] = sparse + return full + + data_list = [] + for tissue in tissues: + mask = sim_data.tissue_masks[tissue] + mask = mask.to(sim_data.device) + + tissue_end = tissue_begin + torch.sum(mask) # index of tissue end in sparse tensors + + data_list.append(VoxelGridPhantom( + to_full(sim_data.PD[tissue_begin:tissue_end]), + to_full(sim_data.T1[tissue_begin:tissue_end]), + to_full(sim_data.T2[tissue_begin:tissue_end]), + to_full(sim_data.T2dash[tissue_begin:tissue_end]), + to_full(sim_data.D[tissue_begin:tissue_end]), + to_full(sim_data.B0[tissue_begin:tissue_end]), + to_full(sim_data.B1[:, tissue_begin:tissue_end]), + to_full(sim_data.coil_sens[:, tissue_begin:tissue_end]), + sim_data.size, + sim_data.affine, + ) + ) + tissue_begin = tissue_end # next tissue in sparse tensors starts where last ended + + return TissueDict(dict(zip(tissues, data_list))) + + def plot(self, tissue="all", plot_masks=False, plot_slice="center", time_unit='s') -> None: + """ + Plots the individual tissues of the PhantomDict + + Parameters + ---------- + tissue : str, default="all" + Specifies which tissue(s) to plot. + - ``"all"`` : Plot all tissues stored in the PhantomDict, one after another. + - ``"combined"`` : Plot a combined phantom created from all tissues using :meth:`combine`. + - any other string is interpreted as a key identifying a single tissue stored in the PhantomDict. + plot_masks : bool + Plot tissue masks (assumes they exist) + slice : str | int + If int, the specified slice is plotted. "center" plots the center + slice and "all" plots all slices as a grid. + time_unit : str + Time unit to use for T1, T2, and T2' maps (default: 's'). Supported 's' and 'ms'. + """ + + if tissue == "all": + print("Plot combined phatom") + self.combine().plot(plot_masks, plot_slice, time_unit) + + fignum = max(plt.get_fignums(), default=1) + + for name, t in self.items(): + print("Plot tissue: ", name) + t.plot(plot_masks, plot_slice, time_unit, f"Figure {fignum} - {name}") + + elif tissue == "combined": + print("Plot combined tissue phantom") + self.combine().plot(plot_masks, plot_slice, time_unit) + + else: + print("Plot tissue: ", tissue) + self[tissue].plot(plot_masks, plot_slice, time_unit) # ============================ # Helpers for importing NIfTIs # ============================ -def load_tissue(config: NiftiTissue, base_dir: Path) -> VoxelGridPhantom: - density, affine = load_file_ref(base_dir, config.density) - size = np.abs(density.shape @ affine[:3, :3]) / 1000 # affine is in mm +def load_tissue(config: NiftiTissue, base_dir: Path, + reslice: ResliceConfig | None = None) -> VoxelGridPhantom: + density, nifti_affine = load_file_ref(base_dir, config.density) def lp(cfg): - return torch.as_tensor(load_property(cfg, base_dir, density, affine)) + return load_property(cfg, base_dir, density, nifti_affine) + + T1 = lp(config.T1) + T2 = lp(config.T2) + T2dash = lp(config.T2dash) + ADC = lp(config.ADC) + B0 = lp(config.dB0) + B1 = [lp(cfg) for cfg in config.B1_tx] + coil = [lp(cfg) for cfg in config.B1_rx] + + if reslice is None: + target_shape = density.shape + aff_mm = nifti_affine[:3, :] + else: + target_shape = tuple(reslice.resolution) + aff_mm = np.array(reslice.affine, dtype=float) + + def rs(arr): + return _resample_nifti(arr, nifti_affine, target_shape, aff_mm) + size = target_shape * np.linalg.norm(aff_mm[:3, :3], axis=0) /1000 # np.abs(target_shape @ aff_mm[:3, :3]) / 1000 + density = rs(density) + T1, T2, T2dash, ADC, B0 = rs(T1), rs(T2), rs(T2dash), rs(ADC), rs(B0) + B1 = [rs(b) for b in B1] + coil = [rs(c) for c in coil] return VoxelGridPhantom( PD=torch.as_tensor(density), size=torch.as_tensor(size), - T1=lp(config.T1), - T2=lp(config.T2), - T2dash=lp(config.T2dash), - D=lp(config.ADC), - B0=lp(config.dB0), - B1=torch.stack([lp(cfg) for cfg in config.B1_tx], 0), - coil_sens=torch.stack([lp(cfg) for cfg in config.B1_rx], 0), + T1=torch.as_tensor(T1), + T2=torch.as_tensor(T2), + T2dash=torch.as_tensor(T2dash), + D=torch.as_tensor(ADC), + B0=torch.as_tensor(B0), + B1=torch.stack([torch.as_tensor(b) for b in B1], 0), + coil_sens=torch.stack([torch.as_tensor(c) for c in coil], 0), + affine=torch.as_tensor(aff_mm), ) @@ -273,3 +378,38 @@ def _load_cached(file_name): import nibabel img = nibabel.loadsave.load(file_name) return np.asarray(img.dataobj), img.get_sform() + + +def _resample_nifti(data: np.ndarray, nifti_affine: np.ndarray, + target_shape: tuple, + target_affine_mm: np.ndarray) -> np.ndarray: + """Resample a 3D array onto a target grid via trilinear interpolation. + + Parameters + ---------- + data: + Source 3D numpy array (native NIfTI voxel space). + nifti_affine: + 4×4 sform affine of the source NIfTI (mm units). + target_shape: + Output shape ``(nx, ny, nz)``. + target_affine_mm: + 3×4 or 4×4 NIfTI-style affine of the target grid in mm. + Maps target voxel ``[i, j, k]`` to physical coordinates in mm. + """ + from scipy.ndimage import affine_transform + + A = np.array(target_affine_mm, dtype=float) + A_rot = A[:3, :3] + A_trans = A[:3, 3] + + A_nifti_inv = np.linalg.inv(nifti_affine[:3, :3]) + M = A_nifti_inv @ A_rot + o = A_nifti_inv @ (A_trans - nifti_affine[:3, 3]) + + kwargs = dict(output_shape=tuple(target_shape), order=1, + mode='constant', cval=0.0) + if np.iscomplexobj(data): + return (affine_transform(data.real, M, offset=o, **kwargs) + + 1j * affine_transform(data.imag, M, offset=o, **kwargs)) + return affine_transform(data.astype(float), M, offset=o, **kwargs) diff --git a/python/MRzeroCore/phantom/voxel_grid_phantom.py b/python/MRzeroCore/phantom/voxel_grid_phantom.py index 9e97764..2d52916 100644 --- a/python/MRzeroCore/phantom/voxel_grid_phantom.py +++ b/python/MRzeroCore/phantom/voxel_grid_phantom.py @@ -84,6 +84,8 @@ class VoxelGridPhantom: (coil_count, sx, sy, sz) tensor of coil sensitivities size : torch.Tensor Size of the data, in meters. + affine : torch.Tensor + Affine matrix of the phantom data, in millimeters. tissue_masks : Dict[str, torch.Tensor] | None Segmentation masks for different tissues. The keys are the tissue names """ @@ -99,6 +101,7 @@ def __init__( B1: torch.Tensor, coil_sens: torch.Tensor, size: torch.Tensor, + affine: torch.Tensor, phantom_motion=None, voxel_motion=None, tissue_masks: Optional[Dict[str,torch.Tensor]] = None, @@ -120,6 +123,7 @@ def __init__( self.tissue_masks = {} self.coil_sens = torch.as_tensor(coil_sens, dtype=torch.complex64) self.size = torch.as_tensor(size, dtype=torch.float32) + self.affine = torch.as_tensor(affine, dtype=torch.float32) self.phantom_motion = phantom_motion self.voxel_motion = voxel_motion @@ -138,23 +142,24 @@ def build(self, PD_threshold: float = 1e-6, shape = torch.tensor(mask.shape) pos_x, pos_y, pos_z = torch.meshgrid( - self.size[0] * - torch.fft.fftshift(torch.fft.fftfreq( - int(shape[0]), device=self.PD.device)), - self.size[1] * - torch.fft.fftshift(torch.fft.fftfreq( - int(shape[1]), device=self.PD.device)), - self.size[2] * - torch.fft.fftshift(torch.fft.fftfreq( - int(shape[2]), device=self.PD.device)), + torch.arange( + int(shape[0]), dtype=torch.float32, device=self.PD.device), + torch.arange( + int(shape[1]), dtype=torch.float32, device=self.PD.device), + torch.arange( + int(shape[2]), dtype=torch.float32, device=self.PD.device), indexing="ij" ) - - voxel_pos = torch.stack([ - pos_x[mask].flatten(), - pos_y[mask].flatten(), - pos_z[mask].flatten() - ], dim=1) + + pos = torch.stack([pos_x, pos_y, pos_z], dim=-1) + + pos_rot = torch.einsum( + 'ij,xyzj->xyzi', + self.affine[:3,:3] / 1000, + pos + ) + self.affine[None, None, None, :3,3] / 1000 + + voxel_pos = pos_rot[mask] if voxel_shape == "box": def dephasing_func(t, n): return sinc(t, 0.5 / n) @@ -178,8 +183,9 @@ def dephasing_func(t, _): return identity(t) self.B1[:, mask], self.coil_sens[:, mask], self.size, + self.affine, voxel_pos, - torch.as_tensor(shape, device=self.PD.device) / 2 / self.size, + 500 / torch.linalg.norm(self.affine[:3,:3], dim=0), dephasing_func, recover_func=lambda data: recover(mask, data), phantom_motion=self.phantom_motion, @@ -210,7 +216,13 @@ def load(cls, file_name: str) -> VoxelGridPhantom: size = torch.tensor(data['FOV'], dtype=torch.float) except KeyError: size = torch.tensor([0.192, 0.192, 0.192]) - + + affine = torch.eye(3,4) + affine[0, 0] = size[0] / PD.shape[0] * 1000 + affine[1, 1] = size[1] / PD.shape[1] * 1000 + affine[2, 2] = size[2] / PD.shape[2] * 1000 + affine[:, 3] = -size / 2 * 1000 + tissue_masks = { key: torch.tensor(mask) for key, mask in data.items() @@ -222,7 +234,7 @@ def load(cls, file_name: str) -> VoxelGridPhantom: return cls( PD, T1, T2, T2dash, D, B0, B1, - torch.ones(1, *PD.shape), size, + torch.ones(1, *PD.shape), size, affine, tissue_masks=tissue_masks ) @@ -283,7 +295,13 @@ def load_mat( T2dash = torch.full_like(data[..., 0], T2dash) if isinstance(D, float): D = torch.full_like(data[..., 0], D) - + + affine = torch.eye(3,4) + affine[0, 0] = size[0] / data[..., 0].shape[0] * 1000 + affine[1, 1] = size[1] / data[..., 0].shape[1] * 1000 + affine[2, 2] = size[2] / data[..., 0].shape[2] * 1000 + affine[:, 3] = -size / 2 * 1000 + return cls( data[..., 0], # PD data[..., 1], # T1 @@ -294,6 +312,7 @@ def load_mat( data[..., 4][None, ...], # B1 coil_sens=torch.ones(1, *data.shape[:-1]), size=torch.as_tensor(size), + affine=affine, ) def slices(self, slices: list[int]) -> VoxelGridPhantom: @@ -331,6 +350,7 @@ def select_multicoil(tensor: torch.Tensor): select_multicoil(self.B1), select_multicoil(self.coil_sens), self.size.clone(), + self.affine.clone(), tissue_masks={ key: mask[..., slices] for key, mask in self.tissue_masks.items() }, @@ -373,6 +393,7 @@ def scale(map: torch.Tensor) -> torch.Tensor: scale(self.B1.squeeze()).unsqueeze(0), scale(self.coil_sens.squeeze()).unsqueeze(0), self.size.clone(), + self.affine.clone(), tissue_masks={ key: scale(mask) for key, mask in self.tissue_masks.items() } @@ -438,10 +459,11 @@ def resample_masks(tensors: Dict) -> Optional[Dict]: resample_multicoil(self.B1), resample_multicoil(self.coil_sens), self.size.clone(), + self.affine.clone(), tissue_masks=resample_masks(self.tissue_masks) ) - def plot(self, plot_masks=False, plot_slice="center", time_unit='s') -> None: + def plot(self, plot_masks=False, plot_slice="center", time_unit='s', title=None) -> None: """ Print and plot all data stored in this phantom. @@ -454,6 +476,8 @@ def plot(self, plot_masks=False, plot_slice="center", time_unit='s') -> None: slice and "all" plots all slices as a grid. time_unit : str Time unit to use for T1, T2, and T2' maps (default: 's'). Supported 's' and 'ms'. + title : None | str + Title of the plot if given. Default is None. """ print("VoxelGridPhantom") print(f"size = {self.size}") @@ -487,8 +511,10 @@ def plot(self, plot_masks=False, plot_slice="center", time_unit='s') -> None: # Calculate the grid size based on the number of plots cols = 3 rows = int(np.ceil(num_plots / cols)) - - plt.figure(figsize=(12, rows * 3)) + + plt.figure(figsize=(12, rows * 3), num=title) + if title: + plt.suptitle(title) # Plot the basic maps plt.subplot(rows, cols, 1) @@ -568,7 +594,8 @@ def to_full(sparse): to_full(sim_data.B0), to_full(sim_data.B1), to_full(sim_data.coil_sens), - sim_data.size + sim_data.size, + sim_data.affine, )