diff --git a/.gitignore b/.gitignore
index 44c7f5a5..2d4f7cd5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -155,6 +155,14 @@ checkpoints/*
dist/
build/
*.egg-info/
+*.dot
+
+# Ignore h5 and xmf formats
+*.h5
+*.xmf
+
+# Ignore CSV files
+*.csv
# USD files
*.usd
@@ -162,4 +170,4 @@ build/
*.usdc
*.usd.gz
*.usd.zip
-*.usd.bz2
\ No newline at end of file
+*.usd.bz2
diff --git a/README.md b/README.md
index 97b2082a..65f312cc 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@
# XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning
-XLB is a fully differentiable 2D/3D Lattice Boltzmann Method (LBM) library that leverages hardware acceleration. It supports [JAX](https://github.com/google/jax) and [NVIDIA Warp](https://github.com/NVIDIA/warp) backends, and is specifically designed to solve fluid dynamics problems in a computationally efficient and differentiable manner. Its unique combination of features positions it as an exceptionally suitable tool for applications in physics-based machine learning. With the new Warp backend, XLB now offers state-of-the-art performance for even faster simulations.
+XLB is a fully differentiable 2D/3D Lattice Boltzmann Method (LBM) library that leverages hardware acceleration. It supports [JAX](https://github.com/google/jax), [NVIDIA Warp](https://github.com/NVIDIA/warp), and [Neon](https://github.com/Autodesk/Neon) backends, and is specifically designed to solve fluid dynamics problems in a computationally efficient and differentiable manner. Its unique combination of features positions it as an exceptionally suitable tool for applications in physics-based machine learning. With the Warp backend, XLB offers state-of-the-art single-GPU performance, and with the new Neon backend it extends to multi-GPU (single-resolution). More importantly, the Neon backend provides grid refinement capabilities for multi-resolution simulations.
## Getting Started
To get started with XLB, you can install it using pip. There are different installation options depending on your hardware and needs:
@@ -28,6 +28,16 @@ This installation is for the JAX backend with TPU support:
pip install "xlb[tpu]"
```
+### Installation with Neon support
+Neon backend enables multi-GPU dense and single-GPU multi-resolution representations. Neon depends on a custom fork of warp-lang, so any existing warp installation must be removed before installing Neon. The Python interface for Neon can be installed from a pre-built wheel hosted on GitHub. Note that the wheel currently requires GLIBC >= 2.38 (e.g., Ubuntu 24.04 or later).
+
+```bash
+pip uninstall warp-lang
+pip install https://github.com/Autodesk/Neon/releases/download/v0.5.2a1/neon_gpu-0.5.2a1-cp312-cp312-linux_x86_64.whl
+```
+
+
+
### Notes:
- For Mac users: Use the basic CPU installation command as JAX's GPU support is not available on MacOS
- The NVIDIA Warp backend is included in all installation options and supports CUDA automatically when available
@@ -63,11 +73,36 @@ If you use XLB in your research, please cite the following paper:
}
```
+If you use the grid refinement capabilities in your work, please also cite:
+
+```
+@inproceedings{mahmoud2024optimized,
+ title={Optimized {GPU} implementation of grid refinement in lattice {Boltzmann} method},
+ author={Mahmoud, Ahmed H and Salehipour, Hesam and Meneghin, Massimiliano},
+ booktitle={2024 IEEE International Parallel and Distributed Processing Symposium (IPDPS)},
+ pages={398--407},
+ year={2024},
+ organization={IEEE}
+}
+
+@inproceedings{meneghin2022neon,
+ title={Neon: A Multi-{GPU} Programming Model for Grid-based Computations},
+ author={Meneghin, Massimiliano and Mahmoud, Ahmed H. and Jayaraman, Pradeep Kumar and Morris, Nigel J. W.},
+ booktitle={Proceedings of the 36th IEEE International Parallel and Distributed Processing Symposium},
+ pages={817--827},
+ year={2022},
+ month={june},
+ doi={10.1109/IPDPS53621.2022.00084},
+ url={https://escholarship.org/uc/item/9fz7k633}
+}
+```
+
## Key Features
-- **Multiple Backend Support:** XLB now includes support for multiple backends including JAX and NVIDIA Warp, providing *state-of-the-art* performance for lattice Boltzmann simulations. Currently, only single GPU is supported for the Warp backend.
+- **Multiple Backend Support:** XLB includes support for JAX, NVIDIA Warp, and Neon backends, providing *state-of-the-art* performance for lattice Boltzmann simulations. The Warp backend targets single-GPU runs, while the Neon backend enables multi-GPU single-resolution and single-GPU multi-resolution simulations.
+- **Multi-Resolution Grid Refinement:** Mesh refinement with nested cuboid grids and multiple kernel-fusion strategies for optimal performance on the Neon backend.
- **Integration with JAX Ecosystem:** The library can be easily integrated with JAX's robust ecosystem of machine learning libraries such as [Flax](https://github.com/google/flax), [Haiku](https://github.com/deepmind/dm-haiku), [Optax](https://github.com/deepmind/optax), and many more.
- **Differentiable LBM Kernels:** XLB provides differentiable LBM kernels that can be used in differentiable physics and deep learning applications.
-- **Scalability:** XLB is capable of scaling on distributed multi-GPU systems using the JAX backend, enabling the execution of large-scale simulations on hundreds of GPUs with billions of cells.
+- **Scalability:** XLB is capable of scaling on distributed multi-GPU systems using the JAX backend or the Neon backend, enabling the execution of large-scale simulations on hundreds of GPUs with billions of cells.
- **Support for Various LBM Boundary Conditions and Kernels:** XLB supports several LBM boundary conditions and collision kernels.
- **User-Friendly Interface:** Written entirely in Python, XLB emphasizes a highly accessible interface that allows users to extend the library with ease and quickly set up and run new simulations.
- **Leverages JAX Array and Shardmap:** The library incorporates the new JAX array unified array type and JAX shardmap, providing users with a numpy-like interface. This allows users to focus solely on the semantics, leaving performance optimizations to the compiler.
@@ -103,7 +138,7 @@ If you use XLB in your research, please cite the following paper:
- Airflow in to, out of, and within a building (~400 million cells)
+ Airflow into, out of, and within a building (~400 million cells)
@@ -128,6 +163,7 @@ The stages of a fluid density field from an initial state to the emergence of th
- BGK collision model (Standard LBM collision model)
- KBC collision model (unconditionally stable for flows with high Reynolds number)
+- Smagorinsky LES sub-grid model for turbulence modelling
### Machine Learning
@@ -143,21 +179,25 @@ The stages of a fluid density field from an initial state to the emergence of th
### Compute Capabilities
- Single GPU support for the Warp backend with state-of-the-art performance
+- Multi-GPU support using the Neon backend with single-resolution grids
+- Grid refinement support on single-GPU using the Neon backend
- Distributed Multi-GPU support using the JAX backend
- Mixed-Precision support (store vs compute)
+- Multiple kernel-fusion performance strategies for multi-resolution simulations
- Out-of-core support (coming soon)
### Output
- Binary and ASCII VTK output (based on PyVista library)
+- HDF5/XDMF output for multi-resolution data (with gzip compression)
- In-situ rendering using [PhantomGaze](https://github.com/loliverhennigh/PhantomGaze) library
- [Orbax](https://github.com/google/orbax)-based distributed asynchronous checkpointing
-- Image Output
+- Image Output (including multi-resolution slice images)
- 3D mesh voxelizer using trimesh
### Boundary conditions
-- **Equilibrium BC:** In this boundary condition, the fluid populations are assumed to be in at equilibrium. Can be used to set prescribed velocity or pressure.
+- **Equilibrium BC:** In this boundary condition, the fluid populations are assumed to be at equilibrium. Can be used to set prescribed velocity or pressure.
- **Full-Way Bounceback BC:** In this boundary condition, the velocity of the fluid populations is reflected back to the fluid side of the boundary, resulting in zero fluid velocity at the boundary.
@@ -171,17 +211,22 @@ The stages of a fluid density field from an initial state to the emergence of th
- **Interpolated Bounceback BC:** Interpolated bounce-back boundary condition for representing curved boundaries.
+- **Hybrid BC:** Combines regularized and bounce-back methods with optional wall-distance interpolation for improved accuracy on curved geometries.
+
+- **Grad's Approximation BC:** Boundary condition based on Grad's approximation of the non-equilibrium distribution.
+
## Roadmap
-### Work in Progress (WIP)
-*Note: Some of the work-in-progress features can be found in the branches of the XLB repository. For contributions to these features, please reach out.*
+### Recently Completed
- - 🌐 **Grid Refinement:** Implementing adaptive mesh refinement techniques for enhanced simulation accuracy.
+ - ✅ **Grid Refinement:** Multi-resolution LBM with nested cuboid grids and multiple kernel-fusion strategies via the Neon backend.
- - 💾 **Out-of-Core Computations:** Enabling simulations that exceed available GPU memory, suitable for CPU+GPU coherent memory models such as NVIDIA's Grace Superchips (coming soon).
+ - ✅ **Multi-GPU Acceleration using [Neon](https://github.com/Autodesk/Neon) + Warp:** Multi-GPU support through Neon's data structures with Warp-based kernels for single-resolution settings.
+### Work in Progress (WIP)
+*Note: Some of the work-in-progress features can be found in the branches of the XLB repository. For contributions to these features, please reach out.*
-- ⚡ **Multi-GPU Acceleration using [Neon](https://github.com/Autodesk/Neon) + Warp:** Using Neon's data structure for improved scaling.
+ - 💾 **Out-of-Core Computations:** Enabling simulations that exceed available GPU memory, suitable for CPU+GPU coherent memory models such as NVIDIA's Grace Superchips (coming soon).
- 🗜️ **GPU Accelerated Lossless Compression and Decompression**: Implementing high-performance lossless compression and decompression techniques for larger-scale simulations and improved performance.
diff --git a/examples/cfd/data/ahmed.json b/examples/cfd/data/ahmed.json
new file mode 100644
index 00000000..59ee02e7
--- /dev/null
+++ b/examples/cfd/data/ahmed.json
@@ -0,0 +1,22 @@
+{
+ "_comment": "Ahmed Car Model, slant - angle = 25 degree. Profiles on symmetry plane (y=0) covering entire field. Origin of coordinate system: x=0: end of the car, y=0: symmetry plane, z=0: ground plane S.Becker/H. Lienhart/C Stoots, Institute of Fluid Mechanics, University Erlangen-Nuremberg, Erlangen, Germany, Coordinates in meters need to convert to voxels, Velocity data in m/s",
+ "data": {
+ "-1.162" : { "x-velocity" : [26.995,29.825,29.182,28.488,27.703,26.988,26.456,26.163,26.190,26.523,27.083,28.033,29.131,30.429,31.747,33.036,34.268,35.354,36.312,37.083,37.770,38.484,39.033,39.447,39.839,40.086,40.268,40.380,40.451], "height" : [0.028,0.048,0.068,0.088,0.108,0.128,0.148,0.168,0.188,0.208,0.228,0.248,0.268,0.288,0.308,0.328,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.7388]},
+ "-1.062" : { "x-velocity" : [30.307,28.962,25.812,21.232,15.848,10.812,7.459,6.080,5.845,6.196,7.428,10.456,15.718,22.129,28.090,32.707,35.888,37.891,39.071,39.840,40.261,40.604,40.767,40.820,40.870,40.890,40.907,40.871,40.853], "height" : [0.028,0.048,0.068,0.088,0.108,0.128,0.148,0.168,0.188,0.208,0.228,0.248,0.268,0.288,0.308,0.328,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "-0.962" : { "x-velocity" : [52.216,51.303,50.196,48.833,47.728,46.790,45.514,44.222,43.379,42.829,42.322,42.056,41.876,41.706,41.584], "height" : [0.363,0.368,0.378,0.388,0.398,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "-0.862" : { "x-velocity" : [46.589,46.538,46.228,46.033,45.810,45.554,45.056,44.369,43.789,43.275,42.789,42.344,42.148,41.913,41.720], "height" : [0.363,0.368,0.378,0.388,0.398,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "-0.562" : { "x-velocity" : [43.237,43.262,43.248,43.225,43.183,43.145,43.083,43.030,42.904,42.776,42.685,42.434,42.358,42.197,42.042], "height" : [0.363,0.368,0.378,0.388,0.398,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "-0.362" : { "x-velocity" : [44.493,44.491,44.443,44.379,44.297,44.215,44.067,43.867,43.577,43.306,43.061,42.689,42.527,42.293,42.105], "height" : [0.363,0.368,0.378,0.388,0.398,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "-0.212" : { "x-velocity" : [49.202,48.429,47.805,46.697,45.883,44.913,44.195,43.650,43.130,42.677,42.432,42.154,41.961], "height" : [0.368,0.378,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "-0.162" : { "x-velocity" : [50.511,49.784,48.894,48.103,47.468,46.322,45.563,44.581,43.933,43.383,42.905,42.505,42.293,42.042,41.863], "height" : [0.348,0.358,0.368,0.378,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "-0.112" : { "x-velocity" : [27.615,35.449,41.526,46.068,46.277,46.038,45.774,45.505,45.237,44.701,44.326,43.765,43.284,42.890,42.529,42.247,42.082,41.880,41.732], "height" : [0.318,0.323,0.328,0.338,0.348,0.358,0.368,0.378,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "-0.062" : { "x-velocity" : [22.891,27.789,32.292,36.568,39.533,41.426,42.371,42.971,43.030,43.081,43.074,43.065,43.039,42.996,42.908,42.665,42.456,42.294,42.105,41.929,41.827,41.660,41.546], "height" : [0.298,0.303,0.308,0.313,0.318,0.323,0.328,0.338,0.348,0.358,0.368,0.378,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "-0.012" : { "x-velocity" : [23.304,26.317,29.429,32.341,34.923,37.106,38.673,39.841,40.447,40.780,40.973,41.085,41.193,41.282,41.359,41.442,41.522,41.699,41.737,41.749,41.724,41.714,41.642,41.574,41.518,41.431,41.366], "height" : [0.278,0.283,0.288,0.293,0.298,0.303,0.308,0.313,0.318,0.323,0.328,0.338,0.348,0.358,0.368,0.378,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "0.038" : { "x-velocity" : [42.752,37.392,15.320,-4.501,-8.079,-8.892,-8.420,-7.027,-5.143,-2.903,-0.936,0.927,2.200,3.099,3.622,4.026,4.280,4.520,5.620,8.938,13.913,17.872,21.148,24.814,29.075,33.188,36.424,38.490,39.388,39.675,39.794,39.911,40.007,40.219,40.425,40.643,40.757,40.896,40.994,41.058,41.124,41.127,41.143,41.106,41.080], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "0.088" : { "x-velocity" : [41.859,35.830,22.660,7.745,-5.808,-12.650,-14.748,-13.756,-10.659,-6.484,-2.121,1.303,3.672,5.441,7.066,9.157,11.613,14.620,17.662,20.639,23.565,26.437,29.484,32.441,35.024,36.938,37.938,38.377,38.595,38.728,38.856,38.976,39.133,39.438,39.749,39.975,40.129,40.344,40.499,40.649,40.783,40.853,40.927,40.945,40.960], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "0.138" : { "x-velocity" : [36.223,32.501,24.752,14.281,2.799,-6.218,-10.908,-11.892,-9.708,-5.258,-0.140,4.331,7.882,10.995,13.961,16.699,19.477,22.063,24.651,27.081,29.524,31.950,34.043,35.594,36.506,37.053,37.386,37.614,37.832,38.032,38.214,38.397,38.575,38.940,39.298,39.533,39.749,40.028,40.206,40.404,40.580,40.691,40.803,40.858,40.921], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "0.188" : { "x-velocity" : [29.417,27.755,23.967,18.261,11.662,5.405,0.676,-0.652,0.937,4.261,7.958,11.427,14.366,17.138,19.735,22.151,24.577,26.883,29.165,31.111,32.781,34.072,34.893,35.524,35.974,36.329,36.604,36.872,37.138,37.402,37.673,37.900,38.112,38.518,38.829,39.088,39.326,39.639,39.871,40.096,40.275,40.423,40.523,40.603,40.687], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "0.238" : { "x-velocity" : [24.405,24.168,22.782,20.196,16.970,13.937,12.137,11.757,12.851,14.649,16.780,18.995,21.070,23.335,25.280,27.468,29.262,30.832,32.133,33.102,33.856,34.473,34.922,35.340,35.698,36.039,36.336,36.629,36.906,37.193,37.454,37.691,37.929,38.329,38.611,38.875,39.126,39.414,39.677,39.917,40.097,40.259,40.380,40.478,40.568], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]},
+ "0.288" : { "x-velocity" : [21.489,22.225,22.127,21.456,20.404,19.743,19.541,19.909,21.002,22.381,24.018,25.670,27.421,28.998,30.371,31.523,32.406,33.111,33.670,34.155,34.532,34.893,35.240,35.567,35.875,36.158,36.437,36.708,36.974,37.230,37.473,37.709,37.932,38.266,38.515,38.773,39.008,39.270,39.562,39.782,39.962,40.148,40.266,40.369,40.475], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}
+ }
+}
\ No newline at end of file
diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py
index 2872d7fe..1616d32a 100644
--- a/examples/cfd/flow_past_sphere_3d.py
+++ b/examples/cfd/flow_past_sphere_3d.py
@@ -20,7 +20,7 @@
omega = 1.6
grid_shape = (512 // 2, 128 // 2, 128 // 2)
-compute_backend = ComputeBackend.WARP
+compute_backend = ComputeBackend.JAX
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend)
u_max = 0.04
@@ -51,7 +51,7 @@
z = np.arange(grid_shape[2])
X, Y, Z = np.meshgrid(x, y, z, indexing="ij")
indices = np.where((X - grid_shape[0] // 6) ** 2 + (Y - grid_shape[1] // 2) ** 2 + (Z - grid_shape[2] // 2) ** 2 < sphere_radius**2)
-sphere = [tuple(indices[i]) for i in range(velocity_set.d)]
+sphere = [tuple(indices[i].tolist()) for i in range(velocity_set.d)]
# Define Boundary Conditions
@@ -80,21 +80,25 @@ def bc_profile_jax():
return bc_profile_jax
- elif compute_backend == ComputeBackend.WARP:
+ else:
+ wp_dtype = precision_policy.compute_precision.wp_dtype
+ H_y = wp_dtype(grid_shape[1] - 1) # Height in y direction
+ H_z = wp_dtype(grid_shape[2] - 1) # Height in z direction
+ two = wp_dtype(2.0)
@wp.func
def bc_profile_warp(index: wp.vec3i):
# Poiseuille flow profile: parabolic velocity distribution
- y = wp.float32(index[1])
- z = wp.float32(index[2])
+ y = wp_dtype(index[1])
+ z = wp_dtype(index[2])
# Calculate normalized distance from center
- y_center = y - (H_y / 2.0)
- z_center = z - (H_z / 2.0)
- r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0
+ y_center = y - (H_y / two)
+ z_center = z - (H_z / two)
+ r_squared = (two * y_center / H_y) ** two + (two * z_center / H_z) ** two
# Parabolic profile: u = u_max * (1 - r²)
- return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), length=1)
+ return wp.vec(wp_dtype(u_max) * wp.max(wp_dtype(0.0), wp_dtype(1.0) - r_squared), length=1)
return bc_profile_warp
@@ -122,15 +126,33 @@ def bc_profile_warp(index: wp.vec3i):
precision_policy=precision_policy,
velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX),
)
+to_jax = xlb.utils.ToJAX("populations", velocity_set.q, grid_shape)
+
+# Setup Momentum Transfer for Force Calculation
+from xlb.operator.force.momentum_transfer import MomentumTransfer
+
+momentum_transfer = MomentumTransfer(bc_sphere, compute_backend=compute_backend)
+sphere_cross_section = np.pi * sphere_radius**2
# Post-Processing Function
-def post_process(step, f_current):
+def post_process(step, f_0, f_1):
+ wp.synchronize()
+
+ # Compute lift and drag
+ boundary_force = momentum_transfer(f_0, f_1, bc_mask, missing_mask)
+ drag = boundary_force[0] # x-direction
+ lift = boundary_force[2]
+ cd = 2.0 * drag / (u_max**2 * sphere_cross_section)
+ cl = 2.0 * lift / (u_max**2 * sphere_cross_section)
+ print(f"CD={cd}, CL={cl}")
+
# Convert to JAX array if necessary
- if not isinstance(f_current, jnp.ndarray):
- f_current = wp.to_jax(f_current)
+ if not isinstance(f_0, jnp.ndarray):
+ f_0 = to_jax(f_0)
+ wp.synchronize()
- rho, u = macro(f_current)
+ rho, u = macro(f_0)
# Remove boundary cells
u = u[:, 1:-1, 1:-1, 1:-1]
@@ -158,9 +180,7 @@ def post_process(step, f_current):
f_0, f_1 = f_1, f_0 # Swap the buffers
if step % post_process_interval == 0 or step == num_steps - 1:
- if compute_backend == ComputeBackend.WARP:
- wp.synchronize()
- post_process(step, f_0)
+ post_process(step, f_0, f_1)
end_time = time.time()
elapsed = end_time - start_time
print(f"Completed step {step}. Time elapsed for {post_process_interval} steps: {elapsed:.6f} seconds.")
diff --git a/examples/cfd/multires_flow_past_sphere_3d.py b/examples/cfd/multires_flow_past_sphere_3d.py
new file mode 100644
index 00000000..ea510a29
--- /dev/null
+++ b/examples/cfd/multires_flow_past_sphere_3d.py
@@ -0,0 +1,290 @@
+"""
+3D fow past a sphere with multi-resolution LBM.
+
+Demonstrates the multi-resolution Neon backend for a 3-D Poiseuille-
+inlet flow past a sphere inside a nested cuboid multi-resolution domain.
+Uses AABB-Close voxelization with halfway bounce-back on the sphere surface and
+computes lift/drag via momentum transfer.
+"""
+
+import neon
+import warp as wp
+import numpy as np
+import time
+
+import xlb
+from xlb.compute_backend import ComputeBackend
+from xlb.precision_policy import PrecisionPolicy
+from xlb.grid import multires_grid_factory
+from xlb.operator.boundary_condition import (
+ FullwayBounceBackBC,
+ HalfwayBounceBackBC,
+ RegularizedBC,
+ ExtrapolationOutflowBC,
+ DoNothingBC,
+ ZouHeBC,
+ HybridBC,
+)
+from xlb.operator.boundary_masker import MeshVoxelizationMethod
+from xlb.utils.mesher import make_cuboid_mesh, prepare_sparsity_pattern
+from xlb.operator.force import MultiresMomentumTransfer
+
+
+def generate_cuboid_mesh(stl_filename, num_finest_voxels_across_part):
+ """
+ Generate a cuboid mesh based on the provided voxel size and domain multipliers.
+ """
+ import trimesh
+ import os
+
+ # Domain multipliers for each refinement level
+ # First entry should be full domain size
+ # Domain multipliers
+ domainMultiplier = [
+ [7, 22, 7, 7, 7, 7], # -x, x, -y, y, -z, z (sphere at 1/4 domain from inlet)
+ [3, 12, 5, 5, 5, 5], # -x, x, -y, y, -z, z (wake-biased)
+ [2, 8, 4, 4, 4, 4],
+ [1, 5, 2, 2, 2, 2],
+ # [1, 2, 1, 1, 1, 1],
+ # [0.4, 1, 0.4, 0.4, 0.4, 0.4],
+ # [0.2, 0.4, 0.2, 0.2, 0.2, 0.2],
+ ]
+
+ # Load the mesh
+ mesh = trimesh.load_mesh(stl_filename, process=False)
+ assert not mesh.is_empty, ValueError("Loaded mesh is empty or invalid.")
+
+ # Compute original bounds
+ # Find voxel size and sphere radius
+ min_bound = mesh.vertices.min(axis=0)
+ max_bound = mesh.vertices.max(axis=0)
+ partSize = max_bound - min_bound
+
+ # smallest voxel size
+ voxel_size = min(partSize) / num_finest_voxels_across_part
+
+ # Compute translation to put mesh into first octant of that domain—
+ shift = np.array(
+ [
+ domainMultiplier[0][0] * partSize[0] - min_bound[0],
+ domainMultiplier[0][2] * partSize[1] - min_bound[1],
+ domainMultiplier[0][4] * partSize[2] - min_bound[2],
+ ],
+ dtype=float,
+ )
+
+ # Apply translation and save out temp stl
+ mesh.apply_translation(shift)
+ _ = mesh.vertex_normals
+ mesh_vertices = np.asarray(mesh.vertices) / voxel_size
+ mesh.export("temp.stl")
+
+ # Mesh based on temp stl
+ level_data = make_cuboid_mesh(
+ voxel_size,
+ domainMultiplier,
+ "temp.stl",
+ )
+ grid_shape_finest = tuple([i * 2 ** (len(level_data) - 1) for i in level_data[-1][0].shape])
+ print(f"Full shape based on finest voxels size is {grid_shape_finest}")
+ os.remove("temp.stl")
+ return level_data, mesh_vertices, tuple([int(a) for a in grid_shape_finest])
+
+
+# -------------------------- Simulation Setup --------------------------
+
+# The following parameters define the resolution of the voxelized grid
+sphere_radius = 5
+num_finest_voxels_across_part = 2 * sphere_radius
+
+# Other setup parameters
+Re = 5000.0
+compute_backend = ComputeBackend.NEON
+precision_policy = PrecisionPolicy.FP32FP32
+velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend)
+u_max = 0.04
+num_steps = 10000
+post_process_interval = 1000
+
+# Initialize XLB
+xlb.init(
+ velocity_set=velocity_set,
+ default_backend=compute_backend,
+ default_precision_policy=precision_policy,
+)
+
+# Generate the cuboid mesh and sphere vertices
+stl_filename = "../stl-files/sphere.stl"
+level_data, sphere, grid_shape_finest = generate_cuboid_mesh(stl_filename, num_finest_voxels_across_part)
+
+
+# Define exporter object for hdf5 output
+from xlb.utils import MultiresIO
+
+# Define an exporter for the multiresolution data
+exporter = MultiresIO({"velocity": 3, "density": 1}, level_data)
+
+# Prepare the sparsity pattern and origins from the level data
+sparsity_pattern, level_origins = prepare_sparsity_pattern(level_data)
+
+# get the number of levels
+num_levels = len(level_data)
+
+# Create the multires grid
+grid = multires_grid_factory(
+ grid_shape_finest,
+ velocity_set=velocity_set,
+ sparsity_pattern_list=sparsity_pattern,
+ sparsity_pattern_origins=[neon.Index_3d(*box_origin) for box_origin in level_origins],
+)
+
+# Define Boundary Indices
+coarsest_level = grid.count_levels - 1
+box = grid.bounding_box_indices(shape=grid.level_to_shape(coarsest_level))
+box_no_edge = grid.bounding_box_indices(shape=grid.level_to_shape(coarsest_level), remove_edges=True)
+inlet = box_no_edge["left"]
+outlet = box_no_edge["right"]
+walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(velocity_set.d)]
+walls = np.unique(np.array(walls), axis=-1).tolist()
+
+
+# Define Boundary Conditions
+def bc_profile():
+ """Build a Warp function for a Poiseuille parabolic inlet velocity profile."""
+ assert compute_backend == ComputeBackend.NEON
+ # IMPORTANT NOTE: the user defined functional must be defined in terms of the indices at the finest level
+ _, ny, nz = grid_shape_finest
+ dtype = precision_policy.compute_precision.wp_dtype
+ H_y = dtype(ny) # Length in y direction (finest level)
+ H_z = dtype(nz) # Length in z direction (finest level)
+ two = dtype(2.0)
+ one = dtype(1.0)
+ zero = dtype(0.0)
+ u_max_wp = dtype(u_max)
+ _u_vec = wp.vec(velocity_set.d, dtype=dtype)
+
+ @wp.func
+ def bc_profile_warp(index: wp.vec3i):
+ # Poiseuille flow profile: parabolic velocity distribution
+ y = dtype(index[1])
+ z = dtype(index[2])
+
+ # Calculate normalized distance from center
+ y_center = y - (H_y / two)
+ z_center = z - (H_z / two)
+ r_squared = (two * y_center / H_y) ** two + (two * z_center / H_z) ** two
+
+ # Parabolic profile: u = u_max * (1 - r²)
+ # Note that unlike RegularizedBC and ZouHeBC which only accept normal velocity, hybridBC accepts the full velocity vector
+
+ # For hybridBC
+ # return _u_vec(u_max_wp * wp.max(zero, one - r_squared), zero, zero)
+
+ # For Regularized and ZouHe
+ return wp.vec(u_max_wp * wp.max(zero, one - r_squared), length=1)
+
+ return bc_profile_warp
+
+
+# Convert bc indices to a list of list (first entry corresponds to the finest level)
+inlet = [[] for _ in range(num_levels - 1)] + [inlet]
+outlet = [[] for _ in range(num_levels - 1)] + [outlet]
+walls = [[] for _ in range(num_levels - 1)] + [walls]
+
+# Initialize Boundary Conditions
+bc_left = RegularizedBC("velocity", profile=bc_profile(), indices=inlet)
+# Alternatives:
+# bc_left = HybridBC(bc_method="bounceback_regularized", profile=bc_profile(), indices=inlet)
+# bc_left = RegularizedBC("velocity", prescribed_value=(u_max, 0.0, 0.0), indices=inlet)
+bc_walls = FullwayBounceBackBC(indices=walls)
+bc_outlet = DoNothingBC(indices=outlet)
+# bc_outlet = ExtrapolationOutflowBC(indices=outlet)
+bc_sphere = HybridBC(
+ bc_method="nonequilibrium_regularized", mesh_vertices=sphere, voxelization_method=MeshVoxelizationMethod("AABB"), use_mesh_distance=True
+)
+# bc_sphere = HalfwayBounceBackBC(mesh_vertices=sphere, voxelization_method=MeshVoxelizationMethod('AABB'))
+
+boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere]
+
+# Configure the simulation relaxation time
+visc = u_max * num_finest_voxels_across_part / Re
+omega_finest = 1.0 / (3.0 * visc + 0.5)
+
+# Make initializer operator
+from xlb.helper.initializers import CustomMultiresInitializer
+
+initializer = CustomMultiresInitializer(
+ bc_id=bc_outlet.id,
+ constant_velocity_vector=(u_max, 0.0, 0.0),
+ velocity_set=velocity_set,
+ precision_policy=precision_policy,
+ compute_backend=compute_backend,
+)
+
+# Define a multi-resolution simulation manager
+sim = xlb.helper.MultiresSimulationManager(
+ omega_finest=omega_finest,
+ grid=grid,
+ boundary_conditions=boundary_conditions,
+ collision_type="KBC",
+ initializer=initializer,
+ mres_perf_opt=xlb.mres_perf_optimization_type.MresPerfOptimizationType.FUSION_AT_FINEST,
+)
+
+# Setup Momentum Transfer for Force Calculation
+bc_sphere = boundary_conditions[-1]
+momentum_transfer = MultiresMomentumTransfer(bc_sphere, mres_perf_opt=sim.mres_perf_opt, compute_backend=compute_backend)
+
+
+def print_lift_drag(sim):
+ """Compute and print drag and lift coefficients from the simulation state."""
+ boundary_force = momentum_transfer(sim.f_0, sim.f_1, sim.bc_mask, sim.missing_mask)
+ drag = boundary_force[0] # x-direction
+ lift = boundary_force[2]
+ sphere_cross_section = np.pi * sphere_radius**2
+ u_avg = 0.5 * u_max
+ cd = 2.0 * drag / (u_avg**2 * sphere_cross_section)
+ cl = 2.0 * lift / (u_avg**2 * sphere_cross_section)
+ print(f"\tCD={cd}, CL={cl}")
+
+
+# -------------------------- Simulation Loop --------------------------
+
+wp.synchronize()
+start_time = time.time()
+for step in range(num_steps):
+ sim.step()
+
+ if step % post_process_interval == 0 or step == num_steps - 1:
+ # # Export VTK for comparison
+ # tic_write = time.perf_counter()
+ # sim.export_macroscopic("multires_flow_over_sphere_3d_")
+ # toc_write = time.perf_counter()
+ # print(f"\tVTK file written in {toc_write - tic_write:0.1f} seconds")
+
+ # Call the Macroscopic operator to compute macroscopic fields
+ wp.synchronize()
+ sim.macro(sim.f_0, sim.bc_mask, sim.rho, sim.u, streamId=0)
+
+ # Call the exporter to save the current state
+ nx, ny, nz = grid_shape_finest
+ filename = f"multires_flow_past_sphere_3d_{step:04d}"
+ wp.synchronize()
+ exporter.to_hdf5(filename, {"velocity": sim.u, "density": sim.rho}, compression="gzip", compression_opts=2)
+ exporter.to_slice_image(
+ filename,
+ {"velocity": sim.u},
+ plane_point=(nx // 2, ny // 2, nz // 2),
+ plane_normal=(0, 0, 1),
+ grid_res=256,
+ slice_thickness=2 ** (num_levels - 1),
+ bounds=(0.1, 0.6, 0.3, 0.7),
+ )
+
+ # Print lift and drag coefficients
+ print_lift_drag(sim)
+ wp.synchronize()
+ end_time = time.time()
+ elapsed = end_time - start_time
+ print(f"\tCompleted step {step}. Time elapsed for {post_process_interval} steps: {elapsed:.6f} seconds.")
+ start_time = time.time()
diff --git a/examples/cfd/multires_windtunnel_3d.py b/examples/cfd/multires_windtunnel_3d.py
new file mode 100644
index 00000000..d94d4357
--- /dev/null
+++ b/examples/cfd/multires_windtunnel_3d.py
@@ -0,0 +1,575 @@
+"""
+Ahmed body aerodynamics with multi-resolution LBM.
+
+Simulates turbulent flow around the Ahmed body (25-degree slant angle)
+using the XLB multi-resolution Neon backend. Computes drag and lift
+coefficients via momentum transfer and exports HDF5/XDMF data for
+post-processing.
+"""
+
+import neon
+import warp as wp
+import numpy as np
+import time
+import os
+import matplotlib.pyplot as plt
+import trimesh
+import shutil
+
+import xlb
+from xlb.compute_backend import ComputeBackend
+from xlb.precision_policy import PrecisionPolicy
+from xlb.grid import multires_grid_factory
+from xlb.operator.boundary_condition import (
+ DoNothingBC,
+ HybridBC,
+ RegularizedBC,
+)
+from xlb.operator.boundary_masker import MeshVoxelizationMethod
+from xlb.utils.mesher import prepare_sparsity_pattern, make_cuboid_mesh, MultiresIO
+from xlb.utils import UnitConvertor
+from xlb.operator.force import MultiresMomentumTransfer
+from xlb.helper.initializers import CustomMultiresInitializer
+
+wp.clear_kernel_cache()
+wp.config.quiet = True
+
+# User Configuration
+# =================
+# Physical and simulation parameters
+wind_speed_lbm = 0.05 # Lattice velocity
+wind_speed_mps = 38.0 # Physical inlet velocity in m/s (user input)
+flow_passes = 2 # Domain flow passes
+kinematic_viscosity = 1.508e-5 # Kinematic viscosity of air in m^2/s 1.508e-5
+voxel_size = 0.005 # Finest voxel size in meters
+
+# STL filename
+stl_filename = "../stl-files/Ahmed_25_NoLegs.stl"
+script_name = "Ahmed"
+
+# I/O settings
+print_interval_percentage = 1 # Print every 1% of iterations
+file_output_crossover_percentage = 10 # Crossover at 50% of iterations
+num_file_outputs_pre_crossover = 20 # Outputs before crossover
+num_file_outputs_post_crossover = 5 # Outputs after crossover
+
+# Other setup parameters
+compute_backend = ComputeBackend.NEON
+precision_policy = PrecisionPolicy.FP32FP32
+velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend)
+
+
+def generate_cuboid_mesh(stl_filename, voxel_size):
+ """
+ Alternative cuboid mesh generation based on Apolo's method with domain multipliers per level.
+ """
+ # Domain multipliers for each refinement level
+ domain_multiplier = [
+ [3.0, 4.0, 2.5, 2.5, 0.0, 4.0], # -x, x, -y, y, -z, z
+ [1.2, 1.25, 1.75, 1.75, 0.0, 1.5],
+ [0.8, 1.0, 1.25, 1.25, 0.0, 1.2],
+ [0.5, 0.65, 0.6, 0.60, 0.0, 0.6],
+ [0.25, 0.25, 0.25, 0.25, 0.0, 0.25],
+ ]
+
+ # Load the mesh
+ mesh = trimesh.load_mesh(stl_filename, process=False)
+ if mesh.is_empty:
+ raise ValueError("Loaded mesh is empty or invalid.")
+
+ # Compute original bounds
+ min_bound = mesh.vertices.min(axis=0)
+ max_bound = mesh.vertices.max(axis=0)
+ partSize = max_bound - min_bound
+ x0 = max_bound[0] # End of car for Ahmed
+
+ # Compute translation to put mesh into first octant of the domain
+ stl_shift = np.array(
+ [
+ domain_multiplier[0][0] * partSize[0] - min_bound[0],
+ domain_multiplier[0][2] * partSize[1] - min_bound[1],
+ domain_multiplier[0][4] * partSize[2] - min_bound[2],
+ ],
+ dtype=float,
+ )
+
+ # Apply translation and save out temp STL
+ mesh.apply_translation(stl_shift)
+ _ = mesh.vertex_normals
+ mesh_vertices = np.asarray(mesh.vertices)
+ mesh.export("temp.stl")
+
+ # Generate mesh using make_cuboid_mesh
+ level_data = make_cuboid_mesh(
+ voxel_size,
+ domain_multiplier,
+ "temp.stl",
+ )
+
+ num_levels = len(level_data)
+ grid_shape_finest = tuple([int(i * 2 ** (num_levels - 1)) for i in level_data[-1][0].shape])
+ print(f"Full shape based on finest voxel size is {grid_shape_finest}")
+ os.remove("temp.stl")
+
+ return (
+ level_data,
+ mesh_vertices,
+ tuple([int(a) for a in grid_shape_finest]),
+ stl_shift,
+ x0,
+ )
+
+
+# Boundary Conditions Setup
+# =========================
+def setup_boundary_conditions(grid, level_data, body_vertices, wind_speed_mps):
+ """
+ Set up boundary conditions for the simulation.
+ """
+ # Convert wind speed to lattice units
+ wind_speed_lbm = unit_convertor.velocity_to_lbm(wind_speed_mps)
+
+ left_indices = grid.boundary_indices_across_levels(level_data, box_side="left", remove_edges=True)
+ right_indices = grid.boundary_indices_across_levels(level_data, box_side="right", remove_edges=True)
+ top_indices = grid.boundary_indices_across_levels(level_data, box_side="top", remove_edges=False)
+ bottom_indices = grid.boundary_indices_across_levels(level_data, box_side="bottom", remove_edges=False)
+ front_indices = grid.boundary_indices_across_levels(level_data, box_side="front", remove_edges=False)
+ back_indices = grid.boundary_indices_across_levels(level_data, box_side="back", remove_edges=False)
+
+ # Initialize boundary conditions
+ bc_inlet = RegularizedBC("velocity", prescribed_value=(wind_speed_lbm, 0.0, 0.0), indices=left_indices)
+ bc_outlet = DoNothingBC(indices=right_indices)
+ bc_top = HybridBC(bc_method="nonequilibrium_regularized", indices=top_indices)
+ bc_bottom = HybridBC(bc_method="nonequilibrium_regularized", indices=bottom_indices)
+ bc_front = HybridBC(bc_method="nonequilibrium_regularized", indices=front_indices)
+ bc_back = HybridBC(bc_method="nonequilibrium_regularized", indices=back_indices)
+ bc_body = HybridBC(
+ bc_method="nonequilibrium_regularized",
+ mesh_vertices=unit_convertor.length_to_lbm(body_vertices),
+ voxelization_method=MeshVoxelizationMethod("AABB_CLOSE", close_voxels=4),
+ use_mesh_distance=True,
+ )
+
+ return [bc_top, bc_bottom, bc_front, bc_back, bc_inlet, bc_outlet, bc_body]
+
+
+# Simulation Initialization
+# =========================
+def initialize_simulation(
+ grid, boundary_conditions, omega_finest, initializer, collision_type="KBC", mres_perf_opt=xlb.MresPerfOptimizationType.FUSION_AT_FINEST
+):
+ """
+ Initialize the multiresolution simulation manager.
+ """
+ sim = xlb.helper.MultiresSimulationManager(
+ omega_finest=omega_finest,
+ grid=grid,
+ boundary_conditions=boundary_conditions,
+ collision_type=collision_type,
+ initializer=initializer,
+ mres_perf_opt=mres_perf_opt,
+ )
+ return sim
+
+
+# Utility Functions
+# =================
+def compute_force_coefficients(sim, step, momentum_transfer, wind_speed_lbm, reference_area):
+ """
+ Calculate and print lift and drag coefficients.
+ """
+ boundary_force = momentum_transfer(sim.f_0, sim.f_1, sim.bc_mask, sim.missing_mask)
+ drag = boundary_force[0]
+ lift = boundary_force[2]
+ cd = 2.0 * drag / (wind_speed_lbm**2 * reference_area)
+ cl = 2.0 * lift / (wind_speed_lbm**2 * reference_area)
+ if np.isnan(cd) or np.isnan(cl):
+ print(f"NaN detected in coefficients at step {step}")
+ raise ValueError(f"NaN detected in coefficients at step {step}: Cd={cd}, Cl={cl}")
+ drag_values.append([cd, cl])
+ return cd, cl, drag
+
+
+def plot_force_coefficients(drag_values, output_dir, print_interval, script_name, percentile_range=(15, 85), use_log_scale=False):
+ """
+ Plot CD and CL over time and save the plot to the output directory.
+ """
+ drag_values_array = np.array(drag_values)
+ steps = np.arange(0, len(drag_values) * print_interval, print_interval)
+ cd_values = drag_values_array[:, 0]
+ cl_values = drag_values_array[:, 1]
+ y_min = min(np.percentile(cd_values, percentile_range[0]), np.percentile(cl_values, percentile_range[0]))
+ y_max = max(np.percentile(cd_values, percentile_range[1]), np.percentile(cl_values, percentile_range[1]))
+ padding = (y_max - y_min) * 0.1
+ y_min, y_max = y_min - padding, y_max + padding
+ if use_log_scale:
+ y_min = max(y_min, 1e-6)
+ plt.figure(figsize=(10, 6))
+ plt.plot(steps, cd_values, label="Drag Coefficient (Cd)", color="blue")
+ plt.plot(steps, cl_values, label="Lift Coefficient (Cl)", color="red")
+ plt.xlabel("Simulation Step")
+ plt.ylabel("Coefficient")
+ plt.title(f"{script_name}: Drag and Lift Coefficients Over Time")
+ plt.legend()
+ plt.grid(True)
+ plt.ylim(y_min, y_max)
+ if use_log_scale:
+ plt.yscale("log")
+ plt.savefig(os.path.join(output_dir, "drag_lift_plot.png"))
+ plt.close()
+
+
+def compute_voxel_statistics(sim, bc_mask_exporter, sparsity_pattern, boundary_conditions, unit_convertor):
+ """
+ Compute active/solid voxels, totals, lattice updates, and reference area based on simulation data.
+ """
+ fields_data = bc_mask_exporter.get_fields_data({"bc_mask": sim.bc_mask})
+ bc_mask_data = fields_data["bc_mask_0"]
+ level_id_field = bc_mask_exporter.level_id_field
+
+ # Compute solid voxels per level (assuming 255 is the solid marker)
+ solid_voxels = []
+ for lvl in range(num_levels):
+ level_mask = level_id_field == lvl
+ solid_voxels.append(np.sum(bc_mask_data[level_mask] == 255))
+
+ # Compute active voxels (total non-zero in sparsity minus solids)
+ active_voxels = [np.count_nonzero(mask) for mask in sparsity_pattern]
+ active_voxels = [max(0, active_voxels[lvl] - solid_voxels[lvl]) for lvl in range(num_levels)]
+
+ # Totals
+ total_voxels = sum(active_voxels)
+ total_lattice_updates_per_step = sum(active_voxels[lvl] * (2 ** (num_levels - 1 - lvl)) for lvl in range(num_levels))
+
+ # Compute reference area (projected on YZ plane at finest level)
+ finest_level = 0
+ mask_finest = level_id_field == finest_level
+ bc_mask_finest = bc_mask_data[mask_finest]
+ active_indices_finest = np.argwhere(sparsity_pattern[0])
+ bc_body_id = boundary_conditions[-1].id # Assuming last BC is bc_body
+ solid_voxels_indices = active_indices_finest[bc_mask_finest == bc_body_id]
+ unique_jk = np.unique(solid_voxels_indices[:, 1:3], axis=0)
+ reference_area = unique_jk.shape[0]
+ reference_area_physical = reference_area * unit_convertor.reference_length**2
+
+ return {
+ "active_voxels": active_voxels,
+ "solid_voxels": solid_voxels,
+ "total_voxels": total_voxels,
+ "total_lattice_updates_per_step": total_lattice_updates_per_step,
+ "reference_area": reference_area,
+ "reference_area_physical": reference_area_physical,
+ }
+
+
+def plot_data(x0, output_dir, delta_x_coarse, sim, IOexporter, prefix="Ahmed"):
+ """
+ Ahmed Car Model, slant - angle = 25 degree
+ Profiles on symmetry plane (y=0) covering entire field
+ Origin of coordinate system:
+ x=0: end of the car, y=0: symmetry plane, z=0: ground plane
+
+ S.Becker/H. Lienhart/C.Stoots
+ Insitute of Fluid Mechanics
+ University Erlangen-Nuremberg
+ Erlangen, Germany
+ Coordaintes in meters need to convert to voxels
+ Velocity data in m/s
+ """
+
+ def _load_sim_line(csv_path):
+ """
+ Read a CSV exported by IOexporter.to_line without pandas.
+ Returns (z, Ux).
+ """
+ # Read with header as column names
+ data = np.genfromtxt(
+ csv_path,
+ delimiter=",",
+ names=True,
+ autostrip=True,
+ dtype=None,
+ encoding="utf-8",
+ )
+ if data.size == 0:
+ raise ValueError(f"No data in {csv_path}")
+
+ z = np.asarray(data["z"], dtype=float)
+ ux = np.asarray(data["value"], dtype=float)
+ return z, ux
+
+ # Load reference data
+ import json
+
+ ref_data_path = "examples/cfd/data/ahmed.json"
+ with open(ref_data_path, "r") as file:
+ data = json.load(file)
+
+ for x_str in data["data"].keys():
+ # Extract reference horizontal velocity in m/s and its corresponding height in m
+ refX = np.array(data["data"][x_str]["x-velocity"])
+ refY = np.array(data["data"][x_str]["height"])
+
+ # From reference x0 (rear of body) find x1 for plot
+ x_pos = float(x_str)
+ x1 = x0 + x_pos
+
+ print(f" x1 is {x1}")
+ sim.macro(sim.f_0, sim.bc_mask, sim.rho, sim.u, streamId=0)
+ filename = os.path.join(output_dir, f"{prefix}_{x_str}")
+ wp.synchronize()
+ IOexporter.to_line(
+ filename,
+ {"velocity": sim.u},
+ start_point=(x1, 0, 0),
+ end_point=(x1, 0, 0.8),
+ resolution=250,
+ component=0,
+ radius=delta_x_coarse, # needed with model units
+ )
+ # read the CSV written by the exporter
+ csv_path = filename + "_velocity_0.csv"
+ print(f"CSV path is {csv_path}")
+
+ try:
+ sim_z, sim_ux = _load_sim_line(csv_path)
+ except Exception as e:
+ print(f"Failed to read {csv_path}: {e}")
+ continue
+
+ # plot reference vs simulation
+ plt.figure(figsize=(4.5, 6))
+ plt.plot(refX, refY, "o", mfc="none", label="Experimental)")
+ plt.plot(sim_ux, sim_z, "-", lw=2, label="Simulation")
+ plt.xlim(np.min(refX) * 0.9, np.max(refX) * 1.1)
+ plt.ylim(np.min(refY), np.max(refY))
+ plt.xlabel("Ux [m/s]")
+ plt.ylabel("z [m]")
+ plt.title(f"Velocity Plot at {x_pos:+.3f}")
+ plt.grid(True, alpha=0.3)
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(filename + ".png", dpi=150)
+ plt.close()
+
+
+# Main Script
+# ===========
+# Initialize XLB
+
+xlb.init(
+ velocity_set=velocity_set,
+ default_backend=compute_backend,
+ default_precision_policy=precision_policy,
+)
+
+# Generate mesh
+level_data, body_vertices, grid_shape_zip, stl_shift, x0 = generate_cuboid_mesh(stl_filename, voxel_size)
+
+# Prepare the sparsity pattern and origins from the level data
+sparsity_pattern, level_origins = prepare_sparsity_pattern(level_data)
+
+# Define a unit convertor
+unit_convertor = UnitConvertor(
+ velocity_lbm_unit=wind_speed_lbm,
+ velocity_physical_unit=wind_speed_mps,
+ voxel_size_physical_unit=voxel_size,
+)
+
+# Calculate lattice parameters
+num_levels = len(level_data)
+delta_x_coarse = voxel_size * 2 ** (num_levels - 1)
+nu_lattice = unit_convertor.viscosity_to_lbm(kinematic_viscosity)
+omega_finest = 1.0 / (3.0 * nu_lattice + 0.5)
+
+# Create output directory
+current_dir = os.path.join(os.path.dirname(__file__))
+output_dir = os.path.join(current_dir, script_name)
+if os.path.exists(output_dir):
+ shutil.rmtree(output_dir)
+os.makedirs(output_dir)
+
+# Define exporter objects
+field_name_cardinality_dict = {"velocity": 3, "density": 1}
+h5exporter = MultiresIO(
+ field_name_cardinality_dict,
+ level_data,
+ offset=-stl_shift,
+ unit_convertor=unit_convertor,
+)
+bc_mask_exporter = MultiresIO(
+ {"bc_mask": 1},
+ level_data,
+ offset=-stl_shift,
+ unit_convertor=unit_convertor,
+)
+
+# Create grid
+grid = multires_grid_factory(
+ grid_shape_zip,
+ velocity_set=velocity_set,
+ sparsity_pattern_list=sparsity_pattern,
+ sparsity_pattern_origins=[neon.Index_3d(*box_origin) for box_origin in level_origins],
+)
+
+# Calculate num_steps
+coarsest_level = grid.count_levels - 1
+grid_shape_x_coarsest = grid.level_to_shape(coarsest_level)[0]
+num_steps = int(flow_passes * (grid_shape_x_coarsest / wind_speed_lbm))
+
+# Calculate print and file output intervals
+print_interval = max(1, int(num_steps * (print_interval_percentage / 100.0)))
+crossover_step = int(num_steps * (file_output_crossover_percentage / 100.0))
+file_output_interval_pre_crossover = (
+ max(1, int(crossover_step / num_file_outputs_pre_crossover)) if num_file_outputs_pre_crossover > 0 else num_steps + 1
+)
+file_output_interval_post_crossover = (
+ max(1, int((num_steps - crossover_step) / num_file_outputs_post_crossover)) if num_file_outputs_post_crossover > 0 else num_steps + 1
+)
+
+# Setup boundary conditions
+boundary_conditions = setup_boundary_conditions(grid, level_data, body_vertices, wind_speed_mps)
+
+# Create initializer
+wind_speed_lbm = unit_convertor.velocity_to_lbm(wind_speed_mps)
+initializer = CustomMultiresInitializer(
+ bc_id=boundary_conditions[-2].id, # bc_outlet
+ constant_velocity_vector=(wind_speed_lbm, 0.0, 0.0),
+ velocity_set=velocity_set,
+ precision_policy=precision_policy,
+ compute_backend=compute_backend,
+)
+
+# Initialize simulation
+sim = initialize_simulation(grid, boundary_conditions, omega_finest, initializer)
+
+# Compute voxel statistics and reference area
+stats = compute_voxel_statistics(sim, bc_mask_exporter, sparsity_pattern, boundary_conditions, unit_convertor)
+active_voxels = stats["active_voxels"]
+solid_voxels = stats["solid_voxels"]
+total_voxels = stats["total_voxels"]
+total_lattice_updates_per_step = stats["total_lattice_updates_per_step"]
+reference_area = stats["reference_area"]
+reference_area_physical = stats["reference_area_physical"]
+
+# Save initial bc_mask
+filename = os.path.join(output_dir, f"{script_name}_initial_bc_mask")
+try:
+ bc_mask_exporter.to_hdf5(filename, {"bc_mask": sim.bc_mask}, compression="gzip", compression_opts=0)
+ xmf_filename = f"{filename}.xmf"
+ hdf5_basename = f"{script_name}_initial_bc_mask.h5"
+except Exception as e:
+ print(f"Error during initial bc_mask output: {e}")
+wp.synchronize()
+
+
+# Setup momentum transfer
+momentum_transfer = MultiresMomentumTransfer(
+ boundary_conditions[-1],
+ mres_perf_opt=xlb.MresPerfOptimizationType.FUSION_AT_FINEST,
+ compute_backend=compute_backend,
+)
+
+# Print simulation info
+print("\n" + "=" * 50 + "\n")
+print(f"Number of flow passes: {flow_passes}")
+print(f"Calculated iterations: {num_steps:,}")
+print(f"Finest voxel size: {voxel_size} meters")
+print(f"Coarsest voxel size: {delta_x_coarse} meters")
+print(f"Total voxels: {sum(np.count_nonzero(mask) for mask in sparsity_pattern):,}")
+print(f"Total active voxels: {total_voxels:,}")
+print(f"Active voxels per level: {[int(v) for v in active_voxels]}")
+print(f"Solid voxels per level: {[int(v) for v in solid_voxels]}")
+print(f"Total lattice updates per global step: {total_lattice_updates_per_step:,}")
+print(f"Number of refinement levels: {num_levels}")
+print(f"Physical inlet velocity: {wind_speed_mps:.4f} m/s")
+print(f"Lattice velocity (ulb): {wind_speed_lbm}")
+print(f"Computed reference area (bc_mask): {reference_area} lattice units")
+print(f"Physical reference area (bc_mask): {reference_area_physical:.6f} m^2")
+print("\n" + "=" * 50 + "\n")
+
+# -------------------------- Simulation Loop --------------------------
+wp.synchronize()
+start_time = time.time()
+compute_time = 0.0
+steps_since_last_print = 0
+drag_values = []
+
+for step in range(num_steps):
+ step_start = time.time()
+ sim.step()
+ wp.synchronize()
+ compute_time += time.time() - step_start
+ steps_since_last_print += 1
+ if step % print_interval == 0 or step == num_steps - 1:
+ sim.macro(sim.f_0, sim.bc_mask, sim.rho, sim.u, streamId=0)
+ wp.synchronize()
+ cd, cl, drag = compute_force_coefficients(sim, step, momentum_transfer, wind_speed_lbm, reference_area)
+ filename = os.path.join(output_dir, f"{script_name}_{step:04d}")
+ h5exporter.to_hdf5(filename, {"velocity": sim.u, "density": sim.rho}, compression="gzip", compression_opts=0)
+ h5exporter.to_slice_image(
+ filename,
+ {"velocity": sim.u},
+ plane_point=(1, 0, 0),
+ plane_normal=(0, 1, 0),
+ grid_res=2000,
+ bounds=(0.25, 0.75, 0, 0.5),
+ show_axes=False,
+ show_colorbar=False,
+ slice_thickness=delta_x_coarse, # needed when using model units
+ )
+ end_time = time.time()
+ elapsed = end_time - start_time
+ total_lattice_updates = total_lattice_updates_per_step * steps_since_last_print
+ MLUPS = total_lattice_updates / compute_time / 1e6 if compute_time > 0 else 0.0
+ current_flow_passes = step * wind_speed_lbm / grid_shape_x_coarsest
+ remaining_steps = num_steps - step - 1
+ time_remaining = 0.0 if MLUPS == 0 else (total_lattice_updates_per_step * remaining_steps) / (MLUPS * 1e6)
+ hours, rem = divmod(time_remaining, 3600)
+ minutes, seconds = divmod(rem, 60)
+ time_remaining_str = f"{int(hours):02d}h {int(minutes):02d}m {int(seconds):02d}s"
+ percent_complete = (step + 1) / num_steps * 100
+ print(f"Completed step {step}/{num_steps} ({percent_complete:.2f}% complete)")
+ print(f" Flow Passes: {current_flow_passes:.2f}")
+ print(f" Time elapsed: {elapsed:.1f}s, Compute time: {compute_time:.1f}s, ETA: {time_remaining_str}")
+ print(f" MLUPS: {MLUPS:.1f}")
+ print(f" Cd={cd:.3f}, Cl={cl:.3f}, Drag Force (lattice units)={drag:.3f}")
+ start_time = time.time()
+ compute_time = 0.0
+ steps_since_last_print = 0
+ file_output_interval = file_output_interval_pre_crossover if step < crossover_step else file_output_interval_post_crossover
+ if step % file_output_interval == 0 or step == num_steps - 1:
+ sim.macro(sim.f_0, sim.bc_mask, sim.rho, sim.u, streamId=0)
+ filename = os.path.join(output_dir, f"{script_name}_{step:04d}")
+ try:
+ h5exporter.to_hdf5(filename, {"velocity": sim.u, "density": sim.rho}, compression="gzip", compression_opts=0)
+ xmf_filename = f"{filename}.xmf"
+ hdf5_basename = f"{script_name}_{step:04d}.h5"
+ except Exception as e:
+ print(f"Error during file output at step {step}: {e}")
+ wp.synchronize()
+ if step == num_steps - 1:
+ plot_data(x0, output_dir, delta_x_coarse, sim, h5exporter, prefix="Ahmed")
+
+# Save drag and lift data to CSV
+if len(drag_values) > 0:
+ with open(os.path.join(output_dir, "drag_lift.csv"), "w") as fd:
+ fd.write("Step,Cd,Cl\n")
+ for i, (cd, cl) in enumerate(drag_values):
+ fd.write(f"{i * print_interval},{cd},{cl}\n")
+ plot_force_coefficients(drag_values, output_dir, print_interval, script_name)
+
+# Calculate and print average Cd and Cl for the last 50%
+drag_values_array = np.array(drag_values)
+if len(drag_values) > 0:
+ start_index = len(drag_values) // 2
+ last_half = drag_values_array[start_index:, :]
+ avg_cd = np.mean(last_half[:, 0])
+ avg_cl = np.mean(last_half[:, 1])
+ print(f"Average Drag Coefficient (Cd) for last 50%: {avg_cd:.6f}")
+ print(f"Average Lift Coefficient (Cl) for last 50%: {avg_cl:.6f}")
+ print(f"Experimental Drag Coefficient (Cd): {0.3088}")
+ print(f"Error Drag Coefficient (Cd): {((avg_cd - 0.3088) / 0.3088) * 100:.2f}%")
+
+else:
+ print("No drag or lift data collected.")
diff --git a/examples/cfd/rotating_sphere_3d.py b/examples/cfd/rotating_sphere_3d.py
new file mode 100644
index 00000000..7a0b10a4
--- /dev/null
+++ b/examples/cfd/rotating_sphere_3d.py
@@ -0,0 +1,327 @@
+"""
+Rotating sphere 3-D example (single-resolution).
+
+Simulates flow past a sphere rotating about the y-axis using the
+halfway bounce-back BC with a prescribed rotational-velocity profile.
+Computes drag and lift coefficients over time and saves VTK snapshots.
+"""
+
+import xlb
+import trimesh
+import time
+import warp as wp
+import numpy as np
+import jax.numpy as jnp
+from typing import Any
+
+from xlb.compute_backend import ComputeBackend
+from xlb.precision_policy import PrecisionPolicy
+from xlb.grid import grid_factory
+from xlb.operator.stepper import IncompressibleNavierStokesStepper
+from xlb.operator.boundary_condition import (
+ HalfwayBounceBackBC,
+ FullwayBounceBackBC,
+ RegularizedBC,
+ DoNothingBC,
+ HybridBC,
+)
+from xlb.operator.force.momentum_transfer import MomentumTransfer
+from xlb.operator.macroscopic import Macroscopic
+from xlb.utils import save_fields_vtk, save_image
+import matplotlib.pyplot as plt
+from xlb.operator.equilibrium import QuadraticEquilibrium
+from xlb.operator import Operator
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.operator.boundary_masker import MeshVoxelizationMethod
+
+# -------------------------- Simulation Setup --------------------------
+
+# Grid parameters
+wp.clear_kernel_cache()
+diam = 32
+grid_size_x, grid_size_y, grid_size_z = 10 * diam, 7 * diam, 7 * diam
+grid_shape = (grid_size_x, grid_size_y, grid_size_z)
+
+# Simulation Configuration
+compute_backend = ComputeBackend.WARP
+precision_policy = PrecisionPolicy.FP32FP32
+
+velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend)
+wind_speed = 0.04
+num_steps = 100000
+print_interval = 1000
+post_process_interval = 1000
+
+# Physical Parameters
+Re = 200.0
+visc = wind_speed * diam / Re
+omega = 1.0 / (3.0 * visc + 0.5)
+
+# Rotational speed parameters (see [1] which discusses the problem in terms of 2 non-dimensional parameters: Re and Omega)
+# [1] J. Fluid Mech. (2016), vol. 807, pp. 62–86. c© Cambridge University Press 2016 doi:10.1017/jfm.2016.596
+# \Omega = \omega * D / (2 U_\infty) where Omega is non-dimensional and omega is dimensional.
+rot_rate_nondim = -0.2
+rot_rate = 2.0 * wind_speed * rot_rate_nondim / diam
+
+# Print simulation info
+print("\n" + "=" * 50 + "\n")
+print("Simulation Configuration:")
+print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}")
+print(f"Backend: {compute_backend}")
+print(f"Velocity set: {velocity_set}")
+print(f"Precision policy: {precision_policy}")
+print(f"Prescribed velocity: {wind_speed}")
+print(f"Reynolds number: {Re}")
+print(f"Max iterations: {num_steps}")
+print("\n" + "=" * 50 + "\n")
+
+# Initialize XLB
+xlb.init(
+ velocity_set=velocity_set,
+ default_backend=compute_backend,
+ default_precision_policy=precision_policy,
+)
+
+# Create Grid
+grid = grid_factory(grid_shape, compute_backend=compute_backend)
+
+# Bounding box indices
+box = grid.bounding_box_indices()
+box_no_edge = grid.bounding_box_indices(remove_edges=True)
+inlet = box_no_edge["left"]
+outlet = box["right"]
+walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(velocity_set.d)]
+walls = np.unique(np.array(walls), axis=-1).tolist()
+
+# Load the mesh (replace with your own mesh)
+stl_filename = "../stl-files/sphere.stl"
+mesh = trimesh.load_mesh(stl_filename, process=False)
+mesh_vertices = mesh.vertices
+
+# Transform the mesh points to be located in the right position in the wind tunnel
+mesh_vertices -= mesh_vertices.min(axis=0)
+mesh_extents = mesh_vertices.max(axis=0)
+length_phys_unit = mesh_extents.max()
+length_lbm_unit = grid_shape[1] / 7
+dx = length_phys_unit / length_lbm_unit
+mesh_vertices = mesh_vertices / dx
+shift = np.array([grid_shape[0] / 3, (grid_shape[1] - mesh_extents[1] / dx) / 2, (grid_shape[2] - mesh_extents[2] / dx) / 2])
+sphere = mesh_vertices + shift
+diam = np.max(sphere.max(axis=0) - sphere.min(axis=0))
+sphere_cross_section = np.pi * diam**2 / 4.0
+
+
+# Define rotating boundary profile
+def bc_profile():
+ """Build a Warp function returning the rotational wall velocity at a voxel."""
+ dtype = precision_policy.compute_precision.wp_dtype
+ _u_vec = wp.vec(velocity_set.d, dtype=dtype)
+ angular_velocity = _u_vec(0.0, rot_rate, 0.0)
+ origin_np = shift + diam / 2
+ origin_wp = _u_vec(origin_np[0], origin_np[1], origin_np[2])
+
+ @wp.func
+ def bc_profile_warp(index: wp.vec3i):
+ x = dtype(index[0])
+ y = dtype(index[1])
+ z = dtype(index[2])
+ surface_coord = _u_vec(x, y, z) - origin_wp
+ return wp.cross(angular_velocity, surface_coord)
+
+ return bc_profile_warp
+
+
+# Define boundary conditions
+bc_left = RegularizedBC("velocity", prescribed_value=(wind_speed, 0.0, 0.0), indices=inlet)
+bc_do_nothing = DoNothingBC(indices=outlet)
+# bc_sphere = HalfwayBounceBackBC(mesh_vertices=sphere, voxelization_method="ray", profile=bc_profile())
+bc_sphere = HybridBC(
+ bc_method="nonequilibrium_regularized",
+ mesh_vertices=sphere,
+ use_mesh_distance=True,
+ voxelization_method=MeshVoxelizationMethod("RAY"),
+ profile=bc_profile(),
+)
+# Not assining BC for walls makes them periodic.
+boundary_conditions = [bc_left, bc_do_nothing, bc_sphere]
+
+
+# Setup Stepper
+stepper = IncompressibleNavierStokesStepper(
+ grid=grid,
+ boundary_conditions=boundary_conditions,
+ collision_type="KBC",
+)
+
+# Make initializer operator
+from xlb.helper.initializers import CustomInitializer
+
+initializer = CustomInitializer(
+ bc_id=bc_do_nothing.id,
+ constant_velocity_vector=(wind_speed, 0.0, 0.0),
+ velocity_set=velocity_set,
+ precision_policy=precision_policy,
+ compute_backend=compute_backend,
+)
+
+# Prepare Fields
+f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields(initializer=initializer)
+
+
+# -------------------------- Helper Functions --------------------------
+
+
+def plot_coefficient(time_steps, coefficients, prefix="drag"):
+ """
+ Plot the drag coefficient with various moving averages.
+
+ Args:
+ time_steps (list): List of time steps.
+ coefficients (list): List of force coefficients.
+ """
+ # Convert lists to numpy arrays for processing
+ time_steps_np = np.array(time_steps)
+ coefficients_np = np.array(coefficients)
+
+ # Define moving average windows
+ windows = [10, 100, 1000, 10000, 100000]
+ labels = ["MA 10", "MA 100", "MA 1,000", "MA 10,000", "MA 100,000"]
+
+ plt.figure(figsize=(12, 8))
+ plt.plot(time_steps_np, coefficients_np, label="Raw", alpha=0.5)
+
+ for window, label in zip(windows, labels):
+ if len(coefficients_np) >= window:
+ ma = np.convolve(coefficients_np, np.ones(window) / window, mode="valid")
+ plt.plot(time_steps_np[window - 1 :], ma, label=label)
+
+ plt.ylim(-1.0, 1.0)
+ plt.legend()
+ plt.xlabel("Time step")
+ plt.ylabel("Drag coefficient")
+ plt.title("Drag Coefficient Over Time with Moving Averages")
+ plt.savefig(prefix + "_ma.png")
+ plt.close()
+
+
+def post_process(
+ step,
+ f_0,
+ f_1,
+ grid_shape,
+ macro,
+ momentum_transfer,
+ missing_mask,
+ bc_mask,
+ wind_speed,
+ car_cross_section,
+ drag_coefficients,
+ lift_coefficients,
+ time_steps,
+):
+ """Compute macroscopic fields, force coefficients, and save VTK output."""
+ """
+ Post-process simulation data: save fields, compute forces, and plot drag coefficient.
+
+ Args:
+ step (int): Current time step.
+ f_current: Current distribution function.
+ grid_shape (tuple): Shape of the grid.
+ macro: Macroscopic operator object.
+ momentum_transfer: MomentumTransfer operator object.
+ missing_mask: Missing mask from stepper.
+ bc_mask: Boundary condition mask from stepper.
+ wind_speed (float): Prescribed wind speed.
+ car_cross_section (float): Cross-sectional area of the car.
+ drag_coefficients (list): List to store drag coefficients.
+ lift_coefficients (list): List to store lift coefficients.
+ time_steps (list): List to store time steps.
+ """
+ wp.synchronize()
+ # Convert to JAX array if necessary
+ if not isinstance(f_0, jnp.ndarray):
+ f_0_jax = wp.to_jax(f_0)
+ else:
+ f_0_jax = f_0
+
+ # Compute macroscopic quantities
+ rho, u = macro(f_0_jax)
+
+ # Remove boundary cells
+ u = u[:, 1:-1, 1:-1, 1:-1]
+ u_magnitude = jnp.sqrt(u[0] ** 2 + u[1] ** 2 + u[2] ** 2)
+
+ fields = {"ux": u[0], "uy": u[1], "uz": u[2], "u_magnitude": u_magnitude}
+
+ # Save fields in VTK format
+ # save_fields_vtk(fields, timestep=step)
+
+ # Save the u_magnitude slice at the mid y-plane
+ mid_y = grid_shape[1] // 2
+ save_image(fields["u_magnitude"][:, mid_y, :], timestep=step)
+
+ # Compute lift and drag
+ boundary_force = momentum_transfer(f_0, f_1, bc_mask, missing_mask)
+ drag = boundary_force[0] # x-direction
+ lift = boundary_force[2]
+ cd = 2.0 * drag / (wind_speed**2 * car_cross_section)
+ cl = 2.0 * lift / (wind_speed**2 * car_cross_section)
+ print(f"CD={cd}, CL={cl}")
+ drag_coefficients.append(cd)
+ lift_coefficients.append(cl)
+ time_steps.append(step)
+
+ # Plot drag coefficient
+ plot_coefficient(time_steps, drag_coefficients, prefix="drag")
+ plot_coefficient(time_steps, lift_coefficients, prefix="lift")
+
+
+# Setup Momentum Transfer for Force Calculation
+bc_car = boundary_conditions[-1]
+momentum_transfer = MomentumTransfer(bc_car, compute_backend=compute_backend)
+
+# Define Macroscopic Calculation
+macro = Macroscopic(
+ compute_backend=ComputeBackend.JAX,
+ precision_policy=precision_policy,
+ velocity_set=xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX),
+)
+
+# Initialize Lists to Store Coefficients and Time Steps
+time_steps = []
+drag_coefficients = []
+lift_coefficients = []
+
+# -------------------------- Simulation Loop --------------------------
+
+start_time = time.time()
+for step in range(num_steps):
+ # Perform simulation step
+ f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, step)
+ f_0, f_1 = f_1, f_0 # Swap the buffers
+
+ # Print progress at intervals
+ if step % print_interval == 0:
+ elapsed_time = time.time() - start_time
+ print(f"Iteration: {step}/{num_steps} | Time elapsed: {elapsed_time:.2f}s")
+ start_time = time.time()
+
+ # Post-process at intervals and final step
+ if (step % post_process_interval == 0) or (step == num_steps - 1):
+ post_process(
+ step,
+ f_0,
+ f_1,
+ grid_shape,
+ macro,
+ momentum_transfer,
+ missing_mask,
+ bc_mask,
+ wind_speed,
+ sphere_cross_section,
+ drag_coefficients,
+ lift_coefficients,
+ time_steps,
+ )
+
+print("Simulation completed successfully.")
diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py
index 074570b8..9909132c 100644
--- a/examples/cfd/windtunnel_3d.py
+++ b/examples/cfd/windtunnel_3d.py
@@ -10,6 +10,7 @@
FullwayBounceBackBC,
RegularizedBC,
ExtrapolationOutflowBC,
+ HybridBC,
)
from xlb.operator.force.momentum_transfer import MomentumTransfer
from xlb.operator.macroscopic import Macroscopic
@@ -18,6 +19,7 @@
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
+from xlb.operator.boundary_masker import MeshVoxelizationMethod
# -------------------------- Simulation Setup --------------------------
@@ -74,6 +76,7 @@
# Load the mesh (replace with your own mesh)
stl_filename = "../stl-files/DrivAer-Notchback.stl"
+voxelization_method = MeshVoxelizationMethod("RAY")
mesh = trimesh.load_mesh(stl_filename, process=False)
mesh_vertices = mesh.vertices
@@ -84,7 +87,15 @@
length_lbm_unit = grid_shape[0] / 4
dx = length_phys_unit / length_lbm_unit
mesh_vertices = mesh_vertices / dx
-shift = np.array([grid_shape[0] / 4, (grid_shape[1] - mesh_extents[1] / dx) / 2, 0.0])
+
+# Depending on the voxelization method, shift_z ensures the bottom ground does not intersect with the voxelized mesh
+# Any smaller shift value would lead to large lift computations due to the initial equilibrium distributions. Bigger
+# values would be fine but leave a gap between surfaces that are supposed to touch.
+if voxelization_method in (MeshVoxelizationMethod("RAY"), MeshVoxelizationMethod("WINDING")):
+ shift_z = 2
+elif voxelization_method in (MeshVoxelizationMethod("AABB"), MeshVoxelizationMethod("AABB_CLOSE", close_voxels=3)):
+ shift_z = 3
+shift = np.array([grid_shape[0] / 4, (grid_shape[1] - mesh_extents[1] / dx) / 2, shift_z])
car_vertices = mesh_vertices + shift
car_cross_section = np.prod(mesh_extents[1:]) / dx**2
@@ -92,15 +103,22 @@
bc_left = RegularizedBC("velocity", prescribed_value=(wind_speed, 0.0, 0.0), indices=inlet)
bc_walls = FullwayBounceBackBC(indices=walls)
bc_do_nothing = ExtrapolationOutflowBC(indices=outlet)
-bc_car = HalfwayBounceBackBC(mesh_vertices=car_vertices)
+bc_car = HalfwayBounceBackBC(mesh_vertices=car_vertices, voxelization_method=voxelization_method)
+# bc_car = HybridBC(bc_method="nonequilibrium_regularized", mesh_vertices=car_vertices,
+# voxelization_method=voxelization_method, use_mesh_distance=True)
boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car]
+# Configure backend options:
+# backend_config = {"occ": neon.SkeletonConfig.OCC.from_string("standard"), "device_list": [0, 1]} if compute_backend == ComputeBackend.NEON else {}
+backend_config = {}
+
# Setup Stepper
stepper = IncompressibleNavierStokesStepper(
grid=grid,
boundary_conditions=boundary_conditions,
collision_type="KBC",
+ backend_config=backend_config,
)
# Prepare Fields
@@ -110,28 +128,28 @@
# -------------------------- Helper Functions --------------------------
-def plot_drag_coefficient(time_steps, drag_coefficients):
+def plot_coefficient(time_steps, coefficients, prefix="drag"):
"""
Plot the drag coefficient with various moving averages.
Args:
time_steps (list): List of time steps.
- drag_coefficients (list): List of drag coefficients.
+ coefficients (list): List of force coefficients.
"""
# Convert lists to numpy arrays for processing
time_steps_np = np.array(time_steps)
- drag_coefficients_np = np.array(drag_coefficients)
+ coefficients_np = np.array(coefficients)
# Define moving average windows
windows = [10, 100, 1000, 10000, 100000]
labels = ["MA 10", "MA 100", "MA 1,000", "MA 10,000", "MA 100,000"]
plt.figure(figsize=(12, 8))
- plt.plot(time_steps_np, drag_coefficients_np, label="Raw", alpha=0.5)
+ plt.plot(time_steps_np, coefficients_np, label="Raw", alpha=0.5)
for window, label in zip(windows, labels):
- if len(drag_coefficients_np) >= window:
- ma = np.convolve(drag_coefficients_np, np.ones(window) / window, mode="valid")
+ if len(coefficients_np) >= window:
+ ma = np.convolve(coefficients_np, np.ones(window) / window, mode="valid")
plt.plot(time_steps_np[window - 1 :], ma, label=label)
plt.ylim(-1.0, 1.0)
@@ -139,7 +157,7 @@ def plot_drag_coefficient(time_steps, drag_coefficients):
plt.xlabel("Time step")
plt.ylabel("Drag coefficient")
plt.title("Drag Coefficient Over Time with Moving Averages")
- plt.savefig("drag_coefficient_ma.png")
+ plt.savefig(prefix + "_ma.png")
plt.close()
@@ -177,7 +195,7 @@ def post_process(
"""
# Convert to JAX array if necessary
if not isinstance(f_0, jnp.ndarray):
- f_0_jax = wp.to_jax(f_0)
+ f_0_jax = to_jax(f_0)
else:
f_0_jax = f_0
@@ -203,12 +221,14 @@ def post_process(
lift = boundary_force[2]
cd = 2.0 * drag / (wind_speed**2 * car_cross_section)
cl = 2.0 * lift / (wind_speed**2 * car_cross_section)
+ print(f"CD={cd}, CL={cl}")
drag_coefficients.append(cd)
lift_coefficients.append(cl)
time_steps.append(step)
# Plot drag coefficient
- plot_drag_coefficient(time_steps, drag_coefficients)
+ plot_coefficient(time_steps, drag_coefficients, prefix="drag")
+ plot_coefficient(time_steps, lift_coefficients, prefix="lift")
# Setup Momentum Transfer for Force Calculation
@@ -221,6 +241,7 @@ def post_process(
precision_policy=precision_policy,
velocity_set=xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX),
)
+to_jax = xlb.utils.ToJAX("populations", velocity_set.q, grid_shape)
# Initialize Lists to Store Coefficients and Time Steps
time_steps = []
@@ -237,7 +258,7 @@ def post_process(
# Print progress at intervals
if step % print_interval == 0:
- if compute_backend == ComputeBackend.WARP:
+ if compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]:
wp.synchronize()
elapsed_time = time.time() - start_time
print(f"Iteration: {step}/{num_steps} | Time elapsed: {elapsed_time:.2f}s")
diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py
index 409a8d59..66945e89 100644
--- a/examples/performance/mlups_3d.py
+++ b/examples/performance/mlups_3d.py
@@ -9,21 +9,109 @@
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC
from xlb.distribute import distribute
+from xlb.operator.macroscopic import Macroscopic
+
# -------------------------- Simulation Setup --------------------------
def parse_arguments():
- parser = argparse.ArgumentParser(description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)")
- parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid")
- parser.add_argument("num_steps", type=int, help="Number of timesteps for the simulation")
- parser.add_argument("compute_backend", type=str, help="Backend for the simulation (jax or warp)")
- parser.add_argument("precision", type=str, help="Precision for the simulation (e.g., fp32/fp32)")
- return parser.parse_args()
+ # Define valid options for consistency
+ COMPUTE_BACKENDS = ["neon", "warp", "jax"]
+ PRECISION_OPTIONS = ["fp32/fp32", "fp64/fp64", "fp64/fp32", "fp32/fp16"]
+ VELOCITY_SETS = ["D3Q19", "D3Q27"]
+ COLLISION_MODELS = ["BGK", "KBC"]
+ OCC_OPTIONS = ["standard", "none"]
+
+ parser = argparse.ArgumentParser(
+ description="MLUPS Benchmark for 3D Lattice Boltzmann Method Simulation",
+ epilog=f"""
+Examples:
+ %(prog)s 100 1000 neon fp32/fp32
+ %(prog)s 200 500 neon fp64/fp64 --collision_model KBC --velocity_set D3Q27
+ %(prog)s 150 2000 neon fp32/fp32 --gpu_devices=[0,1,2] --measure_scalability --report
+ %(prog)s 100 1000 neon fp32/fp32 --repetitions 5 --export_final_velocity
+ """,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ # Positional arguments
+ parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid (e.g., 100)")
+ parser.add_argument("num_steps", type=int, help="Number of timesteps for the simulation (e.g., 1000)")
+ parser.add_argument("compute_backend", type=str, choices=COMPUTE_BACKENDS, help=f"Backend for the simulation ({', '.join(COMPUTE_BACKENDS)})")
+ parser.add_argument("precision", type=str, choices=PRECISION_OPTIONS, help=f"Precision for the simulation ({', '.join(PRECISION_OPTIONS)})")
+
+ # Optional arguments
+ parser.add_argument("--gpu_devices", type=str, default=None, help="CUDA devices to use for Neon backend (e.g., [0,1,2] or [0])")
+ parser.add_argument(
+ "--velocity_set",
+ type=str,
+ default="D3Q19",
+ choices=VELOCITY_SETS,
+ help=f"Lattice velocity set (default: D3Q19, choices: {', '.join(VELOCITY_SETS)})",
+ )
+ parser.add_argument(
+ "--collision_model",
+ type=str,
+ default="BGK",
+ choices=COLLISION_MODELS,
+ help=f"Collision model (default: BGK, choices: {', '.join(COLLISION_MODELS)}, KBC requires D3Q27)",
+ )
+ parser.add_argument(
+ "--occ",
+ type=str,
+ default="standard",
+ choices=OCC_OPTIONS,
+ help=f"Overlapping Communication and Computation strategy (default: standard, choices: {', '.join(OCC_OPTIONS)})",
+ )
+ parser.add_argument("--report", action="store_true", help="Generate Neon performance report")
+ parser.add_argument("--export_final_velocity", action="store_true", help="Export final velocity field to VTI file")
+ parser.add_argument("--measure_scalability", action="store_true", help="Measure performance across different GPU counts")
+ parser.add_argument(
+ "--repetitions", type=int, default=1, metavar="N", help="Number of simulation repetitions for statistical analysis (default: 1)"
+ )
+ args = parser.parse_args()
-def setup_simulation(args):
- compute_backend = ComputeBackend.JAX if args.compute_backend == "jax" else ComputeBackend.WARP
+ # Parse gpu_devices string to list
+ if args.gpu_devices is not None:
+ try:
+ import ast
+
+ args.gpu_devices = ast.literal_eval(args.gpu_devices)
+ if not isinstance(args.gpu_devices, list):
+ args.gpu_devices = [args.gpu_devices] # Handle single integer case
+ except (ValueError, SyntaxError):
+ raise ValueError("Invalid gpu_devices format. Use format like [0,1,2] or [0]")
+
+ # Validate and convert compute backend
+ compute_backend_map = {
+ "jax": ComputeBackend.JAX,
+ "warp": ComputeBackend.WARP,
+ "neon": ComputeBackend.NEON,
+ }
+ compute_backend = compute_backend_map.get(args.compute_backend)
+ if compute_backend is None:
+ raise ValueError(f"Invalid compute backend '{args.compute_backend}'. Use: {', '.join(COMPUTE_BACKENDS)}")
+ args.compute_backend = compute_backend
+
+ # Handle GPU devices for Neon backend
+ if args.compute_backend == ComputeBackend.NEON:
+ if args.gpu_devices is None:
+ print("[INFO] No GPU devices specified. Using default device 0.")
+ args.gpu_devices = [0]
+
+ import neon
+
+ occ_enum = neon.SkeletonConfig.OCC.from_string(args.occ)
+ args.occ_enum = occ_enum # Store the enum for Neon
+ args.occ_display = args.occ # Store the original string for display
+ else:
+ if args.gpu_devices is not None:
+ raise ValueError(f"--gpu_devices can only be used with Neon backend, not {args.compute_backend.name}")
+ args.gpu_devices = [0] # Default for non-Neon backends
+
+ # Checking precision policy
precision_policy_map = {
"fp32/fp32": PrecisionPolicy.FP32FP32,
"fp64/fp64": PrecisionPolicy.FP64FP64,
@@ -32,18 +120,77 @@ def setup_simulation(args):
}
precision_policy = precision_policy_map.get(args.precision)
if precision_policy is None:
- raise ValueError("Invalid precision specified.")
+ raise ValueError(f"Invalid precision '{args.precision}'. Use: {', '.join(PRECISION_OPTIONS)}")
+ args.precision_policy = precision_policy
+
+ # Validate collision model and velocity set compatibility
+ if args.collision_model == "KBC" and args.velocity_set != "D3Q27":
+ raise ValueError("KBC collision model requires D3Q27 velocity set. Use --velocity_set D3Q27")
+
+ if args.velocity_set == "D3Q19":
+ velocity_set = xlb.velocity_set.D3Q19(precision_policy=args.precision_policy, compute_backend=compute_backend)
+ elif args.velocity_set == "D3Q27":
+ velocity_set = xlb.velocity_set.D3Q27(precision_policy=args.precision_policy, compute_backend=compute_backend)
+ args.velocity_set = velocity_set
+
+ print_args(args)
+
+ return args
+
+
+def print_args(args):
+ """Print simulation configuration in a clean, organized format"""
+ print("\n" + "=" * 70)
+ print(" SIMULATION CONFIGURATION")
+ print("=" * 70)
+
+ # Grid and simulation parameters
+ print("GRID & SIMULATION:")
+ print(f" Grid Size: {args.cube_edge}³ ({args.cube_edge:,} × {args.cube_edge:,} × {args.cube_edge:,})")
+ print(f" Total Lattice Points: {args.cube_edge**3:,}")
+ print(f" Time Steps: {args.num_steps:,}")
+ print(f" Repetitions: {args.repetitions}")
+ # Computational settings
+ print("\nCOMPUTATIONAL SETTINGS:")
+ print(f" Compute Backend: {args.compute_backend.name}")
+ print(f" Precision Policy: {args.precision}")
+ print(f" Velocity Set: {args.velocity_set.__class__.__name__}")
+ print(f" Collision Model: {args.collision_model}")
+
+ # Backend-specific settings
+ if args.compute_backend.name == "NEON":
+ print("\nNEON BACKEND SETTINGS:")
+ print(f" GPU Devices: {args.gpu_devices}")
+ print(f" OCC Strategy: {args.occ_display}")
+
+ # Output options
+ print("\nOUTPUT OPTIONS:")
+ print(f" Generate Report: {'Yes' if args.report else 'No'}")
+ print(f" Measure Scalability: {'Yes' if args.measure_scalability else 'No'}")
+ print(f" Export Velocity: {'Yes' if args.export_final_velocity else 'No'}")
+
+ print("=" * 70)
+ print("Starting simulation...\n")
+
+
+def init_xlb(args):
xlb.init(
- velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend),
- default_backend=compute_backend,
- default_precision_policy=precision_policy,
+ velocity_set=args.velocity_set,
+ default_backend=args.compute_backend,
+ default_precision_policy=args.precision_policy,
)
- return compute_backend, precision_policy
+ options = None
+ if args.compute_backend == ComputeBackend.NEON:
+ neon_options = {"occ": args.occ_enum, "device_list": args.gpu_devices}
+ options = neon_options
+ return args.compute_backend, args.precision_policy, options
-def run_simulation(compute_backend, precision_policy, grid_shape, num_steps):
- grid = grid_factory(grid_shape)
+def run_simulation(
+ compute_backend, precision_policy, grid_shape, num_steps, options, export_final_velocity, repetitions, num_devices, collision_model
+):
+ grid = grid_factory(grid_shape, backend_config=options)
box = grid.bounding_box_indices()
box_no_edge = grid.bounding_box_indices(remove_edges=True)
@@ -59,7 +206,8 @@ def run_simulation(compute_backend, precision_policy, grid_shape, num_steps):
stepper = IncompressibleNavierStokesStepper(
grid=grid,
boundary_conditions=boundary_conditions,
- collision_type="BGK",
+ collision_type=collision_model,
+ backend_config=options,
)
# Distribute if using JAX
@@ -74,14 +222,44 @@ def run_simulation(compute_backend, precision_policy, grid_shape, num_steps):
omega = 1.0
f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields()
- start_time = time.time()
- for i in range(num_steps):
+ warmup_iterations = 10
+ # Warp-up iterations
+ for i in range(warmup_iterations):
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, i)
f_0, f_1 = f_1, f_0
wp.synchronize()
- elapsed_time = time.time() - start_time
+ export_num_steps = warmup_iterations
+
+ elapsed_time_list = []
+ for i in range(repetitions):
+ start_time = time.time()
+ for i in range(num_steps):
+ f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, i)
+ f_0, f_1 = f_1, f_0
+ wp.synchronize()
+ elapsed_time = time.time() - start_time
+ elapsed_time_list.append(elapsed_time)
+ export_num_steps += num_steps
+
+ # Define Macroscopic Calculation
+ macro = Macroscopic(
+ compute_backend=compute_backend,
+ precision_policy=precision_policy,
+ velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend),
+ )
+
+ if compute_backend == ComputeBackend.NEON:
+ if export_final_velocity:
+ rho = grid.create_field(cardinality=1, dtype=precision_policy.store_precision)
+ u = grid.create_field(cardinality=3, dtype=precision_policy.store_precision)
- return elapsed_time
+ macro(f_0, rho, u)
+ wp.synchronize()
+ u.update_host(0)
+ wp.synchronize()
+ u.export_vti(f"mlups_3d_size_{grid_shape[0]}_dev_{num_devices}_step_{export_num_steps}.vti", "u")
+
+ return elapsed_time_list
def calculate_mlups(cube_edge, num_steps, elapsed_time):
@@ -90,15 +268,293 @@ def calculate_mlups(cube_edge, num_steps, elapsed_time):
return mlups
+def print_summary_with_stats(args, stats):
+ """Print comprehensive simulation summary with statistics from multiple repetitions"""
+ total_lattice_points = args.cube_edge**3
+ total_lattice_updates = total_lattice_points * args.num_steps
+
+ mean_mlups = stats["mean_mlups"]
+ std_mlups = stats["std_dev_mlups"]
+ mean_elapsed_time = stats["mean_elapsed_time"]
+ std_elapsed_time = stats["std_dev_elapsed_time"]
+
+ print("\n\n\n" + "=" * 70)
+ print(" SIMULATION SUMMARY")
+ print("=" * 70)
+
+ # Simulation Parameters
+ print("SIMULATION PARAMETERS:")
+ print("-" * 25)
+ print(f" Grid Size: {args.cube_edge}³ ({args.cube_edge:,} × {args.cube_edge:,} × {args.cube_edge:,})")
+ print(f" Total Lattice Points: {total_lattice_points:,}")
+ print(f" Time Steps: {args.num_steps:,}")
+ print(f" Total Lattice Updates: {total_lattice_updates:,}")
+ print(f" Repetitions: {args.repetitions}")
+ print(f" Compute Backend: {args.compute_backend.name}")
+ print(f" Precision Policy: {args.precision}")
+ print(f" Velocity Set: {args.velocity_set.__class__.__name__}")
+ print(f" Collision Model: {args.collision_model}")
+ print(f" Generate Report: {'Yes' if args.report else 'No'}")
+ print(f" Measure Scalability: {'Yes' if args.measure_scalability else 'No'}")
+
+ if args.compute_backend.name == "NEON":
+ print(f" GPU Devices: {args.gpu_devices}")
+ occ_display = args.occ_display
+ print(f" OCC Strategy: {occ_display}")
+
+ print()
+
+ # Raw Data (if multiple repetitions)
+ if args.repetitions > 1:
+ print("RAW MEASUREMENT DATA:")
+ print("-" * 21)
+ print(f"{'Run':<6} {'Elapsed Time (s)':<18} {'MLUPs':<12} {'Time/Step (ms)':<15}")
+ print("-" * 53)
+
+ raw_elapsed_times = stats["raw_elapsed_times"]
+ raw_mlups = stats["raw_mlups"]
+
+ for i, (elapsed_time, mlups) in enumerate(zip(raw_elapsed_times, raw_mlups)):
+ time_per_step = elapsed_time / args.num_steps * 1000
+ print(f"{i + 1:<6} {elapsed_time:<18.3f} {mlups:<12.2f} {time_per_step:<15.3f}")
+
+ print("-" * 53)
+ print()
+
+ # Performance Results (Statistical Summary)
+ print("PERFORMANCE RESULTS:")
+ print("-" * 20)
+ if args.repetitions > 1:
+ print(f" Time in main loop: {mean_elapsed_time:.3f} ± {std_elapsed_time:.3f} seconds")
+ print(f" MLUPs: {mean_mlups:.2f} ± {std_mlups:.2f}")
+ print(f" Time per LBM step: {mean_elapsed_time / args.num_steps * 1000:.3f} ± {std_elapsed_time / args.num_steps * 1000:.3f} ms")
+ else:
+ print(f" Time in main loop: {mean_elapsed_time:.3f} seconds")
+ print(f" MLUPs: {mean_mlups:.2f}")
+ print(f" Time per LBM step: {mean_elapsed_time / args.num_steps * 1000:.3f} ms")
+
+ if args.compute_backend.name == "NEON" and len(args.gpu_devices) > 1:
+ mlups_per_gpu = mean_mlups / len(args.gpu_devices)
+ if args.repetitions > 1:
+ mlups_per_gpu_std = std_mlups / len(args.gpu_devices)
+ print(f" MLUPs per GPU: {mlups_per_gpu:.2f} ± {mlups_per_gpu_std:.2f}")
+ else:
+ print(f" MLUPs per GPU: {mlups_per_gpu:.2f}")
+
+ print("=" * 70)
+
+
+def print_scalability_summary(args, stats_list):
+ """Print comprehensive scalability summary with MLUPs statistics for different GPU counts"""
+ total_lattice_points = args.cube_edge**3
+ total_lattice_updates = total_lattice_points * args.num_steps
+
+ print("\n\n\n" + "=" * 95)
+ print(" SCALABILITY ANALYSIS")
+ print("=" * 95)
+
+ # Simulation Parameters
+ print("SIMULATION PARAMETERS:")
+ print("-" * 25)
+ print(f" Grid Size: {args.cube_edge}³ ({args.cube_edge:,} × {args.cube_edge:,} × {args.cube_edge:,})")
+ print(f" Total Lattice Points: {total_lattice_points:,}")
+ print(f" Time Steps: {args.num_steps:,}")
+ print(f" Total Lattice Updates: {total_lattice_updates:,}")
+ print(f" Repetitions: {args.repetitions}")
+ print(f" Compute Backend: {args.compute_backend.name}")
+ print(f" Precision Policy: {args.precision}")
+ print(f" Velocity Set: {args.velocity_set.__class__.__name__}")
+ print(f" Collision Model: {args.collision_model}")
+
+ if args.compute_backend.name == "NEON":
+ occ_display = args.occ_display
+ print(f" OCC Strategy: {occ_display}")
+ print(f" Available GPU Devices: {args.gpu_devices}")
+
+ print()
+
+ # Extract mean MLUPs for calculations
+ mlups_means = [stats["mean_mlups"] for stats in stats_list]
+ baseline_mlups = mlups_means[0] if mlups_means else 0
+
+ # Scalability Results
+ print("SCALABILITY RESULTS:")
+ print("-" * 20)
+ print(f"{'GPUs':<6} {'MLUPs (mean±std)':<18} {'Speedup':<10} {'Efficiency':<12} {'MLUPs/GPU':<12}")
+ print("-" * 68)
+
+ for i, stats in enumerate(stats_list):
+ num_gpus = i + 1
+ mean_mlups = stats["mean_mlups"]
+ std_mlups = stats["std_dev_mlups"]
+ speedup = mean_mlups / baseline_mlups if baseline_mlups > 0 else 0
+ efficiency = (speedup / num_gpus) if num_gpus > 0 else 0
+ mlups_per_gpu = mean_mlups / num_gpus if num_gpus > 0 else 0
+
+ # Format MLUPs with standard deviation
+ if args.repetitions > 1:
+ mlups_str = f"{mean_mlups:.2f}±{std_mlups:.2f}"
+ else:
+ mlups_str = f"{mean_mlups:.2f}"
+
+ print(f"{num_gpus:<6} {mlups_str:<18} {speedup:<10.2f} {efficiency:<11.3f} {mlups_per_gpu:<12.2f}")
+
+ print("-" * 68)
+
+ # Summary Statistics
+ if len(stats_list) > 1:
+ max_mlups = max(mlups_means)
+ max_mlups_idx = mlups_means.index(max_mlups)
+ max_speedup = max_mlups / baseline_mlups if baseline_mlups > 0 else 0
+ best_efficiency_idx = 0
+ best_efficiency = 0.0
+
+ for i, mean_mlups in enumerate(mlups_means):
+ num_gpus = i + 1
+ speedup = mean_mlups / baseline_mlups if baseline_mlups > 0 else 0
+ efficiency = (speedup / num_gpus) if num_gpus > 0 else 0
+ if efficiency > best_efficiency:
+ best_efficiency = efficiency
+ best_efficiency_idx = i
+
+ print()
+ print("SUMMARY STATISTICS:")
+ print("-" * 19)
+ print(f" Best Performance: {max_mlups:.2f} MLUPs ({max_mlups_idx + 1} GPUs)")
+ if args.repetitions > 1:
+ max_std = stats_list[max_mlups_idx]["std_dev_mlups"]
+ print(f" Performance Std Dev: ±{max_std:.2f} MLUPs")
+ print(f" Maximum Speedup: {max_speedup:.2f}x")
+ print(f" Best Efficiency: {best_efficiency:.3f} ({best_efficiency_idx + 1} GPUs)")
+ print(f" Scalability Range: 1-{len(stats_list)} GPUs")
+
+ print("=" * 95)
+
+
+def report(args, stats):
+ import neon
+ import sys
+
+ report = neon.Report("LBM MLUPS LDC")
+
+ # Save the full command line
+ command_line = " ".join(sys.argv)
+ report.add_member("command_line", command_line)
+
+ report.add_member("velocity_set", args.velocity_set.__class__.__name__)
+ report.add_member("compute_backend", args.compute_backend.name)
+ report.add_member("precision_policy", args.precision)
+ report.add_member("collision_model", args.collision_model)
+ report.add_member("grid_size", args.cube_edge)
+ report.add_member("num_steps", args.num_steps)
+ report.add_member("repetitions", args.repetitions)
+
+ # Statistical measures
+ report.add_member("mean_elapsed_time", stats["mean_elapsed_time"])
+ report.add_member("mean_mlups", stats["mean_mlups"])
+ report.add_member("std_dev_elapsed_time", stats["std_dev_elapsed_time"])
+ report.add_member("std_dev_mlups", stats["std_dev_mlups"])
+
+ # Raw data vectors (if multiple repetitions)
+ if args.repetitions > 1:
+ report.add_member_vector("raw_elapsed_times", stats["raw_elapsed_times"])
+ report.add_member_vector("raw_mlups", stats["raw_mlups"])
+
+ # Legacy fields for backwards compatibility
+ report.add_member("elapsed_time", stats["mean_elapsed_time"])
+ report.add_member("mlups", stats["mean_mlups"])
+
+ report.add_member("occ", args.occ_display)
+ report.add_member_vector("gpu_devices", args.gpu_devices)
+ report.add_member("num_devices", len(args.gpu_devices))
+ report.add_member("measure_scalability", args.measure_scalability)
+
+ # Generate report name following the convention: script_name + parameters
+ report_name = "mlups_3d"
+ report_name += f"_velocity_set_{args.velocity_set.__class__.__name__}"
+ report_name += f"_compute_backend_{args.compute_backend.name}"
+ report_name += f"_precision_policy_{args.precision.replace('/', '_')}"
+ report_name += f"_collision_model_{args.collision_model}"
+ report_name += f"_grid_size_{args.cube_edge}"
+ report_name += f"_num_steps_{args.num_steps}"
+
+ if args.compute_backend.name == "NEON":
+ report_name += f"_occ_{args.occ_display}"
+ report_name += f"_num_devices_{len(args.gpu_devices)}"
+
+ if args.repetitions > 1:
+ report_name += f"_repetitions_{args.repetitions}"
+
+ report.write(report_name, True)
+
+
# -------------------------- Simulation Loop --------------------------
-args = parse_arguments()
-compute_backend, precision_policy = setup_simulation(args)
-grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge)
-elapsed_time = run_simulation(compute_backend=compute_backend, precision_policy=precision_policy, grid_shape=grid_shape, num_steps=args.num_steps)
+def benchmark(args):
+ compute_backend, precision_policy, options = init_xlb(args)
+ grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge)
+
+ elapsed_time_list = []
+ mlups_list = []
+ elapsed_time_list = run_simulation(
+ compute_backend=compute_backend,
+ precision_policy=precision_policy,
+ grid_shape=grid_shape,
+ num_steps=args.num_steps,
+ options=options,
+ export_final_velocity=args.export_final_velocity,
+ repetitions=args.repetitions,
+ num_devices=len(args.gpu_devices),
+ collision_model=args.collision_model,
+ )
+
+ for elapsed_time in elapsed_time_list:
+ mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time)
+ mlups_list.append(mlups)
+
+ mean_mlups = np.mean(mlups_list)
+ std_dev_mlups = np.std(mlups_list)
+ mean_elapsed_time = np.mean(elapsed_time_list)
+ std_dev_elapsed_time = np.std(elapsed_time_list)
+
+ stats = {
+ "mean_mlups": mean_mlups,
+ "std_dev_mlups": std_dev_mlups,
+ "mean_elapsed_time": mean_elapsed_time,
+ "std_dev_elapsed_time": std_dev_elapsed_time,
+ "num_devices": len(args.gpu_devices),
+ "raw_mlups": mlups_list,
+ "raw_elapsed_times": elapsed_time_list,
+ }
+ # Generate report if requested
+ if args.report:
+ report(args, stats)
+ print("Report generated successfully.")
+
+ return stats
+
+
+def main():
+ args = parse_arguments()
+ if not args.measure_scalability:
+ stats = benchmark(args)
+ # For single run, print_summary expects individual values with additional stats
+ print_summary_with_stats(args, stats)
+ return
+
+ stats_list = []
+ for num_devices in range(1, len(args.gpu_devices) + 1):
+ import copy
+
+ args_copy = copy.deepcopy(args)
+ args_copy.gpu_devices = args_copy.gpu_devices[:num_devices]
+ stats = benchmark(args_copy)
+ stats_list.append(stats)
+
+ # Print comprehensive scalability analysis
+ print_scalability_summary(args, stats_list)
-mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time)
-print(f"Simulation completed in {elapsed_time:.2f} seconds")
-print(f"MLUPs: {mlups:.2f}")
+if __name__ == "__main__":
+ main()
diff --git a/examples/performance/mlups_3d_multires.py b/examples/performance/mlups_3d_multires.py
new file mode 100644
index 00000000..ba71985b
--- /dev/null
+++ b/examples/performance/mlups_3d_multires.py
@@ -0,0 +1,402 @@
+"""
+MLUPS benchmark for the multi-resolution LBM solver.
+
+Runs a lid-driven cavity simulation on a multi-resolution Neon grid and
+reports the Equivalent Million Lattice Updates Per Second (EMLUPS).
+
+Usage::
+
+ python mlups_3d_multires.py neon \\
+ [options]
+
+Example::
+
+ python mlups_3d_multires.py 100 1000 neon fp32/fp32 2 NAIVE_COLLIDE_STREAM
+"""
+
+import xlb
+import argparse
+import time
+import warp as wp
+import numpy as np
+import neon
+
+from xlb.compute_backend import ComputeBackend
+from xlb.precision_policy import PrecisionPolicy
+from xlb.grid import multires_grid_factory
+from xlb.operator.stepper import MultiresIncompressibleNavierStokesStepper
+from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC
+from xlb.mres_perf_optimization_type import MresPerfOptimizationType
+
+
+def parse_arguments():
+ """Parse and validate command-line arguments."""
+ parser = argparse.ArgumentParser(
+ description="MLUPS for 3D Lattice Boltzmann Method Simulation with Multi-resolution Grid",
+ epilog="""
+Examples:
+ %(prog)s 100 1000 neon fp32/fp32 2 NAIVE_COLLIDE_STREAM
+ %(prog)s 200 500 neon fp64/fp64 3 FUSION_AT_FINEST --report
+ %(prog)s 50 2000 neon fp32/fp16 2 NAIVE_COLLIDE_STREAM --export_final_velocity
+
+Valid values:
+ compute_backend: neon
+ precision: fp32/fp32, fp64/fp64, fp64/fp32, fp32/fp16
+ mres_perf_opt: NAIVE_COLLIDE_STREAM, FUSION_AT_FINEST
+ velocity_set: D3Q19, D3Q27
+ collision_model: BGK, KBC
+ """,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ # Positional arguments
+ parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid (e.g., 100)")
+ parser.add_argument("num_steps", type=int, help="Number of timesteps for the simulation (e.g., 1000)")
+ parser.add_argument("compute_backend", type=str, help="Backend for the simulation (neon)")
+ parser.add_argument("precision", type=str, help="Precision for the simulation (fp32/fp32, fp64/fp64, fp64/fp32, fp32/fp16)")
+ parser.add_argument("num_levels", type=int, help="Number of levels for the multiresolution grid (e.g., 2)")
+ parser.add_argument(
+ "mres_perf_opt",
+ type=MresPerfOptimizationType.from_string,
+ help="Multi-resolution performance optimization strategy (NAIVE_COLLIDE_STREAM, FUSION_AT_FINEST)",
+ )
+
+ # Optional arguments
+ parser.add_argument("--num_devices", type=int, default=0, help="Number of devices for the simulation (default: 0)")
+ parser.add_argument("--velocity_set", type=str, default="D3Q19", help="Lattice type: D3Q19 or D3Q27 (default: D3Q19)")
+ parser.add_argument("--collision_model", type=str, default="BGK", help="Collision model: BGK or KBC (default: BGK)")
+
+ parser.add_argument("--report", action="store_true", help="Generate a neon report file (default: disabled)")
+ parser.add_argument("--export_final_velocity", action="store_true", help="Export the final velocity field to a vti file (default: disabled)")
+
+ try:
+ args = parser.parse_args()
+ except SystemExit:
+ # Re-raise with custom message
+ print("\n" + "=" * 60)
+ print("USAGE EXAMPLES:")
+ print("=" * 60)
+ print("python mlups_3d_multires.py 100 1000 neon fp32/fp32 2 NAIVE_COLLIDE_STREAM")
+ print("python mlups_3d_multires.py 200 500 neon fp64/fp64 3 FUSION_AT_FINEST --report")
+ print("\nVALID VALUES:")
+ print(" compute_backend: neon")
+ print(" precision: fp32/fp32, fp64/fp64, fp64/fp32, fp32/fp16")
+ print(" mres_perf_opt: NAIVE_COLLIDE_STREAM, FUSION_AT_FINEST")
+ print(" velocity_set: D3Q19, D3Q27")
+ print(" collision_model: BGK, KBC")
+ print("=" * 60)
+ raise
+
+ print_args(args)
+
+ if args.compute_backend != "neon":
+ raise ValueError("Invalid compute backend specified. Use 'neon' which supports multi-resolution!")
+
+ if args.collision_model not in ["BGK", "KBC"]:
+ raise ValueError("Invalid collision model specified. Use 'BGK' or 'KBC'.")
+
+ return args
+
+
+def print_args(args):
+ """Print the simulation configuration to stdout."""
+ # Print simulation configuration
+ print("=" * 60)
+ print(" 3D LATTICE BOLTZMANN SIMULATION CONFIG")
+ print("=" * 60)
+ print(f"Grid Size: {args.cube_edge}³ ({args.cube_edge:,} × {args.cube_edge:,} × {args.cube_edge:,})")
+ print(f"Total Lattice Points: {args.cube_edge**3:,}")
+ print(f"Time Steps: {args.num_steps:,}")
+ print(f"Number Levels: {args.num_levels}")
+ print(f"Compute Backend: {args.compute_backend}")
+ print(f"Precision Policy: {args.precision}")
+ print(f"Velocity Set: {args.velocity_set}")
+ print(f"Collision Model: {args.collision_model}")
+ print(f"Mres Perf Opt: {args.mres_perf_opt}")
+ print(f"Generate Report: {'Yes' if args.report else 'No'}")
+ print(f"Export Velocity: {'Yes' if args.export_final_velocity else 'No'}")
+
+ print("=" * 60)
+ print("Starting simulation...")
+ print()
+
+
+def setup_simulation(args):
+ """Initialize XLB globals (velocity set, backend, precision) from CLI args.
+
+ Returns
+ -------
+ VelocitySet
+ The configured lattice velocity set.
+ """
+ compute_backend = None
+ if args.compute_backend == "neon":
+ compute_backend = ComputeBackend.NEON
+ else:
+ raise ValueError("Invalid compute backend specified. Use 'neon' which supports multi-resolution!")
+
+ precision_policy_map = {
+ "fp32/fp32": PrecisionPolicy.FP32FP32,
+ "fp64/fp64": PrecisionPolicy.FP64FP64,
+ "fp64/fp32": PrecisionPolicy.FP64FP32,
+ "fp32/fp16": PrecisionPolicy.FP32FP16,
+ }
+ precision_policy = precision_policy_map.get(args.precision)
+ if precision_policy is None:
+ raise ValueError("Invalid precision")
+
+ velocity_set = None
+ if args.velocity_set == "D3Q19":
+ velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend)
+ elif args.velocity_set == "D3Q27":
+ velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend)
+ if velocity_set is None:
+ raise ValueError("Invalid velocity set")
+
+ xlb.init(
+ velocity_set=velocity_set,
+ default_backend=compute_backend,
+ default_precision_policy=precision_policy,
+ )
+
+ return velocity_set
+
+
+def ldc_multires_setup(grid_shape, velocity_set, num_levels):
+ """Lid-driven cavity with refinement peeling inward from the boundary.
+
+ Each finer level covers only the outermost shell of its parent,
+ concentrating resolution near the walls.
+
+ Parameters
+ ----------
+ grid_shape : tuple of int
+ Domain size at the finest level.
+ velocity_set : VelocitySet
+ Lattice velocity set.
+ num_levels : int
+ Number of refinement levels.
+
+ Returns
+ -------
+ grid : NeonMultiresGrid
+ lid : list of index arrays (per level)
+ walls : list of index arrays (per level)
+ """
+
+ def peel(dim, idx, peel_level, outwards):
+ if outwards:
+ xIn = idx.x <= peel_level or idx.x >= dim.x - 1 - peel_level
+ yIn = idx.y <= peel_level or idx.y >= dim.y - 1 - peel_level
+ zIn = idx.z <= peel_level or idx.z >= dim.z - 1 - peel_level
+ return xIn or yIn or zIn
+ else:
+ xIn = idx.x >= peel_level and idx.x <= dim.x - 1 - peel_level
+ yIn = idx.y >= peel_level and idx.y <= dim.y - 1 - peel_level
+ zIn = idx.z >= peel_level and idx.z <= dim.z - 1 - peel_level
+ return xIn and yIn and zIn
+
+ dim = neon.Index_3d(grid_shape[0], grid_shape[1], grid_shape[2])
+
+ def get_peeled_np(level, width):
+ divider = 2**level
+ m = neon.Index_3d(dim.x // divider, dim.y // divider, dim.z // divider)
+ if level == 0:
+ m = dim
+
+ mask = np.zeros((m.x, m.y, m.z), dtype=int)
+ mask = np.ascontiguousarray(mask, dtype=np.int32)
+ # loop over all the elements in mask and set to one any that have x=0 or y=0 or z=0
+ for i in range(m.x):
+ for j in range(m.y):
+ for k in range(m.z):
+ idx = neon.Index_3d(i, j, k)
+ val = 0
+ if peel(m, idx, width, True):
+ val = 1
+ mask[i, j, k] = val
+ return mask
+
+ def get_levels(num_levels):
+ levels = []
+ for i in range(num_levels - 1):
+ l = get_peeled_np(i, 8)
+ levels.append(l)
+ lastLevel = num_levels - 1
+ divider = 2**lastLevel
+ m = neon.Index_3d(dim.x // divider + 1, dim.y // divider + 1, dim.z // divider + 1)
+ lastLevel = np.ones((m.x, m.y, m.z), dtype=int)
+ lastLevel = np.ascontiguousarray(lastLevel, dtype=np.int32)
+ levels.append(lastLevel)
+ return levels
+
+ levels = get_levels(num_levels)
+
+ grid = multires_grid_factory(
+ grid_shape,
+ velocity_set=velocity_set,
+ sparsity_pattern_list=levels,
+ sparsity_pattern_origins=[neon.Index_3d(0, 0, 0)] * len(levels),
+ )
+
+ box = grid.bounding_box_indices()
+ box_no_edge = grid.bounding_box_indices(remove_edges=True)
+ lid = box_no_edge["top"]
+ walls = [box["bottom"][i] + box["left"][i] + box["right"][i] + box["front"][i] + box["back"][i] for i in range(len(grid.shape))]
+ walls = np.unique(np.array(walls), axis=-1).tolist()
+ # convert bc indices to a list of list, where the first entry of the list corresponds to the finest level
+ lid = [lid] + [[] for _ in range(num_levels - 1)]
+ walls = [walls] + [[] for _ in range(num_levels - 1)]
+ return grid, lid, walls
+
+
+def run(
+ velocity_set,
+ grid_shape,
+ num_steps,
+ num_levels,
+ collision_model,
+ export_final_velocity,
+ mres_perf_opt,
+):
+ """Set up and execute the benchmark simulation.
+
+ Returns
+ -------
+ dict
+ ``{"time": elapsed_seconds, "num_levels": int}``
+ """
+ # Create grid and setup boundary conditions
+ grid, lid, walls = ldc_multires_setup(grid_shape, velocity_set, num_levels)
+
+ prescribed_vel = 0.1
+ boundary_conditions = [
+ EquilibriumBC(rho=1.0, u=(prescribed_vel, 0.0, 0.0), indices=lid),
+ FullwayBounceBackBC(indices=walls),
+ ]
+
+ # Problem parameters
+ Re = 5000.0
+ clength = grid_shape[0] - 1
+ visc = prescribed_vel * clength / Re
+ omega_finest = 1.0 / (3.0 * visc + 0.5)
+
+ # Define a multi-resolution simulation manager
+ sim = xlb.helper.MultiresSimulationManager(
+ omega_finest=omega_finest,
+ grid=grid,
+ boundary_conditions=boundary_conditions,
+ collision_type=collision_model,
+ mres_perf_opt=mres_perf_opt,
+ )
+
+ # sim.export_macroscopic("Initial_")
+ # sim.step()
+
+ print("start timing")
+ wp.synchronize()
+ start_time = time.time()
+
+ if num_levels == 1:
+ num_steps = num_steps // 2
+
+ for i in range(num_steps):
+ sim.step()
+ # if i % 1000 == 0:
+ # print(f"step {i}")
+ # sim.export_macroscopic("u_lid_driven_cavity_")
+ wp.synchronize()
+ t = time.time() - start_time
+ print(f"Timing {t}")
+
+ if export_final_velocity:
+ sim.export_macroscopic("u_lid_driven_cavity_")
+
+ # sim.export_macroscopic("u_lid_driven_cavity_")
+ num_levels = grid.count_levels
+ return {"time": t, "num_levels": num_levels}
+
+
+def calculate_mlups(cube_edge, num_steps, elapsed_time, num_levels):
+ """Compute the Equivalent Million Lattice Updates Per Second (EMLUPS).
+
+ The metric accounts for the fact that finer levels are stepped
+ 2^(num_levels-1) times per coarsest-level step.
+
+ Returns
+ -------
+ dict
+ ``{"EMLUPS": float, "finer_steps": int}``
+ """
+ num_step_finer = num_steps * 2 ** (num_levels - 1)
+ total_lattice_updates = cube_edge**3 * num_step_finer
+ mlups = (total_lattice_updates / elapsed_time) / 1e6
+ return {"EMLUPS": mlups, "finer_steps": num_step_finer}
+
+ # # remove boundary cells
+ # rho = rho[:, 1:-1, 1:-1, 1:-1]
+ # u = u[:, 1:-1, 1:-1, 1:-1]
+ # u_magnitude = (u[0] ** 2 + u[1] ** 2) ** 0.5
+ #
+ # fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_magnitude": u_magnitude}
+ #
+ # # save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity")
+ # ny=fields["u_magnitude"].shape[1]
+ # from xlb.utils import save_image
+ # save_image(fields["u_magnitude"][:, ny//2, :], timestep=i, prefix="lid_driven_cavity")
+
+
+def generate_report(args, stats, mlups_stats):
+ """Generate a neon report file with simulation parameters and results"""
+ import neon
+ import sys
+
+ report = neon.Report("LBM MLUPS Multiresolution LDC")
+
+ # Save the full command line
+ command_line = " ".join(sys.argv)
+ report.add_member("command_line", command_line)
+
+ report.add_member("velocity_set", args.velocity_set)
+ report.add_member("compute_backend", args.compute_backend)
+ report.add_member("precision_policy", args.precision)
+ report.add_member("collision_model", args.collision_model)
+ report.add_member("grid_size", args.cube_edge)
+ report.add_member("num_steps", args.num_steps)
+ report.add_member("num_levels", stats["num_levels"])
+ report.add_member("finer_steps", mlups_stats["finer_steps"])
+
+ # Performance metrics
+ report.add_member("elapsed_time", stats["time"])
+ report.add_member("emlups", mlups_stats["EMLUPS"])
+
+ report_name = f"mlups_3d_multires_size_{args.cube_edge}_levels_{stats['num_levels']}"
+ report.write(report_name, True)
+ print("Report generated successfully.")
+
+
+def main():
+ args = parse_arguments()
+ velocity_set = setup_simulation(args)
+ grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge)
+ stats = run(
+ velocity_set, grid_shape, args.num_steps, args.num_levels, args.collision_model, args.export_final_velocity, mres_perf_opt=args.mres_perf_opt
+ )
+ mlups_stats = calculate_mlups(args.cube_edge, args.num_steps, stats["time"], stats["num_levels"])
+
+ print(f"Simulation completed in {stats['time']:.2f} seconds")
+ print(f"Number of levels {stats['num_levels']}")
+ print(f"Cube edge {args.cube_edge}")
+ print(f"Coarse Iterations {args.num_steps}")
+ finer_steps = mlups_stats["finer_steps"]
+ print(f"Fine Iterations {finer_steps}")
+ EMLUPS = mlups_stats["EMLUPS"]
+ print(f"EMLUPs: {EMLUPS:.2f}")
+
+ # Generate report if requested
+ if args.report:
+ generate_report(args, stats, mlups_stats)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/requirements.txt b/requirements.txt
index 0693dd31..a3e9bbcb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,6 +7,8 @@ trimesh
warp-lang
numpy-stl
pydantic
+nvtx
pytest
ruff
-usd-core
\ No newline at end of file
+usd-core
+h5py
\ No newline at end of file
diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py
index 07e68cf4..711c34bf 100644
--- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py
+++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py
@@ -30,7 +30,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape):
my_grid = grid_factory(grid_shape)
velocity_set = DefaultConfig.velocity_set
- missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL)
+ missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.UINT8)
bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8)
diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py
index 3cc15cb3..4f5d6757 100644
--- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py
+++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py
@@ -33,7 +33,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape):
my_grid = grid_factory(grid_shape)
velocity_set = DefaultConfig.velocity_set
- missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL)
+ missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.UINT8)
bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8)
diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py
index 4ec0639e..cb012cee 100644
--- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py
+++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py
@@ -32,7 +32,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape):
my_grid = grid_factory(grid_shape)
velocity_set = DefaultConfig.velocity_set
- missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL)
+ missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.UINT8)
bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8)
@@ -61,7 +61,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape):
bc_mask,
missing_mask,
)
- assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype
+ assert missing_mask.dtype == xlb.Precision.UINT8.wp_dtype
assert bc_mask.dtype == xlb.Precision.UINT8.wp_dtype
diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py
index f3f4308f..3d6ea901 100644
--- a/tests/kernels/collision/test_bgk_collision_jax.py
+++ b/tests/kernels/collision/test_bgk_collision_jax.py
@@ -28,7 +28,7 @@ def init_xlb_env(velocity_set):
(3, xlb.velocity_set.D3Q27, (50, 50, 50), 1.0),
],
)
-def test_bgk_ollision(dim, velocity_set, grid_shape, omega):
+def test_bgk_collision(dim, velocity_set, grid_shape, omega):
init_xlb_env(velocity_set)
my_grid = grid_factory(grid_shape)
@@ -45,7 +45,7 @@ def test_bgk_ollision(dim, velocity_set, grid_shape, omega):
f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)
- f_out = compute_collision(f_orig, f_eq, rho, u, omega)
+ f_out = compute_collision(f_orig, f_eq, omega)
assert jnp.allclose(f_out, f_orig - omega * (f_orig - f_eq))
diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py
index aa51ea1d..fa6884b2 100644
--- a/tests/kernels/collision/test_bgk_collision_warp.py
+++ b/tests/kernels/collision/test_bgk_collision_warp.py
@@ -44,7 +44,7 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega):
f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)
f_out = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q)
- f_out = compute_collision(f_orig, f_eq, f_out, rho, u, omega)
+ f_out = compute_collision(f_orig, f_eq, f_out, omega)
f_eq = f_eq.numpy()
f_out = f_out.numpy()
diff --git a/xlb/__init__.py b/xlb/__init__.py
index d03a368a..02b7a994 100644
--- a/xlb/__init__.py
+++ b/xlb/__init__.py
@@ -9,6 +9,7 @@
from xlb.compute_backend import ComputeBackend as ComputeBackend
from xlb.precision_policy import PrecisionPolicy as PrecisionPolicy, Precision as Precision
from xlb.physics_type import PhysicsType as PhysicsType
+from xlb.mres_perf_optimization_type import MresPerfOptimizationType as MresPerfOptimizationType
# Config
from .default_config import init as init, DefaultConfig as DefaultConfig
diff --git a/xlb/cell_type.py b/xlb/cell_type.py
new file mode 100644
index 00000000..ea082e10
--- /dev/null
+++ b/xlb/cell_type.py
@@ -0,0 +1,11 @@
+# Boundary-mask constants for the bc_mask field.
+# Each voxel in the domain carries a uint8 tag in bc_mask that encodes its role:
+# BC_NONE — regular fluid voxel (no boundary condition)
+# BC_SFV — Simple Fluid Voxel: fluid cell not involved in any BC,
+# explosion, or coalescence (used for fast-path kernels)
+# BC_SOLID — solid / obstacle voxel (skipped by all LBM operators)
+# Registered boundary conditions receive IDs in the range [1, 253].
+
+BC_NONE = 0
+BC_SFV = 254
+BC_SOLID = 255
diff --git a/xlb/compute_backend.py b/xlb/compute_backend.py
index 60da2912..d53ff8a4 100644
--- a/xlb/compute_backend.py
+++ b/xlb/compute_backend.py
@@ -1,8 +1,18 @@
-# Enum used to keep track of the compute backends
+"""
+Compute-backend enumeration for XLB.
+"""
from enum import Enum, auto
class ComputeBackend(Enum):
+ """Available compute backends.
+
+ ``JAX`` — single-res, multi-GPU/TPU via JAX.
+ ``WARP`` — single-res, single-GPU CUDA via NVIDIA Warp.
+ ``NEON`` — single-res and multi-res, single-GPU CUDA via Neon (uses Warp kernels internally).
+ """
+
JAX = auto()
WARP = auto()
+ NEON = auto()
diff --git a/xlb/default_config.py b/xlb/default_config.py
index 3823353c..078128b8 100644
--- a/xlb/default_config.py
+++ b/xlb/default_config.py
@@ -1,4 +1,11 @@
-import jax
+"""
+Global configuration for XLB.
+
+Call :func:`init` once at the start of every script to select the velocity
+set, compute backend, and precision policy. All operators read their
+defaults from :class:`DefaultConfig` when explicit arguments are omitted.
+"""
+
from xlb.compute_backend import ComputeBackend
from dataclasses import dataclass
from xlb.precision_policy import PrecisionPolicy
@@ -6,12 +13,40 @@
@dataclass
class DefaultConfig:
+ """Singleton holding the active global configuration.
+
+ Attributes are set by :func:`init` and read by operators, grids, and
+ helpers throughout XLB.
+
+ Attributes
+ ----------
+ default_precision_policy : PrecisionPolicy or None
+ Active precision policy (compute / store dtype pair).
+ velocity_set : VelocitySet or None
+ Active lattice velocity set.
+ default_backend : ComputeBackend or None
+ Active compute backend.
+ """
+
default_precision_policy = None
velocity_set = None
default_backend = None
def init(velocity_set, default_backend, default_precision_policy):
+ """Initialize the global XLB configuration.
+
+ Must be called before creating any grid, operator, or field.
+
+ Parameters
+ ----------
+ velocity_set : VelocitySet
+ Lattice velocity set (e.g. ``D3Q19``).
+ default_backend : ComputeBackend
+ Compute backend to use (JAX, WARP, or NEON).
+ default_precision_policy : PrecisionPolicy
+ Precision policy for compute and storage dtypes.
+ """
DefaultConfig.velocity_set = velocity_set
DefaultConfig.default_backend = default_backend
DefaultConfig.default_precision_policy = default_precision_policy
@@ -20,6 +55,23 @@ def init(velocity_set, default_backend, default_precision_policy):
import warp as wp
wp.init() # TODO: Must be removed in the future versions of WARP
+ elif default_backend == ComputeBackend.NEON:
+ import warp as wp
+ import neon
+
+ # wp.config.mode = "release"
+ # wp.config.llvm_cuda = False
+ # wp.config.verbose = True
+ # wp.verbose_warnings = True
+
+ wp.init()
+
+ # It's a good idea to always clear the kernel cache when developing new native or codegen features
+ wp.build.clear_kernel_cache()
+
+ # !!! DO THIS BEFORE DEFINING/USING ANY KERNELS WITH CUSTOM TYPES
+ neon.init()
+
elif default_backend == ComputeBackend.JAX:
check_backend_support()
else:
@@ -27,10 +79,14 @@ def init(velocity_set, default_backend, default_precision_policy):
def default_backend() -> ComputeBackend:
+ """Return the currently configured compute backend."""
return DefaultConfig.default_backend
def check_backend_support():
+ """Print a summary of available JAX hardware accelerators."""
+ import jax
+
if jax.devices()[0].platform == "gpu":
gpus = jax.devices("gpu")
if len(gpus) > 1:
diff --git a/xlb/grid/__init__.py b/xlb/grid/__init__.py
index f242c387..c4a80a8b 100644
--- a/xlb/grid/__init__.py
+++ b/xlb/grid/__init__.py
@@ -1,4 +1,5 @@
from xlb.grid.grid import grid_factory as grid_factory
+from xlb.grid.grid import multires_grid_factory as multires_grid_factory
from xlb.grid.warp_grid import WarpGrid
from xlb.grid.jax_grid import JaxGrid
diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py
index 53139fc1..7ae9236f 100644
--- a/xlb/grid/grid.py
+++ b/xlb/grid/grid.py
@@ -1,17 +1,55 @@
+"""
+Grid abstraction and factory functions for XLB.
+
+Defines the :class:`Grid` abstract base class that every backend-specific
+grid must implement, plus two factory helpers:
+
+* :func:`grid_factory` — creates a single-resolution grid for any backend.
+* :func:`multires_grid_factory` — creates a multi-resolution grid (Neon only).
+"""
+
from abc import ABC, abstractmethod
-from typing import Tuple
+from typing import Tuple, List
import numpy as np
from xlb import DefaultConfig
from xlb.compute_backend import ComputeBackend
-def grid_factory(shape: Tuple[int, ...], compute_backend: ComputeBackend = None):
+def grid_factory(
+ shape: Tuple[int, ...],
+ compute_backend: ComputeBackend = None,
+ velocity_set=None,
+ backend_config=None,
+):
+ """Create a single-resolution grid for the specified backend.
+
+ Parameters
+ ----------
+ shape : tuple of int
+ Domain dimensions, e.g. ``(nx, ny, nz)``.
+ compute_backend : ComputeBackend, optional
+ Backend to use. Defaults to ``DefaultConfig.default_backend``.
+ velocity_set : VelocitySet, optional
+ Required for the Neon backend.
+ backend_config : dict, optional
+ Backend-specific configuration (Neon only).
+
+ Returns
+ -------
+ Grid
+ A backend-specific grid instance.
+ """
compute_backend = compute_backend or DefaultConfig.default_backend
+ velocity_set = velocity_set or DefaultConfig.velocity_set
if compute_backend == ComputeBackend.WARP:
from xlb.grid.warp_grid import WarpGrid
return WarpGrid(shape)
+ elif compute_backend == ComputeBackend.NEON:
+ from xlb.grid.neon_grid import NeonGrid
+
+ return NeonGrid(shape=shape, velocity_set=velocity_set, backend_config=backend_config)
elif compute_backend == ComputeBackend.JAX:
from xlb.grid.jax_grid import JaxGrid
@@ -20,8 +58,67 @@ def grid_factory(shape: Tuple[int, ...], compute_backend: ComputeBackend = None)
raise ValueError(f"Compute backend {compute_backend} is not supported")
+def multires_grid_factory(
+ shape: Tuple[int, ...],
+ compute_backend: ComputeBackend = None,
+ velocity_set=None,
+ sparsity_pattern_list: List[np.ndarray] = [],
+ sparsity_pattern_origins=[],
+):
+ import neon
+
+ """Create a multi-resolution grid (Neon backend only).
+
+ Parameters
+ ----------
+ shape : tuple of int
+ Bounding-box dimensions at the finest level.
+ compute_backend : ComputeBackend, optional
+ Must be ``ComputeBackend.NEON``.
+ velocity_set : VelocitySet, optional
+ Lattice velocity set.
+ sparsity_pattern_list : list of np.ndarray
+ Active-voxel masks, one per level (finest first).
+ sparsity_pattern_origins : list of neon.Index_3d
+ Origin of each level's pattern in finest-level coordinates.
+
+ Returns
+ -------
+ NeonMultiresGrid
+ A multi-resolution Neon grid.
+ """
+ compute_backend = compute_backend or DefaultConfig.default_backend
+ velocity_set = velocity_set or DefaultConfig.velocity_set
+ if compute_backend == ComputeBackend.NEON:
+ from xlb.grid.multires_grid import NeonMultiresGrid
+
+ return NeonMultiresGrid(
+ shape=shape, velocity_set=velocity_set, sparsity_pattern_list=sparsity_pattern_list, sparsity_pattern_origins=sparsity_pattern_origins
+ )
+ else:
+ raise ValueError(f"Compute backend {compute_backend} is not supported for multires grid")
+
+
class Grid(ABC):
- def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend):
+ """Abstract base class for all XLB computational grids.
+
+ Subclasses must implement :meth:`_initialize_backend` to set up the
+ backend-specific data structures and :meth:`create_field` (not
+ enforced by ABC but expected by all operators).
+
+ Parameters
+ ----------
+ shape : tuple of int
+ Domain dimensions.
+ compute_backend : ComputeBackend
+ The compute backend this grid is associated with.
+ """
+
+ def __init__(
+ self,
+ shape: Tuple[int, ...],
+ compute_backend: ComputeBackend,
+ ):
self.shape = shape
self.dim = len(shape)
self.compute_backend = compute_backend
@@ -31,7 +128,11 @@ def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend):
def _initialize_backend(self):
pass
- def bounding_box_indices(self, remove_edges=False):
+ def get_compute_backend(self):
+ """Return the compute backend associated with this grid."""
+ return self.compute_backend
+
+ def bounding_box_indices(self, shape=None, remove_edges=False):
"""
This function calculates the indices of the bounding box of a 2D or 3D grid.
The bounding box is defined as the set of grid points on the outer edge of the grid.
@@ -49,9 +150,13 @@ def bounding_box_indices(self, remove_edges=False):
are numpy arrays of indices corresponding to each face.
"""
+ # If shape is not give, use self.shape
+ if shape is None:
+ shape = self.shape
+
# Get the shape of the grid
origin = np.array([0, 0, 0])
- bounds = np.array(self.shape)
+ bounds = np.array(shape)
if remove_edges:
origin += 1
bounds -= 1
@@ -60,11 +165,11 @@ def bounding_box_indices(self, remove_edges=False):
dim = len(bounds)
# Generate bounding box indices for each face
- grid = np.indices(self.shape)
+ grid = np.indices(shape)
boundingBoxIndices = {}
if dim == 2:
- nx, ny = self.shape
+ nx, ny = shape
boundingBoxIndices = {
"bottom": grid[:, slice_x, 0],
"top": grid[:, slice_x, ny - 1],
@@ -72,7 +177,7 @@ def bounding_box_indices(self, remove_edges=False):
"right": grid[:, nx - 1, slice_y],
}
elif dim == 3:
- nx, ny, nz = self.shape
+ nx, ny, nz = shape
slice_z = slice(origin[2], bounds[2])
boundingBoxIndices = {
"bottom": grid[:, slice_x, slice_y, 0].reshape(3, -1),
diff --git a/xlb/grid/multires_grid.py b/xlb/grid/multires_grid.py
new file mode 100644
index 00000000..582dd27e
--- /dev/null
+++ b/xlb/grid/multires_grid.py
@@ -0,0 +1,225 @@
+"""
+Multi-resolution sparse grid backed by the Neon ``mGrid`` runtime.
+
+This module wraps ``neon.multires.mGrid`` and exposes it through the
+:class:`Grid` interface. The grid is hierarchical: level 0 is the finest
+and level *N-1* is the coarsest. Each coarser level has half the
+resolution of the level below it (refinement factor 2).
+"""
+
+import numpy as np
+import warp as wp
+import neon
+from .grid import Grid
+from xlb.precision_policy import Precision
+from xlb.compute_backend import ComputeBackend
+from typing import Literal, List
+from xlb import DefaultConfig
+
+
+class NeonMultiresGrid(Grid):
+ """Hierarchical multi-resolution grid on the Neon backend.
+
+ Wraps ``neon.multires.mGrid``. Each level is described by a boolean
+ sparsity pattern (active-voxel mask) and an integer origin that
+ places it within the finest-level coordinate system.
+
+ Parameters
+ ----------
+ shape : tuple of int
+ Bounding-box dimensions at the **finest** level ``(nx, ny, nz)``.
+ velocity_set : VelocitySet
+ Lattice velocity set defining neighbour connectivity.
+ sparsity_pattern_list : list of np.ndarray
+ One boolean/int array per level indicating which voxels are active.
+ Index 0 = finest level, index *N-1* = coarsest.
+ sparsity_pattern_origins : list of neon.Index_3d
+ Origin offset for each level's pattern in the finest-level
+ coordinate system.
+ """
+
+ def __init__(
+ self,
+ shape,
+ velocity_set,
+ sparsity_pattern_list: List[np.ndarray],
+ sparsity_pattern_origins: List[neon.Index_3d],
+ ):
+ self.bk = None
+ self.dim = None
+ self.grid = None
+ self.velocity_set = velocity_set
+ self.sparsity_pattern_list = sparsity_pattern_list
+ self.sparsity_pattern_origins = sparsity_pattern_origins
+ self.count_levels = len(sparsity_pattern_list)
+ self.refinement_factor = 2
+
+ super().__init__(shape, ComputeBackend.NEON)
+
+ def _get_velocity_set(self):
+ return self.velocity_set
+
+ def _initialize_backend(self):
+ # FIXME@max: for now we hardcode the number of devices to 0
+ num_devs = 1
+ dev_idx_list = list(range(num_devs))
+
+ if len(self.shape) == 2:
+ import py_neon
+
+ self.dim = py_neon.Index_3d(self.shape[0], 1, self.shape[1])
+ self.neon_stencil = []
+ for q in range(self.velocity_set.q):
+ xval, yval = self.velocity_set._c[:, q]
+ self.neon_stencil.append([xval, 0, yval])
+
+ else:
+ self.dim = neon.Index_3d(self.shape[0], self.shape[1], self.shape[2])
+
+ self.neon_stencil = []
+ for q in range(self.velocity_set.q):
+ xval, yval, zval = self.velocity_set._c[:, q]
+ self.neon_stencil.append([xval, yval, zval])
+
+ self.bk = neon.Backend(runtime=neon.Backend.Runtime.stream, dev_idx_list=dev_idx_list)
+
+ self.grid = neon.multires.mGrid(
+ backend=self.bk,
+ dim=self.dim,
+ sparsity_pattern_list=self.sparsity_pattern_list,
+ sparsity_pattern_origins=self.sparsity_pattern_origins,
+ stencil=self.neon_stencil,
+ )
+ # Print grid stats about voxel distribution between levels.
+ self.grid.print_info()
+ pass
+
+ def create_field(
+ self,
+ cardinality: int,
+ dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16] = None,
+ fill_value=None,
+ neon_memory_type: neon.MemoryType = neon.MemoryType.host_device(),
+ ):
+ """Allocate a new multi-resolution Neon field.
+
+ The field spans all grid levels. Each level is either zero-filled
+ or filled with *fill_value*.
+
+ Parameters
+ ----------
+ cardinality : int
+ Number of components per voxel.
+ dtype : Precision, optional
+ Element precision. Defaults to the store precision from the
+ global config.
+ fill_value : float, optional
+ Value to fill every element with. ``None`` means zero.
+ neon_memory_type : neon.MemoryType
+ Memory residency (host, device, or both).
+
+ Returns
+ -------
+ neon.multires.mField
+ The newly allocated multi-resolution field.
+ """
+ dtype = dtype.wp_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.wp_dtype
+ field = self.grid.new_field(
+ cardinality=cardinality,
+ dtype=dtype,
+ memory_type=neon_memory_type,
+ )
+ for l in range(self.count_levels):
+ if fill_value is None:
+ field.zero_run(l, stream_idx=0)
+ else:
+ field.fill_run(level=l, value=fill_value, stream_idx=0)
+ return field
+
+ def get_neon_backend(self):
+ """Return the underlying ``neon.Backend`` instance."""
+ return self.bk
+
+ def level_to_shape(self, level):
+ """Return the bounding-box shape at the given grid level.
+
+ Level 0 is the finest and has shape ``self.shape``. Each subsequent
+ level halves each dimension.
+ """
+ # level = 0 corresponds to the finest level
+ return tuple(x // self.refinement_factor**level for x in self.shape)
+
+ def boundary_indices_across_levels(self, level_data, box_side: str = "front", remove_edges: bool = False):
+ """
+ Get indices for creating a boundary condition on the specified box side that crosses multiples levels of a multiresolution grid.
+ The indices are returned as a list of lists, where each sublist corresponds to a level
+
+ Parameters
+ ----------
+ - level_data: Level data containing the origins and sparsity patterns for each level as prepared by mesher/make_cuboid_mesh function!
+ - box_side: The side of the bounding box to get indices for (default is "front").
+ returns:
+ - A list of lists, where each sublist contains the indices for the boundary condition at that level.
+ """
+ num_levels = len(level_data)
+ bc_indices_list = []
+ d = self.velocity_set.d # Dimensionality (2 or 3)
+
+ # Define side configurations (adjust if your conventions differ)
+ if d == 3:
+ side_config = {
+ "left": {"dim": 0, "value": 0},
+ "right": {"dim": 0, "value": lambda s: s[0] - 1},
+ "front": {"dim": 1, "value": 0},
+ "back": {"dim": 1, "value": lambda s: s[1] - 1},
+ "bottom": {"dim": 2, "value": 0},
+ "top": {"dim": 2, "value": lambda s: s[2] - 1},
+ }
+ elif d == 2:
+ side_config = {
+ "left": {"dim": 0, "value": 0},
+ "right": {"dim": 0, "value": lambda s: s[0] - 1},
+ "bottom": {"dim": 1, "value": 0},
+ "top": {"dim": 1, "value": lambda s: s[1] - 1},
+ }
+ else:
+ raise ValueError(f"Unsupported dimensionality: {d}")
+
+ if box_side not in side_config:
+ raise ValueError(f"Unsupported box_side: {box_side}")
+
+ for level in range(num_levels):
+ mask = level_data[level][0]
+ origin = level_data[level][2] # Assume np.array of shape (d,)
+ grid_shape = self.level_to_shape(level) # tuple of length d
+
+ conf = side_config[box_side]
+ dim_idx = conf["dim"]
+ grid_bounds = conf["value"](grid_shape) if callable(conf["value"]) else conf["value"]
+
+ # Get local indices of active voxels
+ local_coords = np.nonzero(mask) # Tuple of d arrays, each of length num_active
+ if not local_coords[0].size:
+ bc_indices_list.append([])
+ continue
+
+ # Compute global coords (list of d arrays)
+ global_coords = [local_coords[i] + origin[i] for i in range(d)]
+
+ # Filter: must match grid_bounds along the dimension associated with the selected box_side
+ cond = global_coords[dim_idx] == grid_bounds
+
+ # If remove_edges, exclude perimeter of the face
+ if remove_edges:
+ for i in range(d):
+ if i != dim_idx:
+ cond &= (global_coords[i] > 0) & (global_coords[i] < grid_shape[i] - 1)
+
+ # Collect filtered indices
+ if np.any(cond):
+ active_bc = [gc[cond] for gc in global_coords]
+ bc_indices_list.append([arr.tolist() for arr in active_bc])
+ else:
+ bc_indices_list.append([])
+
+ return bc_indices_list
diff --git a/xlb/grid/neon_grid.py b/xlb/grid/neon_grid.py
new file mode 100644
index 00000000..e92eb7ff
--- /dev/null
+++ b/xlb/grid/neon_grid.py
@@ -0,0 +1,137 @@
+"""
+Single-resolution dense grid backed by the Neon multi-GPU runtime.
+
+This module wraps ``neon.dense.dGrid`` and exposes it through the
+:class:`Grid` interface so that XLB operators can allocate and operate on
+fields transparently.
+"""
+
+import neon
+from .grid import Grid
+from xlb.precision_policy import Precision
+from xlb.compute_backend import ComputeBackend
+from typing import Literal
+from xlb import DefaultConfig
+
+
+class NeonGrid(Grid):
+ """Dense single-resolution grid on the Neon backend.
+
+ Wraps a ``neon.dense.dGrid``. The grid is initialized with the LBM
+ stencil derived from the provided *velocity_set* so that Neon can
+ set up the correct halo exchanges for neighbour communication.
+
+ Parameters
+ ----------
+ shape : tuple of int
+ Bounding-box dimensions of the domain ``(nx, ny, nz)`` (or
+ ``(nx, ny)`` for 2-D).
+ velocity_set : VelocitySet
+ Lattice velocity set whose stencil defines neighbour connectivity.
+ backend_config : dict, optional
+ Neon backend configuration. Must contain ``"device_list"`` (list
+ of GPU device indices). Defaults to ``{"device_list": [0]}``.
+ """
+
+ def __init__(
+ self,
+ shape,
+ velocity_set,
+ backend_config=None,
+ ):
+ from .warp_grid import WarpGrid
+
+ if backend_config is None:
+ backend_config = {
+ "device_list": [0],
+ "skeleton_config": neon.SkeletonConfig.OCC.none(),
+ }
+
+ # check that the config dictionary has the required keys
+ required_keys = ["device_list"]
+ for key in required_keys:
+ if key not in backend_config:
+ raise ValueError(f"backend_config must contain a '{key}' key")
+
+ # check that the device list is a list of integers
+ if not isinstance(backend_config["device_list"], list):
+ raise ValueError("backend_config['device_list'] must be a list of integers")
+ for device in backend_config["device_list"]:
+ if not isinstance(device, int):
+ raise ValueError("backend_config['device_list'] must be a list of integers")
+
+ self.config = backend_config
+ self.bk = None
+ self.dim = None
+ self.grid = None
+ self.velocity_set = velocity_set
+
+ super().__init__(shape, ComputeBackend.NEON)
+
+ def _get_velocity_set(self):
+ return self.velocity_set
+
+ def _initialize_backend(self):
+ dev_idx_list = self.config["device_list"]
+
+ if len(self.shape) == 2:
+ import py_neon
+
+ self.dim = py_neon.Index_3d(self.shape[0], 1, self.shape[1])
+ self.neon_stencil = []
+ for q in range(self.velocity_set.q):
+ xval, yval = self.velocity_set._c[:, q]
+ self.neon_stencil.append([xval, 0, yval])
+
+ else:
+ self.dim = neon.Index_3d(self.shape[0], self.shape[1], self.shape[2])
+
+ self.neon_stencil = []
+ for q in range(self.velocity_set.q):
+ xval, yval, zval = self.velocity_set._c[:, q]
+ self.neon_stencil.append([xval, yval, zval])
+
+ self.bk = neon.Backend(runtime=neon.Backend.Runtime.stream, dev_idx_list=dev_idx_list)
+ self.bk.info_print()
+ self.grid = neon.dense.dGrid(backend=self.bk, dim=self.dim, sparsity=None, stencil=self.neon_stencil)
+ pass
+
+ def create_field(
+ self,
+ cardinality: int,
+ dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16] = None,
+ fill_value=None,
+ ):
+ """Allocate a new Neon field on this grid.
+
+ Parameters
+ ----------
+ cardinality : int
+ Number of components per voxel (e.g. ``q`` for populations).
+ dtype : Precision, optional
+ Element precision. Defaults to the store precision from the
+ global config.
+ fill_value : float, optional
+ If provided every element is set to this value; otherwise the
+ field is zero-initialized.
+
+ Returns
+ -------
+ neon.dense.dField
+ The newly allocated field.
+ """
+ dtype = dtype.wp_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.wp_dtype
+ field = self.grid.new_field(
+ cardinality=cardinality,
+ dtype=dtype,
+ )
+
+ if fill_value is None:
+ field.zero_run(stream_idx=0)
+ else:
+ field.fill_run(value=fill_value, stream_idx=0)
+ return field
+
+ def get_neon_backend(self):
+ """Return the underlying ``neon.Backend`` instance."""
+ return self.bk
diff --git a/xlb/helper/__init__.py b/xlb/helper/__init__.py
index aa6dc961..452cfbce 100644
--- a/xlb/helper/__init__.py
+++ b/xlb/helper/__init__.py
@@ -1,6 +1,7 @@
-from xlb.helper.nse_solver import create_nse_fields
-from xlb.helper.initializers import initialize_eq
+from xlb.helper.nse_fields import create_nse_fields
+from xlb.helper.initializers import initialize_eq, initialize_multires_eq, CustomInitializer, CustomMultiresInitializer
from xlb.helper.check_boundary_overlaps import check_bc_overlaps
+from xlb.helper.simulation_manager import MultiresSimulationManager
from xlb.helper.ibm_helper import (
reconstruct_mesh_from_vertices_and_faces,
transform_mesh,
diff --git a/xlb/helper/initializers.py b/xlb/helper/initializers.py
index 487d2cfa..bda54e11 100644
--- a/xlb/helper/initializers.py
+++ b/xlb/helper/initializers.py
@@ -1,8 +1,57 @@
+"""
+Initializers for distribution function fields.
+
+Provides helper functions and Operator subclasses that populate
+distribution-function fields with equilibrium values. Two usage patterns
+are supported:
+
+* **Functional helpers** (`initialize_eq`, `initialize_multires_eq`) —
+ one-shot initialization used during simulation setup.
+* **Operator classes** (`CustomInitializer`, `CustomMultiresInitializer`) —
+ reusable operators that can target the whole domain or a single boundary
+ condition region, with support for JAX, Warp, and Neon backends.
+"""
+
+import warp as wp
+from typing import Any
+from xlb import DefaultConfig
+from xlb.operator import Operator
+from xlb.velocity_set import VelocitySet
from xlb.compute_backend import ComputeBackend
from xlb.operator.equilibrium import QuadraticEquilibrium
+from xlb.operator.equilibrium import MultiresQuadraticEquilibrium
def initialize_eq(f, grid, velocity_set, precision_policy, compute_backend, rho=None, u=None):
+ """Initialize a distribution-function field to equilibrium.
+
+ Computes the quadratic equilibrium for the given density and velocity
+ fields and writes it into *f*. When *rho* or *u* are ``None`` the
+ defaults are uniform density 1 and zero velocity.
+
+ Parameters
+ ----------
+ f : field
+ Distribution-function field to populate (modified in-place for
+ Warp / Neon backends; replaced for JAX).
+ grid : Grid
+ Computational grid used to allocate temporary fields.
+ velocity_set : VelocitySet
+ Lattice velocity set (e.g. D3Q19).
+ precision_policy : PrecisionPolicy
+ Precision policy for compute / store dtypes.
+ compute_backend : ComputeBackend
+ Active compute backend (JAX, WARP, or NEON).
+ rho : field, optional
+ Density field. Defaults to uniform 1.0.
+ u : field, optional
+ Velocity field. Defaults to uniform 0.0.
+
+ Returns
+ -------
+ field
+ The initialized distribution-function field.
+ """
if rho is None:
rho = grid.create_field(cardinality=1, fill_value=1.0, dtype=precision_policy.compute_precision)
if u is None:
@@ -11,10 +60,250 @@ def initialize_eq(f, grid, velocity_set, precision_policy, compute_backend, rho=
if compute_backend == ComputeBackend.JAX:
f = equilibrium(rho, u)
-
elif compute_backend == ComputeBackend.WARP:
f = equilibrium(rho, u, f)
+ elif compute_backend == ComputeBackend.NEON:
+ f = equilibrium(rho, u, f)
+ else:
+ raise NotImplementedError(f"Backend {compute_backend} not implemented")
del rho, u
return f
+
+
+def initialize_multires_eq(f, grid, velocity_set, precision_policy, backend, rho, u):
+ """Initialize a multi-resolution distribution-function field to equilibrium.
+
+ Parameters
+ ----------
+ f : field
+ Multi-resolution distribution-function field to populate.
+ grid : NeonMultiresGrid
+ Multi-resolution grid.
+ velocity_set : VelocitySet
+ Lattice velocity set.
+ precision_policy : PrecisionPolicy
+ Precision policy.
+ backend : ComputeBackend
+ Compute backend (expected to be NEON).
+ rho : field
+ Density field across all grid levels.
+ u : field
+ Velocity field across all grid levels.
+
+ Returns
+ -------
+ field
+ The initialized multi-resolution distribution-function field.
+ """
+ equilibrium = MultiresQuadraticEquilibrium()
+ return equilibrium(rho, u, f, stream=0)
+
+
+class CustomInitializer(Operator):
+ """Operator that initializes distribution functions to equilibrium.
+
+ When ``bc_id == -1`` (default) the entire domain is initialized with the
+ given constant velocity and density. Otherwise only voxels whose
+ ``bc_mask`` matches *bc_id* are set while the rest receive the
+ weight-only equilibrium (zero velocity, unit density).
+
+ Supports JAX, Warp, and Neon backends.
+
+ Parameters
+ ----------
+ constant_velocity_vector : list of float
+ Macroscopic velocity [ux, uy, uz] used for initialization.
+ constant_density : float
+ Macroscopic density used for initialization.
+ bc_id : int
+ Boundary-condition ID to target. ``-1`` means the whole domain.
+ initialization_operator : Operator, optional
+ Equilibrium operator to use. Defaults to ``QuadraticEquilibrium``.
+ velocity_set : VelocitySet, optional
+ precision_policy : PrecisionPolicy, optional
+ compute_backend : ComputeBackend, optional
+ """
+
+ def __init__(
+ self,
+ constant_velocity_vector=[0.0, 0.0, 0.0],
+ constant_density: float = 1.0,
+ bc_id: int = -1,
+ initialization_operator=None,
+ velocity_set: VelocitySet = None,
+ precision_policy=None,
+ compute_backend=None,
+ ):
+ self.bc_id = bc_id
+ self.constant_velocity_vector = constant_velocity_vector
+ self.constant_density = constant_density
+ if initialization_operator is None:
+ compute_backend = compute_backend or DefaultConfig.default_backend
+ self.initialization_operator = QuadraticEquilibrium(
+ velocity_set=velocity_set or DefaultConfig.velocity_set,
+ precision_policy=precision_policy or DefaultConfig.precision_policy,
+ compute_backend=compute_backend if compute_backend == ComputeBackend.JAX else ComputeBackend.WARP,
+ )
+ super().__init__(velocity_set, precision_policy, compute_backend)
+
+ @Operator.register_backend(ComputeBackend.JAX)
+ def jax_implementation(self, bc_mask, f_field):
+ from xlb.grid import grid_factory
+ import jax.numpy as jnp
+
+ grid_shape = f_field.shape[1:]
+ grid = grid_factory(grid_shape)
+ rho_init = grid.create_field(cardinality=1, fill_value=self.constant_density, dtype=self.precision_policy.compute_precision)
+ u_init = grid.create_field(cardinality=self.velocity_set.d, fill_value=0.0, dtype=self.precision_policy.compute_precision)
+ _vel = jnp.array(self.constant_velocity_vector)[(...,) + (None,) * self.velocity_set.d]
+ if self.bc_id == -1:
+ u_init += _vel
+ else:
+ u_init = jnp.where(bc_mask[0] == self.bc_id, u_init + _vel, u_init)
+ return self.initialization_operator(rho_init, u_init)
+
+ def _construct_warp(self):
+ _q = self.velocity_set.q
+ _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
+ _u = _u_vec(self.constant_velocity_vector[0], self.constant_velocity_vector[1], self.constant_velocity_vector[2])
+ _rho = self.compute_dtype(self.constant_density)
+ _w = self.velocity_set.w
+ bc_id = self.bc_id
+
+ @wp.func
+ def functional_local(index: Any, bc_mask: Any, f_field: Any):
+ # Check if the index corresponds to the outlet
+ if self.read_field(bc_mask, index, 0) == bc_id:
+ _f_init = self.initialization_operator.warp_functional(_rho, _u)
+ for l in range(_q):
+ self.write_field(f_field, index, l, self.store_dtype(_f_init[l]))
+ else:
+ # In the rest of the domain, we assume zero velocity and equilibrium distribution.
+ for l in range(_q):
+ self.write_field(f_field, index, l, self.store_dtype(_w[l]))
+
+ @wp.func
+ def functional_domain(index: Any, bc_mask: Any, f_field: Any):
+ # If bc_id is -1, initialize the entire domain according to the custom initialization operator for the given velocity
+ _f_init = self.initialization_operator.warp_functional(_rho, _u)
+ for l in range(_q):
+ self.write_field(f_field, index, l, self.store_dtype(_f_init[l]))
+
+ # Set the functional based on whether we are initializing a specific BC or the entire domain
+ functional = functional_local if self.bc_id != -1 else functional_domain
+
+ # Construct the warp kernel
+ @wp.kernel
+ def kernel(
+ bc_mask: wp.array4d(dtype=wp.uint8),
+ f_field: wp.array4d(dtype=Any),
+ ):
+ # Get the global index
+ i, j, k = wp.tid()
+ index = wp.vec3i(i, j, k)
+
+ # Set the velocity at the outlet (i.e. where i = nx-1)
+ functional(index, bc_mask, f_field)
+
+ return functional, kernel
+
+ @Operator.register_backend(ComputeBackend.WARP)
+ def warp_implementation(self, bc_mask, f_field):
+ # Launch the warp kernel
+ wp.launch(
+ self.warp_kernel,
+ inputs=[bc_mask, f_field],
+ dim=f_field.shape[1:],
+ )
+ return f_field
+
+ def _construct_neon(self):
+ import neon
+
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+
+ @neon.Container.factory(name="CustomInitializer")
+ def container(
+ bc_mask: Any,
+ f_field: Any,
+ ):
+ def launcher(loader: neon.Loader):
+ loader.set_grid(f_field.get_grid())
+ f_field_pn = loader.get_write_handle(f_field)
+ bc_mask_pn = loader.get_read_handle(bc_mask)
+
+ @wp.func
+ def kernel(index: Any):
+ # apply the functional
+ functional(index, bc_mask_pn, f_field_pn)
+
+ loader.declare_kernel(kernel)
+
+ return launcher
+
+ return _, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, bc_mask, f_field, stream=0):
+ # Launch the neon container
+ c = self.neon_container(bc_mask, f_field)
+ c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon)
+ return f_field
+
+
+class CustomMultiresInitializer(CustomInitializer):
+ """Multi-resolution variant of :class:`CustomInitializer`.
+
+ Iterates over all grid levels and initializes distribution functions
+ using the Neon multi-resolution container API.
+ """
+
+ def __init__(
+ self,
+ constant_velocity_vector=[0.0, 0.0, 0.0],
+ constant_density: float = 1.0,
+ bc_id: int = -1,
+ initialization_operator=None,
+ velocity_set: VelocitySet = None,
+ precision_policy=None,
+ compute_backend=None,
+ ):
+ super().__init__(constant_velocity_vector, constant_density, bc_id, initialization_operator, velocity_set, precision_policy, compute_backend)
+
+ def _construct_neon(self):
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+
+ @neon.Container.factory(name="CustomMultiresInitializer")
+ def container(
+ bc_mask: Any,
+ f_field: Any,
+ level: Any,
+ ):
+ def launcher(loader: neon.Loader):
+ loader.set_mres_grid(f_field.get_grid(), level)
+ f_field_pn = loader.get_mres_write_handle(f_field)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask)
+
+ @wp.func
+ def kernel(index: Any):
+ # apply the functional
+ functional(index, bc_mask_pn, f_field_pn)
+
+ loader.declare_kernel(kernel)
+
+ return launcher
+
+ return _, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, bc_mask, f_field, stream=0):
+ grid = bc_mask.get_grid()
+ for level in range(grid.num_levels):
+ # Launch the neon container
+ c = self.neon_container(bc_mask, f_field, level)
+ c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon)
+ return f_field
diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_fields.py
similarity index 70%
rename from xlb/helper/nse_solver.py
rename to xlb/helper/nse_fields.py
index 361dc6e3..81e01006 100644
--- a/xlb/helper/nse_solver.py
+++ b/xlb/helper/nse_fields.py
@@ -1,6 +1,15 @@
+"""
+Factory function for creating the standard Navier-Stokes field arrays.
+
+Returns the distribution-function pair (*f_0*, *f_1*), the boundary-
+condition mask, and the missing-population mask, all allocated on the
+given grid and backend.
+"""
+
from xlb import DefaultConfig
from xlb.grid import grid_factory
from xlb.precision_policy import Precision
+from xlb.compute_backend import ComputeBackend
from typing import Tuple
@@ -30,12 +39,17 @@ def create_nse_fields(
if grid is None:
if grid_shape is None:
raise ValueError("grid_shape must be provided when grid is None")
- grid = grid_factory(grid_shape, compute_backend=compute_backend)
+ grid = grid_factory(grid_shape, compute_backend=compute_backend, velocity_set=velocity_set)
# Create fields
f_0 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision)
f_1 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision)
- missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.BOOL)
bc_mask = grid.create_field(cardinality=1, dtype=Precision.UINT8)
+ if compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]:
+ # For WARP and NEON, we use UINT8 for missing mask
+ missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.UINT8)
+ else:
+ # For JAX, we use bool for missing mask
+ missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.BOOL)
return grid, f_0, f_1, missing_mask, bc_mask
diff --git a/xlb/helper/simulation_manager.py b/xlb/helper/simulation_manager.py
new file mode 100644
index 00000000..ad60b939
--- /dev/null
+++ b/xlb/helper/simulation_manager.py
@@ -0,0 +1,244 @@
+"""
+High-level simulation manager for multi-resolution LBM on the Neon backend.
+
+:class:`MultiresSimulationManager` orchestrates the complete simulation
+lifecycle: field allocation, boundary-condition setup, coalescence-factor
+precomputation, and the recursive time-stepping skeleton that correctly
+interleaves coarse and fine grid updates.
+"""
+
+import warp as wp
+from xlb.operator.stepper import MultiresIncompressibleNavierStokesStepper
+from xlb.operator.macroscopic import MultiresMacroscopic
+from xlb.mres_perf_optimization_type import MresPerfOptimizationType
+
+
+class MultiresSimulationManager(MultiresIncompressibleNavierStokesStepper):
+ """Orchestrates multi-resolution LBM simulations on the Neon backend.
+
+ Inherits from :class:`MultiresIncompressibleNavierStokesStepper` and
+ adds field management, omega computation across levels, and the
+ recursive skeleton builder that encodes the multi-resolution
+ time-stepping order.
+
+ Parameters
+ ----------
+ omega_finest : float
+ Relaxation parameter at the finest grid level.
+ grid : NeonMultiresGrid
+ Multi-resolution grid.
+ boundary_conditions : list of BoundaryCondition
+ Boundary conditions to apply.
+ collision_type : str
+ ``"BGK"`` or ``"KBC"``.
+ forcing_scheme : str
+ Forcing scheme (used only when *force_vector* is given).
+ force_vector : array-like, optional
+ External body force.
+ initializer : Operator, optional
+ Custom initializer for distribution functions. If ``None``
+ the default equilibrium initialization is used.
+ mres_perf_opt : MresPerfOptimizationType
+ Performance optimization strategy.
+ """
+
+ def __init__(
+ self,
+ omega_finest,
+ grid,
+ boundary_conditions=[],
+ collision_type="BGK",
+ forcing_scheme="exact_difference",
+ force_vector=None,
+ initializer=None,
+ mres_perf_opt: MresPerfOptimizationType = MresPerfOptimizationType.NAIVE_COLLIDE_STREAM,
+ ):
+ super().__init__(grid, boundary_conditions, collision_type, forcing_scheme, force_vector)
+
+ self.initializer = initializer
+ self.count_levels = grid.count_levels
+ self.omega_list = [self.compute_omega(omega_finest, level) for level in range(self.count_levels)]
+ self.mres_perf_opt = mres_perf_opt
+ # Create fields
+ self.rho = grid.create_field(cardinality=1, dtype=self.precision_policy.store_precision)
+ self.u = grid.create_field(cardinality=3, dtype=self.precision_policy.store_precision)
+ self.coalescence_factor = grid.create_field(cardinality=self.velocity_set.q, dtype=self.precision_policy.store_precision)
+
+ for level in range(self.count_levels):
+ self.u.fill_run(level, 0.0, 0)
+ self.rho.fill_run(level, 1.0, 0)
+ self.coalescence_factor.fill_run(level, 0.0, 0)
+
+ # Prepare fields
+ self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.prepare_fields(self.rho, self.u, self.initializer)
+ self.prepare_coalescence_count(coalescence_factor=self.coalescence_factor, bc_mask=self.bc_mask)
+
+ self.iteration_idx = -1
+ self.macro = MultiresMacroscopic(
+ compute_backend=self.compute_backend,
+ precision_policy=self.precision_policy,
+ velocity_set=self.velocity_set,
+ )
+
+ # Construct the stepper skeleton
+ self._construct_stepper_skeleton()
+
+ def compute_omega(self, omega_finest, level):
+ """
+ Compute the relaxation parameter omega at a given grid level based on the finest level omega.
+ We select a refinement ratio of 2 where a coarse cell at level L is uniformly divided into 2^d cells
+ where d is the dimension. to arrive at level L - 1, or in other words ∆x_{L-1} = ∆x_L/2.
+ For neighboring cells that interface two grid levels, a maximum jump in grid level of ∆L = 1 is
+ allowed. Due to acoustic scaling which requires the speed of sound cs to remain constant across various grid levels,
+ ∆tL ∝ ∆xL and hence ∆t_{L-1} = ∆t_{L}/2. In addition, the fluid viscosity \nu must also remain constant on each
+ grid level which leads to the following relationship for the relaxation parameter omega at grid level L base
+ on the finest grid level omega_finest.
+
+ Args:
+ omega_finest: Relaxation parameter at the finest grid level.
+ level: Current grid level (0-indexed, with 0 being the finest level).
+
+ Returns:
+ Relaxation parameter omega at the specified grid level.
+ """
+ omega0 = omega_finest
+ return 2 ** (level + 1) * omega0 / ((2**level - 1.0) * omega0 + 2.0)
+
+ def export_macroscopic(self, fname_prefix):
+ """Compute macroscopic fields and export velocity to a VTI file.
+
+ Parameters
+ ----------
+ fname_prefix : str
+ Output filename prefix. The iteration index is appended
+ automatically (e.g. ``"u_"`` → ``"u_42.vti"``).
+ """
+ print(f"exporting macroscopic: #levels {self.count_levels}")
+ self.macro(self.f_0, self.bc_mask, self.rho, self.u, streamId=0)
+
+ wp.synchronize()
+ self.u.update_host(0)
+ wp.synchronize()
+ self.u.export_vti(f"{fname_prefix}{self.iteration_idx}.vti", "u")
+ print("DONE exporting macroscopic")
+
+ return
+
+ def step(self):
+ """Advance the simulation by one coarsest-level timestep.
+
+ Internally this executes the pre-compiled Neon skeleton which
+ performs the correct number of sub-steps at each finer level
+ according to the acoustic-scaling time refinement ratio.
+ """
+ self.iteration_idx = self.iteration_idx + 1
+ self.sk.run()
+
+ def _build_recursion(self, level, app, config):
+ """Unified multi-resolution recursion builder.
+
+ config keys:
+ finest_ops: list of (op_name, swap_fields, extra_kwargs) for level 0,
+ or None to treat level 0 like any coarse level.
+ coarse_collide_ops: list of op_names for coarse collision.
+ coarse_stream_ops: list of (op_name, extra_kwargs) for coarse streaming.
+ fuse_finest: if True, recurse once (not twice) when child is at level 0.
+ """
+ if level < 0:
+ return
+
+ omega = self.omega_list[level]
+ fields = dict(f_0_fd=self.f_0, f_1_fd=self.f_1, bc_mask_fd=self.bc_mask, missing_mask_fd=self.missing_mask)
+ fields_swapped = dict(f_0_fd=self.f_1, f_1_fd=self.f_0, bc_mask_fd=self.bc_mask, missing_mask_fd=self.missing_mask)
+
+ if level == 0 and config["finest_ops"] is not None:
+ for op_name, swap, extra in config["finest_ops"]:
+ base = fields_swapped if swap else fields
+ self.add_to_app(app=app, op_name=op_name, level=level, **base, omega=omega, **extra)
+ return
+
+ for op_name in config["coarse_collide_ops"]:
+ self.add_to_app(app=app, op_name=op_name, level=level, **fields, omega=omega, timestep=0)
+
+ if config["fuse_finest"] and level - 1 == 0:
+ self._build_recursion(level - 1, app, config)
+ else:
+ self._build_recursion(level - 1, app, config)
+ self._build_recursion(level - 1, app, config)
+
+ for op_name, extra in config["coarse_stream_ops"]:
+ self.add_to_app(app=app, op_name=op_name, level=level, **fields_swapped, **extra)
+
+ def _construct_stepper_skeleton(self):
+ import neon
+
+ """Build the Neon skeleton that encodes the recursive time-stepping order.
+
+ The skeleton is a list of Neon container invocations that, when
+ executed in sequence, perform one coarsest-level timestep with the
+ correct sub-cycling at finer levels. The structure depends on
+ ``self.mres_perf_opt``.
+ """
+ self.app = []
+
+ stream_abc = {"omega": self.coalescence_factor, "timestep": 0}
+
+ # Finest-level op descriptors: (op_name, swap_f0_f1, extra_kwargs)
+ fused_pull_finest = [
+ ("finest_fused_pull", False, {"timestep": 0, "is_f1_the_explosion_src_field": True}),
+ ("finest_fused_pull", True, {"timestep": 0, "is_f1_the_explosion_src_field": False}),
+ ]
+ sfv_fused_pull_finest = [
+ ("CFV_finest_fused_pull", False, {"timestep": 0, "is_f1_the_explosion_src_field": True}),
+ ("SFV_finest_fused_pull", False, {}),
+ ("CFV_finest_fused_pull", True, {"timestep": 0, "is_f1_the_explosion_src_field": False}),
+ ("SFV_finest_fused_pull", True, {}),
+ ]
+
+ configs = {
+ MresPerfOptimizationType.NAIVE_COLLIDE_STREAM: {
+ "finest_ops": None,
+ "coarse_collide_ops": ["collide_coarse"],
+ "coarse_stream_ops": [("stream_coarse_step_ABC", stream_abc)],
+ "fuse_finest": False,
+ },
+ MresPerfOptimizationType.FUSION_AT_FINEST: {
+ "finest_ops": fused_pull_finest,
+ "coarse_collide_ops": ["collide_coarse"],
+ "coarse_stream_ops": [("stream_coarse_step_ABC", stream_abc)],
+ "fuse_finest": True,
+ },
+ MresPerfOptimizationType.FUSION_AT_FINEST_SFV: {
+ "finest_ops": sfv_fused_pull_finest,
+ "coarse_collide_ops": ["collide_coarse"],
+ "coarse_stream_ops": [("stream_coarse_step_ABC", stream_abc)],
+ "fuse_finest": True,
+ },
+ MresPerfOptimizationType.FUSION_AT_FINEST_SFV_ALL: {
+ "finest_ops": sfv_fused_pull_finest,
+ "coarse_collide_ops": ["CFV_collide_coarse", "SFV_collide_coarse"],
+ "coarse_stream_ops": [("SFV_stream_coarse_step_ABC", stream_abc), ("SFV_stream_coarse_step", {})],
+ "fuse_finest": True,
+ },
+ }
+
+ config = configs.get(self.mres_perf_opt)
+ if config is None:
+ raise ValueError(f"Unknown optimization level: {self.mres_perf_opt}")
+
+ # Pre-recursion SFV mask setup
+ if self.mres_perf_opt == MresPerfOptimizationType.FUSION_AT_FINEST_SFV:
+ wp.synchronize()
+ self.neon_container["SFV_reset_bc_mask"](0, self.f_0, self.f_1, self.bc_mask, self.bc_mask).run(0)
+ wp.synchronize()
+ elif self.mres_perf_opt == MresPerfOptimizationType.FUSION_AT_FINEST_SFV_ALL:
+ wp.synchronize()
+ for l in range(self.f_0.get_grid().num_levels):
+ self.neon_container["SFV_reset_bc_mask"](l, self.f_0, self.f_1, self.bc_mask, self.bc_mask).run(0)
+ wp.synchronize()
+
+ self._build_recursion(self.count_levels - 1, self.app, config)
+
+ bk = self.grid.get_neon_backend()
+ self.sk = neon.Skeleton(backend=bk)
+ self.sk.sequence("mres_nse_stepper", self.app)
diff --git a/xlb/mres_perf_optimization_type.py b/xlb/mres_perf_optimization_type.py
new file mode 100644
index 00000000..797699f5
--- /dev/null
+++ b/xlb/mres_perf_optimization_type.py
@@ -0,0 +1,78 @@
+"""
+Multi-resolution performance-optimization strategies.
+
+Defines the kernel-fusion levels available for the multi-resolution LBM
+stepper and provides CLI argument parsing helpers.
+"""
+
+import argparse
+from enum import Enum, auto
+
+
+class MresPerfOptimizationType(Enum):
+ """
+ Enumeration of available optimization strategies for the LBM solver.
+
+ Supports parsing from either the enum member name (case-insensitive)
+ or its integer value, and provides a method to build the CLI parser.
+ """
+
+ NAIVE_COLLIDE_STREAM = auto()
+ FUSION_AT_FINEST = auto()
+ FUSION_AT_FINEST_SFV = auto()
+ FUSION_AT_FINEST_SFV_ALL = auto()
+
+ @staticmethod
+ def from_string(value: str) -> "MresPerfOptimizationType":
+ """
+ Parse a string to an OptimizationType.
+
+ Accepts either the enum member name (case-insensitive) or its integer value.
+
+ Args:
+ value: The enum name (e.g. 'naive_collide_stream') or integer value (e.g. '0').
+
+ Returns:
+ An OptimizationType member.
+
+ Raises:
+ argparse.ArgumentTypeError: If the input is invalid.
+ """
+ # Attempt to parse by name (case-insensitive)
+ key = value.strip().upper()
+ if key in MresPerfOptimizationType.__members__:
+ return MresPerfOptimizationType[key]
+
+ # Attempt to parse by integer value
+ try:
+ int_value = int(value)
+ return MresPerfOptimizationType(int_value)
+ except (ValueError, KeyError):
+ valid_options = ", ".join(f"{member.name}({member.value})" for member in MresPerfOptimizationType)
+ raise argparse.ArgumentTypeError(f"Invalid OptimizationType {value!r}. Choose from: {valid_options}.")
+
+ def __str__(self) -> str:
+ """
+ Return a human-readable string for the enum member.
+ """
+ return self.name
+
+ @staticmethod
+ def build_arg_parser() -> argparse.ArgumentParser:
+ """
+ Create and configure the argument parser with optimization option.
+
+ Returns:
+ A configured ArgumentParser instance.
+ """
+ parser = argparse.ArgumentParser(description="Run the LBM multiresolution simulation with specified optimizations.")
+ # Dynamically generate help text from enum members
+ valid_options = ", ".join(f"{member.name}({member.value})" for member in MresPerfOptimizationType)
+ parser.add_argument(
+ "-o",
+ "--optimization",
+ type=MresPerfOptimizationType.from_string,
+ default=MresPerfOptimizationType.NAIVE_COLLIDE_STREAM,
+ help=f"Select optimization strategy: {valid_options}",
+ )
+ return parser
diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py
index 7c87f58c..8be2f226 100644
--- a/xlb/operator/boundary_condition/__init__.py
+++ b/xlb/operator/boundary_condition/__init__.py
@@ -1,4 +1,4 @@
-from xlb.operator.boundary_condition.helper_functions_bc import HelperFunctionsBC
+from xlb.operator.boundary_condition.helper_functions_bc import HelperFunctionsBC, EncodeAuxiliaryData, MultiresEncodeAuxiliaryData
from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition
from xlb.operator.boundary_condition.boundary_condition_registry import BoundaryConditionRegistry
from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC
@@ -8,4 +8,4 @@
from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC
from xlb.operator.boundary_condition.bc_regularized import RegularizedBC
from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC
-from xlb.operator.boundary_condition.bc_grads_approximation import GradsApproximationBC
+from xlb.operator.boundary_condition.bc_hybrid import HybridBC
diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py
index aeefd788..7c5ae0b5 100644
--- a/xlb/operator/boundary_condition/bc_do_nothing.py
+++ b/xlb/operator/boundary_condition/bc_do_nothing.py
@@ -1,5 +1,8 @@
"""
-Base class for boundary conditions in a LBM simulation.
+Do-nothing boundary condition.
+
+Skips the streaming step at tagged boundary voxels, leaving the
+populations unchanged.
"""
import jax.numpy as jnp
@@ -16,6 +19,7 @@
ImplementationStep,
BoundaryCondition,
)
+from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod
class DoNothingBC(BoundaryCondition):
@@ -31,6 +35,7 @@ def __init__(
compute_backend: ComputeBackend = None,
indices=None,
mesh_vertices=None,
+ voxelization_method: MeshVoxelizationMethod = None,
):
super().__init__(
ImplementationStep.STREAMING,
@@ -39,6 +44,7 @@ def __init__(
compute_backend,
indices,
mesh_vertices,
+ voxelization_method,
)
@Operator.register_backend(ComputeBackend.JAX)
@@ -74,3 +80,12 @@ def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
dim=f_pre.shape[1:],
)
return f_post
+
+ def _construct_neon(self):
+ functional, _ = self._construct_warp()
+ return functional, None
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f_pre, f_post, bc_mask, missing_mask):
+ # raise exception as this feature is not implemented yet
+ raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.")
diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py
index 85cfd653..85ebe92b 100644
--- a/xlb/operator/boundary_condition/bc_equilibrium.py
+++ b/xlb/operator/boundary_condition/bc_equilibrium.py
@@ -12,18 +12,30 @@
from xlb.velocity_set.velocity_set import VelocitySet
from xlb.precision_policy import PrecisionPolicy
from xlb.compute_backend import ComputeBackend
-from xlb.operator.equilibrium.equilibrium import Equilibrium
-from xlb.operator.equilibrium import QuadraticEquilibrium
+from xlb.operator.equilibrium import Equilibrium, QuadraticEquilibrium
from xlb.operator.operator import Operator
from xlb.operator.boundary_condition.boundary_condition import (
ImplementationStep,
BoundaryCondition,
)
+from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod
class EquilibriumBC(BoundaryCondition):
- """
- Full Bounce-back boundary condition for a lattice Boltzmann method simulation.
+ """Equilibrium boundary condition.
+
+ Sets populations at tagged voxels to the equilibrium distribution
+ computed from the prescribed macroscopic density *rho* and velocity
+ *u*. Commonly used as an inlet or outlet condition.
+
+ Parameters
+ ----------
+ rho : float
+ Prescribed macroscopic density.
+ u : tuple of float
+ Prescribed macroscopic velocity ``(ux, uy, uz)``.
+ equilibrium_operator : Operator, optional
+ Equilibrium operator. Defaults to ``QuadraticEquilibrium``.
"""
def __init__(
@@ -36,6 +48,7 @@ def __init__(
compute_backend: ComputeBackend = None,
indices=None,
mesh_vertices=None,
+ voxelization_method: MeshVoxelizationMethod = None,
):
# Store the equilibrium information
self.rho = rho
@@ -53,6 +66,7 @@ def __init__(
compute_backend,
indices,
mesh_vertices,
+ voxelization_method,
)
@Operator.register_backend(ComputeBackend.JAX)
@@ -90,8 +104,17 @@ def functional(
return functional, kernel
+ def _construct_neon(self):
+ # Redefine the equilibrium operators for the neon backend
+ # This is because the neon backend relies on the warp functionals for its operations.
+ self.equilibrium_operator = QuadraticEquilibrium(compute_backend=ComputeBackend.WARP)
+
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+ return functional, None
+
@Operator.register_backend(ComputeBackend.WARP)
- def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
+ def warp_launch(self, f_pre, f_post, bc_mask, missing_mask):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py
index 884e691e..a1a26e0b 100644
--- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py
+++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py
@@ -1,5 +1,14 @@
"""
-Base class for boundary conditions in a LBM simulation.
+Extrapolation outflow boundary condition.
+
+Uses first-order extrapolation from the interior to set the unknown
+populations at outflow boundaries, avoiding strong wave reflections.
+
+Reference
+---------
+Geier, M. et al. (2015). "The cumulant lattice Boltzmann equation in
+three dimensions: Theory and validation." *Computers & Mathematics
+with Applications*, 70(4), 507-547.
"""
import jax.numpy as jnp
@@ -19,6 +28,7 @@
ImplementationStep,
BoundaryCondition,
)
+from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod
class ExtrapolationOutflowBC(BoundaryCondition):
@@ -42,6 +52,7 @@ def __init__(
compute_backend: ComputeBackend = None,
indices=None,
mesh_vertices=None,
+ voxelization_method: MeshVoxelizationMethod = None,
):
# Call the parent constructor
super().__init__(
@@ -51,16 +62,20 @@ def __init__(
compute_backend,
indices,
mesh_vertices,
+ voxelization_method,
)
# find and store the normal vector using indices
- self._get_normal_vec(indices)
+ if self.compute_backend == ComputeBackend.JAX:
+ self._get_normal_vectors(indices)
# Unpack the two warp functionals needed for this BC!
if self.compute_backend == ComputeBackend.WARP:
- self.warp_functional, self.update_bc_auxilary_data = self.warp_functional
+ self.warp_functional, self.assemble_auxiliary_data = self.warp_functional
+ elif self.compute_backend == ComputeBackend.NEON:
+ self.neon_functional, self.assemble_auxiliary_data = self.neon_functional
- def _get_normal_vec(self, indices):
+ def _get_normal_vectors(self, indices):
# Get the frequency count and most common element directly
freq_counts = [Counter(coord).most_common(1)[0] for coord in indices]
@@ -89,9 +104,10 @@ def _roll(self, fld, vec):
return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3))
@partial(jit, static_argnums=(0,), inline=True)
- def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
+ def assemble_auxiliary_data(self, f_pre, f_post, bc_mask, missing_mask):
"""
- Update the auxilary distribution functions for the boundary condition.
+ Prepare time-dependent dynamic data for imposing the boundary condition in the next iteration after streaming.
+ We use directions that leave the domain for storing this prepared data.
Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision
"""
sound_speed = 1.0 / jnp.sqrt(3.0)
@@ -102,7 +118,7 @@ def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
# Roll boundary mask in the opposite of the normal vector to mask its next immediate neighbour
neighbour = self._roll(boundary, -self.normal)
- # gather post-streaming values associated with previous time-step to construct the auxilary data for BC
+ # gather post-streaming values associated with previous time-step to construct the required data for BC
fpop = jnp.where(boundary, f_pre, f_post)
fpop_neighbour = jnp.where(neighbour, f_pre, f_post)
@@ -168,7 +184,7 @@ def functional(
return _f
@wp.func
- def update_bc_auxilary_data(
+ def assemble_auxiliary_data_warp(
index: Any,
timestep: Any,
missing_mask: Any,
@@ -177,9 +193,9 @@ def update_bc_auxilary_data(
_f_pre: Any,
_f_post: Any,
):
- # Update the auxilary data for this BC using the neighbour's populations stored in f_aux and
- # f_pre (post-streaming values of the current voxel). We use directions that leave the domain
- # for storing this prepared data.
+ # Prepare time-dependent dynamic data for imposing the boundary condition in the next iteration after streaming.
+ # We use directions that leave the domain for storing this prepared data.
+ # Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision
_f = _f_post
nv = get_normal_vectors(missing_mask)
for l in range(self.velocity_set.q):
@@ -194,9 +210,42 @@ def update_bc_auxilary_data(
_f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * _f_pre[l] + sound_speed * f_aux
return _f
+ @wp.func
+ def assemble_auxiliary_data_neon(
+ index: Any,
+ timestep: Any,
+ missing_mask: Any,
+ f_0: Any,
+ f_1: Any,
+ _f_pre: Any,
+ _f_post: Any,
+ level: Any = 0,
+ ):
+ # Prepare time-dependent dynamic data for imposing the boundary condition in the next iteration after streaming.
+ # We use directions that leave the domain for storing this prepared data.
+ # Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision
+ _f = _f_post
+ nv = get_normal_vectors(missing_mask)
+ for lattice_dir in range(self.velocity_set.q):
+ if missing_mask[lattice_dir] == wp.uint8(1):
+ # f_0 is the post-collision values of the current time-step
+ # Get pull index associated with the "neighbours" pull_index
+ offset = wp.vec3i(-_c[0, lattice_dir], -_c[1, lattice_dir], -_c[2, lattice_dir])
+ for d in range(self.velocity_set.d):
+ offset[d] = offset[d] - nv[d]
+ offset_pull_index = wp.neon_ngh_idx(wp.int8(offset[0]), wp.int8(offset[1]), wp.int8(offset[2]))
+
+ # The following is the post-streaming values of the neighbor cell
+ # This function reads a field value at a given neighboring index and direction.
+ unused_is_valid = wp.bool(False)
+ f_aux = self.compute_dtype(wp.neon_read_ngh(f_0, index, offset_pull_index, lattice_dir, self.store_dtype(0.0), unused_is_valid))
+ _f[_opp_indices[lattice_dir]] = (self.compute_dtype(1.0) - sound_speed) * _f_pre[lattice_dir] + sound_speed * f_aux
+ return _f
+
kernel = self._construct_kernel(functional)
+ assemble_auxiliary_data = assemble_auxiliary_data_warp if self.compute_backend == ComputeBackend.WARP else assemble_auxiliary_data_neon
- return (functional, update_bc_auxilary_data), kernel
+ return (functional, assemble_auxiliary_data), kernel
@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, _f_pre, _f_post, bc_mask, missing_mask):
@@ -207,3 +256,12 @@ def warp_implementation(self, _f_pre, _f_post, bc_mask, missing_mask):
dim=_f_pre.shape[1:],
)
return _f_post
+
+ def _construct_neon(self):
+ functional, _ = self._construct_warp()
+ return functional, None
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f_pre, f_post, bc_mask, missing_mask):
+ # raise exception as this feature is not implemented yet
+ raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.")
diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py
index 995e2ff9..4b7f8f0f 100644
--- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py
+++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py
@@ -1,5 +1,8 @@
"""
-Base class for boundary conditions in a LBM simulation.
+Full-way bounce-back boundary condition.
+
+Reverses every population at tagged solid voxels, effectively
+imposing a no-slip wall located *on* the grid node.
"""
import jax.numpy as jnp
@@ -17,6 +20,7 @@
BoundaryCondition,
ImplementationStep,
)
+from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod
class FullwayBounceBackBC(BoundaryCondition):
@@ -31,6 +35,7 @@ def __init__(
compute_backend: ComputeBackend = None,
indices=None,
mesh_vertices=None,
+ voxelization_method: MeshVoxelizationMethod = None,
):
super().__init__(
ImplementationStep.COLLISION,
@@ -39,6 +44,7 @@ def __init__(
compute_backend,
indices,
mesh_vertices,
+ voxelization_method,
)
@Operator.register_backend(ComputeBackend.JAX)
@@ -84,3 +90,7 @@ def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
dim=f_pre.shape[1:],
)
return f_post
+
+ def _construct_neon(self):
+ functional, _ = self._construct_warp()
+ return functional, None
diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py
deleted file mode 100644
index 22fbb4ec..00000000
--- a/xlb/operator/boundary_condition/bc_grads_approximation.py
+++ /dev/null
@@ -1,321 +0,0 @@
-"""
-Base class for boundary conditions in a LBM simulation.
-"""
-
-import jax.numpy as jnp
-from jax import jit
-import jax.lax as lax
-from functools import partial
-import warp as wp
-from typing import Any
-from collections import Counter
-import numpy as np
-
-from xlb.velocity_set.velocity_set import VelocitySet
-from xlb.precision_policy import PrecisionPolicy
-from xlb.compute_backend import ComputeBackend
-from xlb.operator.operator import Operator
-from xlb.operator.macroscopic import Macroscopic
-from xlb.operator.macroscopic.zero_moment import ZeroMoment
-from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux
-from xlb.operator.equilibrium import QuadraticEquilibrium
-from xlb.operator.boundary_condition.boundary_condition import (
- ImplementationStep,
- BoundaryCondition,
-)
-
-
-class GradsApproximationBC(BoundaryCondition):
- """
- Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and
- Dirichlet BCs [2]
- [1] S. Chikatamarla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann
- simulations", Europhys. Lett. 74, 215 (2006).
- [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and
- stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354.
-
- """
-
- def __init__(
- self,
- velocity_set: VelocitySet = None,
- precision_policy: PrecisionPolicy = None,
- compute_backend: ComputeBackend = None,
- indices=None,
- mesh_vertices=None,
- ):
- # TODO: the input velocity must be suitably stored elesewhere when mesh is moving.
- self.u = (0, 0, 0)
-
- # Call the parent constructor
- super().__init__(
- ImplementationStep.STREAMING,
- velocity_set,
- precision_policy,
- compute_backend,
- indices,
- mesh_vertices,
- )
-
- # Instantiate the operator for computing macroscopic values
- self.macroscopic = Macroscopic()
- self.zero_moment = ZeroMoment()
- self.equilibrium = QuadraticEquilibrium()
- self.momentum_flux = MomentumFlux()
-
- # This BC needs implicit distance to the mesh
- self.needs_mesh_distance = True
-
- # If this BC is defined using indices, it would need padding in order to find missing directions
- # when imposed on a geometry that is in the domain interior
- if self.mesh_vertices is None:
- assert self.indices is not None
- self.needs_padding = True
-
- # Raise error if used for 2d examples:
- if self.velocity_set.d == 2:
- raise NotImplementedError("This BC is not implemented in 2D!")
-
- # if indices is not None:
- # # this BC would be limited to stationary boundaries
- # # assert mesh_vertices is None
- # if mesh_vertices is not None:
- # # this BC would be applicable for stationary and moving boundaries
- # assert indices is None
- # if mesh_velocity_function is not None:
- # # mesh is moving and/or deforming
-
- assert self.compute_backend == ComputeBackend.WARP, "This BC is currently only implemented with the Warp backend!"
-
- @Operator.register_backend(ComputeBackend.JAX)
- @partial(jit, static_argnums=(0))
- def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
- # TODO
- raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!")
- return
-
- def _construct_warp(self):
- # Set local variables and constants
- _c = self.velocity_set.c
- _q = self.velocity_set.q
- _d = self.velocity_set.d
- _w = self.velocity_set.w
- _qi = self.velocity_set.qi
- _opp_indices = self.velocity_set.opp_indices
- _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
- _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
- _u_wall = _u_vec(self.u[0], self.u[1], self.u[2]) if _d == 3 else _u_vec(self.u[0], self.u[1])
- # diagonal = wp.vec3i(0, 3, 5) if _d == 3 else wp.vec2i(0, 2)
-
- @wp.func
- def regularize_fpop(
- missing_mask: Any,
- rho: Any,
- u: Any,
- fpop: Any,
- ):
- """
- Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop.
- """
- # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq}
- feq = self.equilibrium.warp_functional(rho, u)
- f_neq = fpop - feq
- PiNeq = self.momentum_flux.warp_functional(f_neq)
-
- # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq)
- nt = _d * (_d + 1) // 2
- for l in range(_q):
- QiPi1 = self.compute_dtype(0.0)
- for t in range(nt):
- QiPi1 += _qi[l, t] * PiNeq[t]
-
- # assign all populations based on eq 45 of Latt et al (2008)
- # fneq ~ f^1
- fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1
- fpop[l] = feq[l] + fpop1
- return fpop
-
- @wp.func
- def grads_approximate_fpop(
- missing_mask: Any,
- rho: Any,
- u: Any,
- f_post: Any,
- ):
- # Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and
- # Dirichlet BCs [2]
- # [1] S. Chikatax`marla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann
- # simulations", Europhys. Lett. 74, 215 (2006).
- # [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and
- # stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354.
-
- # Note: See also self.regularize_fpop function which is somewhat similar.
-
- # Compute pressure tensor Pi using all f_post-streaming values
- Pi = self.momentum_flux.warp_functional(f_post)
-
- # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq)
- nt = _d * (_d + 1) // 2
- for l in range(_q):
- # if missing_mask[l] == wp.uint8(1):
- QiPi = self.compute_dtype(0.0)
- for t in range(nt):
- if t == 0 or t == 3 or t == 5:
- QiPi += _qi[l, t] * (Pi[t] - rho / self.compute_dtype(3.0))
- else:
- QiPi += _qi[l, t] * Pi[t]
-
- # Compute c.u
- cu = self.compute_dtype(0.0)
- for d in range(self.velocity_set.d):
- if _c[d, l] == 1:
- cu += u[d]
- elif _c[d, l] == -1:
- cu -= u[d]
- cu *= self.compute_dtype(3.0)
-
- # change f_post using the Grad's approximation
- f_post[l] = rho * _w[l] * (self.compute_dtype(1.0) + cu) + _w[l] * self.compute_dtype(4.5) * QiPi
-
- return f_post
-
- # Construct the functionals for this BC
- @wp.func
- def functional_method1(
- index: Any,
- timestep: Any,
- missing_mask: Any,
- f_0: Any,
- f_1: Any,
- f_pre: Any,
- f_post: Any,
- ):
- # NOTE: this BC has been reformulated to become entirely local and so has differences compared to the original paper.
- # Here we use the current time-step populations (f_pre = f_post_collision and f_post = f_post_streaming).
- one = self.compute_dtype(1.0)
- for l in range(_q):
- # If the mask is missing then take the opposite index
- if missing_mask[l] == wp.uint8(1):
- # The implicit distance to the boundary or "weights" have been stored in known directions of f_1
- # weight = f_1[_opp_indices[l], index[0], index[1], index[2]]
- weight = self.compute_dtype(0.5)
-
- # Use differentiable interpolated BB to find f_missing:
- f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight)
-
- # # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC
- # cu = self.compute_dtype(0.0)
- # for d in range(_d):
- # if _c[d, l] == 1:
- # cu += _u_wall[d]
- # elif _c[d, l] == -1:
- # cu -= _u_wall[d]
- # cu *= self.compute_dtype(-6.0) * _w[l]
- # f_post[l] += cu
-
- # Compute density, velocity using all f_post-streaming values
- rho, u = self.macroscopic.warp_functional(f_post)
-
- # Compute Grad's appriximation using full equation as in Eq (10) of Dorschner et al.
- f_post = regularize_fpop(missing_mask, rho, u, f_post)
- # f_post = grads_approximate_fpop(missing_mask, rho, u, f_post)
- return f_post
-
- # Construct the functionals for this BC
- @wp.func
- def functional_method2(
- index: Any,
- timestep: Any,
- missing_mask: Any,
- f_0: Any,
- f_1: Any,
- f_pre: Any,
- f_post: Any,
- ):
- # NOTE: this BC has been reformulated to become entirely local and so has differences compared to the original paper.
- # Here we use the current time-step populations (f_pre = f_post_collision and f_post = f_post_streaming).
- # NOTE: f_aux should contain populations at "x_f" (see their fig 1) in the missign direction of the BC which amounts
- # to post-collision values being pulled from appropriate cells like ExtrapolationBC
- #
- # here I need to compute all terms in Eq (10)
- # Strategy:
- # 1) "weights" should have been stored somewhere to be used here.
- # 2) Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14)
- # NOTE: in the original paper "u_target" is associated with the previous time step not current time.
- # 3) Given "weights" use differentiable interpolated BB to find f_missing as I had before:
- # fmissing = ((1. - weights) * f_poststreaming_iknown + weights * (f_postcollision_imissing + f_postcollision_iknown)) / (1.0 + weights)
- # 4) Add contribution due to u_w to f_missing as is usual in regular Bouzidi BC (ie. -6.0 * self.lattice.w * jnp.dot(self.vel, c)
- # 5) Compute rho_target = \sum(f_ibb) based on these values
- # 6) Compute feq using feq = self.equilibrium(rho_target, u_target)
- # 7) Compute Pi_neq and Pi_eq using all f_post-streaming values as per:
- # Pi_neq = self.momentum_flux(fneq) and Pi_eq = self.momentum_flux(feq)
- # 8) Compute Grad's appriximation using full equation as in Eq (10)
- # NOTE: this is very similar to the regularization procedure.
-
- _f_nbr = _f_vec()
- u_target = _u_vec(0.0, 0.0, 0.0) if _d == 3 else _u_vec(0.0, 0.0)
- num_missing = 0
- one = self.compute_dtype(1.0)
- for l in range(_q):
- # If the mask is missing then take the opposite index
- if missing_mask[l] == wp.uint8(1):
- # Find the neighbour and its velocity value
- for ll in range(_q):
- # f_0 is the post-collision values of the current time-step
- # Get index associated with the fluid neighbours
- fluid_nbr_index = type(index)()
- for d in range(_d):
- fluid_nbr_index[d] = index[d] + _c[d, l]
- # The following is the post-collision values of the fluid neighbor cell
- _f_nbr[ll] = self.compute_dtype(f_0[ll, fluid_nbr_index[0], fluid_nbr_index[1], fluid_nbr_index[2]])
-
- # Compute the velocity vector at the fluid neighbouring cells
- _, u_f = self.macroscopic.warp_functional(_f_nbr)
-
- # Record the number of missing directions
- num_missing += 1
-
- # The implicit distance to the boundary or "weights" have been stored in known directions of f_1
- weight = f_1[_opp_indices[l], index[0], index[1], index[2]]
-
- # Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14)
- for d in range(_d):
- u_target[d] += (weight * u_f[d] + _u_wall[d]) / (one + weight)
-
- # Use differentiable interpolated BB to find f_missing:
- f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight)
-
- # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC
- cu = self.compute_dtype(0.0)
- for d in range(_d):
- if _c[d, l] == 1:
- cu += _u_wall[d]
- elif _c[d, l] == -1:
- cu -= _u_wall[d]
- cu *= self.compute_dtype(-6.0) * _w[l]
- f_post[l] += cu
-
- # Compute rho_target = \sum(f_ibb) based on these values
- rho_target = self.zero_moment.warp_functional(f_post)
- for d in range(_d):
- u_target[d] /= num_missing
-
- # Compute Grad's appriximation using full equation as in Eq (10) of Dorschner et al.
- f_post = grads_approximate_fpop(missing_mask, rho_target, u_target, f_post)
- return f_post
-
- functional = functional_method1
-
- kernel = self._construct_kernel(functional)
-
- return functional, kernel
-
- @Operator.register_backend(ComputeBackend.WARP)
- def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
- # Launch the warp kernel
- wp.launch(
- self.warp_kernel,
- inputs=[f_pre, f_post, bc_mask, missing_mask],
- dim=f_pre.shape[1:],
- )
- return f_post
diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py
index 8ede0c8b..3ee31584 100644
--- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py
+++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py
@@ -1,5 +1,10 @@
"""
-Base class for boundary conditions in a LBM simulation.
+Halfway bounce-back boundary condition.
+
+Implements the standard halfway bounce-back scheme where the no-slip
+wall is located halfway between a solid node and a fluid node.
+Optionally supports prescribed wall velocity (moving walls) and
+interpolated variants that use wall-distance data.
"""
import jax.numpy as jnp
@@ -7,7 +12,8 @@
import jax.lax as lax
from functools import partial
import warp as wp
-from typing import Any
+from typing import Any, Union, Tuple, Callable
+import numpy as np
from xlb.velocity_set.velocity_set import VelocitySet
from xlb.precision_policy import PrecisionPolicy
@@ -16,7 +22,9 @@
from xlb.operator.boundary_condition.boundary_condition import (
ImplementationStep,
BoundaryCondition,
+ HelperFunctionsBC,
)
+from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod
class HalfwayBounceBackBC(BoundaryCondition):
@@ -33,6 +41,9 @@ def __init__(
compute_backend: ComputeBackend = None,
indices=None,
mesh_vertices=None,
+ voxelization_method: MeshVoxelizationMethod = None,
+ profile: Callable = None,
+ prescribed_value: Union[float, Tuple[float, ...], np.ndarray] = None,
):
# Call the parent constructor
super().__init__(
@@ -42,24 +53,86 @@ def __init__(
compute_backend,
indices,
mesh_vertices,
+ voxelization_method,
)
# This BC needs padding for finding missing directions when imposed on a geometry that is in the domain interior
self.needs_padding = True
+ # This BC class accepts both constant prescribed values of velocity with keyword "prescribed_value" or
+ # velocity profiles given by keyword "profile" which must be a callable function.
+ self.profile = profile
+
+ # A flag to enable moving wall treatment when either "prescribed_value" or "profile" are provided.
+ self.needs_moving_wall_treatment = False
+
+ if (profile is not None) or (prescribed_value is not None):
+ self.needs_moving_wall_treatment = True
+
+ # Handle no-slip BCs if neither prescribed_value or profile are provided.
+ if prescribed_value is None and profile is None:
+ print(f"WARNING! Assuming no-slip condition for BC type = {self.__class__.__name__}!")
+ prescribed_value = [0] * self.velocity_set.d
+
+ # Handle prescribed value if provided
+ if prescribed_value is not None:
+ if profile is not None:
+ raise ValueError("Cannot specify both profile and prescribed_value")
+
+ # Ensure prescribed_value is a NumPy array of floats
+ if isinstance(prescribed_value, (tuple, list, np.ndarray)):
+ prescribed_value = np.asarray(prescribed_value, dtype=np.float64)
+ else:
+ raise ValueError("Velocity prescribed_value must be a tuple, list, or array")
+
+ # Create a constant prescribed profile function
+ if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]:
+ if self.velocity_set.d == 2:
+ prescribed_value = np.array([prescribed_value[0], prescribed_value[1], 0.0], dtype=np.float64)
+ prescribed_value = wp.vec(3, dtype=self.compute_dtype)(prescribed_value)
+ self.profile = self._create_constant_prescribed_profile(prescribed_value)
+
+ def _create_constant_prescribed_profile(self, prescribed_value):
+ _u_vec = wp.vec(3, dtype=self.compute_dtype)
+
+ @wp.func
+ def prescribed_profile_warp(index: Any, time: Any):
+ return _u_vec(prescribed_value[0], prescribed_value[1], prescribed_value[2])
+
+ def prescribed_profile_jax():
+ return jnp.array(prescribed_value, dtype=self.precision_policy.store_precision.jax_dtype).reshape(-1, 1)
+
+ if self.compute_backend == ComputeBackend.JAX:
+ return prescribed_profile_jax
+ elif self.compute_backend == ComputeBackend.WARP:
+ return prescribed_profile_warp
+ elif self.compute_backend == ComputeBackend.NEON:
+ return prescribed_profile_warp
+
@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0))
def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
boundary = bc_mask == self.id
new_shape = (self.velocity_set.q,) + boundary.shape[1:]
boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1)))
- return jnp.where(
- jnp.logical_and(missing_mask, boundary),
- f_pre[self.velocity_set.opp_indices],
- f_post,
- )
+
+ # Add contribution due to moving_wall to f_missing
+ moving_wall_component = 0.0
+ if self.needs_moving_wall_treatment:
+ u_wall = self.profile()
+ cu = self.velocity_set.w[:, None] * jnp.tensordot(self.velocity_set.c, u_wall, axes=(0, 0))
+ cu = cu.reshape((-1,) + (1,) * (len(f_post[1:].shape) - 1))
+ moving_wall_component = 6.0 * cu
+
+ # Apply the halfway bounce-back condition
+ f_post = jnp.where(jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices] + moving_wall_component, f_post)
+
+ return f_post
def _construct_warp(self):
+ # load helper functions. Explicitly using the WARP backend for helper functions as it may also be called by the Neon backend.
+ bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=ComputeBackend.WARP)
+
# Set local constants
_opp_indices = self.velocity_set.opp_indices
@@ -74,6 +147,9 @@ def functional(
f_pre: Any,
f_post: Any,
):
+ # Get wall velocity
+ u_wall = self.profile(index, timestep)
+
# Post-streaming values are only modified at missing direction
_f = f_post
for l in range(self.velocity_set.q):
@@ -82,6 +158,10 @@ def functional(
# Get the pre-streaming distribution function in oppisite direction
_f[l] = f_pre[_opp_indices[l]]
+ # Add contribution due to moving_wall to f_missing
+ if wp.static(self.needs_moving_wall_treatment):
+ _f[l] += bc_helper.moving_wall_fpop_correction(u_wall, l)
+
return _f
kernel = self._construct_kernel(functional)
@@ -97,3 +177,12 @@ def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
dim=f_pre.shape[1:],
)
return f_post
+
+ def _construct_neon(self):
+ functional, _ = self._construct_warp()
+ return functional, None
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f_pre, f_post, bc_mask, missing_mask):
+ # raise exception as this feature is not implemented yet
+ raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.")
diff --git a/xlb/operator/boundary_condition/bc_hybrid.py b/xlb/operator/boundary_condition/bc_hybrid.py
new file mode 100644
index 00000000..62584160
--- /dev/null
+++ b/xlb/operator/boundary_condition/bc_hybrid.py
@@ -0,0 +1,391 @@
+"""
+Hybrid boundary condition combining interpolated bounce-back with regularization.
+
+Provides three wall-treatment strategies, selectable via *bc_method*:
+
+* ``"bounceback_regularized"`` — interpolated bounce-back + Latt regularization.
+* ``"bounceback_grads"`` — interpolated bounce-back + Grad's approximation.
+* ``"nonequilibrium_regularized"`` — Tao non-equilibrium bounce-back + Latt
+ regularization.
+
+All variants optionally support:
+
+* Moving walls (via *prescribed_value* or *profile*).
+* Curved boundaries with fractional distance to the mesh surface (via
+ *use_mesh_distance*).
+"""
+
+import inspect
+from jax import jit
+from functools import partial
+import warp as wp
+from typing import Any, Union, Tuple, Callable
+import numpy as np
+
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.operator import Operator
+from xlb.operator.macroscopic import Macroscopic
+from xlb.operator.equilibrium import QuadraticEquilibrium
+from xlb.operator.boundary_condition.boundary_condition import (
+ ImplementationStep,
+ BoundaryCondition,
+ HelperFunctionsBC,
+)
+from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod
+
+
+class HybridBC(BoundaryCondition):
+ """
+ The hybrid BC methods in this boundary condition have been originally developed by H. Salehipour and are inspired from
+ various previous publications, in particular [1]. The reformulations are aimed to provide local formulations that are
+ computationally efficient and numerically stable at high Reynolds numbers.
+
+ [1] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and
+ stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354.
+ """
+
+ def __init__(
+ self,
+ bc_method,
+ profile: Callable = None,
+ prescribed_value: Union[float, Tuple[float, ...], np.ndarray] = None,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ indices=None,
+ mesh_vertices=None,
+ voxelization_method: MeshVoxelizationMethod = None,
+ use_mesh_distance=False,
+ ):
+ """
+ Parameters
+ ----------
+ bc_method : str
+ Wall-treatment strategy. One of ``"bounceback_regularized"``,
+ ``"bounceback_grads"``, or ``"nonequilibrium_regularized"``.
+ profile : callable, optional
+ Warp function ``(index) -> u_vec`` or ``(index, timestep) -> u_vec``
+ defining the wall velocity. Mutually exclusive with *prescribed_value*.
+ prescribed_value : float or array-like, optional
+ Constant wall velocity vector. Mutually exclusive with *profile*.
+ If neither is given, a no-slip wall is assumed.
+ velocity_set : VelocitySet, optional
+ precision_policy : PrecisionPolicy, optional
+ compute_backend : ComputeBackend, optional
+ indices : list of array-like, optional
+ Boundary voxel indices (use this **or** *mesh_vertices*, not both).
+ mesh_vertices : np.ndarray, optional
+ Mesh triangle vertices for mesh-based voxelization.
+ voxelization_method : MeshVoxelizationMethod, optional
+ Voxelization strategy (AABB, RAY, AABB_CLOSE, etc.).
+ use_mesh_distance : bool
+ If ``True``, fractional distances to the mesh surface are
+ computed and stored for interpolated boundary schemes.
+ """
+ assert bc_method in [
+ "bounceback_regularized",
+ "bounceback_grads",
+ "nonequilibrium_regularized",
+ ], f"type = {bc_method} not supported! Use 'bounceback_regularized', 'bounceback_grads' or 'nonequilibrium_regularized'."
+ self.bc_method = bc_method
+
+ # Call the parent constructor
+ super().__init__(
+ ImplementationStep.STREAMING,
+ velocity_set,
+ precision_policy,
+ compute_backend,
+ indices,
+ mesh_vertices,
+ voxelization_method,
+ )
+
+ # Raise error if used for 2d examples:
+ if self.velocity_set.d == 2:
+ raise NotImplementedError("This BC is not implemented in 2D!")
+
+ # Check if the compute backend is Warp
+ assert self.compute_backend in (ComputeBackend.WARP, ComputeBackend.NEON), "This BC is currently not supported by JAX backend!"
+
+ # Instantiate the operator for computing macroscopic values
+ # Explicitly using the WARP backend for these operators as they may also be called by the Neon backend.
+ self.macroscopic = Macroscopic(compute_backend=ComputeBackend.WARP)
+ self.equilibrium = QuadraticEquilibrium(compute_backend=ComputeBackend.WARP)
+
+ # Define BC helper functions. Explicitly using the WARP backend for helper functions as it may also be called by the Neon backend.
+ self.bc_helper = HelperFunctionsBC(
+ velocity_set=self.velocity_set,
+ precision_policy=self.precision_policy,
+ compute_backend=ComputeBackend.WARP,
+ distance_decoder_function=self._construct_distance_decoder_function(),
+ )
+
+ # A flag to enable moving wall treatment when either "prescribed_value" or "profile" are provided.
+ self.needs_moving_wall_treatment = False
+
+ if (profile is not None) or (prescribed_value is not None):
+ self.needs_moving_wall_treatment = True
+
+ # Handle no-slip BCs if neither prescribed_value or profile are provided.
+ if prescribed_value is None and profile is None:
+ print(f"WARNING! Assuming no-slip condition for BC type = {self.__class__.__name__}_{self.bc_method}!")
+ prescribed_value = [0, 0, 0]
+
+ # Handle prescribed value if provided
+ if prescribed_value is not None:
+ assert profile is None, "Cannot specify both profile and prescribed_value"
+
+ # Ensure prescribed_value is a NumPy array of floats
+ if isinstance(prescribed_value, (tuple, list, np.ndarray)):
+ prescribed_value = np.asarray(prescribed_value, dtype=np.float64)
+ else:
+ raise ValueError("Velocity prescribed_value must be a tuple, list, or array")
+
+ # Handle 2D velocity sets
+ if self.velocity_set.d == 2:
+ assert len(prescribed_value) == 2, "For 2D velocity set, prescribed_value must be a tuple or array of length 2!"
+ prescribed_value = np.array([prescribed_value[0], prescribed_value[1], 0.0], dtype=np.float64)
+
+ # create a constant prescribed profile
+ _u_vec = wp.vec(3, dtype=self.compute_dtype)
+ prescribed_value = _u_vec(prescribed_value)
+
+ @wp.func
+ def prescribed_profile_warp(index: Any):
+ return _u_vec(prescribed_value[0], prescribed_value[1], prescribed_value[2])
+
+ profile = prescribed_profile_warp
+
+ # Inspect the function signature and add time parameter if needed
+ self.is_time_dependent = False
+ sig = inspect.signature(profile)
+ if len(sig.parameters) > 1:
+ # We assume the profile function takes only the index as input and is hence time-independent.
+ # In case it is defined with more than 1 input, we assume the second input is time and create
+ # a wrapper function that also accepts time as a parameter.
+ self.is_time_dependent = True
+
+ # This BC class accepts both constant prescribed values of velocity with keyword "prescribed_value" or
+ # velocity profiles given by keyword "profile" which must be a callable function.
+ self.profile = profile
+
+ # Set whether this BC needs mesh distance
+ self.needs_mesh_distance = use_mesh_distance
+
+ # This BC needs normalized distance to the mesh
+ if self.needs_mesh_distance:
+ # This BC needs auxiliary data recovery after streaming
+ self.needs_aux_recovery = True
+
+ # If this BC is defined using indices, it would need padding in order to find missing directions
+ # when imposed on a geometry that is in the domain interior
+ if self.mesh_vertices is None:
+ assert self.indices is not None
+ assert self.needs_mesh_distance is False, 'To use mesh distance, please provide the mesh vertices using keyword "mesh_vertices"!'
+ assert self.voxelization_method is None, "Voxelization method is only applicable when using mesh vertices!"
+ self.needs_padding = True
+ else:
+ assert self.indices is None, "Cannot use indices with mesh vertices! Please provide mesh vertices only."
+
+ # Define the profile functional
+ self.profile_functional = self._construct_profile_functional()
+
+ @Operator.register_backend(ComputeBackend.JAX)
+ @partial(jit, static_argnums=(0))
+ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
+ raise NotImplementedError(f"Operation {self.__class__.__name__} not implemented in JAX!")
+
+ def _construct_distance_decoder_function(self):
+ """
+ Constructs the distance decoder function for this BC.
+ """
+ # Get the opposite indices for the velocity set
+ _opp_indices = self.velocity_set.opp_indices
+
+ # Define the distance decoder function for this BC
+ @wp.func
+ def distance_decoder_function(f_1: Any, index: Any, direction: Any):
+ return self.read_field(f_1, index, _opp_indices[direction])
+
+ return distance_decoder_function
+
+ def _construct_profile_functional(self):
+ """
+ Get the profile functional for this BC.
+ TODO@Hesam:
+ Right now, we can impose a profile on a boundary which requires mesh-distance only if that boundary lives on the finest level.
+ In order to extract "level" from the "neon_field_hdl" we can use the function wp.neon_level(neon_field_hdl). This will allow us
+ to do the following and get rid of the above limitation.
+ cIdx = wp.neon_global_idx(field_neon_hdl, index)
+ gx = wp.neon_get_x(cIdx) // 2 ** level
+ gy = wp.neon_get_y(cIdx) // 2 ** level
+ gz = wp.neon_get_z(cIdx) // 2 ** level
+ """
+
+ @wp.func
+ def profile_functional_neon(f_1: Any, index: Any, timestep: Any):
+ # Convert neon index to warp index
+ warp_index = self.bc_helper.neon_index_to_warp(f_1, index)
+ if wp.static(self.is_time_dependent):
+ return self.profile(warp_index, timestep)
+ else:
+ return self.profile(warp_index)
+
+ @wp.func
+ def profile_functional_warp(f_1: Any, index: Any, timestep: Any):
+ if wp.static(self.is_time_dependent):
+ return self.profile(index, timestep)
+ else:
+ return self.profile(index)
+
+ return profile_functional_warp if self.compute_backend == ComputeBackend.WARP else profile_functional_neon
+
+ def _construct_warp(self):
+ # Construct the functionals for this BC
+ @wp.func
+ def hybrid_bounceback_regularized(
+ index: Any,
+ timestep: Any,
+ _missing_mask: Any,
+ f_0: Any,
+ f_1: Any,
+ f_pre: Any,
+ f_post: Any,
+ ):
+ # Using regularization technique [1] to represent fpop using macroscopic values derived from interpolated bounceback scheme of [2].
+ # missing data in lattice Boltzmann.
+ # [1] Latt, J., Chopard, B., Malaspinas, O., Deville, M., Michler, A., 2008. Straight velocity
+ # boundaries in the lattice Boltzmann method. Physical Review E 77, 056703.
+ # [2] Yu, D., Mei, R., Shyy, W., 2003. A unified boundary treatment in lattice boltzmann method,
+ # in: 41st aerospace sciences meeting and exhibit, p. 953.
+
+ # Apply interpolated bounceback first to find missing populations at the boundary
+ u_wall = self.profile_functional(f_1, index, timestep)
+ f_post = self.bc_helper.interpolated_bounceback(
+ index,
+ _missing_mask,
+ f_0,
+ f_1,
+ f_pre,
+ f_post,
+ u_wall,
+ wp.static(self.needs_moving_wall_treatment),
+ wp.static(self.needs_mesh_distance),
+ )
+
+ # Compute density, velocity using all f_post-streaming values
+ rho, u = self.macroscopic.warp_functional(f_post)
+
+ # Regularize the resulting populations
+ feq = self.equilibrium.warp_functional(rho, u)
+ f_post = self.bc_helper.regularize_fpop(f_post, feq)
+ return f_post
+
+ @wp.func
+ def hybrid_bounceback_grads(
+ index: Any,
+ timestep: Any,
+ _missing_mask: Any,
+ f_0: Any,
+ f_1: Any,
+ f_pre: Any,
+ f_post: Any,
+ ):
+ # Using Grad's approximation [1] to represent fpop using macroscopic values derived from interpolated bounceback scheme of [2].
+ # missing data in lattice Boltzmann.
+ # [1] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and
+ # stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354.
+ # [2] Yu, D., Mei, R., Shyy, W., 2003. A unified boundary treatment in lattice boltzmann method,
+ # in: 41st aerospace sciences meeting and exhibit, p. 953.
+
+ # Apply interpolated bounceback first to find missing populations at the boundary
+ u_wall = self.profile_functional(f_1, index, timestep)
+ f_post = self.bc_helper.interpolated_bounceback(
+ index,
+ _missing_mask,
+ f_0,
+ f_1,
+ f_pre,
+ f_post,
+ u_wall,
+ wp.static(self.needs_moving_wall_treatment),
+ wp.static(self.needs_mesh_distance),
+ )
+
+ # Compute density, velocity using all f_post-streaming values
+ rho, u = self.macroscopic.warp_functional(f_post)
+
+ # Compute Grad's approximation using full equation as in Eq (10) of Dorschner et al.
+ f_post = self.bc_helper.grads_approximate_fpop(_missing_mask, rho, u, f_post)
+ return f_post
+
+ @wp.func
+ def hybrid_nonequilibrium_regularized(
+ index: Any,
+ timestep: Any,
+ _missing_mask: Any,
+ f_0: Any,
+ f_1: Any,
+ f_pre: Any,
+ f_post: Any,
+ ):
+ # This boundary condition uses the method of Tao et al (2018) [1] to get unknown populations on curved boundaries (denoted here by
+ # interpolated_nonequilibrium_bounceback method). To further stabilize this BC, we add regularization technique of [2].
+ # [1] Tao, Shi, et al. "One-point second-order curved boundary condition for lattice Boltzmann simulation of suspended particles."
+ # Computers & Mathematics with Applications 76.7 (2018): 1593-1607.
+ # [2] Latt, J., Chopard, B., Malaspinas, O., Deville, M., Michler, A., 2008. Straight velocity
+ # boundaries in the lattice Boltzmann method. Physical Review E 77, 056703.
+
+ # Apply interpolated bounceback first to find missing populations at the boundary
+ u_wall = self.profile_functional(f_1, index, timestep)
+ f_post = self.bc_helper.interpolated_nonequilibrium_bounceback(
+ index,
+ _missing_mask,
+ f_0,
+ f_1,
+ f_pre,
+ f_post,
+ u_wall,
+ wp.static(self.needs_moving_wall_treatment),
+ wp.static(self.needs_mesh_distance),
+ )
+
+ # Compute density, velocity using all f_post-streaming values
+ rho, u = self.macroscopic.warp_functional(f_post)
+
+ # Regularize the resulting populations
+ feq = self.equilibrium.warp_functional(rho, u)
+ f_post = self.bc_helper.regularize_fpop(f_post, feq)
+ return f_post
+
+ if self.bc_method == "bounceback_regularized":
+ functional = hybrid_bounceback_regularized
+ elif self.bc_method == "bounceback_grads":
+ functional = hybrid_bounceback_grads
+ elif self.bc_method == "nonequilibrium_regularized":
+ functional = hybrid_nonequilibrium_regularized
+
+ kernel = self._construct_kernel(functional)
+
+ return functional, kernel
+
+ @Operator.register_backend(ComputeBackend.WARP)
+ def warp_implementation(self, f_pre, f_post, bc_mask, _missing_mask):
+ # Launch the warp kernel
+ wp.launch(
+ self.warp_kernel,
+ inputs=[f_pre, f_post, bc_mask, _missing_mask],
+ dim=f_pre.shape[1:],
+ )
+ return f_post
+
+ def _construct_neon(self):
+ functional, _ = self._construct_warp()
+ return functional, None
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f_pre, f_post, bc_mask, missing_mask):
+ # raise exception as this feature is not implemented yet
+ raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.")
diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py
index 8d741798..1ea1f84a 100644
--- a/xlb/operator/boundary_condition/bc_regularized.py
+++ b/xlb/operator/boundary_condition/bc_regularized.py
@@ -1,5 +1,13 @@
"""
-Base class for boundary conditions in a LBM simulation.
+Regularized boundary condition.
+
+A non-equilibrium bounce-back scheme with additional regularization of the
+distribution function. Applicable as velocity or pressure boundary conditions.
+
+Reference
+---------
+Latt, J. et al. (2008). "Straight velocity boundaries in the lattice
+Boltzmann method." *Physical Review E*, 77(5), 056703.
"""
import jax.numpy as jnp
@@ -7,7 +15,7 @@
import jax.lax as lax
from functools import partial
import warp as wp
-from typing import Any, Union, Tuple
+from typing import Any, Union, Tuple, Callable
import numpy as np
from xlb.velocity_set.velocity_set import VelocitySet
@@ -16,6 +24,7 @@
from xlb.operator.operator import Operator
from xlb.operator.boundary_condition import ZouHeBC, HelperFunctionsBC
from xlb.operator.macroscopic import SecondMoment as MomentumFlux
+from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod
class RegularizedBC(ZouHeBC):
@@ -43,13 +52,14 @@ class RegularizedBC(ZouHeBC):
def __init__(
self,
bc_type,
- profile=None,
+ profile: Callable = None,
prescribed_value: Union[float, Tuple[float, ...], np.ndarray] = None,
velocity_set: VelocitySet = None,
precision_policy: PrecisionPolicy = None,
compute_backend: ComputeBackend = None,
indices=None,
mesh_vertices=None,
+ voxelization_method: MeshVoxelizationMethod = None,
):
# Call the parent constructor
super().__init__(
@@ -61,6 +71,7 @@ def __init__(
compute_backend,
indices,
mesh_vertices,
+ voxelization_method,
)
self.momentum_flux = MomentumFlux()
@@ -124,18 +135,17 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post
def _construct_warp(self):
- # load helper functions
- bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)
+ # load helper functions. Always use warp backend for helper functions as it may also be called by the Neon backend.
+ bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=ComputeBackend.WARP)
# Set local constants
_d = self.velocity_set.d
- _q = self.velocity_set.q
- _opp_indices = self.velocity_set.opp_indices
+ lattice_central_index = self.velocity_set.center_index
@wp.func
def functional_velocity(
index: Any,
timestep: Any,
- missing_mask: Any,
+ _missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
@@ -145,16 +155,16 @@ def functional_velocity(
_f = f_post
# Find normal vector
- normals = bc_helper.get_normal_vectors(missing_mask)
+ normals = bc_helper.get_normal_vectors(_missing_mask)
# Find the value of u from the missing directions
# Since we are only considering normal velocity, we only need to find one value (stored at the center of f_1)
# Create velocity vector by multiplying the prescribed value with the normal vector
- prescribed_value = self.compute_dtype(f_1[0, index[0], index[1], index[2]])
+ prescribed_value = self.decoder_functional(f_1, index, _missing_mask)[0]
_u = -prescribed_value * normals
# calculate rho
- fsum = bc_helper.get_bc_fsum(_f, missing_mask)
+ fsum = bc_helper.get_bc_fsum(_f, _missing_mask)
unormal = self.compute_dtype(0.0)
for d in range(_d):
unormal += _u[d] * normals[d]
@@ -162,7 +172,7 @@ def functional_velocity(
# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
- _f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask)
+ _f = bc_helper.bounceback_nonequilibrium(_f, feq, _missing_mask)
# Regularize the boundary fpop
_f = bc_helper.regularize_fpop(_f, feq)
@@ -172,7 +182,7 @@ def functional_velocity(
def functional_pressure(
index: Any,
timestep: Any,
- missing_mask: Any,
+ _missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
@@ -182,20 +192,20 @@ def functional_pressure(
_f = f_post
# Find normal vector
- normals = bc_helper.get_normal_vectors(missing_mask)
+ normals = bc_helper.get_normal_vectors(_missing_mask)
# Find the value of rho from the missing directions
# Since we need only one scalar value, we only need to find one value (stored at the center of f_1)
- _rho = self.compute_dtype(f_1[0, index[0], index[1], index[2]])
+ _rho = self.decoder_functional(f_1, index, _missing_mask)[0]
# calculate velocity
- fsum = bc_helper.get_bc_fsum(_f, missing_mask)
+ fsum = bc_helper.get_bc_fsum(_f, _missing_mask)
unormal = -self.compute_dtype(1.0) + fsum / _rho
_u = unormal * normals
# impose non-equilibrium bounceback
feq = self.equilibrium_operator.warp_functional(_rho, _u)
- _f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask)
+ _f = bc_helper.bounceback_nonequilibrium(_f, feq, _missing_mask)
# Regularize the boundary fpop
_f = bc_helper.regularize_fpop(_f, feq)
diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py
index 5cad5048..9865585c 100644
--- a/xlb/operator/boundary_condition/bc_zouhe.py
+++ b/xlb/operator/boundary_condition/bc_zouhe.py
@@ -1,5 +1,14 @@
"""
-Base class for boundary conditions in a LBM simulation.
+Zou-He boundary condition.
+
+Sets unknown populations at velocity or pressure boundaries using
+mass and momentum conservation combined with non-equilibrium
+bounce-back. Commonly used for inlets and outlets.
+
+Reference
+---------
+Zou, Q. & He, X. (1997). "On pressure and velocity boundary conditions
+for the lattice Boltzmann BGK model." *Physics of Fluids*, 9(6), 1591.
"""
import jax.numpy as jnp
@@ -7,7 +16,7 @@
import jax.lax as lax
from functools import partial
import warp as wp
-from typing import Any, Union, Tuple
+from typing import Any, Union, Tuple, Callable
import numpy as np
from xlb.velocity_set.velocity_set import VelocitySet
@@ -20,6 +29,8 @@
)
from xlb.operator.boundary_condition import HelperFunctionsBC
from xlb.operator.equilibrium import QuadraticEquilibrium
+from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod
+from xlb.operator.boundary_condition.helper_functions_bc import EncodeAuxiliaryData
class ZouHeBC(BoundaryCondition):
@@ -37,20 +48,20 @@ class ZouHeBC(BoundaryCondition):
def __init__(
self,
bc_type,
- profile=None,
+ profile: Callable = None,
prescribed_value: Union[float, Tuple[float, ...], np.ndarray] = None,
velocity_set: VelocitySet = None,
precision_policy: PrecisionPolicy = None,
compute_backend: ComputeBackend = None,
indices=None,
mesh_vertices=None,
+ voxelization_method: MeshVoxelizationMethod = None,
):
# Important Note: it is critical to add id inside __init__ for this BC because different instantiations of this BC
# may have different types (velocity or pressure).
assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'."
self.bc_type = bc_type
self.equilibrium_operator = QuadraticEquilibrium()
- self.profile = profile
# Call the parent constructor
super().__init__(
@@ -60,28 +71,29 @@ def __init__(
compute_backend,
indices,
mesh_vertices,
+ voxelization_method,
)
+ # This BC class accepts both constant prescribed values of velocity with keyword "prescribed_value" or
+ # velocity profiles given by keyword "profile" which must be a callable function.
+ self.profile = profile
+
# Handle prescribed value if provided
if prescribed_value is not None:
if profile is not None:
raise ValueError("Cannot specify both profile and prescribed_value")
- # Convert input to numpy array for validation
- if isinstance(prescribed_value, (tuple, list)):
- prescribed_value = np.array(prescribed_value, dtype=np.float64)
- elif isinstance(prescribed_value, (int, float)):
- if bc_type == "pressure":
+ # Ensure prescribed_value is a NumPy array of floats
+ if bc_type == "velocity":
+ if isinstance(prescribed_value, (tuple, list, np.ndarray)):
+ prescribed_value = np.asarray(prescribed_value, dtype=np.float64)
+ else:
+ raise ValueError("Velocity prescribed_value must be a tuple, list, or array-like")
+ elif bc_type == "pressure":
+ if isinstance(prescribed_value, (int, float)):
prescribed_value = float(prescribed_value)
else:
- raise ValueError("Velocity prescribed_value must be a tuple or array")
- elif isinstance(prescribed_value, np.ndarray):
- prescribed_value = prescribed_value.astype(np.float64)
-
- # Validate prescribed value
- if bc_type == "velocity":
- if not isinstance(prescribed_value, np.ndarray):
- raise ValueError("Velocity prescribed_value must be an array-like")
+ raise ValueError("Pressure prescribed_value must be a scalar (int or float)")
# Check for non-zero elements - only one element should be non-zero
non_zero_count = np.count_nonzero(prescribed_value)
@@ -93,21 +105,38 @@ def __init__(
# a single non-zero number associated with pressure BC OR
# a vector of zeros associated with no-slip BC.
# Accounting for all scenarios here.
- if self.compute_backend is ComputeBackend.WARP:
+ if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]:
idx = np.nonzero(prescribed_value)[0]
prescribed_value = prescribed_value[idx][0] if idx.size else 0.0
prescribed_value = self.precision_policy.store_precision.wp_dtype(prescribed_value)
self.prescribed_value = prescribed_value
self.profile = self._create_constant_prescribed_profile()
- # This BC needs auxilary data initialization before streaming
- self.needs_aux_init = True
+ if self.compute_backend == ComputeBackend.JAX:
+ self.prescribed_values = self.profile()
+ else:
+ # This BC needs auxiliary data initialization before streaming
+ self.needs_aux_init = True
+
+ # This BC needs auxiliary data recovery after streaming
+ self.needs_aux_recovery = True
- # This BC needs auxilary data recovery after streaming
- self.needs_aux_recovery = True
+ # This BC needs one auxiliary data for the density or normal velocity
+ self.num_of_aux_data = 1
- # This BC needs one auxilary data for the density or normal velocity
- self.num_of_aux_data = 1
+ # Create the encoder operator for storing the auxiliary data
+ encode_auxiliary_data = EncodeAuxiliaryData(
+ self.id,
+ self.num_of_aux_data,
+ self.profile,
+ velocity_set=self.velocity_set,
+ precision_policy=self.precision_policy,
+ compute_backend=self.compute_backend,
+ )
+
+ # get decoder functional
+ functional_dict, _ = encode_auxiliary_data._construct_warp()
+ self.decoder_functional = functional_dict["decoder"]
# This BC needs padding for finding missing directions when imposed on a geometry that is in the domain interior
self.needs_padding = True
@@ -126,6 +155,8 @@ def prescribed_profile_jax():
return prescribed_profile_jax
elif self.compute_backend == ComputeBackend.WARP:
return prescribed_profile_warp
+ elif self.compute_backend == ComputeBackend.NEON:
+ return prescribed_profile_warp
@partial(jit, static_argnums=(0,), inline=True)
def _get_known_middle_mask(self, missing_mask):
@@ -269,12 +300,11 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask):
return f_post
def _construct_warp(self):
- # load helper functions
- bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)
+ # load helper functions. Always use warp backend for helper functions as it may also be called by the Neon backend.
+ bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=ComputeBackend.WARP)
+
# Set local constants
_d = self.velocity_set.d
- _q = self.velocity_set.q
- _opp_indices = self.velocity_set.opp_indices
@wp.func
def functional_velocity(
@@ -299,7 +329,7 @@ def functional_velocity(
# Find the value of u from the missing directions
# Since we are only considering normal velocity, we only need to find one value (stored at the center of f_1)
# Create velocity vector by multiplying the prescribed value with the normal vector
- prescribed_value = f_1[0, index[0], index[1], index[2]]
+ prescribed_value = self.decoder_functional(f_1, index, _missing_mask)[0]
_u = -prescribed_value * normals
for d in range(_d):
@@ -330,7 +360,7 @@ def functional_pressure(
# Find the value of rho from the missing directions
# Since we need only one scalar value, we only need to find one value (stored at the center of f_1)
- _rho = f_1[0, index[0], index[1], index[2]]
+ _rho = self.decoder_functional(f_1, index, _missing_mask)[0]
# calculate velocity
fsum = bc_helper.get_bc_fsum(_f, _missing_mask)
@@ -360,3 +390,15 @@ def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask):
dim=f_pre.shape[1:],
)
return f_post
+
+ def _construct_neon(self):
+ # Redefine the quadratic eq operator for the neon backend
+ # This is because the neon backend relies on the warp functionals for its operations.
+ self.equilibrium_operator = QuadraticEquilibrium(compute_backend=ComputeBackend.WARP)
+ functional, _ = self._construct_warp()
+ return functional, None
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f_pre, f_post, bc_mask, missing_mask):
+ # raise exception as this feature is not implemented yet
+ raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.")
diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py
index 0b8a93c6..4ac2a96a 100644
--- a/xlb/operator/boundary_condition/boundary_condition.py
+++ b/xlb/operator/boundary_condition/boundary_condition.py
@@ -1,5 +1,9 @@
"""
-Base class for boundary conditions in a LBM simulation.
+Base class for boundary conditions in a Lattice Boltzmann simulation.
+
+Every concrete BC inherits from :class:`BoundaryCondition`, which provides
+a registration mechanism, helper-function access, and the boilerplate
+needed to encode auxiliary data into the ``f_1`` buffer.
"""
from enum import Enum, auto
@@ -7,8 +11,7 @@
from typing import Any
from jax import jit
from functools import partial
-import jax
-import jax.numpy as jnp
+import numpy as np
from xlb.velocity_set.velocity_set import VelocitySet
from xlb.precision_policy import PrecisionPolicy
@@ -17,17 +20,39 @@
from xlb import DefaultConfig
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
from xlb.operator.boundary_condition import HelperFunctionsBC
+from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod
-# Enum for implementation step
class ImplementationStep(Enum):
+ """At which algorithmic stage the boundary condition is applied."""
+
COLLISION = auto()
STREAMING = auto()
class BoundaryCondition(Operator):
- """
- Base class for boundary conditions in a LBM simulation.
+ """Abstract base class for all LBM boundary conditions.
+
+ Each BC is registered with a unique numeric *id* and annotated with:
+
+ * ``implementation_step`` - whether it executes after streaming or after
+ collision.
+ * ``needs_aux_recovery`` / ``needs_aux_init`` - whether the BC stores
+ auxiliary data in the ``f_1`` distribution buffer.
+
+ Parameters
+ ----------
+ implementation_step : ImplementationStep
+ Phase in the LBM algorithm where this BC is applied.
+ velocity_set : VelocitySet, optional
+ precision_policy : PrecisionPolicy, optional
+ compute_backend : ComputeBackend, optional
+ indices : array-like, optional
+ Explicit voxel indices for this BC.
+ mesh_vertices : array-like, optional
+ Mesh vertices for geometry-based BCs.
+ voxelization_method : MeshVoxelizationMethod, optional
+ Voxelization strategy when *mesh_vertices* is provided.
"""
def __init__(
@@ -38,6 +63,7 @@ def __init__(
compute_backend: ComputeBackend = None,
indices=None,
mesh_vertices=None,
+ voxelization_method: MeshVoxelizationMethod = None,
):
self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + str(hash(self)))
velocity_set = velocity_set or DefaultConfig.velocity_set
@@ -54,28 +80,65 @@ def __init__(
self.implementation_step = implementation_step
# A flag to indicate whether bc indices need to be padded in both normal directions to identify missing directions
- # when inside/outside of the geoemtry is not known
+ # when inside/outside of the geometry is not known
self.needs_padding = False
- # A flag for BCs that need implicit boundary distance between the grid and a mesh (to be set to True if applicable inside each BC)
+ # A flag for BCs that need normalized distance between the grid and a mesh (to be set to True if applicable inside each BC)
self.needs_mesh_distance = False
- # A flag for BCs that need auxilary data initialization before stepper
+ # A flag for BCs that need auxiliary data initialization before stepper
self.needs_aux_init = False
- # A flag to track if the BC is initialized with auxilary data
+ # A flag to track if the BC is initialized with auxiliary data
self.is_initialized_with_aux_data = False
- # Number of auxilary data needed for the BC (for prescribed values)
+ # Number of auxiliary data needed for the BC (for prescribed values)
self.num_of_aux_data = 0
- # A flag for BCs that need auxilary data recovery after streaming
+ # A flag for BCs that need auxiliary data recovery after streaming
self.needs_aux_recovery = False
+ # Voxelization method. For BC's specified on a mesh, the user can specify the voxelization scheme.
+ # Currently we support three methods based on (a) aabb method (b) ray casting and (c) winding number.
+ self.voxelization_method = voxelization_method
+
+ # Construct a default warp functional for assembling auxiliary data if needed
+ if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]:
+
+ @wp.func
+ def assemble_auxiliary_data(
+ index: Any,
+ timestep: Any,
+ missing_mask: Any,
+ f_0: Any,
+ f_1: Any,
+ f_pre: Any,
+ f_post: Any,
+ level: Any = 0,
+ ):
+ return f_post
+
+ self.assemble_auxiliary_data = assemble_auxiliary_data
+
+ def pad_indices(self):
+ """
+ This method pads the indices to ensure that the boundary condition can be applied correctly.
+ It is used to find missing directions in indices_boundary_masker when the BC is imposed on a
+ geometry that is in the domain interior.
+ """
+ _d = self.velocity_set.d
+ bc_indices = np.array(self.indices)
+ lattice_velocity_np = self.velocity_set._c
+ if self.needs_padding:
+ bc_indices_padded = bc_indices[:, :, None] + lattice_velocity_np[:, None, :]
+ return np.unique(bc_indices_padded.reshape(_d, -1), axis=1)
+ else:
+ return bc_indices
+
@partial(jit, static_argnums=(0,), inline=True)
- def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask):
+ def assemble_auxiliary_data(self, f_pre, f_post, bc_mask, missing_mask):
"""
- A placeholder function for prepare the auxilary distribution functions for the boundary condition.
+ A placeholder function for prepare the auxiliary distribution functions for the boundary condition.
currently being called after collision only.
"""
return f_post
@@ -94,14 +157,14 @@ def kernel(
f_pre: wp.array4d(dtype=Any),
f_post: wp.array4d(dtype=Any),
bc_mask: wp.array4d(dtype=wp.uint8),
- missing_mask: wp.array4d(dtype=wp.bool),
+ missing_mask: wp.array4d(dtype=wp.uint8),
):
# Get the global index
i, j, k = wp.tid()
index = wp.vec3i(i, j, k)
# read tid data
- _f_pre, _f_post, _boundary_id, _missing_mask = bc_helper.get_thread_data(f_pre, f_post, bc_mask, missing_mask, index)
+ _f_pre, _f_post, _boundary_id, _missing_mask = bc_helper.get_bc_thread_data(f_pre, f_post, bc_mask, missing_mask, index)
# Apply the boundary condition
if _boundary_id == _id:
@@ -115,61 +178,3 @@ def kernel(
f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l])
return kernel
-
- def _construct_aux_data_init_kernel(self, functional):
- """
- Constructs the warp kernel for the auxilary data recovery.
- """
- bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend)
-
- _id = wp.uint8(self.id)
- _opp_indices = self.velocity_set.opp_indices
- _num_of_aux_data = self.num_of_aux_data
-
- # Construct the warp kernel
- @wp.kernel
- def aux_data_init_kernel(
- f_0: wp.array4d(dtype=Any),
- f_1: wp.array4d(dtype=Any),
- bc_mask: wp.array4d(dtype=wp.uint8),
- missing_mask: wp.array4d(dtype=wp.bool),
- ):
- # Get the global index
- i, j, k = wp.tid()
- index = wp.vec3i(i, j, k)
-
- # read tid data
- _f_0, _f_1, _boundary_id, _missing_mask = bc_helper.get_thread_data(f_0, f_1, bc_mask, missing_mask, index)
-
- # Apply the functional
- if _boundary_id == _id:
- # prescribed_values is a q-sized vector of type wp.vec
- prescribed_values = functional(index)
- # Write the result for all q directions, but only store up to num_of_aux_data
- # TODO: Somehow raise an error if the number of prescribed values does not match the number of missing directions
-
- # The first BC auxiliary data is stored in the zero'th index of f_1 associated with its center.
- f_1[0, index[0], index[1], index[2]] = self.store_dtype(prescribed_values[0])
- counter = wp.int32(1)
-
- # The other remaining BC auxiliary data are stored in missing directions of f_1.
- for l in range(1, self.velocity_set.q):
- if _missing_mask[l] == wp.uint8(1) and counter < _num_of_aux_data:
- f_1[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(prescribed_values[counter])
- counter += 1
-
- return aux_data_init_kernel
-
- def aux_data_init(self, f_0, f_1, bc_mask, missing_mask):
- if self.compute_backend == ComputeBackend.WARP:
- # Launch the warp kernel
- wp.launch(
- self._construct_aux_data_init_kernel(self.profile),
- inputs=[f_0, f_1, bc_mask, missing_mask],
- dim=f_0.shape[1:],
- )
- elif self.compute_backend == ComputeBackend.JAX:
- # We don't use boundary aux encoding/decoding in JAX
- self.prescribed_values = self.profile()
- self.is_initialized_with_aux_data = True
- return f_0, f_1
diff --git a/xlb/operator/boundary_condition/helper_functions_bc.py b/xlb/operator/boundary_condition/helper_functions_bc.py
index 6f8e768b..c613573b 100644
--- a/xlb/operator/boundary_condition/helper_functions_bc.py
+++ b/xlb/operator/boundary_condition/helper_functions_bc.py
@@ -1,11 +1,44 @@
-from xlb import DefaultConfig, ComputeBackend
-from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux
+"""
+Warp/Neon helper functions shared by multiple boundary conditions.
+
+:class:`HelperFunctionsBC` exposes ``@wp.func`` helpers for bounce-back,
+regularization, Grad's approximation, moving-wall corrections,
+interpolated BCs, and BC thread-data loading. These are used as building
+blocks by the concrete BC classes.
+
+Also contains :class:`EncodeAuxiliaryData` and
+:class:`MultiresEncodeAuxiliaryData` operators for writing user-prescribed
+BC profiles into the ``f_1`` buffer during initialization.
+"""
+
+import inspect
+from typing import Any, Callable
+
import warp as wp
-from typing import Any
+
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb import DefaultConfig, ComputeBackend
+from xlb.operator.operator import Operator
+from xlb.operator.macroscopic import SecondMoment as MomentumFlux
+from xlb.operator.macroscopic import Macroscopic
+from xlb.operator.equilibrium import QuadraticEquilibrium
class HelperFunctionsBC(object):
- def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None):
+ """Collection of Warp/Neon ``@wp.func`` helpers for boundary conditions.
+
+ Parameters
+ ----------
+ velocity_set : VelocitySet, optional
+ precision_policy : PrecisionPolicy, optional
+ compute_backend : ComputeBackend, optional
+ Must be ``WARP`` or ``NEON`` (JAX not supported).
+ distance_decoder_function : callable, optional
+ Function to decode wall-distance data for interpolated BCs.
+ """
+
+ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None, distance_decoder_function=None):
if compute_backend == ComputeBackend.JAX:
raise ValueError("This helper class contains helper functions only for the WARP implementation of some BCs not JAX!")
@@ -13,6 +46,7 @@ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=Non
self.velocity_set = velocity_set or DefaultConfig.velocity_set
self.precision_policy = precision_policy or DefaultConfig.default_precision_policy
self.compute_backend = compute_backend or DefaultConfig.default_backend
+ self.distance_decoder_function = distance_decoder_function
# Set the compute and Store dtypes
compute_dtype = self.precision_policy.compute_precision.wp_dtype
@@ -30,15 +64,21 @@ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=Non
_f_vec = wp.vec(_q, dtype=compute_dtype)
_missing_mask_vec = wp.vec(_q, dtype=wp.uint8) # TODO fix vec bool
+ # Define the operator needed for computing equilibrium
+ equilibrium = QuadraticEquilibrium(velocity_set, precision_policy, compute_backend)
+
+ # Define the operator needed for computing macroscopic variables
+ macroscopic = Macroscopic(velocity_set, precision_policy, compute_backend)
+
# Define the operator needed for computing the momentum flux
momentum_flux = MomentumFlux(velocity_set, precision_policy, compute_backend)
@wp.func
- def get_thread_data(
+ def get_bc_thread_data(
f_pre: wp.array4d(dtype=Any),
f_post: wp.array4d(dtype=Any),
bc_mask: wp.array4d(dtype=wp.uint8),
- missing_mask: wp.array4d(dtype=wp.bool),
+ missing_mask: wp.array4d(dtype=wp.uint8),
index: wp.vec3i,
):
# Get the boundary id and missing mask
@@ -58,41 +98,62 @@ def get_thread_data(
_missing_mask[l] = wp.uint8(0)
return _f_pre, _f_post, _boundary_id, _missing_mask
+ @wp.func
+ def neon_get_bc_thread_data(
+ f_pre_pn: Any,
+ f_post_pn: Any,
+ bc_mask_pn: Any,
+ missing_mask_pn: Any,
+ index: Any,
+ ):
+ # Get the boundary id and missing mask
+ _f_pre = _f_vec()
+ _f_post = _f_vec()
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ _missing_mask = _missing_mask_vec()
+ for l in range(_q):
+ # q-sized vector of populations
+ _f_pre[l] = compute_dtype(wp.neon_read(f_pre_pn, index, l))
+ _f_post[l] = compute_dtype(wp.neon_read(f_post_pn, index, l))
+ _missing_mask[l] = wp.neon_read(missing_mask_pn, index, l)
+
+ return _f_pre, _f_post, _boundary_id, _missing_mask
+
@wp.func
def get_bc_fsum(
fpop: Any,
- missing_mask: Any,
+ _missing_mask: Any,
):
fsum_known = compute_dtype(0.0)
fsum_middle = compute_dtype(0.0)
for l in range(_q):
- if missing_mask[_opp_indices[l]] == wp.uint8(1):
+ if _missing_mask[_opp_indices[l]] == wp.uint8(1):
fsum_known += compute_dtype(2.0) * fpop[l]
- elif missing_mask[l] != wp.uint8(1):
+ elif _missing_mask[l] != wp.uint8(1):
fsum_middle += fpop[l]
return fsum_known + fsum_middle
@wp.func
def get_normal_vectors(
- missing_mask: Any,
+ _missing_mask: Any,
):
if wp.static(_d == 3):
for l in range(_q):
- if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
+ if _missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l])
else:
for l in range(_q):
- if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
+ if _missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1:
return -_u_vec(_c_float[0, l], _c_float[1, l])
@wp.func
def bounceback_nonequilibrium(
fpop: Any,
feq: Any,
- missing_mask: Any,
+ _missing_mask: Any,
):
for l in range(_q):
- if missing_mask[l] == wp.uint8(1):
+ if _missing_mask[l] == wp.uint8(1):
fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]]
return fpop
@@ -121,8 +182,461 @@ def regularize_fpop(
fpop[l] = feq[l] + fpop1
return fpop
- self.get_thread_data = get_thread_data
+ @wp.func
+ def grads_approximate_fpop(
+ _missing_mask: Any,
+ rho: Any,
+ u: Any,
+ f_post: Any,
+ ):
+ # Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and
+ # Dirichlet BCs [2]
+ # [1] S. Chikatax`marla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann
+ # simulations", Europhys. Lett. 74, 215 (2006).
+ # [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and
+ # stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354.
+
+ # Note: See also self.regularize_fpop function which is somewhat similar.
+
+ # Compute pressure tensor Pi using all f_post-streaming values
+ Pi = momentum_flux.warp_functional(f_post)
+
+ # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq)
+ nt = _d * (_d + 1) // 2
+ for l in range(_q):
+ if _missing_mask[l] == wp.uint8(1):
+ # compute dot product of qi and Pi
+ QiPi = compute_dtype(0.0)
+ for t in range(nt):
+ if t == 0 or t == 3 or t == 5:
+ QiPi += _qi[l, t] * (Pi[t] - rho / compute_dtype(3.0))
+ else:
+ QiPi += _qi[l, t] * Pi[t]
+
+ # Compute c.u
+ cu = compute_dtype(0.0)
+ for d in range(_d):
+ if _c[d, l] == 1:
+ cu += u[d]
+ elif _c[d, l] == -1:
+ cu -= u[d]
+ cu *= compute_dtype(3.0)
+
+ # change f_post using the Grad's approximation
+ f_post[l] = rho * _w[l] * (compute_dtype(1.0) + cu) + _w[l] * compute_dtype(4.5) * QiPi
+
+ return f_post
+
+ @wp.func
+ def moving_wall_fpop_correction(
+ u_wall: Any,
+ lattice_direction: Any,
+ ):
+ # Add forcing term necessary to account for the local density changes caused by the mass displacement
+ # as the object moves with velocity u_wall.
+ # [1] L.-S. Luo, Unified theory of lattice Boltzmann models for nonideal gases, Phys. Rev. Lett. 81 (1998) 1618-1621.
+ # [2] L.-S. Luo, Theory of the lattice Boltzmann method: Lattice Boltzmann models for nonideal gases, Phys. Rev. E 62 (2000) 4982-4996.
+ #
+ # Note: this function must be called within a for-loop over all lattice directions and the populations to be modified must
+ # be only those in the missing direction (the check for missing direction must be outside of this function).
+ cu = compute_dtype(0.0)
+ l = lattice_direction
+ for d in range(_d):
+ if _c[d, l] == 1:
+ cu += u_wall[d]
+ elif _c[d, l] == -1:
+ cu -= u_wall[d]
+ cu *= compute_dtype(6.0) * _w[l]
+ return cu
+
+ @wp.func
+ def interpolated_bounceback(
+ index: Any,
+ _missing_mask: Any,
+ f_0: Any,
+ f_1: Any,
+ f_pre: Any,
+ f_post: Any,
+ u_wall: Any,
+ needs_moving_wall_treatment: bool,
+ needs_mesh_distance: bool,
+ ):
+ # A local single-node version of the interpolated bounce-back boundary condition due to Bouzidi for a lattice
+ # Boltzmann method simulation.
+ # Ref:
+ # [1] Yu, D., Mei, R., Shyy, W., 2003. A unified boundary treatment in lattice boltzmann method,
+ # in: 41st aerospace sciences meeting and exhibit, p. 953.
+
+ one = compute_dtype(1.0)
+ for l in range(_q):
+ # If the mask is missing then take the opposite index
+ if _missing_mask[l] == wp.uint8(1):
+ # The normalized distance to the mesh or "weights" have been stored in known directions of f_1
+ if needs_mesh_distance:
+ # use weights associated with curved boundaries that are properly stored in f_1.
+ weight = compute_dtype(self.distance_decoder_function(f_1, index, l))
+
+ # Use differentiable interpolated BB to find f_missing:
+ f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight)
+ else:
+ # Use regular halfway bounceback
+ f_post[l] = f_pre[_opp_indices[l]]
+
+ if _missing_mask[_opp_indices[l]] == wp.uint8(1):
+ # These are cases where the boundary is sandwiched between 2 solid cells and so both opposite directions are missing.
+ f_post[l] = f_pre[_opp_indices[l]]
+
+ # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC
+ if needs_moving_wall_treatment:
+ f_post[l] += moving_wall_fpop_correction(u_wall, l)
+ return f_post
+
+ @wp.func
+ def interpolated_nonequilibrium_bounceback(
+ index: Any,
+ _missing_mask: Any,
+ f_0: Any,
+ f_1: Any,
+ f_pre: Any,
+ f_post: Any,
+ u_wall: Any,
+ needs_moving_wall_treatment: bool,
+ needs_mesh_distance: bool,
+ ):
+ # Compute density, velocity using all f_post-collision values
+ rho, u = macroscopic.warp_functional(f_pre)
+ feq = equilibrium.warp_functional(rho, u)
+
+ # Compute equilibrium distribution at the wall
+ if needs_moving_wall_treatment:
+ feq_wall = equilibrium.warp_functional(rho, u_wall)
+ else:
+ feq_wall = _f_vec()
+
+ # Apply method in Tao et al (2018) [1] to find missing populations at the boundary
+ one = compute_dtype(1.0)
+ for l in range(_q):
+ # If the mask is missing then take the opposite index
+ if _missing_mask[l] == wp.uint8(1):
+ # The normalized distance to the mesh or "weights" have been stored in known directions of f_1
+ if needs_mesh_distance:
+ # use weights associated with curved boundaries that are properly stored in f_1.
+ weight = compute_dtype(self.distance_decoder_function(f_1, index, l))
+ else:
+ weight = compute_dtype(0.5)
+
+ # Use non-equilibrium bounceback to find f_missing:
+ fneq = f_pre[_opp_indices[l]] - feq[_opp_indices[l]]
+
+ # Compute equilibrium distribution at the wall
+ # Same quadratic equilibrium but accounting for zero velocity (no-slip)
+ if not needs_moving_wall_treatment:
+ feq_wall[l] = _w[l] * rho
+
+ # Assemble wall population for doing interpolation at the boundary
+ f_wall = feq_wall[l] + fneq
+ f_post[l] = (f_wall + weight * f_pre[l]) / (one + weight)
+
+ return f_post
+
+ @wp.func
+ def neon_index_to_warp(neon_field_hdl: Any, index: Any):
+ # Unpack the global index in Neon at the finest level and convert it to a warp vector
+ cIdx = wp.neon_global_idx(neon_field_hdl, index)
+ gx = wp.neon_get_x(cIdx)
+ gy = wp.neon_get_y(cIdx)
+ gz = wp.neon_get_z(cIdx)
+
+ # XLB is flattening the z dimension in 3D, while neon uses the y dimension
+ if _d == 2:
+ gy, gz = gz, gy
+
+ # Get warp indices
+ index_wp = wp.vec3i(gx, gy, gz)
+ return index_wp
+
+ self.get_bc_thread_data = get_bc_thread_data
self.get_bc_fsum = get_bc_fsum
self.get_normal_vectors = get_normal_vectors
self.bounceback_nonequilibrium = bounceback_nonequilibrium
self.regularize_fpop = regularize_fpop
+ self.grads_approximate_fpop = grads_approximate_fpop
+ self.moving_wall_fpop_correction = moving_wall_fpop_correction
+ self.interpolated_bounceback = interpolated_bounceback
+ self.interpolated_nonequilibrium_bounceback = interpolated_nonequilibrium_bounceback
+ self.neon_get_bc_thread_data = neon_get_bc_thread_data
+ self.neon_index_to_warp = neon_index_to_warp
+
+
+class EncodeAuxiliaryData(Operator):
+ """
+ Operator for encoding boundary auxiliary data during initialization.
+ """
+
+ def __init__(
+ self,
+ boundary_id: int,
+ num_of_aux_data: int,
+ user_defined_functional: Callable,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ ):
+ self.user_defined_functional = user_defined_functional
+ self.boundary_id = wp.uint8(boundary_id)
+ self.num_of_aux_data = num_of_aux_data
+
+ super().__init__(velocity_set, precision_policy, compute_backend)
+
+ # Inspect the signature of the user-defined functional.
+ # We assume the profile function takes only the index as input and is hence time-independent.
+ sig = inspect.signature(user_defined_functional)
+ assert self.compute_backend != ComputeBackend.JAX, "Encoding/decoding of auxiliary data are not required for boundary conditions in JAX"
+ assert len(sig.parameters) == 1, f"User-defined functional must take exactly one argument (the index), it received {len(sig.parameters)}."
+
+ # Define a HelperFunctionsBC instance
+ self.bc_helper = HelperFunctionsBC(
+ velocity_set=self.velocity_set,
+ precision_policy=self.precision_policy,
+ compute_backend=self.compute_backend,
+ )
+
+ # TODO: Somehow raise an error if the number of prescribed values does not match the number of missing directions
+
+ def _construct_warp(self):
+ """
+ Constructs the warp kernel for the auxiliary data recovery.
+ """
+ # Find velocity index for (0, 0, 0)
+ lattice_central_index = self.velocity_set.center_index
+ _opp_indices = self.velocity_set.opp_indices
+ _id = self.boundary_id
+ _num_of_aux_data = self.num_of_aux_data
+ _aux_vec = wp.vec(_num_of_aux_data, dtype=self.compute_dtype)
+
+ @wp.func
+ def encoder_functional(
+ index: Any,
+ _missing_mask: Any,
+ field_storage: Any,
+ prescribed_values: Any,
+ ):
+ if len(prescribed_values) != _num_of_aux_data:
+ wp.printf("Error: User-defined profile must return a vector of size %d\n", _num_of_aux_data)
+ return
+
+ # Write the result for all q directions, but only store up to _num_of_aux_data
+ counter = wp.int32(0)
+ for l in range(self.velocity_set.q):
+ # Only store up to _num_of_aux_data
+ if counter == _num_of_aux_data:
+ return
+
+ if l == lattice_central_index:
+ # The first BC auxiliary data is stored in the zero'th index of f_1 associated with its center.
+ self.write_field(field_storage, index, l, self.store_dtype(prescribed_values[l]))
+ counter += 1
+ elif _missing_mask[l] == wp.uint8(1):
+ # The other remaining BC auxiliary data are stored in missing directions of f_1.
+ self.write_field(field_storage, index, _opp_indices[l], self.store_dtype(prescribed_values[l]))
+ counter += 1
+
+ @wp.func
+ def decoder_functional(
+ field_storage: Any,
+ index: Any,
+ _missing_mask: Any,
+ ):
+ """
+ Decode the encoded values needed for the boundary condition treatment from the center location in field_storage.
+ """
+
+ # Define a vector to hold prescribed_values
+ prescribed_values = _aux_vec()
+
+ # Read all q directions, but only retrieve up to _num_of_aux_data
+ counter = wp.int32(0)
+ for l in range(self.velocity_set.q):
+ # Only retrieve up to _num_of_aux_data
+ if counter == _num_of_aux_data:
+ return prescribed_values
+
+ if l == lattice_central_index:
+ # The first BC auxiliary data is stored in the zero'th index of f_1 associated with its center.
+ value = self.read_field(field_storage, index, l)
+ prescribed_values[counter] = self.compute_dtype(value)
+ counter += 1
+ elif _missing_mask[l] == wp.uint8(1):
+ # The other remaining BC auxiliary data are stored in missing directions of f_1.
+ value = self.read_field(field_storage, index, _opp_indices[l])
+ prescribed_values[counter] = self.compute_dtype(value)
+ counter += 1
+
+ # Construct the warp kernel
+ @wp.kernel
+ def kernel(
+ f_1: wp.array4d(dtype=Any),
+ bc_mask: wp.array4d(dtype=wp.uint8),
+ missing_mask: wp.array4d(dtype=wp.uint8),
+ ):
+ # Get the global index
+ i, j, k = wp.tid()
+ index = wp.vec3i(i, j, k)
+
+ # read tid data
+ _, _, _boundary_id, _missing_mask = self.bc_helper.get_bc_thread_data(f_1, f_1, bc_mask, missing_mask, index)
+
+ # Apply the functional
+ # change this to use central location
+ if _boundary_id == _id:
+ # prescribed_values is a q-sized vector of type wp.vec
+ prescribed_values = self.user_defined_functional(index)
+
+ # call the functional
+ encoder_functional(index, _missing_mask, f_1, prescribed_values)
+
+ functional_dict = {"encoder": encoder_functional, "decoder": decoder_functional}
+ return functional_dict, kernel
+
+ def _construct_neon(self):
+ import neon
+
+ """
+ Constructs the Neon container for encoding auxiliary data recovery.
+ """
+ # Use the warp functional for the Neon backend
+ functional_dict, _ = self._construct_warp()
+ encoder_functional = functional_dict["encoder"]
+ _id = self.boundary_id
+
+ # Construct the Neon container
+ @neon.Container.factory(name="EncodingAuxData_" + str(_id))
+ def aux_data_init_container(
+ f_1: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ ):
+ def aux_data_init_ll(loader: neon.Loader):
+ loader.set_grid(f_1.get_grid())
+
+ f_1_pn = loader.get_write_handle(f_1)
+ bc_mask_pn = loader.get_read_handle(bc_mask)
+ missing_mask_pn = loader.get_read_handle(missing_mask)
+
+ @wp.func
+ def aux_data_init_cl(index: Any):
+ # read tid data
+ _, _, _boundary_id, _missing_mask = self.bc_helper.neon_get_bc_thread_data(f_1_pn, f_1_pn, bc_mask_pn, missing_mask_pn, index)
+
+ # Apply the functional
+ if _boundary_id == _id:
+ warp_index = self.bc_helper.neon_index_to_warp(f_1_pn, index)
+ prescribed_values = self.user_defined_functional(warp_index)
+
+ # Call the functional
+ encoder_functional(index, _missing_mask, f_1_pn, prescribed_values)
+
+ # Declare the kernel in the Neon loader
+ loader.declare_kernel(aux_data_init_cl)
+
+ return aux_data_init_ll
+
+ return functional_dict, aux_data_init_container
+
+ @Operator.register_backend(ComputeBackend.WARP)
+ def warp_implementation(self, f_1, bc_mask, missing_mask):
+ # Launch the warp kernel
+ wp.launch(
+ self.warp_kernel,
+ inputs=[f_1, bc_mask, missing_mask],
+ dim=f_1.shape[1:],
+ )
+ return f_1
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f_1, bc_mask, missing_mask):
+ c = self.neon_container(f_1, bc_mask, missing_mask)
+ c.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+ return f_1
+
+
+class MultiresEncodeAuxiliaryData(EncodeAuxiliaryData):
+ """
+ Operator for encoding boundary auxiliary data during initialization.
+ """
+
+ def __init__(
+ self,
+ boundary_id: int,
+ num_of_aux_data: int,
+ user_defined_functional: Callable,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ ):
+ super().__init__(
+ boundary_id=boundary_id,
+ num_of_aux_data=num_of_aux_data,
+ user_defined_functional=user_defined_functional,
+ velocity_set=velocity_set,
+ precision_policy=precision_policy,
+ compute_backend=compute_backend,
+ )
+
+ assert self.compute_backend == ComputeBackend.NEON, f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend."
+
+ def _construct_neon(self):
+ """
+ Constructs the Neon container for encoding auxiliary data recovery.
+ """
+
+ # Borrow the functional from the warp implementation
+ functional_dict, _ = self._construct_warp()
+ encoder_functional = functional_dict["encoder"]
+ _id = self.boundary_id
+
+ # Construct the Neon container
+ @neon.Container.factory(name="MultiresEncodingAuxData_" + str(_id))
+ def aux_data_init_container(
+ f_1: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ level: Any,
+ ):
+ def aux_data_init_ll(loader: neon.Loader):
+ loader.set_mres_grid(f_1.get_grid(), level)
+
+ f_1_pn = loader.get_mres_write_handle(f_1)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask)
+
+ @wp.func
+ def aux_data_init_cl(index: Any):
+ # read tid data
+ _, _, _boundary_id, _missing_mask = self.bc_helper.neon_get_bc_thread_data(f_1_pn, f_1_pn, bc_mask_pn, missing_mask_pn, index)
+
+ # Apply the functional
+ if _boundary_id == _id:
+ # IMPORTANT NOTE:
+ # It is assumed in XLB that the user_defined_functional in multi-res simulations is defined in terms of the indices at the finest level.
+ # This assumption enables handling of BCs whose indices span multiple levels
+ warp_index = self.bc_helper.neon_index_to_warp(f_1_pn, index)
+ prescribed_values = self.user_defined_functional(warp_index)
+
+ # Call the functional
+ encoder_functional(index, _missing_mask, f_1_pn, prescribed_values)
+
+ # Declare the kernel in the Neon loader
+ loader.declare_kernel(aux_data_init_cl)
+
+ return aux_data_init_ll
+
+ return functional_dict, aux_data_init_container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f_1, bc_mask, missing_mask, stream):
+ grid = bc_mask.get_grid()
+ for level in range(grid.num_levels):
+ c = self.neon_container(f_1, bc_mask, missing_mask, level)
+ c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon)
+ return f_1
diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py
index 3417c3c8..5a1ceb75 100644
--- a/xlb/operator/boundary_masker/__init__.py
+++ b/xlb/operator/boundary_masker/__init__.py
@@ -1,2 +1,12 @@
+from xlb.operator.boundary_masker.helper_functions_masker import HelperFunctionsMasker
from xlb.operator.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker
from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker
+from xlb.operator.boundary_masker.aabb import MeshMaskerAABB
+from xlb.operator.boundary_masker.ray import MeshMaskerRay
+from xlb.operator.boundary_masker.winding import MeshMaskerWinding
+from xlb.operator.boundary_masker.aabb_close import MeshMaskerAABBClose
+from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod
+from xlb.operator.boundary_masker.multires_aabb import MultiresMeshMaskerAABB
+from xlb.operator.boundary_masker.multires_aabb_close import MultiresMeshMaskerAABBClose
+from xlb.operator.boundary_masker.multires_indices_boundary_masker import MultiresIndicesBoundaryMasker
+from xlb.operator.boundary_masker.multires_ray import MultiresMeshMaskerRay
diff --git a/xlb/operator/boundary_masker/aabb.py b/xlb/operator/boundary_masker/aabb.py
new file mode 100644
index 00000000..94def41b
--- /dev/null
+++ b/xlb/operator/boundary_masker/aabb.py
@@ -0,0 +1,198 @@
+"""
+AABB mesh-based boundary masker.
+
+Voxelizes an STL mesh using ``warp.mesh_query_aabb`` for approximate
+one-voxel-thick surface detection around the geometry.
+"""
+
+import warp as wp
+from typing import Any
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker
+from xlb.operator.operator import Operator
+from xlb.cell_type import BC_SOLID
+
+
+class MeshMaskerAABB(MeshBoundaryMasker):
+ """
+ Operator for creating boundary missing_mask from mesh using Axis-Aligned Bounding Box (AABB) voxelization.
+
+ This implementation uses warp.mesh_query_aabb for efficient mesh-voxel intersection testing,
+ providing approximate 1-voxel thick surface detection around the mesh geometry.
+ Suitable for scenarios where fast, approximate boundary detection is sufficient.
+ """
+
+ def __init__(
+ self,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ ):
+ # Call super
+ super().__init__(velocity_set, precision_policy, compute_backend)
+
+ def _construct_warp(self):
+ # Make constants for warp
+ _c = self.velocity_set.c
+ _q = self.velocity_set.q
+ _opp_indices = self.velocity_set.opp_indices
+
+ # Set local constants
+ lattice_central_index = self.velocity_set.center_index
+
+ @wp.func
+ def functional(
+ index: Any,
+ mesh_id: Any,
+ id_number: Any,
+ distances: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ needs_mesh_distance: Any,
+ ):
+ # position of the point
+ cell_center_pos = self.helper_masker.index_to_position(bc_mask, index)
+ HALF_VOXEL = wp.vec3(0.5, 0.5, 0.5)
+
+ if self.read_field(bc_mask, index, 0) == wp.uint8(BC_SOLID) or self.mesh_voxel_intersect(
+ mesh_id=mesh_id, low=cell_center_pos - HALF_VOXEL
+ ):
+ # Make solid voxel
+ self.write_field(bc_mask, index, 0, wp.uint8(BC_SOLID))
+ else:
+ # Find the boundary voxels and their missing directions
+ for direction_idx in range(_q):
+ if direction_idx == lattice_central_index:
+ # Skip the central index as it is not relevant for boundary masking
+ continue
+
+ # Get the lattice direction vector
+ direction_vec = wp.vec3f(wp.float32(_c[0, direction_idx]), wp.float32(_c[1, direction_idx]), wp.float32(_c[2, direction_idx]))
+
+ # Check to see if this neighbor is solid
+ if self.mesh_voxel_intersect(mesh_id=mesh_id, low=cell_center_pos + direction_vec - HALF_VOXEL):
+ # We know we have a solid neighbor
+ # Set the boundary id and missing_mask
+ self.write_field(bc_mask, index, 0, wp.uint8(id_number))
+ self.write_field(missing_mask, index, _opp_indices[direction_idx], wp.uint8(True))
+
+ # If we don't need the mesh distance, we can return early
+ if not needs_mesh_distance:
+ continue
+
+ # Find the fractional distance to the mesh in each direction
+ # We increase max_length to find intersections in neighboring cells
+ max_length = wp.length(direction_vec)
+ query = wp.mesh_query_ray(mesh_id, cell_center_pos, direction_vec / max_length, 1.5 * max_length)
+ if query.result:
+ # get position of the mesh triangle that intersects with the ray
+ pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v)
+ # We reduce the distance to give some wall thickness
+ dist = wp.length(pos_mesh - cell_center_pos) - 0.5 * max_length
+ weight = dist / max_length
+ self.write_field(distances, index, direction_idx, self.store_dtype(weight))
+ else:
+ # Expected an intersection in this direction but none was found.
+ # Assume the solid extends one lattice unit beyond the BC voxel leading to a distance fraction of 1.
+ self.write_field(distances, index, direction_idx, self.store_dtype(1.0))
+
+ @wp.kernel
+ def kernel(
+ mesh_id: wp.uint64,
+ id_number: wp.int32,
+ distances: wp.array4d(dtype=Any),
+ bc_mask: wp.array4d(dtype=wp.uint8),
+ missing_mask: wp.array4d(dtype=wp.uint8),
+ needs_mesh_distance: bool,
+ ):
+ # get index
+ i, j, k = wp.tid()
+
+ # Get local indices
+ index = wp.vec3i(i, j, k)
+
+ # apply the functional
+ functional(
+ index,
+ mesh_id,
+ id_number,
+ distances,
+ bc_mask,
+ missing_mask,
+ needs_mesh_distance,
+ )
+
+ return functional, kernel
+
+ @Operator.register_backend(ComputeBackend.WARP)
+ def warp_implementation(
+ self,
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ ):
+ return self.warp_implementation_base(
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ )
+
+ def _construct_neon(self):
+ import neon
+
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+
+ @neon.Container.factory(name="MeshMaskerAABB")
+ def container(
+ mesh_id: Any,
+ id_number: Any,
+ distances: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ needs_mesh_distance: Any,
+ ):
+ def aabb_launcher(loader: neon.Loader):
+ loader.set_grid(bc_mask.get_grid())
+ bc_mask_pn = loader.get_write_handle(bc_mask)
+ missing_mask_pn = loader.get_write_handle(missing_mask)
+ distances_pn = loader.get_write_handle(distances)
+
+ @wp.func
+ def aabb_kernel(index: Any):
+ # apply the functional
+ functional(
+ index,
+ mesh_id,
+ id_number,
+ distances_pn,
+ bc_mask_pn,
+ missing_mask_pn,
+ needs_mesh_distance,
+ )
+
+ loader.declare_kernel(aabb_kernel)
+
+ return aabb_launcher
+
+ return functional, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(
+ self,
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ ):
+ # Prepare inputs
+ mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask)
+
+ # Launch the appropriate neon container
+ c = self.neon_container(mesh_id, bc_id, distances, bc_mask, missing_mask, wp.static(bc.needs_mesh_distance))
+ c.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+ return distances, bc_mask, missing_mask
diff --git a/xlb/operator/boundary_masker/aabb_close.py b/xlb/operator/boundary_masker/aabb_close.py
new file mode 100644
index 00000000..21ee31f9
--- /dev/null
+++ b/xlb/operator/boundary_masker/aabb_close.py
@@ -0,0 +1,365 @@
+"""
+AABB-Close boundary masker with morphological close operation.
+
+Identifies solid voxels via axis-aligned bounding-box (AABB) intersection,
+then applies a morphological *close* (dilate followed by erode) to fill
+thin gaps and small cavities in the mesh surface. The resulting solid
+mask is used to determine boundary voxels and their missing population
+directions.
+
+Supports both Warp (single-resolution) and Neon (multi-resolution)
+backends.
+"""
+
+import numpy as np
+import warp as wp
+import jax
+from typing import Any
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.operator import Operator
+from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker
+from xlb.cell_type import BC_SOLID
+
+
+class MeshMaskerAABBClose(MeshBoundaryMasker):
+ """Boundary masker using AABB voxelization with morphological close.
+
+ The *close* operation (dilate then erode) thickens the raw solid mask
+ by ``close_voxels`` layers before shrinking it back, sealing small
+ holes and thin slits in the mesh surface.
+
+ Parameters
+ ----------
+ velocity_set : VelocitySet, optional
+ precision_policy : PrecisionPolicy, optional
+ compute_backend : ComputeBackend, optional
+ close_voxels : int
+ Half-width of the morphological structuring element. Must be
+ provided explicitly.
+ """
+
+ def __init__(
+ self,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ close_voxels: int = None,
+ ):
+ assert close_voxels is not None, (
+ "Please provide the number of close voxels using the 'close_voxels' argument! e.g., MeshVoxelizationMethod('AABB_CLOSE', close_voxels=3)"
+ )
+ self.close_voxels = close_voxels
+ # Call super
+ self.tile_half = close_voxels
+ self.tile_size = self.tile_half * 2 + 1
+ super().__init__(velocity_set, precision_policy, compute_backend)
+
+ def _construct_warp(self):
+ # Make constants for warp
+ _c = self.velocity_set.c
+ _q = self.velocity_set.q
+ _opp_indices = self.velocity_set.opp_indices
+ TILE_SIZE = wp.constant(self.tile_size)
+ TILE_HALF = wp.constant(self.tile_half)
+ lattice_central_index = self.velocity_set.center_index
+
+ # Erode the solid mask in mask_field, removing a layer of outer solid voxels, storing output in mask_field_out
+ @wp.kernel
+ def erode_tile(mask_field: wp.array3d(dtype=Any), mask_field_out: wp.array3d(dtype=Any)):
+ i, j, k = wp.tid()
+ index = wp.vec3i(i, j, k)
+ if not self.helper_masker.is_in_bounds(index, wp.vec3i(mask_field.shape[0], mask_field.shape[1], mask_field.shape[2]), TILE_HALF):
+ mask_field_out[i, j, k] = mask_field[i, j, k]
+ return
+ t = wp.tile_load(mask_field, shape=(TILE_SIZE, TILE_SIZE, TILE_SIZE), offset=(i - TILE_HALF, j - TILE_HALF, k - TILE_HALF))
+ min_val = wp.tile_min(t)
+ mask_field_out[i, j, k] = min_val[0]
+
+ # Dilate the solid mask in mask_field, adding a layer of outer solid voxels, storing output in mask_field_out
+ @wp.kernel
+ def dilate_tile(mask_field: wp.array3d(dtype=Any), mask_field_out: wp.array3d(dtype=Any)):
+ i, j, k = wp.tid()
+ index = wp.vec3i(i, j, k)
+ if not self.helper_masker.is_in_bounds(index, wp.vec3i(mask_field.shape[0], mask_field.shape[1], mask_field.shape[2]), TILE_HALF):
+ mask_field_out[i, j, k] = mask_field[i, j, k]
+ return
+ t = wp.tile_load(mask_field, shape=(TILE_SIZE, TILE_SIZE, TILE_SIZE), offset=(i - TILE_HALF, j - TILE_HALF, k - TILE_HALF))
+ max_val = wp.tile_max(t)
+ mask_field_out[i, j, k] = max_val[0]
+
+ # Erode the solid mask in mask_field, removing a layer of outer solid voxels, storing output in mask_field_out
+ @wp.func
+ def functional_erode(index: Any, mask_field: Any, mask_field_out: Any):
+ min_val = wp.uint8(BC_SOLID)
+ for l in range(_q):
+ if l == lattice_central_index:
+ continue
+ is_valid = wp.bool(False)
+ ngh = wp.neon_ngh_idx(wp.int8(_c[0, l]), wp.int8(_c[1, l]), wp.int8(_c[2, l]))
+ ngh_val = wp.neon_read_ngh(mask_field, index, ngh, 0, wp.uint8(0), is_valid)
+ if is_valid:
+ # Take the min value of all neighbors in bounds
+ min_val = wp.min(min_val, ngh_val)
+ self.write_field(mask_field_out, index, 0, min_val)
+
+ # Dilate the solid mask in mask_field, adding a layer of outer solid voxels, storing output in mask_field_out
+ @wp.func
+ def functional_dilate(index: Any, mask_field: Any, mask_field_out: Any):
+ max_val = wp.uint8(0)
+ for l in range(_q):
+ if l == lattice_central_index:
+ continue
+ is_valid = wp.bool(False)
+ ngh = wp.neon_ngh_idx(wp.int8(_c[0, l]), wp.int8(_c[1, l]), wp.int8(_c[2, l]))
+ ngh_val = wp.neon_read_ngh(mask_field, index, ngh, 0, wp.uint8(0), is_valid)
+ if is_valid:
+ max_val = wp.max(max_val, ngh_val)
+ self.write_field(mask_field_out, index, 0, max_val)
+
+ # Construct the warp kernel
+ # Find solid voxels that intersect the mesh
+ @wp.func
+ def functional_solid(index: Any, mesh_id: Any, solid_mask: Any, offset: Any):
+ # position of the point
+ cell_center_pos = self.helper_masker.index_to_position(solid_mask, index) + offset
+ half = wp.vec3(0.5, 0.5, 0.5)
+
+ if self.mesh_voxel_intersect(mesh_id=mesh_id, low=cell_center_pos - half):
+ # Make solid voxel
+ self.write_field(solid_mask, index, 0, wp.uint8(BC_SOLID))
+
+ @wp.kernel
+ def kernel_solid(
+ mesh_id: wp.uint64,
+ solid_mask: wp.array3d(dtype=wp.int32),
+ offset: wp.vec3f,
+ ):
+ # get index
+ i, j, k = wp.tid()
+
+ # Get local indices
+ index = wp.vec3i(i, j, k)
+
+ functional_solid(index, mesh_id, solid_mask, offset)
+
+ return
+
+ @wp.func
+ def functional_aabb(
+ index: Any,
+ mesh_id: wp.uint64,
+ id_number: wp.int32,
+ distances: wp.array4d(dtype=Any),
+ bc_mask: wp.array4d(dtype=wp.uint8),
+ missing_mask: wp.array4d(dtype=wp.uint8),
+ solid_mask: wp.array3d(dtype=wp.uint8),
+ needs_mesh_distance: bool,
+ ):
+ # position of the point
+ cell_center_pos = self.helper_masker.index_to_position(bc_mask, index)
+ HALF_VOXEL = wp.vec3(0.5, 0.5, 0.5)
+
+ if self.read_field(solid_mask, index, 0) == wp.uint8(BC_SOLID) or self.read_field(bc_mask, index, 0) == wp.uint8(BC_SOLID):
+ # Make solid voxel
+ self.write_field(bc_mask, index, 0, wp.uint8(BC_SOLID))
+ else:
+ # Find the boundary voxels and their missing directions
+ for direction_idx in range(_q):
+ if direction_idx == lattice_central_index:
+ # Skip the central index as it is not relevant for boundary masking
+ continue
+
+ # Get the lattice direction vector
+ direction_vec = wp.vec3f(wp.float32(_c[0, direction_idx]), wp.float32(_c[1, direction_idx]), wp.float32(_c[2, direction_idx]))
+
+ # Check to see if this neighbor is solid
+ if self.helper_masker.is_in_bounds(index, wp.vec3i(solid_mask.shape[0], solid_mask.shape[1], solid_mask.shape[2]), 1):
+ if self.read_field(solid_mask, index + direction_idx, 0) == wp.uint8(BC_SOLID):
+ # We know we have a solid neighbor
+ # Set the boundary id and missing_mask
+ self.write_field(bc_mask, index, 0, wp.uint8(id_number))
+ self.write_field(missing_mask, index, _opp_indices[direction_idx], wp.uint8(True))
+
+ # If we don't need the mesh distance, we can return early
+ if not needs_mesh_distance:
+ continue
+
+ # Find the fractional distance to the mesh in each direction
+ # We increase max_length to find intersections in neighboring cells
+ max_length = wp.length(direction_vec)
+ query = wp.mesh_query_ray(mesh_id, cell_center_pos, direction_vec / max_length, 1.5 * max_length)
+ if query.result:
+ # get position of the mesh triangle that intersects with the ray
+ pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v)
+ # We reduce the distance to give some wall thickness
+ dist = wp.length(pos_mesh - cell_center_pos) - 0.5 * max_length
+ weight = dist / max_length
+ self.write_field(distances, index, direction_idx, self.store_dtype(weight))
+ else:
+ # Expected an intersection in this direction but none was found.
+ # Assume the solid extends one lattice unit beyond the BC voxel leading to a distance fraction of 1.
+ self.write_field(distances, index, direction_idx, self.store_dtype(1.0))
+
+ # Assign the bc_mask and distances based on the solid_mask we already computed
+ @wp.kernel
+ def kernel(
+ mesh_id: wp.uint64,
+ id_number: wp.int32,
+ distances: wp.array4d(dtype=Any),
+ bc_mask: wp.array4d(dtype=wp.uint8),
+ missing_mask: wp.array4d(dtype=wp.uint8),
+ solid_mask: wp.array3d(dtype=wp.uint8),
+ needs_mesh_distance: bool,
+ ):
+ # get index
+ i, j, k = wp.tid()
+
+ # Get local indices
+ index = wp.vec3i(i, j, k)
+
+ # position of the point
+ cell_center_pos = self.helper_masker.index_to_position(bc_mask, index)
+
+ if solid_mask[i, j, k] == wp.uint8(BC_SOLID) or bc_mask[0, index[0], index[1], index[2]] == wp.uint8(BC_SOLID):
+ # Make solid voxel
+ bc_mask[0, index[0], index[1], index[2]] = wp.uint8(BC_SOLID)
+ else:
+ # Find the boundary voxels and their missing directions
+ for direction_idx in range(_q):
+ if direction_idx == lattice_central_index:
+ # Skip the central index as it is not relevant for boundary masking
+ continue
+ direction_vec = wp.vec3f(wp.float32(_c[0, direction_idx]), wp.float32(_c[1, direction_idx]), wp.float32(_c[2, direction_idx]))
+
+ # Check to see if this neighbor is solid - this is super inefficient TODO: make it way better
+ # if solid_mask[i,j,k] == wp.uint8(BC_SOLID):
+ if solid_mask[i + _c[0, direction_idx], j + _c[1, direction_idx], k + _c[2, direction_idx]] == wp.uint8(BC_SOLID):
+ # We know we have a solid neighbor
+ # Set the boundary id and missing_mask
+ bc_mask[0, index[0], index[1], index[2]] = wp.uint8(id_number)
+ missing_mask[_opp_indices[direction_idx], index[0], index[1], index[2]] = wp.uint8(True)
+
+ # If we don't need the mesh distance, we can return early
+ if not needs_mesh_distance:
+ continue
+
+ # Find the fractional distance to the mesh in each direction
+ # We increase max_length to find intersections in neighboring cells
+ max_length = wp.length(direction_vec)
+ query = wp.mesh_query_ray(mesh_id, cell_center_pos, direction_vec / max_length, 1.5 * max_length)
+ if query.result:
+ # get position of the mesh triangle that intersects with the ray
+ pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v)
+ # We reduce the distance to give some wall thickness
+ dist = wp.length(pos_mesh - cell_center_pos) - 0.5 * max_length
+ weight = self.store_dtype(dist / max_length)
+ distances[direction_idx, index[0], index[1], index[2]] = weight
+ else:
+ # We didn't have an intersection in the given direction but we know we should so we assume the solid is slightly thicker
+ # and one lattice direction away from the BC voxel
+ distances[direction_idx, index[0], index[1], index[2]] = self.store_dtype(1.0)
+
+ functional_dict = {
+ "functional_erode": functional_erode,
+ "functional_dilate": functional_dilate,
+ "functional_solid": functional_solid,
+ "functional_aabb": functional_aabb,
+ }
+ kernel_dict = {
+ "kernel": kernel,
+ "kernel_solid": kernel_solid,
+ "erode_tile": erode_tile,
+ "dilate_tile": dilate_tile,
+ }
+ return functional_dict, kernel_dict
+
+ @Operator.register_backend(ComputeBackend.WARP)
+ def warp_implementation(
+ self,
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ ):
+ assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!'
+ assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!"
+ assert bc.mesh_vertices.shape[1] == self.velocity_set.d, (
+ "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!"
+ )
+
+ domain_shape = bc_mask.shape[1:] # (nx, ny, nz)
+ mesh_vertices = bc.mesh_vertices
+ mesh_min = np.min(mesh_vertices, axis=0)
+ mesh_max = np.max(mesh_vertices, axis=0)
+
+ if any(mesh_min < 0) or any(mesh_max >= domain_shape):
+ raise ValueError(
+ f"Mesh extents ({mesh_min}, {mesh_max}) exceed domain dimensions {domain_shape}. The mesh must be fully contained within the domain."
+ )
+
+ # We are done with bc.mesh_vertices. Remove them from BC objects
+ bc.__dict__.pop("mesh_vertices", None)
+
+ mesh_indices = np.arange(mesh_vertices.shape[0])
+ mesh = wp.Mesh(
+ points=wp.array(mesh_vertices, dtype=wp.vec3),
+ indices=wp.array(mesh_indices, dtype=wp.int32),
+ )
+ mesh_id = wp.uint64(mesh.id)
+ bc_id = bc.id
+
+ # Create a padded mask for the solid voxels to account for the tile size
+ # It needs to be padded by twice the tile size on each side since we run two tile operations
+ tile_length = 2 * self.tile_half
+ offset = wp.vec3f(-tile_length, -tile_length, -tile_length)
+ pad = 2 * tile_length
+ nx, ny, nz = domain_shape
+ solid_mask = wp.zeros((nx + pad, ny + pad, nz + pad), dtype=wp.int32)
+ solid_mask_out = wp.zeros((nx + pad, ny + pad, nz + pad), dtype=wp.int32)
+
+ # Prepare the warp kernel dictionary
+ kernel_dict = self.warp_kernel
+
+ # Launch all required kernels for creating the solid mask
+ wp.launch(
+ kernel=kernel_dict["kernel_solid"],
+ inputs=[
+ mesh_id,
+ solid_mask,
+ offset,
+ ],
+ dim=solid_mask.shape,
+ )
+ wp.launch_tiled(
+ kernel=kernel_dict["dilate_tile"],
+ dim=solid_mask.shape,
+ block_dim=32,
+ inputs=[solid_mask, solid_mask_out],
+ )
+ wp.launch_tiled(
+ kernel=kernel_dict["erode_tile"],
+ dim=solid_mask.shape,
+ block_dim=32,
+ inputs=[solid_mask_out, solid_mask],
+ )
+ solid_mask_cropped = wp.array(
+ solid_mask[tile_length:-tile_length, tile_length:-tile_length, tile_length:-tile_length],
+ dtype=wp.uint8,
+ )
+
+ # Launch the main kernel for boundary masker
+ wp.launch(
+ kernel_dict["kernel"],
+ inputs=[mesh_id, bc_id, distances, bc_mask, missing_mask, solid_mask_cropped, wp.static(bc.needs_mesh_distance)],
+ dim=bc_mask.shape[1:],
+ )
+
+ # Resolve out of bound indices
+ wp.launch(
+ self.resolve_out_of_bound_kernel,
+ inputs=[bc_id, bc_mask, missing_mask],
+ dim=bc_mask.shape[1:],
+ )
+ return distances, bc_mask, missing_mask
diff --git a/xlb/operator/boundary_masker/helper_functions_masker.py b/xlb/operator/boundary_masker/helper_functions_masker.py
new file mode 100644
index 00000000..565455fe
--- /dev/null
+++ b/xlb/operator/boundary_masker/helper_functions_masker.py
@@ -0,0 +1,135 @@
+"""
+Warp/Neon helper functions shared by boundary masker operators.
+"""
+
+import warp as wp
+from typing import Any
+from xlb import DefaultConfig, ComputeBackend
+
+
+class HelperFunctionsMasker(object):
+ """Warp ``@wp.func`` helpers for boundary masker operators.
+
+ Provides coordinate-conversion, bounds-checking, pull-index
+ computation, and BC-index membership tests used by the mesh and
+ indices boundary maskers on both Warp and Neon backends.
+ """
+
+ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None):
+ if compute_backend == ComputeBackend.JAX:
+ raise ValueError("This helper class contains helper functions only for the WARP implementation of some BCs not JAX!")
+
+ # Set the default values from the global config
+ self.velocity_set = velocity_set or DefaultConfig.velocity_set
+ self.precision_policy = precision_policy or DefaultConfig.default_precision_policy
+ self.compute_backend = compute_backend or DefaultConfig.default_backend
+
+ # Set local constants
+ _d = self.velocity_set.d
+ _c = self.velocity_set.c
+
+ @wp.func
+ def neon_index_to_warp(neon_field_hdl: Any, index: Any):
+ # Unpack the global index in Neon at the finest level and convert it to a warp vector
+ cIdx = wp.neon_global_idx(neon_field_hdl, index)
+ gx = wp.neon_get_x(cIdx)
+ gy = wp.neon_get_y(cIdx)
+ gz = wp.neon_get_z(cIdx)
+
+ # XLB is flattening the z dimension in 3D, while neon uses the y dimension
+ if _d == 2:
+ gy, gz = gz, gy
+
+ # Get warp indices
+ index_wp = wp.vec3i(gx, gy, gz)
+ return index_wp
+
+ @wp.func
+ def index_to_position_warp(field: Any, index: wp.vec3i):
+ # position of the point
+ ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2]))
+ pos = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center
+ return pos
+
+ @wp.func
+ def index_to_position_neon(field: Any, index: Any):
+ # position of the point
+ index_wp = neon_index_to_warp(field, index)
+ return index_to_position_warp(field, index_wp)
+
+ @wp.func
+ def is_in_bounds(index: wp.vec3i, grid_shape: wp.vec3i, SHIFT: Any = 0):
+ return (
+ index[0] >= SHIFT
+ and index[0] < grid_shape[0] - SHIFT
+ and index[1] >= SHIFT
+ and index[1] < grid_shape[1] - SHIFT
+ and index[2] >= SHIFT
+ and index[2] < grid_shape[2] - SHIFT
+ )
+
+ @wp.func
+ def get_pull_index_warp(
+ field: Any,
+ lattice_dir: wp.int32,
+ index: wp.vec3i,
+ level: Any,
+ ):
+ pull_index = wp.vec3i()
+ offset = wp.vec3i()
+ for d in range(self.velocity_set.d):
+ offset[d] = -_c[d, lattice_dir]
+ for _ in range(level):
+ offset[d] *= 2
+ pull_index[d] = index[d] + offset[d]
+
+ return pull_index, offset
+
+ @wp.func
+ def get_pull_index_neon(
+ field: Any,
+ lattice_dir: wp.int32,
+ index: Any,
+ level: Any,
+ ):
+ # Convert the index to warp
+ index_wp = neon_index_to_warp(field, index)
+ pull_index_wp, _ = get_pull_index_warp(field, lattice_dir, index_wp, level)
+ offset = wp.neon_ngh_idx(wp.int8(-_c[0, lattice_dir]), wp.int8(-_c[1, lattice_dir]), wp.int8(-_c[2, lattice_dir]))
+ return pull_index_wp, offset
+
+ @wp.func
+ def is_in_bc_indices_warp(
+ field: Any,
+ index: Any,
+ bc_indices: wp.array2d(dtype=wp.int32),
+ ii: wp.int32,
+ ):
+ return bc_indices[0, ii] == index[0] and bc_indices[1, ii] == index[1] and bc_indices[2, ii] == index[2]
+
+ @wp.func
+ def is_in_bc_indices_neon(
+ field: Any,
+ index: Any,
+ bc_indices: wp.array2d(dtype=wp.int32),
+ ii: wp.int32,
+ ):
+ index_wp = neon_index_to_warp(field, index)
+ return is_in_bc_indices_warp(field, index_wp, bc_indices, ii)
+
+ # Construct some helper warp functions
+ self.is_in_bounds = is_in_bounds
+ self.index_to_position = index_to_position_warp if self.compute_backend == ComputeBackend.WARP else index_to_position_neon
+ self.get_pull_index = get_pull_index_warp if self.compute_backend == ComputeBackend.WARP else get_pull_index_neon
+ self.is_in_bc_indices = is_in_bc_indices_warp if self.compute_backend == ComputeBackend.WARP else is_in_bc_indices_neon
+
+ def get_grid_shape(self, field):
+ """
+ Get the grid shape from the boundary mask. This is a CPU function that returns the shape of the grid
+ """
+ if self.compute_backend == ComputeBackend.WARP:
+ return field.shape[1:]
+ elif self.compute_backend == ComputeBackend.NEON:
+ return wp.vec3i(field.get_grid().dim.x, field.get_grid().dim.y, field.get_grid().dim.z)
+ else:
+ raise ValueError(f"Unsupported compute backend: {self.compute_backend}")
diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py
index c779972c..e888c284 100644
--- a/xlb/operator/boundary_masker/indices_boundary_masker.py
+++ b/xlb/operator/boundary_masker/indices_boundary_masker.py
@@ -1,12 +1,25 @@
-import numpy as np
-import warp as wp
+"""
+Indices-based boundary masker.
+
+Creates boundary masks from explicit arrays of voxel indices, computing
+missing-population masks via pull-index tests for each tagged voxel.
+"""
+
+from typing import Any
+import copy
+
import jax
import jax.numpy as jnp
+import numpy as np
+import warp as wp
+
from xlb.compute_backend import ComputeBackend
+from xlb.grid import grid_factory
from xlb.operator.operator import Operator
from xlb.operator.stream.stream import Stream
-from xlb.grid import grid_factory
from xlb.precision_policy import Precision
+from xlb.operator.boundary_masker.helper_functions_masker import HelperFunctionsMasker
+from xlb.cell_type import BC_SOLID
class IndicesBoundaryMasker(Operator):
@@ -19,12 +32,21 @@ def __init__(
velocity_set=None,
precision_policy=None,
compute_backend=None,
+ grid=None,
):
- # Make stream operator
- self.stream = Stream(velocity_set, precision_policy, compute_backend)
-
# Call super
super().__init__(velocity_set, precision_policy, compute_backend)
+ self.grid = grid
+ if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]:
+ # Define masker helper functions
+ self.helper_masker = HelperFunctionsMasker(
+ velocity_set=self.velocity_set,
+ precision_policy=self.precision_policy,
+ compute_backend=self.compute_backend,
+ )
+ else:
+ # Make stream operator
+ self.stream = Stream(velocity_set, precision_policy, compute_backend)
def are_indices_in_interior(self, indices, shape):
"""
@@ -35,141 +57,262 @@ def are_indices_in_interior(self, indices, shape):
:param shape: Tuple representing the shape of the domain (nx, ny) for 2D or (nx, ny, nz) for 3D.
:return: Array of boolean flags where each flag indicates whether the corresponding index is inside the bounds.
"""
- d = self.velocity_set.d
+ _d = self.velocity_set.d
shape_array = np.array(shape)
- return np.all((indices[:d] > 0) & (indices[:d] < shape_array[:d, np.newaxis] - 1), axis=0)
+ return np.all((indices[:_d] > 0) & (indices[:_d] < shape_array[:_d, np.newaxis] - 1), axis=0)
+
+ def _find_bclist_interior(self, bclist, grid_shape):
+ bc_interior = []
+ for bc in bclist:
+ if any(self.are_indices_in_interior(np.array(bc.indices), grid_shape)):
+ bc_copy = copy.copy(bc) # shallow copy of the whole object
+ bc_copy.indices = copy.deepcopy(bc.pad_indices()) # deep copy only the modified part
+ bc_interior.append(bc_copy)
+ return bc_interior
@Operator.register_backend(ComputeBackend.JAX)
# TODO HS: figure out why uncommenting the line below fails unlike other operators!
# @partial(jit, static_argnums=(0))
def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
- # Pad the missing mask to create a grid mask to identify out of bound boundaries
- # Set padded regin to True (i.e. boundary)
+ # Extend the missing mask by padding to identify out of bound boundaries
+ # Set padded region to True (i.e. boundary)
dim = missing_mask.ndim - 1
+ grid_shape = bc_mask[0].shape
nDevices = jax.device_count()
pad_x, pad_y, pad_z = nDevices, 1, 1
- # TODO MEHDI: There is sometimes a halting problem here when padding is used in a multi-GPU setting since we're not jitting this function.
- # For now, we compute the bmap on GPU zero.
- if dim == 2:
- bmap = jnp.zeros((pad_x * 2 + bc_mask[0].shape[0], pad_y * 2 + bc_mask[0].shape[1]), dtype=jnp.uint8)
- bmap = bmap.at[pad_x:-pad_x, pad_y:-pad_y].set(bc_mask[0])
- grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True)
- # bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0)
- if dim == 3:
- bmap = jnp.zeros((pad_x * 2 + bc_mask[0].shape[0], pad_y * 2 + bc_mask[0].shape[1], pad_z * 2 + bc_mask[0].shape[2]), dtype=jnp.uint8)
- bmap = bmap.at[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z].set(bc_mask[0])
- grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True)
- # bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0)
- # shift indices
- shift_tup = (pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z)
+ # Shift indices due to padding
+ shift = np.array((pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z))[:, np.newaxis]
if start_index is None:
start_index = (0,) * dim
- domain_shape = bc_mask[0].shape
+ # TODO MEHDI: There is sometimes a halting problem here when padding is used in a multi-GPU setting since we're not jitting this function.
+ # For now, we compute the bc_mask_extended on GPU zero.
+ if dim == 2:
+ bc_mask_extended = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0)
+ missing_mask_extended = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True)
+ if dim == 3:
+ bc_mask_extended = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0)
+ missing_mask_extended = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True)
+
+ # Iterate over boundary conditions and set the mask
for bc in bclist:
assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!"
- assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!"
+ assert bc.mesh_vertices is None, (
+ f"Please use operators based on MeshBoundaryMasker if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!"
+ )
id_number = bc.id
bc_indices = np.array(bc.indices)
- local_indices = bc_indices - np.array(start_index)[:, np.newaxis]
- padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis]
- bmap = bmap.at[tuple(padded_indices)].set(id_number)
- if any(self.are_indices_in_interior(bc_indices, domain_shape)) and bc.needs_padding:
- # checking if all indices associated with this BC are in the interior of the domain.
- # This flag is needed e.g. if the no-slip geometry is anywhere but at the boundaries of the computational domain.
+ indices_origin = np.array(start_index)[:, np.newaxis]
+ if any(self.are_indices_in_interior(bc_indices, grid_shape)):
+ # If the indices are in the interior, we assume the usre specified indices are solid indices
+ solid_indices = bc_indices - indices_origin
+ solid_indices_shifted = solid_indices + shift
+
+ # We obtain the boundary indices by padding the solid indices in all lattice directions
+ indices_padded = bc.pad_indices() - indices_origin
+ indices_shifted = indices_padded + shift
+
+ # The missing mask is set to True meaning (exterior or solid nodes) using the original indices.
+ # This is because of the following streaming step which will assign missing directions for the boundary nodes.
if dim == 2:
- grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True)
- if dim == 3:
- grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True)
+ missing_mask_extended = missing_mask_extended.at[:, solid_indices_shifted[0], solid_indices_shifted[1]].set(True)
+ else:
+ missing_mask_extended = missing_mask_extended.at[
+ :, solid_indices_shifted[0], solid_indices_shifted[1], solid_indices_shifted[2]
+ ].set(True)
+ else:
+ indices_shifted = bc_indices - indices_origin + shift
- # Assign the boundary id to the push indices
- push_indices = padded_indices[:, :, None] + self.velocity_set.c[:, None, :]
- push_indices = push_indices.reshape(dim, -1)
- bmap = bmap.at[tuple(push_indices)].set(id_number)
+ # Assign the boundary id to the shifted (and possibly padded) indices
+ bc_mask_extended = bc_mask_extended.at[tuple(indices_shifted)].set(id_number)
# We are done with bc.indices. Remove them from BC objects
bc.__dict__.pop("indices", None)
- grid_mask = self.stream(grid_mask)
+ # Stream the missing mask to identify missing directions
+ missing_mask_extended = self.stream(missing_mask_extended)
+
+ # Crop the extended masks to remove padding
if dim == 2:
- missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y]
- bc_mask = bc_mask.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y])
+ missing_mask = missing_mask_extended[:, pad_x:-pad_x, pad_y:-pad_y]
+ bc_mask = bc_mask.at[0].set(bc_mask_extended[pad_x:-pad_x, pad_y:-pad_y])
if dim == 3:
- missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z]
- bc_mask = bc_mask.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z])
+ missing_mask = missing_mask_extended[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z]
+ bc_mask = bc_mask.at[0].set(bc_mask_extended[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z])
return bc_mask, missing_mask
def _construct_warp(self):
# Make constants for warp
- _c = self.velocity_set.c
- _q = wp.constant(self.velocity_set.q)
+ _q = self.velocity_set.q
@wp.func
- def check_index_bounds(index: wp.vec3i, shape: wp.vec3i):
- is_in_bounds = index[0] >= 0 and index[0] < shape[0] and index[1] >= 0 and index[1] < shape[1] and index[2] >= 0 and index[2] < shape[2]
- return is_in_bounds
+ def functional_domain_bounds(
+ index: Any,
+ bc_indices: Any,
+ id_number: Any,
+ is_interior: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ grid_shape: Any,
+ level: Any = 0,
+ ):
+ for ii in range(bc_indices.shape[1]):
+ # If the current index does not match the boundary condition index, we skip it
+ if not self.helper_masker.is_in_bc_indices(bc_mask, index, bc_indices, ii):
+ continue
+
+ if is_interior[ii] == wp.uint8(True):
+ # If the index is in the interior, we set that index to be a solid node (identified by BC_SOLID)
+ # This information will be used in the next kernel to identify missing directions using the
+ # padded indices of the solid node that are associated with the boundary condition.
+ self.write_field(bc_mask, index, 0, wp.uint8(BC_SOLID))
+ return
+
+ # Set bc_mask for all bc indices
+ self.write_field(bc_mask, index, 0, wp.uint8(id_number[ii]))
+
+ # Stream indices
+ for l in range(_q):
+ # Get the pull index which is the index of the neighboring node where information is pulled from
+ pull_index, _ = self.helper_masker.get_pull_index(bc_mask, l, index, level)
+
+ # Check if pull index is out of bound
+ # These directions will have missing information after streaming
+ if not self.helper_masker.is_in_bounds(pull_index, grid_shape):
+ # Set the missing mask
+ self.write_field(missing_mask, index, l, wp.uint8(True))
+
+ @wp.func
+ def functional_interior_bc_mask(
+ index: Any,
+ bc_indices: Any,
+ id_number: Any,
+ bc_mask: Any,
+ ):
+ for ii in range(bc_indices.shape[1]):
+ # If the current index does not match the boundary condition index, we skip it
+ if not self.helper_masker.is_in_bc_indices(bc_mask, index, bc_indices, ii):
+ continue
+ # Set bc_mask for all interior bc indices
+ self.write_field(bc_mask, index, 0, wp.uint8(id_number[ii]))
+
+ @wp.func
+ def functional_interior_missing_mask(
+ index: Any,
+ bc_indices: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ grid_shape: Any,
+ level: Any = 0,
+ ):
+ for ii in range(bc_indices.shape[1]):
+ # If the current index does not match the boundary condition index, we skip it
+ if not self.helper_masker.is_in_bc_indices(bc_mask, index, bc_indices, ii):
+ continue
+ for l in range(_q):
+ # Get the index of the streaming direction
+ pull_index, offset = self.helper_masker.get_pull_index(bc_mask, l, index, level)
+
+ # Check if pull index is a fluid node (bc_mask is zero for fluid nodes)
+ bc_mask_ngh = self.read_field_neighbor(bc_mask, index, offset, 0)
+ if (self.helper_masker.is_in_bounds(pull_index, grid_shape)) and (bc_mask_ngh == wp.uint8(BC_SOLID)):
+ self.write_field(missing_mask, index, l, wp.uint8(True))
# Construct the warp 3D kernel
@wp.kernel
- def kernel(
- indices: wp.array2d(dtype=wp.int32),
+ def kernel_domain_bounds(
+ bc_indices: wp.array2d(dtype=wp.int32),
id_number: wp.array1d(dtype=wp.uint8),
- is_interior: wp.array1d(dtype=wp.bool),
+ is_interior: wp.array1d(dtype=wp.uint8),
bc_mask: wp.array4d(dtype=wp.uint8),
- missing_mask: wp.array4d(dtype=wp.bool),
+ missing_mask: wp.array4d(dtype=wp.uint8),
+ grid_shape: wp.vec3i,
):
- # Get the index of indices
- ii = wp.tid()
+ # get index
+ i, j, k = wp.tid()
# Get local indices
- index = wp.vec3i()
- index[0] = indices[0, ii]
- index[1] = indices[1, ii]
- index[2] = indices[2, ii]
-
- # Check if index is in bounds
- shape = wp.vec3i(missing_mask.shape[1], missing_mask.shape[2], missing_mask.shape[3])
- if check_index_bounds(index, shape):
- # Stream indices
- for l in range(_q):
- # Get the index of the streaming direction
- pull_index = wp.vec3i()
- push_index = wp.vec3i()
- for d in range(self.velocity_set.d):
- pull_index[d] = index[d] - _c[d, l]
- push_index[d] = index[d] + _c[d, l]
+ index = wp.vec3i(i, j, k)
+
+ # Call the functional
+ functional_domain_bounds(
+ index,
+ bc_indices,
+ id_number,
+ is_interior,
+ bc_mask,
+ missing_mask,
+ grid_shape,
+ )
- # set bc_mask for all bc indices
- bc_mask[0, index[0], index[1], index[2]] = id_number[ii]
+ @wp.kernel
+ def kernel_interior_bc_mask(
+ bc_indices: wp.array2d(dtype=wp.int32),
+ id_number: wp.array1d(dtype=wp.uint8),
+ bc_mask: wp.array4d(dtype=wp.uint8),
+ ):
+ # get index
+ i, j, k = wp.tid()
- # check if pull index is out of bound
- # These directions will have missing information after streaming
- if not check_index_bounds(pull_index, shape):
- # Set the missing mask
- missing_mask[l, index[0], index[1], index[2]] = True
+ # Get local indices
+ index = wp.vec3i(i, j, k)
- # handling geometries in the interior of the computational domain
- elif check_index_bounds(pull_index, shape) and is_interior[ii]:
- # Set the missing mask
- missing_mask[l, push_index[0], push_index[1], push_index[2]] = True
- bc_mask[0, push_index[0], push_index[1], push_index[2]] = id_number[ii]
+ # Set bc_mask for all interior bc indices
+ functional_interior_bc_mask(
+ index,
+ bc_indices,
+ id_number,
+ bc_mask,
+ )
+ return
+
+ @wp.kernel
+ def kernel_interior_missing_mask(
+ bc_indices: wp.array2d(dtype=wp.int32),
+ bc_mask: wp.array4d(dtype=wp.uint8),
+ missing_mask: wp.array4d(dtype=wp.uint8),
+ grid_shape: wp.vec3i,
+ ):
+ # get index
+ i, j, k = wp.tid()
- return None, kernel
+ # Get local indices
+ index = wp.vec3i(i, j, k)
+
+ functional_interior_missing_mask(index, bc_indices, bc_mask, missing_mask, grid_shape)
+
+ functional_dict = {
+ "functional_domain_bounds": functional_domain_bounds,
+ "functional_interior_bc_mask": functional_interior_bc_mask,
+ "functional_interior_missing_mask": functional_interior_missing_mask,
+ }
+ kernel_dict = {
+ "kernel_domain_bounds": kernel_domain_bounds,
+ "kernel_interior_bc_mask": kernel_interior_bc_mask,
+ "kernel_interior_missing_mask": kernel_interior_missing_mask,
+ }
+ return functional_dict, kernel_dict
+
+ def _prepare_kernel_inputs(self, bclist, grid_shape, start_index=None):
+ """
+ Prepare the inputs for the warp kernel by pre-allocating arrays and filling them with boundary condition information.
+ """
- @Operator.register_backend(ComputeBackend.WARP)
- def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
# Pre-allocate arrays with maximum possible size
- max_size = sum(len(bc.indices[0]) if isinstance(bc.indices, list) else bc.indices.shape[1] for bc in bclist if bc.indices is not None)
+ max_size = sum(
+ len(bc.indices[0]) if isinstance(bc.indices, (list, tuple)) else bc.indices.shape[1] for bc in bclist if bc.indices is not None
+ )
indices = np.zeros((3, max_size), dtype=np.int32)
id_numbers = np.zeros(max_size, dtype=np.uint8)
- is_interior = np.zeros(max_size, dtype=bool)
+ is_interior = np.zeros(max_size, dtype=np.uint8)
current_index = 0
for bc in bclist:
assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC using keyword "indices"!'
- assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!"
-
+ assert bc.mesh_vertices is None, (
+ f"Please use operators based on MeshBoundaryMasker if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!"
+ )
bc_indices = np.asarray(bc.indices)
num_indices = bc_indices.shape[1]
@@ -188,10 +331,7 @@ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
id_numbers[current_index : current_index + num_indices] = bc.id
# Set is_interior flags
- if bc.needs_padding:
- is_interior[current_index : current_index + num_indices] = self.are_indices_in_interior(bc_indices, bc_mask[0].shape)
- else:
- is_interior[current_index : current_index + num_indices] = False
+ is_interior[current_index : current_index + num_indices] = self.are_indices_in_interior(bc_indices, grid_shape)
current_index += num_indices
@@ -199,26 +339,232 @@ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
# bc.__dict__.pop("indices", None)
# Trim arrays to actual size
- indices = indices[:, :current_index]
- id_numbers = id_numbers[:current_index]
- is_interior = is_interior[:current_index]
+ total_index = current_index
+ indices = indices[:, :total_index]
+ id_numbers = id_numbers[:total_index]
+ is_interior = is_interior[:total_index]
# Convert to Warp arrays
- wp_indices = wp.array(indices, dtype=wp.int32)
- wp_id_numbers = wp.array(id_numbers, dtype=wp.uint8)
- wp_is_interior = wp.array(is_interior, dtype=wp.bool)
+ def _to_wp_arrays(indices, id_numbers, is_interior, device=None):
+ return (
+ wp.array(indices, dtype=wp.int32, device=device),
+ wp.array(id_numbers, dtype=wp.uint8, device=device),
+ wp.array(is_interior, dtype=wp.uint8, device=device),
+ )
+
+ if self.compute_backend == ComputeBackend.NEON:
+ grid = self.grid
+ ndevice = 1 if grid is None else grid.bk.get_num_devices()
+
+ if ndevice == 1:
+ return _to_wp_arrays(indices, id_numbers, is_interior)
+ else:
+ # For multi-device, we need to split the indices across devices
+ wp_bc_indices = []
+ wp_id_numbers = []
+ wp_is_interior = []
+ for i in range(ndevice):
+ device_name = grid.bk.get_device_name(i)
+ wp_bc_indices.append(wp.array(indices, dtype=wp.int32, device=device_name))
+ wp_id_numbers.append(wp.array(id_numbers, dtype=wp.uint8, device=device_name))
+ wp_is_interior.append(wp.array(is_interior, dtype=wp.uint8, device=device_name))
+ return wp_bc_indices, wp_id_numbers, wp_is_interior
+ else:
+ return _to_wp_arrays(indices, id_numbers, is_interior)
+
+ @Operator.register_backend(ComputeBackend.WARP)
+ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
+ # get the grid shape
+ grid_shape = self.helper_masker.get_grid_shape(bc_mask)
+
+ # Find interior boundary conditions
+ bc_interior = self._find_bclist_interior(bclist, grid_shape)
+
+ # Prepare the first kernel inputs for all items in boundary condition list
+ wp_bc_indices, wp_id_numbers, wp_is_interior = self._prepare_kernel_inputs(bclist, grid_shape, start_index)
# Launch the warp kernel
wp.launch(
- self.warp_kernel,
- dim=current_index,
+ self.warp_kernel["kernel_domain_bounds"],
+ dim=bc_mask.shape[1:],
+ inputs=[wp_bc_indices, wp_id_numbers, wp_is_interior, bc_mask, missing_mask, grid_shape],
+ )
+
+ # If there are no interior boundary conditions, skip the rest and retun early
+ if not bc_interior:
+ return bc_mask, missing_mask
+
+ # Prepare the second and third kernel inputs for only a subset of boundary conditions associated with the interior
+ # Note 1: launching order of the following kernels are important here!
+ # Note 2: Due to race conditioning, the two kernels cannot be fused together.
+ wp_bc_indices, wp_id_numbers, _ = self._prepare_kernel_inputs(bc_interior, grid_shape)
+ wp.launch(
+ self.warp_kernel["kernel_interior_missing_mask"],
+ dim=bc_mask.shape[1:],
+ inputs=[wp_bc_indices, bc_mask, missing_mask, grid_shape],
+ )
+ wp.launch(
+ self.warp_kernel["kernel_interior_bc_mask"],
+ dim=bc_mask.shape[1:],
inputs=[
- wp_indices,
+ wp_bc_indices,
wp_id_numbers,
- wp_is_interior,
bc_mask,
- missing_mask,
],
)
return bc_mask, missing_mask
+
+ def _construct_neon(self):
+ import neon
+
+ # Use the warp functional for the NEON backend
+ functional_dict, _ = self._construct_warp()
+ functional_domain_bounds = functional_dict.get("functional_domain_bounds")
+ functional_interior_bc_mask = functional_dict.get("functional_interior_bc_mask")
+ functional_interior_missing_mask = functional_dict.get("functional_interior_missing_mask")
+
+ @neon.Container.factory(name="IndicesBoundaryMasker_DomainBounds")
+ def container_domain_bounds(
+ wp_bc_indices_,
+ wp_id_numbers_,
+ wp_is_interior_,
+ bc_mask,
+ missing_mask,
+ grid_shape,
+ ):
+ def domain_bounds_launcher(loader: neon.Loader):
+ loader.set_grid(bc_mask.get_grid())
+ bc_mask_pn = loader.get_write_handle(bc_mask)
+ missing_mask_pn = loader.get_write_handle(missing_mask)
+ grid = bc_mask.get_grid()
+ bk = grid.backend
+ if bk.get_num_devices() == 1:
+ # If there is only one device, we can use the warp arrays directly
+ wp_bc_indices = wp_bc_indices_
+ wp_id_numbers = wp_id_numbers_
+ wp_is_interior = wp_is_interior_
+ else:
+ device_id = loader.get_device_id()
+ wp_bc_indices = wp_bc_indices_[device_id]
+ wp_id_numbers = wp_id_numbers_[device_id]
+ wp_is_interior = wp_is_interior_[device_id]
+
+ @wp.func
+ def domain_bounds_kernel(index: Any):
+ # apply the functional
+ functional_domain_bounds(
+ index,
+ wp_bc_indices,
+ wp_id_numbers,
+ wp_is_interior,
+ bc_mask_pn,
+ missing_mask_pn,
+ grid_shape,
+ )
+
+ loader.declare_kernel(domain_bounds_kernel)
+
+ return domain_bounds_launcher
+
+ @neon.Container.factory(name="IndicesBoundaryMasker_InteriorBcMask")
+ def container_interior_bc_mask(
+ wp_bc_indices,
+ wp_id_numbers,
+ bc_mask,
+ ):
+ def interior_bc_mask_launcher(loader: neon.Loader):
+ loader.set_grid(bc_mask.get_grid())
+ bc_mask_pn = loader.get_write_handle(bc_mask)
+
+ @wp.func
+ def interior_bc_mask_kernel(index: Any):
+ # apply the functional
+ functional_interior_bc_mask(
+ index,
+ wp_bc_indices,
+ wp_id_numbers,
+ bc_mask_pn,
+ )
+
+ loader.declare_kernel(interior_bc_mask_kernel)
+
+ return interior_bc_mask_launcher
+
+ @neon.Container.factory(name="IndicesBoundaryMasker_InteriorMissingMask")
+ def container_interior_missing_mask(
+ wp_bc_indices,
+ bc_mask,
+ missing_mask,
+ grid_shape,
+ ):
+ def interior_bc_mask_launcher(loader: neon.Loader):
+ loader.set_grid(bc_mask.get_grid())
+ bc_mask_pn = loader.get_write_handle(bc_mask)
+ missing_mask_pn = loader.get_write_handle(missing_mask)
+
+ @wp.func
+ def interior_missing_mask_kernel(index: Any):
+ # apply the functional
+ functional_interior_missing_mask(
+ index,
+ wp_bc_indices,
+ bc_mask_pn,
+ missing_mask_pn,
+ grid_shape,
+ )
+
+ loader.declare_kernel(interior_missing_mask_kernel)
+
+ return interior_bc_mask_launcher
+
+ container_dict = {
+ "container_domain_bounds": container_domain_bounds,
+ "container_interior_bc_mask": container_interior_bc_mask,
+ "container_interior_missing_mask": container_interior_missing_mask,
+ }
+
+ return functional_dict, container_dict
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
+ # get the grid shape
+ grid_shape = self.helper_masker.get_grid_shape(bc_mask)
+
+ # Find interior boundary conditions
+ bc_interior = self._find_bclist_interior(bclist, grid_shape)
+
+ # Prepare the first kernel inputs for all items in boundary condition list
+ wp_bc_indices, wp_id_numbers, wp_is_interior = self._prepare_kernel_inputs(bclist, grid_shape, start_index)
+
+ # Launch the first container
+ container_domain_bounds = self.neon_container["container_domain_bounds"](
+ wp_bc_indices,
+ wp_id_numbers,
+ wp_is_interior,
+ bc_mask,
+ missing_mask,
+ grid_shape,
+ )
+ container_domain_bounds.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+
+ # If there are no interior boundary conditions, skip the rest and retun early
+ if not bc_interior:
+ return bc_mask, missing_mask
+
+ # Prepare the second and third kernel inputs for only a subset of boundary conditions associated with the interior
+ # Note 1: launching order of the following kernels are important here!
+ # Note 2: Due to race conditioning, the two kernels cannot be fused together.
+ wp_bc_indices, wp_id_numbers, _ = self._prepare_kernel_inputs(bc_interior, grid_shape)
+ container_interior_missing_mask = self.neon_container["container_interior_missing_mask"](wp_bc_indices, bc_mask, missing_mask, grid_shape)
+ container_interior_missing_mask.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+
+ # Launch the third container
+ container_interior_bc_mask = self.neon_container["container_interior_bc_mask"](
+ wp_bc_indices,
+ wp_id_numbers,
+ bc_mask,
+ )
+ container_interior_bc_mask.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+
+ return bc_mask, missing_mask
diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py
index 40dd0311..c6fb778d 100644
--- a/xlb/operator/boundary_masker/mesh_boundary_masker.py
+++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py
@@ -1,63 +1,71 @@
-# Base class for all equilibriums
+"""
+Abstract base class for mesh-based boundary maskers.
+
+Provides shared input preparation logic (mesh construction, kernel arrays)
+used by AABB, Ray, Winding, and AABB-Close masker subclasses.
+"""
import numpy as np
import warp as wp
-import jax
+from typing import Any
from xlb.velocity_set.velocity_set import VelocitySet
from xlb.precision_policy import PrecisionPolicy
from xlb.compute_backend import ComputeBackend
from xlb.operator.operator import Operator
+from xlb.operator.boundary_masker.helper_functions_masker import HelperFunctionsMasker
class MeshBoundaryMasker(Operator):
"""
- Operator for creating a boundary missing_mask from an STL file
+ Operator for creating a boundary missing_mask from a mesh file
"""
def __init__(
self,
- velocity_set: VelocitySet,
- precision_policy: PrecisionPolicy,
- compute_backend: ComputeBackend.WARP,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
):
# Call super
super().__init__(velocity_set, precision_policy, compute_backend)
+ assert self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON], (
+ f"MeshBoundaryMasker is only implemented for {ComputeBackend.WARP} and {ComputeBackend.NEON} backends!"
+ )
+
+ assert self.velocity_set.d == 3, "MeshBoundaryMasker is only implemented for 3D velocity sets!"
# Raise error if used for 2d examples:
if self.velocity_set.d == 2:
raise NotImplementedError("This Operator is not implemented in 2D!")
- # Also using Warp kernels for JAX implementation
- if self.compute_backend == ComputeBackend.JAX:
- self.warp_functional, self.warp_kernel = self._construct_warp()
-
- @Operator.register_backend(ComputeBackend.JAX)
- def jax_implementation(
- self,
- bc,
- bc_mask,
- missing_mask,
- ):
- raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!")
- # Use Warp backend even for this particular operation.
- wp.init()
- bc_mask = wp.from_jax(bc_mask)
- missing_mask = wp.from_jax(missing_mask)
- bc_mask, missing_mask = self.warp_implementation(bc, bc_mask, missing_mask)
- return wp.to_jax(bc_mask), wp.to_jax(missing_mask)
-
- def _construct_warp(self):
# Make constants for warp
- _c_float = self.velocity_set.c_float
- _q = wp.constant(self.velocity_set.q)
- _opp_indices = self.velocity_set.opp_indices
+ _c = self.velocity_set.c
+ _q = self.velocity_set.q
+
+ if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]:
+ # Define masker helper functions
+ self.helper_masker = HelperFunctionsMasker(
+ velocity_set=self.velocity_set,
+ precision_policy=self.precision_policy,
+ compute_backend=self.compute_backend,
+ )
@wp.func
- def index_to_position(index: wp.vec3i):
- # position of the point
- ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2]))
- pos = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center
- return pos
+ def out_of_bound_pull_index(
+ lattice_dir: wp.int32,
+ index: wp.vec3i,
+ field: wp.array4d(dtype=wp.uint8),
+ grid_shape: wp.vec3i,
+ ):
+ # Get the index of the streaming direction
+ pull_index = wp.vec3i()
+ for d in range(self.velocity_set.d):
+ pull_index[d] = index[d] - _c[d, lattice_dir]
+
+ # check if pull index is out of bound
+ # These directions will have missing information after streaming
+ missing = not self.helper_masker.is_in_bounds(pull_index, grid_shape)
+ return missing
# Function to precompute useful values per triangle, assuming spacing is (1,1,1)
# inputs: verts: triangle vertices, normal: triangle unit normal
@@ -78,6 +86,7 @@ def pre_compute(
dist_edge = wp.mat33f(0.0)
for axis0 in range(0, 3):
+ axis1 = (axis0 + 1) % 3
axis2 = (axis0 + 2) % 3
sgn = 1.0
@@ -85,21 +94,18 @@ def pre_compute(
sgn = -1.0
for i in range(0, 3):
- normal_edge0[i][axis0] = -1.0 * sgn * edges[i][axis0]
- normal_edge1[i][axis0] = sgn * edges[i][axis0]
+ normal_edge0[i, axis0] = -1.0 * sgn * edges[i, axis1]
+ normal_edge1[i, axis0] = sgn * edges[i, axis0]
- dist_edge[i][axis0] = (
- -1.0 * (normal_edge0[i][axis0] * verts[i][axis0] + normal_edge1[i][axis0] * verts[i][axis0])
- + wp.max(0.0, normal_edge0[i][axis0])
- + wp.max(0.0, normal_edge1[i][axis0])
+ dist_edge[i, axis0] = (
+ -1.0 * (normal_edge0[i, axis0] * verts[i, axis0] + normal_edge1[i, axis0] * verts[i, axis1])
+ + wp.max(0.0, normal_edge0[i, axis0])
+ + wp.max(0.0, normal_edge1[i, axis0])
)
return dist1, dist2, normal_edge0, normal_edge1, dist_edge
# Check whether this triangle intersects the unit cube at position low
- # inputs: low: position of the cube, normal: triangle unit normal, dist1, dist2, normal_edge0, normal_edge1, dist_edge: precomputed values
- # outputs: True if intersection, False otherwise
- # reference: Fast parallel surface and solid voxelization on GPUs, M. Schwarz, H-P. Siedel, https://dl.acm.org/doi/10.1145/1882261.1866201
@wp.func
def triangle_box_intersect(
low: wp.vec3f,
@@ -116,7 +122,7 @@ def triangle_box_intersect(
for ax0 in range(0, 3):
ax1 = (ax0 + 1) % 3
for i in range(0, 3):
- intersect = intersect and (normal_edge0[i][ax0] * low[ax0] + normal_edge1[i][ax0] * low[ax1] + dist_edge[i][ax0] >= 0.0)
+ intersect = intersect and (normal_edge0[i, ax0] * low[ax0] + normal_edge1[i, ax0] * low[ax1] + dist_edge[i, ax0] >= 0.0)
return intersect
else:
@@ -147,15 +153,11 @@ def mesh_voxel_intersect(mesh_id: wp.uint64, low: wp.vec3):
return False
- # Construct the warp kernel
- # Do voxelization mesh query (warp.mesh_query_aabb) to find solid voxels
- # - this gives an approximate 1 voxel thick surface around mesh
@wp.kernel
- def kernel(
- mesh_id: wp.uint64,
+ def resolve_out_of_bound_kernel(
id_number: wp.int32,
bc_mask: wp.array4d(dtype=wp.uint8),
- missing_mask: wp.array4d(dtype=wp.bool),
+ missing_mask: wp.array4d(dtype=wp.uint8),
):
# get index
i, j, k = wp.tid()
@@ -163,74 +165,81 @@ def kernel(
# Get local indices
index = wp.vec3i(i, j, k)
- # position of the point
- pos_bc_cell = index_to_position(index)
- half = wp.vec3(0.5, 0.5, 0.5)
+ # domain shape to check for out of bounds
+ grid_shape = wp.vec3i(bc_mask.shape[1], bc_mask.shape[2], bc_mask.shape[3])
- if mesh_voxel_intersect(mesh_id=mesh_id, low=pos_bc_cell - half):
- # Make solid voxel
- bc_mask[0, index[0], index[1], index[2]] = wp.uint8(255)
- else:
- # Find the fractional distance to the mesh in each direction
+ # Find the fractional distance to the mesh in each direction
+ if bc_mask[0, index[0], index[1], index[2]] == wp.uint8(id_number):
for l in range(1, _q):
- _dir = wp.vec3f(_c_float[0, l], _c_float[1, l], _c_float[2, l])
-
- # Check to see if this neighbor is solid - this is super inefficient TODO: make it way better
- if mesh_voxel_intersect(mesh_id=mesh_id, low=pos_bc_cell + _dir - half):
- # We know we have a solid neighbor
- # Set the boundary id and missing_mask
- bc_mask[0, index[0], index[1], index[2]] = wp.uint8(id_number)
- missing_mask[_opp_indices[l], index[0], index[1], index[2]] = True
+ # Ensuring out of bound pull indices are properly considered in the missing_mask
+ if out_of_bound_pull_index(l, index, missing_mask, grid_shape):
+ missing_mask[l, index[0], index[1], index[2]] = wp.uint8(True)
- return None, kernel
+ # Construct some helper warp functions
+ self.mesh_voxel_intersect = mesh_voxel_intersect
+ self.resolve_out_of_bound_kernel = resolve_out_of_bound_kernel
- @Operator.register_backend(ComputeBackend.WARP)
- def warp_implementation(
+ def _prepare_kernel_inputs(
self,
bc,
bc_mask,
- missing_mask,
):
assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!'
assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!"
assert bc.mesh_vertices.shape[1] == self.velocity_set.d, (
"Mesh points must be reshaped into an array (N, 3) where N indicates number of points!"
)
- mesh_vertices = bc.mesh_vertices
- id_number = bc.id
- # Check mesh extents against domain dimensions
- domain_shape = bc_mask.shape[1:] # (nx, ny, nz)
+ grid_shape = self.helper_masker.get_grid_shape(bc_mask) # (nx, ny, nz)
+ mesh_vertices = bc.mesh_vertices
mesh_min = np.min(mesh_vertices, axis=0)
mesh_max = np.max(mesh_vertices, axis=0)
- if any(mesh_min < 0) or any(mesh_max >= domain_shape):
+ if any(mesh_min < 0) or any(mesh_max >= grid_shape):
raise ValueError(
- f"Mesh extents ({mesh_min}, {mesh_max}) exceed domain dimensions {domain_shape}. The mesh must be fully contained within the domain."
+ f"Mesh extents ({mesh_min}, {mesh_max}) exceed domain dimensions {grid_shape}. The mesh must be fully contained within the domain."
)
# We are done with bc.mesh_vertices. Remove them from BC objects
bc.__dict__.pop("mesh_vertices", None)
- # Ensure this masker is called only for BCs that need implicit distance to the mesh
- assert not bc.needs_mesh_distance, 'Please use "MeshDistanceBoundaryMasker" if this BC needs mesh distance!'
-
mesh_indices = np.arange(mesh_vertices.shape[0])
mesh = wp.Mesh(
points=wp.array(mesh_vertices, dtype=wp.vec3),
- indices=wp.array(mesh_indices, dtype=int),
+ indices=wp.array(mesh_indices, dtype=wp.int32),
)
+ mesh_id = wp.uint64(mesh.id)
+ bc_id = bc.id
+ return mesh_id, bc_id
+
+ @Operator.register_backend(ComputeBackend.JAX)
+ def jax_implementation(
+ self,
+ bc,
+ bc_mask,
+ missing_mask,
+ ):
+ raise NotImplementedError(f"Operation {self.__class__.__name__} not implemented in JAX!")
- # Launch the warp kernel
+ def warp_implementation_base(
+ self,
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ ):
+ # Prepare inputs
+ mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask)
+
+ # Launch the appropriate warp kernel
wp.launch(
self.warp_kernel,
- inputs=[
- mesh.id,
- id_number,
- bc_mask,
- missing_mask,
- ],
+ inputs=[mesh_id, bc_id, distances, bc_mask, missing_mask, wp.static(bc.needs_mesh_distance)],
dim=bc_mask.shape[1:],
)
-
- return bc_mask, missing_mask
+ wp.launch(
+ self.resolve_out_of_bound_kernel,
+ inputs=[bc_id, bc_mask, missing_mask],
+ dim=bc_mask.shape[1:],
+ )
+ return distances, bc_mask, missing_mask
diff --git a/xlb/operator/boundary_masker/mesh_voxelization_method.py b/xlb/operator/boundary_masker/mesh_voxelization_method.py
new file mode 100644
index 00000000..b0162de7
--- /dev/null
+++ b/xlb/operator/boundary_masker/mesh_voxelization_method.py
@@ -0,0 +1,55 @@
+"""
+Mesh voxelization method registry.
+
+Defines the available voxelization strategies (AABB, Ray, AABB-Close,
+Winding) and provides a factory function to create the corresponding
+:class:`VoxelizationMethod` data object.
+"""
+
+from dataclasses import dataclass
+
+
+METHODS = {
+ "AABB": 1,
+ "RAY": 2,
+ "AABB_CLOSE": 3,
+ "WINDING": 4,
+}
+
+
+@dataclass
+class VoxelizationMethod:
+ """Describes a mesh voxelization strategy.
+
+ Attributes
+ ----------
+ id : int
+ Numeric identifier for the method.
+ name : str
+ Human-readable name (``"AABB"``, ``"RAY"``, etc.).
+ options : dict
+ Extra options (e.g. ``close_voxels`` for AABB_CLOSE).
+ """
+
+ id: int
+ name: str
+ options: dict
+
+
+def MeshVoxelizationMethod(name: str, **options):
+ """Create a :class:`VoxelizationMethod` by name.
+
+ Parameters
+ ----------
+ name : str
+ One of ``"AABB"``, ``"RAY"``, ``"AABB_CLOSE"``, ``"WINDING"``.
+ **options
+ Additional keyword arguments forwarded to
+ ``VoxelizationMethod.options``.
+
+ Returns
+ -------
+ VoxelizationMethod
+ """
+ assert name in METHODS.keys(), f"Unsupported voxelization method: {name}"
+ return VoxelizationMethod(METHODS[name], name, options)
diff --git a/xlb/operator/boundary_masker/multires_aabb.py b/xlb/operator/boundary_masker/multires_aabb.py
new file mode 100644
index 00000000..f9cc5887
--- /dev/null
+++ b/xlb/operator/boundary_masker/multires_aabb.py
@@ -0,0 +1,99 @@
+"""
+Multi-resolution AABB mesh-based boundary masker for the Neon backend.
+"""
+
+import warp as wp
+from typing import Any
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.boundary_masker import MeshMaskerAABB
+from xlb.operator.operator import Operator
+
+
+class MultiresMeshMaskerAABB(MeshMaskerAABB):
+ """
+ Operator for creating boundary missing_mask from mesh using Axis-Aligned Bounding Box (AABB) voxelization in multiresolution simulations.
+
+ This implementation uses warp.mesh_query_aabb for efficient mesh-voxel intersection testing,
+ providing approximate 1-voxel thick surface detection around the mesh geometry.
+ Suitable for scenarios where fast, approximate boundary detection is sufficient.
+ TODO@Hesam:
+ Right now, we cannot properly mask a mesh file if it lives on any level other than the finest. This issue can be easily solved by adding
+ gx = wp.neon_get_x(cIdx) // 2 ** level
+ gy = wp.neon_get_y(cIdx) // 2 ** level
+ gz = wp.neon_get_z(cIdx) // 2 ** level
+ to the "neon_index_to_warp" and subsequently add "level" to the arguments of "index_to_position_neon", "get_pull_index_neon" and
+ "is_in_bc_indices_neon". In order to extract "level" from the "neon_field_hdl" we can use the function wp.neon_level(neon_field_hdl).
+ """
+
+ def __init__(
+ self,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ ):
+ # Call super
+ super().__init__(velocity_set, precision_policy, compute_backend)
+ if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]:
+ raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.")
+
+ def _construct_neon(self):
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+
+ @neon.Container.factory(name="MeshMaskerAABB")
+ def container(
+ mesh_id: Any,
+ id_number: Any,
+ distances: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ needs_mesh_distance: Any,
+ level: Any,
+ ):
+ def aabb_launcher(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask.get_grid(), level)
+ distances_pn = loader.get_mres_write_handle(distances)
+ bc_mask_pn = loader.get_mres_write_handle(bc_mask)
+ missing_mask_pn = loader.get_mres_write_handle(missing_mask)
+
+ @wp.func
+ def aabb_kernel(index: Any):
+ # apply the functional
+ functional(
+ index,
+ mesh_id,
+ id_number,
+ distances_pn,
+ bc_mask_pn,
+ missing_mask_pn,
+ needs_mesh_distance,
+ )
+
+ loader.declare_kernel(aabb_kernel)
+
+ return aabb_launcher
+
+ return functional, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(
+ self,
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ stream=0,
+ ):
+ import neon
+
+ # Prepare inputs
+ mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask)
+
+ grid = bc_mask.get_grid()
+ for level in range(grid.num_levels):
+ # Launch the neon container
+ c = self.neon_container(mesh_id, bc_id, distances, bc_mask, missing_mask, wp.static(bc.needs_mesh_distance), level)
+ c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon)
+ return distances, bc_mask, missing_mask
diff --git a/xlb/operator/boundary_masker/multires_aabb_close.py b/xlb/operator/boundary_masker/multires_aabb_close.py
new file mode 100644
index 00000000..df461706
--- /dev/null
+++ b/xlb/operator/boundary_masker/multires_aabb_close.py
@@ -0,0 +1,273 @@
+"""
+Multi-resolution AABB-Close boundary masker with morphological closing.
+
+Extends the AABB-Close masker for Neon multi-resolution grids, applying
+dilate-then-erode operations to fill narrow channels with solid voxels.
+"""
+
+import warp as wp
+from typing import Any
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.boundary_masker import MeshMaskerAABBClose
+from xlb.operator.operator import Operator
+from xlb.cell_type import BC_SOLID
+
+
+class MultiresMeshMaskerAABBClose(MeshMaskerAABBClose):
+ """
+ Operator for creating boundary missing_mask from mesh using Axis-Aligned Bounding Box (AABB) voxelization
+ in multiresolution simulations (NEON backend). It takes in a number of close_voxels to perform morphological
+ operations (dilate followed by erode) to ensure small channels are filled with solid voxels.
+
+ This version provides NEON-specific functionals working on multires partitions (mPartition) and bIndex.
+ """
+
+ def __init__(
+ self,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ close_voxels: int = None,
+ ):
+ super().__init__(velocity_set, precision_policy, compute_backend, close_voxels)
+ if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]:
+ raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.")
+
+ # Build and store NEON dicts
+ self.neon_functional_dict, self.neon_container_dict = self._construct_neon()
+
+ def _construct_neon(self):
+ import neon
+
+ # Use the warp functionals from the base (for reference), but implement NEON variants here
+ functional_dict_warp, _ = self._construct_warp()
+ functional_erode_warp = functional_dict_warp.get("functional_erode")
+ functional_dilate_warp = functional_dict_warp.get("functional_dilate")
+ functional_solid = functional_dict_warp.get("functional_solid")
+ # We will not directly reuse functional_solid / functional_aabb from warp; we write NEON-specific ones.
+
+ # We also need lattice info for neighbor iteration
+ _c = self.velocity_set.c
+ _q = self.velocity_set.q
+ _opp_indices = self.velocity_set.opp_indices
+
+ # Set local constants
+ lattice_central_index = self.velocity_set.center_index
+
+ # Main AABB close: sets bc_mask, missing_mask, distances based on solid_mask
+ # bc_mask: wp.uint8, missing_mask: wp.uint8, distances: dtype from precision policy (float)
+ @wp.func
+ def mres_functional_aabb(
+ index: Any,
+ mesh_id: wp.uint64,
+ id_number: wp.int32,
+ distances_pn: Any, # mPartition(dtype=distance type), cardinality=_q
+ bc_mask_pn: Any, # mPartition_uint8, cardinality=1
+ missing_mask_pn: Any, # mPartition_uint8, cardinality=_q
+ solid_mask_pn: Any, # mPartition_uint8, cardinality=1
+ needs_mesh_distance: bool,
+ ):
+ # Cell center from bc_mask partition
+ cell_center = self.helper_masker.index_to_position(bc_mask_pn, index)
+
+ # If already solid or bc, mark solid
+ solid_val = wp.neon_read(solid_mask_pn, index, 0)
+ bc_val = wp.neon_read(bc_mask_pn, index, 0)
+ if solid_val == wp.uint8(BC_SOLID) or bc_val == wp.uint8(BC_SOLID):
+ wp.neon_write(bc_mask_pn, index, 0, wp.uint8(BC_SOLID))
+ return
+
+ # loop lattice directions
+ for direction_idx in range(_q):
+ # skip central if provided by velocity set
+ if direction_idx == lattice_central_index:
+ continue
+
+ # If neighbor index is valid at this resolution level
+ ngh = wp.neon_ngh_idx(wp.int8(_c[0, direction_idx]), wp.int8(_c[1, direction_idx]), wp.int8(_c[2, direction_idx]))
+ is_valid = wp.bool(False)
+ nval = wp.neon_read_ngh(solid_mask_pn, index, ngh, 0, wp.uint8(0), is_valid)
+ if is_valid:
+ if nval == wp.uint8(BC_SOLID):
+ # Found solid neighbor -> boundary cell
+ self.write_field(bc_mask_pn, index, 0, wp.uint8(id_number))
+ self.write_field(missing_mask_pn, index, _opp_indices[direction_idx], wp.uint8(True))
+
+ if not needs_mesh_distance:
+ # No distance needed; continue to next direction
+ continue
+
+ # Compute mesh distance along lattice direction
+ dir_vec = wp.vec3f(
+ wp.float32(_c[0, direction_idx]),
+ wp.float32(_c[1, direction_idx]),
+ wp.float32(_c[2, direction_idx]),
+ )
+ max_length = wp.length(dir_vec)
+ # Avoid division by zero for any pathological dir (shouldn't happen)
+ norm_dir = dir_vec / (max_length if max_length > 0.0 else 1.0)
+ query = wp.mesh_query_ray(mesh_id, cell_center, norm_dir, 1.5 * max_length)
+ if query.result:
+ pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v)
+ dist = wp.length(pos_mesh - cell_center) - 0.5 * max_length
+ weight = dist / (max_length if max_length > 0.0 else 1.0)
+ # distances has cardinality _q; store into this channel
+ self.write_field(distances_pn, index, direction_idx, self.store_dtype(weight))
+ else:
+ self.write_field(distances_pn, index, direction_idx, self.store_dtype(1.0))
+
+ # Containers
+
+ # Erode: f_field -> f_field_out
+ @neon.Container.factory(name="Erode")
+ def container_erode(f_field: wp.array3d(dtype=Any), f_field_out: wp.array3d(dtype=Any), level: int):
+ def erode_launcher(loader: neon.Loader):
+ loader.set_mres_grid(f_field.get_grid(), level)
+ f_field_pn = loader.get_mres_read_handle(f_field)
+ f_field_out_pn = loader.get_mres_write_handle(f_field_out)
+
+ @wp.func
+ def erode_kernel(index: Any):
+ functional_erode_warp(index, f_field_pn, f_field_out_pn)
+
+ loader.declare_kernel(erode_kernel)
+
+ return erode_launcher
+
+ # Dilate: f_field -> f_field_out
+ @neon.Container.factory(name="Dilate")
+ def container_dilate(f_field: wp.array3d(dtype=Any), f_field_out: wp.array3d(dtype=Any), level: int):
+ def dilate_launcher(loader: neon.Loader):
+ loader.set_mres_grid(f_field.get_grid(), level)
+ f_field_pn = loader.get_mres_read_handle(f_field)
+ f_field_out_pn = loader.get_mres_write_handle(f_field_out)
+
+ @wp.func
+ def dilate_kernel(index: Any):
+ functional_dilate_warp(index, f_field_pn, f_field_out_pn)
+
+ loader.declare_kernel(dilate_kernel)
+
+ return dilate_launcher
+
+ # Solid mask: voxelize mesh into solid_mask
+ @neon.Container.factory(name="Solid")
+ def container_solid(mesh_id: wp.uint64, solid_mask: wp.array3d(dtype=wp.uint8), level: int):
+ def solid_launcher(loader: neon.Loader):
+ loader.set_mres_grid(solid_mask.get_grid(), level)
+ solid_mask_pn = loader.get_mres_write_handle(solid_mask)
+
+ @wp.func
+ def solid_kernel(index: Any):
+ # apply the functional
+ functional_solid(index, mesh_id, solid_mask_pn, wp.vec3f(0.0, 0.0, 0.0))
+
+ loader.declare_kernel(solid_kernel)
+
+ return solid_launcher
+
+ # Main AABB container
+ @neon.Container.factory(name="MeshMaskerAABBClose")
+ def container(
+ mesh_id: Any,
+ id_number: Any,
+ distances: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ solid_mask: Any,
+ needs_mesh_distance: Any,
+ level: Any,
+ ):
+ def aabb_launcher(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask.get_grid(), level)
+ distances_pn = loader.get_mres_write_handle(distances)
+ bc_mask_pn = loader.get_mres_write_handle(bc_mask)
+ missing_mask_pn = loader.get_mres_write_handle(missing_mask)
+ solid_mask_pn = loader.get_mres_write_handle(solid_mask)
+
+ @wp.func
+ def aabb_kernel(index: Any):
+ mres_functional_aabb(
+ index,
+ mesh_id,
+ id_number,
+ distances_pn,
+ bc_mask_pn,
+ missing_mask_pn,
+ solid_mask_pn,
+ needs_mesh_distance,
+ )
+
+ loader.declare_kernel(aabb_kernel)
+
+ return aabb_launcher
+
+ container_dict = {
+ "container_erode": container_erode,
+ "container_dilate": container_dilate,
+ "container_solid": container_solid,
+ "container_aabb": container,
+ }
+
+ # Expose NEON functionals too (in case callers want to reuse)
+ functional_dict = {
+ "mres_functional_aabb": mres_functional_aabb,
+ }
+
+ return functional_dict, container_dict
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(
+ self,
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ stream=0,
+ ):
+ # Prepare inputs
+ mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask)
+
+ grid = bc_mask.get_grid()
+ # Create fields using new_field
+ solid_mask = grid.new_field(cardinality=1, dtype=wp.uint8, memory_type=neon.MemoryType.device())
+ solid_mask_out = grid.new_field(
+ cardinality=1,
+ dtype=wp.uint8,
+ memory_type=neon.MemoryType.device(),
+ # memory_type=neon.MemoryType.host_device()
+ )
+
+ for level in range(grid.num_levels):
+ # Initialize to 0
+ solid_mask.fill_run(level=level, value=wp.uint8(0), stream_idx=stream)
+ solid_mask_out.fill_run(level=level, value=wp.uint8(0), stream_idx=stream)
+
+ # Launch the neon containers
+ container_solid = self.neon_container_dict["container_solid"](mesh_id, solid_mask, level)
+ container_solid.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+
+ for _ in range(self.close_voxels):
+ container_dilate = self.neon_container_dict["container_dilate"](solid_mask, solid_mask_out, level)
+ container_dilate.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+ solid_mask, solid_mask_out = solid_mask_out, solid_mask
+
+ if self.close_voxels % 2 > 0:
+ solid_mask, solid_mask_out = solid_mask_out, solid_mask
+
+ for _ in range(self.close_voxels):
+ container_erode = self.neon_container_dict["container_erode"](solid_mask_out, solid_mask, level)
+ container_erode.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+ solid_mask, solid_mask_out = solid_mask_out, solid_mask
+
+ if self.close_voxels % 2 > 0:
+ solid_mask, solid_mask_out = solid_mask_out, solid_mask
+
+ container_aabb = self.neon_container_dict["container_aabb"](
+ mesh_id, bc_id, distances, bc_mask, missing_mask, solid_mask, wp.static(bc.needs_mesh_distance), level
+ )
+ container_aabb.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+
+ return distances, bc_mask, missing_mask
diff --git a/xlb/operator/boundary_masker/multires_indices_boundary_masker.py b/xlb/operator/boundary_masker/multires_indices_boundary_masker.py
new file mode 100644
index 00000000..bf7fc1d7
--- /dev/null
+++ b/xlb/operator/boundary_masker/multires_indices_boundary_masker.py
@@ -0,0 +1,210 @@
+"""
+Multi-resolution indices-based boundary masker for the Neon backend.
+
+Creates boundary masks from explicit voxel indices on multi-resolution
+grids, computing missing-population masks for each tagged voxel.
+"""
+
+from typing import Any
+import copy
+import numpy as np
+
+import warp as wp
+
+from xlb.operator.operator import Operator
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.boundary_masker import IndicesBoundaryMasker
+
+
+class MultiresIndicesBoundaryMasker(IndicesBoundaryMasker):
+ """
+ Operator for creating a boundary mask using indices of boundary conditions in a multi-resolution setting.
+ """
+
+ def __init__(
+ self,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ ):
+ # Call super
+ super().__init__(velocity_set, precision_policy, compute_backend)
+ if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]:
+ raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.")
+
+ def _construct_neon(self):
+ # Use the warp functional for the NEON backend
+ functional_dict, _ = self._construct_warp()
+ functional_domain_bounds = functional_dict.get("functional_domain_bounds")
+ functional_interior_bc_mask = functional_dict.get("functional_interior_bc_mask")
+ functional_interior_missing_mask = functional_dict.get("functional_interior_missing_mask")
+
+ @neon.Container.factory(name="IndicesBoundaryMasker_DomainBounds")
+ def container_domain_bounds(
+ wp_bc_indices,
+ wp_id_numbers,
+ wp_is_interior,
+ bc_mask,
+ missing_mask,
+ grid_shape,
+ level,
+ ):
+ def domain_bounds_launcher(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask.get_grid(), level)
+ bc_mask_pn = loader.get_mres_write_handle(bc_mask)
+ missing_mask_pn = loader.get_mres_write_handle(missing_mask)
+
+ @wp.func
+ def domain_bounds_kernel(index: Any):
+ # apply the functional
+ functional_domain_bounds(
+ index,
+ wp_bc_indices,
+ wp_id_numbers,
+ wp_is_interior,
+ bc_mask_pn,
+ missing_mask_pn,
+ grid_shape,
+ level,
+ )
+
+ loader.declare_kernel(domain_bounds_kernel)
+
+ return domain_bounds_launcher
+
+ @neon.Container.factory(name="IndicesBoundaryMasker_InteriorBcMask")
+ def container_interior_bc_mask(
+ wp_bc_indices,
+ wp_id_numbers,
+ bc_mask,
+ level,
+ ):
+ def interior_bc_mask_launcher(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask.get_grid(), level)
+ bc_mask_pn = loader.get_mres_write_handle(bc_mask)
+
+ @wp.func
+ def interior_bc_mask_kernel(index: Any):
+ # apply the functional
+ functional_interior_bc_mask(
+ index,
+ wp_bc_indices,
+ wp_id_numbers,
+ bc_mask_pn,
+ )
+
+ loader.declare_kernel(interior_bc_mask_kernel)
+
+ return interior_bc_mask_launcher
+
+ @neon.Container.factory(name="IndicesBoundaryMasker_InteriorMissingMask")
+ def container_interior_missing_mask(
+ wp_bc_indices,
+ bc_mask,
+ missing_mask,
+ grid_shape,
+ level,
+ ):
+ def interior_bc_mask_launcher(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask.get_grid(), level)
+ bc_mask_pn = loader.get_mres_write_handle(bc_mask)
+ missing_mask_pn = loader.get_mres_write_handle(missing_mask)
+
+ @wp.func
+ def interior_missing_mask_kernel(index: Any):
+ # apply the functional
+ functional_interior_missing_mask(
+ index,
+ wp_bc_indices,
+ bc_mask_pn,
+ missing_mask_pn,
+ grid_shape,
+ level,
+ )
+
+ loader.declare_kernel(interior_missing_mask_kernel)
+
+ return interior_bc_mask_launcher
+
+ container_dict = {
+ "container_domain_bounds": container_domain_bounds,
+ "container_interior_bc_mask": container_interior_bc_mask,
+ "container_interior_missing_mask": container_interior_missing_mask,
+ }
+
+ return functional_dict, container_dict
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, bclist, bc_mask, missing_mask, start_index=None):
+ import neon
+
+ grid = bc_mask.get_grid()
+ num_levels = grid.num_levels
+ grid_shape_finest = self.helper_masker.get_grid_shape(bc_mask)
+ for level in range(num_levels):
+ # Create a copy of the boundary condition list for the current level if the indices at that level are not empty
+ bclist_at_level = []
+ for bc in bclist:
+ if bc.indices is not None and bc.indices[level]:
+ bc_copy = copy.copy(bc) # shallow copy of the whole object
+ indices = copy.deepcopy(bc.indices[level]) # deep copy only the modified part
+ indices = np.array(indices) * 2**level # TODO: This is a hack
+ bc_copy.indices = tuple(indices.tolist()) # convert to tuple
+ bclist_at_level.append(bc_copy)
+
+ # If the boundary condition list is empty, skip to the next level
+ if not bclist_at_level:
+ continue
+
+ # find grid shape at current level
+ # TODO: this is a hack. Should be corrected in the helper function when getting neon global indices
+ grid_shape_at_level = tuple([shape // 2**level for shape in grid_shape_finest])
+ grid_shape_finest_warp = wp.vec3i(*grid_shape_finest)
+
+ # find interior boundary conditions
+ bc_interior = self._find_bclist_interior(bclist_at_level, grid_shape_at_level)
+
+ # Prepare the first kernel inputs for all items in boundary condition list
+ wp_bc_indices, wp_id_numbers, wp_is_interior = self._prepare_kernel_inputs(bclist_at_level, grid_shape_at_level)
+
+ # Launch the first container
+ container_domain_bounds = self.neon_container["container_domain_bounds"](
+ wp_bc_indices,
+ wp_id_numbers,
+ wp_is_interior,
+ bc_mask,
+ missing_mask,
+ grid_shape_finest_warp,
+ level,
+ )
+ container_domain_bounds.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+
+ # If there are no interior boundary conditions, skip the rest of the processing for this level
+ if not bc_interior:
+ continue
+
+ # Prepare the second and third kernel inputs for only a subset of boundary conditions associated with the interior
+ # Note 1: launching order of the following kernels are important here!
+ # Note 2: Due to race conditioning, the two kernels cannot be fused together.
+ wp_bc_indices, wp_id_numbers, _ = self._prepare_kernel_inputs(bc_interior, grid_shape_at_level)
+ container_interior_missing_mask = self.neon_container["container_interior_missing_mask"](
+ wp_bc_indices,
+ bc_mask,
+ missing_mask,
+ grid_shape_finest_warp,
+ level,
+ )
+ container_interior_missing_mask.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+
+ # Launch the third container
+ container_interior_bc_mask = self.neon_container["container_interior_bc_mask"](
+ wp_bc_indices,
+ wp_id_numbers,
+ bc_mask,
+ level,
+ )
+ container_interior_bc_mask.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+
+ return bc_mask, missing_mask
diff --git a/xlb/operator/boundary_masker/multires_ray.py b/xlb/operator/boundary_masker/multires_ray.py
new file mode 100644
index 00000000..9a5a01d8
--- /dev/null
+++ b/xlb/operator/boundary_masker/multires_ray.py
@@ -0,0 +1,90 @@
+"""
+Multi-resolution ray-cast mesh-based boundary masker for the Neon backend.
+"""
+
+import warp as wp
+from typing import Any
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.boundary_masker import MeshMaskerRay
+from xlb.operator.operator import Operator
+
+
+class MultiresMeshMaskerRay(MeshMaskerRay):
+ """
+ Operator for creating a boundary missing_mask from an STL file in multiresolution simulations.
+
+ This implementation uses warp.mesh_query_ray for efficient mesh-voxel intersection testing.
+ """
+
+ def __init__(
+ self,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ ):
+ # Call super
+ super().__init__(velocity_set, precision_policy, compute_backend)
+ if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]:
+ raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.")
+
+ def _construct_neon(self):
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+
+ @neon.Container.factory(name="MeshMaskerRay")
+ def container(
+ mesh_id: Any,
+ id_number: Any,
+ distances: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ needs_mesh_distance: Any,
+ level: Any,
+ ):
+ def ray_launcher(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask.get_grid(), level)
+ distances_pn = loader.get_mres_write_handle(distances)
+ bc_mask_pn = loader.get_mres_write_handle(bc_mask)
+ missing_mask_pn = loader.get_mres_write_handle(missing_mask)
+
+ @wp.func
+ def ray_kernel(index: Any):
+ # apply the functional
+ functional(
+ index,
+ mesh_id,
+ id_number,
+ distances_pn,
+ bc_mask_pn,
+ missing_mask_pn,
+ needs_mesh_distance,
+ )
+
+ loader.declare_kernel(ray_kernel)
+
+ return ray_launcher
+
+ return functional, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(
+ self,
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ stream=0,
+ ):
+ import neon
+
+ # Prepare inputs
+ mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask)
+
+ grid = bc_mask.get_grid()
+ for level in range(grid.num_levels):
+ # Launch the neon container
+ c = self.neon_container(mesh_id, bc_id, distances, bc_mask, missing_mask, wp.static(bc.needs_mesh_distance), level)
+ c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon)
+ return distances, bc_mask, missing_mask
diff --git a/xlb/operator/boundary_masker/ray.py b/xlb/operator/boundary_masker/ray.py
new file mode 100644
index 00000000..f5a2d96b
--- /dev/null
+++ b/xlb/operator/boundary_masker/ray.py
@@ -0,0 +1,175 @@
+"""
+Ray-cast mesh-based boundary masker.
+
+Voxelizes a mesh file by casting rays along each lattice direction using
+``warp.mesh_query_ray`` to detect surface crossings.
+"""
+
+import warp as wp
+from typing import Any
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker
+from xlb.operator.operator import Operator
+
+
+class MeshMaskerRay(MeshBoundaryMasker):
+ """
+ Operator for creating a boundary missing_mask from a mesh file
+ """
+
+ def __init__(
+ self,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ ):
+ # Call super
+ super().__init__(velocity_set, precision_policy, compute_backend)
+
+ def _construct_warp(self):
+ # Make constants for warp
+ _c = self.velocity_set.c
+ _q = self.velocity_set.q
+ _opp_indices = self.velocity_set.opp_indices
+
+ # Set local constants
+ lattice_central_index = self.velocity_set.center_index
+
+ @wp.func
+ def functional(
+ index: Any,
+ mesh_id: Any,
+ id_number: Any,
+ distances: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ needs_mesh_distance: Any,
+ ):
+ # position of the point
+ cell_center_pos = self.helper_masker.index_to_position(bc_mask, index)
+
+ # Find the fractional distance to the mesh in each direction
+ for direction_idx in range(_q):
+ if direction_idx == lattice_central_index:
+ # Skip the central index as it is not relevant for boundary masking
+ continue
+
+ direction_vec = wp.vec3f(wp.float32(_c[0, direction_idx]), wp.float32(_c[1, direction_idx]), wp.float32(_c[2, direction_idx]))
+ # Max length depends on ray direction (diagonals are longer)
+ max_length = wp.length(direction_vec)
+ query = wp.mesh_query_ray(mesh_id, cell_center_pos, direction_vec / max_length, max_length)
+ if query.result:
+ # Set the boundary id and missing_mask
+ self.write_field(bc_mask, index, 0, wp.uint8(id_number))
+ self.write_field(missing_mask, index, _opp_indices[direction_idx], wp.uint8(True))
+
+ # If we don't need the mesh distance, we can return early
+ if not needs_mesh_distance:
+ continue
+
+ # get position of the mesh triangle that intersects with the ray
+ pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v)
+ dist = wp.length(pos_mesh - cell_center_pos)
+ weight = self.store_dtype(dist / max_length)
+ self.write_field(distances, index, direction_idx, self.store_dtype(weight))
+
+ @wp.kernel
+ def kernel(
+ mesh_id: wp.uint64,
+ id_number: wp.int32,
+ distances: wp.array4d(dtype=Any),
+ bc_mask: wp.array4d(dtype=wp.uint8),
+ missing_mask: wp.array4d(dtype=wp.uint8),
+ needs_mesh_distance: bool,
+ ):
+ # get index
+ i, j, k = wp.tid()
+
+ # Get local indices
+ index = wp.vec3i(i, j, k)
+
+ # apply the functional
+ functional(
+ index,
+ mesh_id,
+ id_number,
+ distances,
+ bc_mask,
+ missing_mask,
+ needs_mesh_distance,
+ )
+
+ return functional, kernel
+
+ @Operator.register_backend(ComputeBackend.WARP)
+ def warp_implementation(
+ self,
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ ):
+ return self.warp_implementation_base(
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ )
+
+ def _construct_neon(self):
+ import neon
+
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+
+ @neon.Container.factory(name="MeshMaskerRay")
+ def container(
+ mesh_id: Any,
+ id_number: Any,
+ distances: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ needs_mesh_distance: Any,
+ ):
+ def ray_launcher(loader: neon.Loader):
+ loader.set_grid(bc_mask.get_grid())
+ bc_mask_pn = loader.get_write_handle(bc_mask)
+ missing_mask_pn = loader.get_write_handle(missing_mask)
+ distances_pn = loader.get_write_handle(distances)
+
+ @wp.func
+ def ray_kernel(index: Any):
+ # apply the functional
+ functional(
+ index,
+ mesh_id,
+ id_number,
+ distances_pn,
+ bc_mask_pn,
+ missing_mask_pn,
+ needs_mesh_distance,
+ )
+
+ loader.declare_kernel(ray_kernel)
+
+ return ray_launcher
+
+ return functional, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(
+ self,
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ ):
+ # Prepare inputs
+ mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask)
+
+ # Launch the appropriate neon container
+ c = self.neon_container(mesh_id, bc_id, distances, bc_mask, missing_mask, wp.static(bc.needs_mesh_distance))
+ c.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+ return distances, bc_mask, missing_mask
diff --git a/xlb/operator/boundary_masker/winding.py b/xlb/operator/boundary_masker/winding.py
new file mode 100644
index 00000000..1510f3a5
--- /dev/null
+++ b/xlb/operator/boundary_masker/winding.py
@@ -0,0 +1,115 @@
+"""
+Winding-number mesh-based boundary masker.
+
+Uses the generalized winding-number test (``warp.mesh_query_point``) to
+classify voxels as inside or outside the mesh, providing a
+solid-detection method even for non-watertight geometries.
+"""
+
+import warp as wp
+from typing import Any
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker
+from xlb.operator.operator import Operator
+from xlb.cell_type import BC_SOLID
+
+
+class MeshMaskerWinding(MeshBoundaryMasker):
+ """
+ Operator for creating a boundary missing_mask from a mesh file
+ """
+
+ def __init__(
+ self,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ ):
+ # Call super
+ super().__init__(velocity_set, precision_policy, compute_backend)
+ assert self.compute_backend != ComputeBackend.NEON, (
+ 'MeshVoxelizationMethod("WINDING") is not implemented in Neon yet! Please use a different method of mesh voxelization!'
+ )
+
+ def _construct_warp(self):
+ # Make constants for warp
+ _c = self.velocity_set.c
+ _q = self.velocity_set.q
+ _opp_indices = self.velocity_set.opp_indices
+
+ @wp.kernel
+ def kernel(
+ mesh_id: wp.uint64,
+ id_number: wp.int32,
+ distances: wp.array4d(dtype=Any),
+ bc_mask: wp.array4d(dtype=wp.uint8),
+ missing_mask: wp.array4d(dtype=wp.uint8),
+ needs_mesh_distance: bool,
+ ):
+ # get index
+ i, j, k = wp.tid()
+
+ # Get local indices
+ index = wp.vec3i(i, j, k)
+
+ # position of the point
+ pos_cell = self.helper_masker.index_to_position(bc_mask, index)
+
+ # Compute the maximum length
+ max_length = wp.sqrt(
+ (wp.float32(bc_mask.shape[1])) ** 2.0 + (wp.float32(bc_mask.shape[2])) ** 2.0 + (wp.float32(bc_mask.shape[3])) ** 2.0
+ )
+
+ # evaluate if point is inside mesh
+ query = wp.mesh_query_point_sign_winding_number(mesh_id, pos_cell, max_length)
+ if query.result:
+ # set point to be solid
+ if query.sign <= 0: # TODO: fix this
+ # Make solid voxel
+ bc_mask[0, index[0], index[1], index[2]] = wp.uint8(BC_SOLID)
+
+ # Find the fractional distance to the mesh in each direction
+ for direction_idx in range(1, _q):
+ direction_vec = wp.vec3f(wp.float32(_c[0, direction_idx]), wp.float32(_c[1, direction_idx]), wp.float32(_c[2, direction_idx]))
+ # Max length depends on ray direction (diagonals are longer)
+ max_length = wp.length(direction_vec)
+ query_dir = wp.mesh_query_ray(mesh_id, pos_cell, direction_vec / max_length, max_length)
+ if query_dir.result:
+ # Get the index of the streaming direction
+ push_index = wp.vec3i()
+ for d in range(self.velocity_set.d):
+ push_index[d] = index[d] + _c[d, direction_idx]
+
+ # Set the boundary id and missing_mask
+ bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number)
+ missing_mask[direction_idx, push_index[0], push_index[1], push_index[2]] = wp.uint8(True)
+
+ # If we don't need the mesh distance, we can return early
+ if not needs_mesh_distance:
+ continue
+
+ # get position of the mesh triangle that intersects with the ray
+ pos_mesh = wp.mesh_eval_position(mesh_id, query_dir.face, query_dir.u, query_dir.v)
+ cell_center_pos = self.helper_masker.index_to_position(bc_mask, push_index)
+ dist = wp.length(pos_mesh - cell_center_pos)
+ weight = self.store_dtype(dist / max_length)
+ distances[_opp_indices[direction_idx], push_index[0], push_index[1], push_index[2]] = weight
+
+ return None, kernel
+
+ @Operator.register_backend(ComputeBackend.WARP)
+ def warp_implementation(
+ self,
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ ):
+ return self.warp_implementation_base(
+ bc,
+ distances,
+ bc_mask,
+ missing_mask,
+ )
diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py
index ac2da2e0..29331dc7 100644
--- a/xlb/operator/collision/bgk.py
+++ b/xlb/operator/collision/bgk.py
@@ -1,3 +1,7 @@
+"""
+Bhatnagar-Gross-Krook (BGK) single-relaxation-time collision operator.
+"""
+
import jax.numpy as jnp
from jax import jit
import warp as wp
@@ -10,13 +14,19 @@
class BGK(Collision):
- """
- BGK collision operator for LBM.
+ """Single-relaxation-time BGK collision operator.
+
+ Relaxes the distribution function toward equilibrium at a rate
+ controlled by the relaxation parameter *omega*::
+
+ f_out = f - omega * (f - f_eq)
+
+ Supports JAX, Warp, and Neon backends.
"""
@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0,))
- def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega):
+ def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, omega):
fneq = f - feq
fout = f - self.compute_dtype(omega) * fneq
return fout
@@ -28,7 +38,7 @@ def _construct_warp(self):
# Construct the functional
@wp.func
- def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any):
+ def functional(f: Any, feq: Any, omega: Any):
fneq = f - feq
fout = f - self.compute_dtype(omega) * fneq
return fout
@@ -39,8 +49,6 @@ def kernel(
f: wp.array4d(dtype=Any),
feq: wp.array4d(dtype=Any),
fout: wp.array4d(dtype=Any),
- rho: wp.array4d(dtype=Any),
- u: wp.array4d(dtype=Any),
omega: Any,
):
# Get the global index
@@ -55,7 +63,7 @@ def kernel(
_feq[l] = feq[l, index[0], index[1], index[2]]
# Compute the collision
- _fout = functional(_f, _feq, rho, u, omega)
+ _fout = functional(_f, _feq, omega)
# Write the result
for l in range(self.velocity_set.q):
@@ -63,8 +71,12 @@ def kernel(
return functional, kernel
+ def _construct_neon(self):
+ functional, _ = self._construct_warp()
+ return functional, None
+
@Operator.register_backend(ComputeBackend.WARP)
- def warp_implementation(self, f, feq, fout, rho, u, omega):
+ def warp_implementation(self, f, feq, fout, omega):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
@@ -72,8 +84,6 @@ def warp_implementation(self, f, feq, fout, rho, u, omega):
f,
feq,
fout,
- rho,
- u,
omega,
],
dim=f.shape[1:],
diff --git a/xlb/operator/collision/forced_collision.py b/xlb/operator/collision/forced_collision.py
index 80c9b0b2..55fa631c 100644
--- a/xlb/operator/collision/forced_collision.py
+++ b/xlb/operator/collision/forced_collision.py
@@ -1,3 +1,7 @@
+"""
+Collision operator with external body-force correction.
+"""
+
import jax.numpy as jnp
from jax import jit
import warp as wp
@@ -11,8 +15,19 @@
class ForcedCollision(Collision):
- """
- A collision operator for LBM with external force.
+ """Collision operator that wraps another collision with a forcing term.
+
+ After the inner collision the forcing operator is applied to
+ incorporate the effect of an external body force.
+
+ Parameters
+ ----------
+ collision_operator : Operator
+ The base collision operator (e.g. :class:`BGK`).
+ forcing_scheme : str
+ Forcing scheme. Currently only ``"exact_difference"`` is supported.
+ force_vector : array-like
+ External force vector of length ``d`` (number of spatial dimensions).
"""
def __init__(
@@ -33,9 +48,9 @@ def __init__(
@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0,))
- def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega):
- fout = self.collision_operator(f, feq, rho, u, omega)
- fout = self.forcing_operator(fout, feq, rho, u)
+ def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, omega):
+ fout = self.collision_operator(f, feq, omega)
+ fout = self.forcing_operator(fout, feq)
return fout
def _construct_warp(self):
@@ -45,9 +60,9 @@ def _construct_warp(self):
# Construct the functional
@wp.func
- def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any):
- fout = self.collision_operator.warp_functional(f, feq, rho, u, omega)
- fout = self.forcing_operator.warp_functional(fout, feq, rho, u)
+ def functional(f: Any, feq: Any, omega: Any):
+ fout = self.collision_operator.warp_functional(f, feq, omega)
+ fout = self.forcing_operator.warp_functional(fout, feq)
return fout
# Construct the warp kernel
@@ -56,8 +71,6 @@ def kernel(
f: wp.array4d(dtype=Any),
feq: wp.array4d(dtype=Any),
fout: wp.array4d(dtype=Any),
- rho: wp.array4d(dtype=Any),
- u: wp.array4d(dtype=Any),
omega: Any,
):
# Get the global index
@@ -71,13 +84,9 @@ def kernel(
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
_feq[l] = feq[l, index[0], index[1], index[2]]
- _u = _u_vec()
- for l in range(_d):
- _u[l] = u[l, index[0], index[1], index[2]]
- _rho = rho[0, index[0], index[1], index[2]]
# Compute the collision
- _fout = functional(_f, _feq, _rho, _u, omega)
+ _fout = functional(_f, _feq, omega)
# Write the result
for l in range(self.velocity_set.q):
@@ -86,7 +95,7 @@ def kernel(
return functional, kernel
@Operator.register_backend(ComputeBackend.WARP)
- def warp_implementation(self, f, feq, fout, rho, u, omega):
+ def warp_implementation(self, f, feq, fout, omega):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
@@ -94,8 +103,6 @@ def warp_implementation(self, f, feq, fout, rho, u, omega):
f,
feq,
fout,
- rho,
- u,
omega,
],
dim=f.shape[1:],
diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py
index 7724e07e..d814a7cf 100644
--- a/xlb/operator/collision/kbc.py
+++ b/xlb/operator/collision/kbc.py
@@ -43,8 +43,6 @@ def jax_implementation(
self,
f: jnp.ndarray,
feq: jnp.ndarray,
- rho: jnp.ndarray,
- u: jnp.ndarray,
omega,
):
"""
@@ -56,18 +54,14 @@ def jax_implementation(
Distribution function.
feq : jax.numpy.array
Equilibrium distribution function.
- rho : jax.numpy.array
- Density.
- u : jax.numpy.array
- Velocity.
"""
fneq = f - feq
if isinstance(self.velocity_set, D2Q9):
shear = self.decompose_shear_d2q9_jax(fneq)
- delta_s = shear * rho / 4.0
+ delta_s = shear / 4.0
elif isinstance(self.velocity_set, D3Q27):
shear = self.decompose_shear_d3q27_jax(fneq)
- delta_s = shear * rho
+ delta_s = shear
else:
raise NotImplementedError("Velocity set not supported: {}".format(type(self.velocity_set)))
@@ -269,18 +263,16 @@ def compute_entropic_scalar_products(
def functional(
f: Any,
feq: Any,
- rho: Any,
- u: Any,
omega: Any,
):
# Compute shear and delta_s
fneq = f - feq
if wp.static(self.velocity_set.d == 3):
shear = decompose_shear_d3q27(fneq)
- delta_s = shear * rho
+ delta_s = shear
else:
shear = decompose_shear_d2q9(fneq)
- delta_s = shear * rho / self.compute_dtype(4.0)
+ delta_s = shear / self.compute_dtype(4.0)
# Compute required constants based on the input omega (omega is the inverse relaxation time)
_beta = self.compute_dtype(0.5) * self.compute_dtype(omega)
@@ -301,8 +293,6 @@ def kernel(
f: wp.array4d(dtype=Any),
feq: wp.array4d(dtype=Any),
fout: wp.array4d(dtype=Any),
- rho: wp.array4d(dtype=Any),
- u: wp.array4d(dtype=Any),
omega: Any,
):
# Get the global index
@@ -316,13 +306,9 @@ def kernel(
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
_feq[l] = feq[l, index[0], index[1], index[2]]
- _u = _u_vec()
- for l in range(_d):
- _u[l] = u[l, index[0], index[1], index[2]]
- _rho = rho[0, index[0], index[1], index[2]]
# Compute the collision
- _fout = functional(_f, _feq, _rho, _u, omega)
+ _fout = functional(_f, _feq, omega)
# Write the result
for l in range(self.velocity_set.q):
@@ -330,8 +316,15 @@ def kernel(
return functional, kernel
+ def _construct_neon(self):
+ # Redefine the momentum flux operator for the neon backend
+ # This is because the neon backend relies on the warp functionals for its operations.
+ self.momentum_flux = MomentumFlux(compute_backend=ComputeBackend.WARP)
+ functional, _ = self._construct_warp()
+ return functional, None
+
@Operator.register_backend(ComputeBackend.WARP)
- def warp_implementation(self, f, feq, fout, rho, u, omega):
+ def warp_implementation(self, f, feq, fout, omega):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
@@ -339,8 +332,6 @@ def warp_implementation(self, f, feq, fout, rho, u, omega):
f,
feq,
fout,
- rho,
- u,
omega,
],
dim=f.shape[1:],
diff --git a/xlb/operator/collision/smagorinsky_les_bgk.py b/xlb/operator/collision/smagorinsky_les_bgk.py
index 92ae23d2..aa552094 100644
--- a/xlb/operator/collision/smagorinsky_les_bgk.py
+++ b/xlb/operator/collision/smagorinsky_les_bgk.py
@@ -1,3 +1,7 @@
+"""
+BGK collision operator with Smagorinsky large-eddy-simulation sub-grid model.
+"""
+
import jax.numpy as jnp
from jax import jit
import warp as wp
@@ -12,8 +16,19 @@
class SmagorinskyLESBGK(Collision):
- """
- BGK collision operator for LBM with Smagorinsky LES model.
+ """BGK collision with Smagorinsky LES turbulence modelling.
+
+ Adjusts the effective relaxation time based on the local strain rate
+ estimated from the non-equilibrium stress tensor, using the
+ Smagorinsky model constant *C_s*.
+
+ Parameters
+ ----------
+ velocity_set : VelocitySet, optional
+ precision_policy : PrecisionPolicy, optional
+ compute_backend : ComputeBackend, optional
+ smagorinsky_coef : float
+ Smagorinsky model constant (default 0.17).
"""
def __init__(
@@ -28,7 +43,7 @@ def __init__(
@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0,))
- def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray, omega):
+ def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, omega):
fneq = f - feq
pi_neq = jnp.tensordot(self.velocity_set.cc, fneq, axes=(0, 0))
@@ -44,9 +59,7 @@ def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray,
tau0 = self.compute_dtype(1.0) / self.compute_dtype(omega)
cs = self.compute_dtype(self.smagorinsky_coef)
- tau = self.compute_dtype(0.5) * (
- tau0 + jnp.sqrt(tau0 * tau0 + self.compute_dtype(36.0) * (cs * cs) * jnp.sqrt(strain))
- )
+ tau = self.compute_dtype(0.5) * (tau0 + jnp.sqrt(tau0 * tau0 + self.compute_dtype(36.0) * (cs * cs) * jnp.sqrt(strain)))
omega_eff = self.compute_dtype(1.0) / tau
fout = f - omega_eff[None, ...] * fneq
@@ -67,38 +80,11 @@ def _construct_warp(self):
def functional(
f: Any,
feq: Any,
- rho: Any,
- u: Any,
omega: Any,
):
# Compute the non-equilibrium distribution
fneq = f - feq
- # Sailfish implementation
- # {
- # float tmp, strain;
-
- # strain = 0.0f;
-
- # // Off-diagonal components, count twice for symmetry reasons.
- # %for a in range(0, dim):
- # %for b in range(a + 1, dim):
- # tmp = ${cex(sym.ex_flux(grid, 'd0', a, b, config), pointers=True)} -
- # ${cex(sym.ex_eq_flux(grid, a, b))};
- # strain += 2.0f * tmp * tmp;
- # %endfor
- # %endfor
-
- # // Diagonal components.
- # %for a in range(0, dim):
- # tmp = ${cex(sym.ex_flux(grid, 'd0', a, a, config), pointers=True)} -
- # ${cex(sym.ex_eq_flux(grid, a, a))};
- # strain += tmp * tmp;
- # %endfor
-
- # tau0 += 0.5f * (sqrtf(tau0 * tau0 + 36.0f * ${cex(smagorinsky_const**2)} * sqrtf(strain)) - tau0);
- # }
-
# Compute strain
pi_neq = _pi_vec()
for a in range(_pi_dim):
@@ -117,8 +103,7 @@ def functional(
# Compute the Smagorinsky model
_tau = self.compute_dtype(1.0) / self.compute_dtype(omega)
tau = _tau + (
- self.compute_dtype(0.5)
- * (wp.sqrt(_tau * _tau + self.compute_dtype(36.0) * (_smagorinsky_coef**2.0) * wp.sqrt(strain)) - _tau)
+ self.compute_dtype(0.5) * (wp.sqrt(_tau * _tau + self.compute_dtype(36.0) * (_smagorinsky_coef**2.0) * wp.sqrt(strain)) - _tau)
)
# Compute the collision
@@ -130,8 +115,6 @@ def functional(
def kernel(
f: wp.array4d(dtype=Any),
feq: wp.array4d(dtype=Any),
- rho: wp.array4d(dtype=Any),
- u: wp.array4d(dtype=Any),
fout: wp.array4d(dtype=Any),
omega: wp.float32,
):
@@ -145,13 +128,9 @@ def kernel(
for l in range(self.velocity_set.q):
_f[l] = f[l, index[0], index[1], index[2]]
_feq[l] = feq[l, index[0], index[1], index[2]]
- _u = _u_vec()
- for l in range(_d):
- _u[l] = u[l, index[0], index[1], index[2]]
- _rho = rho[0, index[0], index[1], index[2]]
# Compute the collision
- _fout = functional(_f, _feq, _rho, _u, omega)
+ _fout = functional(_f, _feq, omega)
# Write the result
for l in range(self.velocity_set.q):
@@ -160,15 +139,13 @@ def kernel(
return functional, kernel
@Operator.register_backend(ComputeBackend.WARP)
- def warp_implementation(self, f, feq, rho, u, fout, omega):
+ def warp_implementation(self, f, feq, fout, omega):
# Launch the warp kernel
wp.launch(
self.warp_kernel,
inputs=[
f,
feq,
- rho,
- u,
fout,
omega,
],
diff --git a/xlb/operator/equilibrium/__init__.py b/xlb/operator/equilibrium/__init__.py
index 987aa74a..beb7bb5e 100644
--- a/xlb/operator/equilibrium/__init__.py
+++ b/xlb/operator/equilibrium/__init__.py
@@ -1 +1,3 @@
-from xlb.operator.equilibrium.quadratic_equilibrium import Equilibrium, QuadraticEquilibrium
+from xlb.operator.equilibrium.equilibrium import Equilibrium
+from xlb.operator.equilibrium.quadratic_equilibrium import QuadraticEquilibrium
+from xlb.operator.equilibrium.multires_quadratic_equilibrium import MultiresQuadraticEquilibrium
diff --git a/xlb/operator/equilibrium/multires_quadratic_equilibrium.py b/xlb/operator/equilibrium/multires_quadratic_equilibrium.py
new file mode 100644
index 00000000..0a3bb959
--- /dev/null
+++ b/xlb/operator/equilibrium/multires_quadratic_equilibrium.py
@@ -0,0 +1,75 @@
+"""
+Multi-resolution quadratic equilibrium operator for the Neon backend.
+"""
+
+import warp as wp
+from typing import Any
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.equilibrium import QuadraticEquilibrium
+from xlb.operator import Operator
+
+
+class MultiresQuadraticEquilibrium(QuadraticEquilibrium):
+ """Quadratic equilibrium operator for multi-resolution grids (Neon only).
+
+ Computes the second-order Hermite-polynomial equilibrium distribution
+ from density and velocity at every active cell on each grid level.
+ Cells that have child refinement (halo cells) are zeroed out.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]:
+ raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.")
+
+ def _construct_neon(self):
+ import neon
+
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+
+ # Set local constants TODO: This is a hack and should be fixed with warp update
+ _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
+
+ @neon.Container.factory(name="QuadraticEquilibrium")
+ def container(
+ rho: Any,
+ u: Any,
+ f: Any,
+ level,
+ ):
+ def quadratic_equilibrium_ll(loader: neon.Loader):
+ loader.set_mres_grid(rho.get_grid(), level)
+
+ rho_pn = loader.get_mres_read_handle(rho)
+ u_pn = loader.get_mres_read_handle(u)
+ f_pn = loader.get_mres_write_handle(f)
+
+ @wp.func
+ def quadratic_equilibrium_cl(index: Any):
+ _u = _u_vec()
+ for d in range(self.velocity_set.d):
+ _u[d] = self.compute_dtype(wp.neon_read(u_pn, index, d))
+ _rho = self.compute_dtype(wp.neon_read(rho_pn, index, 0))
+ feq = functional(_rho, _u)
+
+ if wp.neon_has_child(f_pn, index):
+ for l in range(self.velocity_set.q):
+ feq[l] = self.compute_dtype(0.0)
+ # Set the output
+ for l in range(self.velocity_set.q):
+ wp.neon_write(f_pn, index, l, self.store_dtype(feq[l]))
+
+ loader.declare_kernel(quadratic_equilibrium_cl)
+
+ return quadratic_equilibrium_ll
+
+ return functional, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, rho, u, f, stream=0):
+ grid = f.get_grid()
+ for level in range(grid.num_levels):
+ c = self.neon_container(rho, u, f, level)
+ c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon)
+ return f
diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py
index f3572d74..5ec14618 100644
--- a/xlb/operator/equilibrium/quadratic_equilibrium.py
+++ b/xlb/operator/equilibrium/quadratic_equilibrium.py
@@ -2,10 +2,12 @@
import jax.numpy as jnp
from jax import jit
import warp as wp
+import os
+
from typing import Any
from xlb.compute_backend import ComputeBackend
-from xlb.operator.equilibrium.equilibrium import Equilibrium
+from xlb.operator.equilibrium import Equilibrium
from xlb.operator import Operator
@@ -15,6 +17,9 @@ class QuadraticEquilibrium(Equilibrium):
Standard equilibrium model for LBM.
"""
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0))
def jax_implementation(self, rho, u):
@@ -96,3 +101,48 @@ def warp_implementation(self, rho, u, f):
dim=rho.shape[1:],
)
return f
+
+ def _construct_neon(self):
+ import neon
+
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+
+ # Set local constants TODO: This is a hack and should be fixed with warp update
+ _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
+
+ @neon.Container.factory(name="QuadraticEquilibrium")
+ def container(
+ rho: Any,
+ u: Any,
+ f: Any,
+ ):
+ def quadratic_equilibrium_ll(loader: neon.Loader):
+ loader.set_grid(rho.get_grid())
+ rho_pn = loader.get_read_handle(rho)
+ u_pn = loader.get_read_handle(u)
+ f_pn = loader.get_write_handle(f)
+
+ @wp.func
+ def quadratic_equilibrium_cl(index: typing.Any):
+ _u = _u_vec()
+ for d in range(self.velocity_set.d):
+ _u[d] = self.compute_dtype(wp.neon_read(u_pn, index, d))
+ _rho = self.compute_dtype(wp.neon_read(rho_pn, index, 0))
+ feq = functional(_rho, _u)
+
+ # Set the output
+ for l in range(self.velocity_set.q):
+ wp.neon_write(f_pn, index, l, self.store_dtype(feq[l]))
+
+ loader.declare_kernel(quadratic_equilibrium_cl)
+
+ return quadratic_equilibrium_ll
+
+ return functional, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, rho, u, f):
+ c = self.neon_container(rho, u, f)
+ c.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+ return f
diff --git a/xlb/operator/force/__init__.py b/xlb/operator/force/__init__.py
index ba8a13c3..f3ceec57 100644
--- a/xlb/operator/force/__init__.py
+++ b/xlb/operator/force/__init__.py
@@ -1,2 +1,3 @@
from xlb.operator.force.momentum_transfer import MomentumTransfer
from xlb.operator.force.exact_difference_force import ExactDifference
+from xlb.operator.force.multires_momentum_transfer import MultiresMomentumTransfer
diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py
index 1c6255d3..05952d7d 100644
--- a/xlb/operator/force/momentum_transfer.py
+++ b/xlb/operator/force/momentum_transfer.py
@@ -3,6 +3,7 @@
from jax import jit, lax
import warp as wp
from typing import Any
+from enum import Enum, auto
from xlb.velocity_set.velocity_set import VelocitySet
from xlb.precision_policy import PrecisionPolicy
@@ -11,6 +12,112 @@
from xlb.operator.stream import Stream
+# Enum used to keep track of LBM operations
+class LBMOperationSequence(Enum):
+ """
+ Note that for dense and single resolution simulations in XLB, the order of operations in the stepper is "stream-then-collide".
+ For MultiRes stepper however the order of operations is always "collide-then-stream" except at the finest level when the FUSION_AT_FINEST
+ optimization is used.
+ In that case the order of operations is "stream-then-collide" ONLY at the finest level.
+ """
+
+ STREAM_THEN_COLLIDE = auto()
+ COLLIDE_THEN_STREAM = auto()
+
+
+class FetchPopulations(Operator):
+ """
+ This operator is used to get the post-collision and post-streaming populations
+ Note that for dense and single resolution simulations in XLB, the order of operations in the stepper is "stream-then-collide".
+ Therefore, f_0 represents the post-collision values and post_streaming values of the current time step need to be reconstructed
+ by applying the streaming and boundary conditions. These populations are readily available in XLB when using multi-resolution
+ grids because the mres stepper relies on "collide-then-stream".
+ """
+
+ def __init__(
+ self,
+ no_slip_bc_instance,
+ operation_sequence: LBMOperationSequence = LBMOperationSequence.STREAM_THEN_COLLIDE,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ ):
+ self.no_slip_bc_instance = no_slip_bc_instance
+ self.stream = Stream(velocity_set, precision_policy, compute_backend)
+ self.operation_sequence = operation_sequence
+
+ if compute_backend == ComputeBackend.WARP:
+ self.stream_functional = self.stream.warp_functional
+ self.bc_functional = self.no_slip_bc_instance.warp_functional
+ elif compute_backend == ComputeBackend.NEON:
+ self.stream_functional = self.stream.neon_functional
+ self.bc_functional = self.no_slip_bc_instance.neon_functional
+
+ # Call the parent constructor
+ super().__init__(
+ velocity_set,
+ precision_policy,
+ compute_backend,
+ )
+
+ @Operator.register_backend(ComputeBackend.JAX)
+ @partial(jit, static_argnums=(0))
+ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask):
+ # Give the input post-collision populations, streaming once and apply the BC the find post-stream values.
+ f_post_collision = f_0
+ f_post_stream = self.stream(f_post_collision)
+ f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, bc_mask, missing_mask)
+ return f_post_collision, f_post_stream
+
+ def _construct_warp(self):
+ _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
+
+ @wp.func
+ def functional_stream_then_collide(
+ index: Any,
+ f_0: Any,
+ f_1: Any,
+ _missing_mask: Any,
+ ):
+ # Get the distribution function
+ f_post_collision = _f_vec()
+ for l in range(self.velocity_set.q):
+ f_post_collision[l] = self.compute_dtype(self.read_field(f_0, index, l))
+
+ # Apply streaming (pull method)
+ timestep = 0
+ f_post_stream = self.stream_functional(f_0, index)
+ f_post_stream = self.bc_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream)
+ return f_post_collision, f_post_stream
+
+ @wp.func
+ def functional_collide_then_stream(
+ index: Any,
+ f_0: Any,
+ f_1: Any,
+ _missing_mask: Any,
+ ):
+ # Get the distribution function
+ f_post_collision = _f_vec()
+ f_post_stream = _f_vec()
+ for l in range(self.velocity_set.q):
+ f_post_stream[l] = self.compute_dtype(self.read_field(f_0, index, l))
+ f_post_collision[l] = self.compute_dtype(self.read_field(f_1, index, l))
+ return f_post_collision, f_post_stream
+
+ if self.operation_sequence == LBMOperationSequence.STREAM_THEN_COLLIDE:
+ return functional_stream_then_collide, None
+ elif self.operation_sequence == LBMOperationSequence.COLLIDE_THEN_STREAM:
+ return functional_collide_then_stream, None
+ else:
+ raise ValueError(f"Unknown operation sequence: {self.operation_sequence}")
+
+ def _construct_neon(self):
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+ return functional, None
+
+
class MomentumTransfer(Operator):
"""
An opertor for the momentum exchange method to compute the boundary force vector exerted on the solid geometry
@@ -34,12 +141,23 @@ class MomentumTransfer(Operator):
def __init__(
self,
no_slip_bc_instance,
+ operation_sequence: LBMOperationSequence = LBMOperationSequence.STREAM_THEN_COLLIDE,
velocity_set: VelocitySet = None,
precision_policy: PrecisionPolicy = None,
compute_backend: ComputeBackend = None,
):
+ # Assign the no-slip boundary condition instance
self.no_slip_bc_instance = no_slip_bc_instance
- self.stream = Stream(velocity_set, precision_policy, compute_backend)
+ self.operation_sequence = operation_sequence
+
+ # Define the needed for the momentum transfer
+ self.fetcher = FetchPopulations(
+ no_slip_bc_instance=self.no_slip_bc_instance,
+ operation_sequence=self.operation_sequence,
+ velocity_set=velocity_set,
+ precision_policy=precision_policy,
+ compute_backend=compute_backend,
+ )
# Call the parent constructor
super().__init__(
@@ -48,6 +166,11 @@ def __init__(
compute_backend,
)
+ if self.compute_backend != ComputeBackend.JAX:
+ # Allocate the force vector (the total integral value will be computed)
+ _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
+ self.force = wp.zeros((1), dtype=_u_vec)
+
@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0))
def jax_implementation(self, f_0, f_1, bc_mask, missing_mask):
@@ -71,9 +194,7 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask):
The force exerted on the solid geometry at each boundary node.
"""
# Give the input post-collision populations, streaming once and apply the BC the find post-stream values.
- f_post_collision = f_0
- f_post_stream = self.stream(f_post_collision)
- f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, bc_mask, missing_mask)
+ f_post_collision, f_post_stream = self.fetcher(f_0, f_1, bc_mask, missing_mask)
# Compute momentum transfer
boundary = bc_mask == self.no_slip_bc_instance.id
@@ -90,61 +211,42 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask):
return force_net
def _construct_warp(self):
- # Set local constants TODO: This is a hack and should be fixed with warp update
+ # Set local constants
_c = self.velocity_set.c
_opp_indices = self.velocity_set.opp_indices
- _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
- _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool
+ _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8)
_no_slip_id = self.no_slip_bc_instance.id
- # Find velocity index for 0, 0, 0
- for l in range(self.velocity_set.q):
- if _c[0, l] == 0 and _c[1, l] == 0 and _c[2, l] == 0:
- zero_index = l
- _zero_index = wp.int32(zero_index)
+ # Find velocity index for (0, 0, 0)
+ lattice_central_index = self.velocity_set.center_index
- # Construct the warp kernel
- @wp.kernel
- def kernel(
- f_0: wp.array4d(dtype=Any),
- f_1: wp.array4d(dtype=Any),
- bc_mask: wp.array4d(dtype=wp.uint8),
- missing_mask: wp.array4d(dtype=wp.bool),
- force: wp.array(dtype=Any),
+ @wp.func
+ def functional(
+ index: Any,
+ f_0: Any,
+ f_1: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ force: Any,
):
- # Get the global index
- i, j, k = wp.tid()
- index = wp.vec3i(i, j, k)
-
# Get the boundary id
- _boundary_id = bc_mask[0, index[0], index[1], index[2]]
+ _boundary_id = self.read_field(bc_mask, index, 0)
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
- # TODO fix vec bool
- if missing_mask[l, index[0], index[1], index[2]]:
- _missing_mask[l] = wp.uint8(1)
- else:
- _missing_mask[l] = wp.uint8(0)
+ _missing_mask[l] = self.read_field(missing_mask, index, l)
# Determin if boundary is an edge by checking if center is missing
is_edge = wp.bool(False)
if _boundary_id == wp.uint8(_no_slip_id):
- if _missing_mask[_zero_index] == wp.uint8(0):
+ if _missing_mask[lattice_central_index] == wp.uint8(0):
is_edge = wp.bool(True)
# If the boundary is an edge then add the momentum transfer
m = _u_vec()
if is_edge:
- # Get the distribution function
- f_post_collision = _f_vec()
- for l in range(self.velocity_set.q):
- f_post_collision[l] = f_0[l, index[0], index[1], index[2]]
-
- # Apply streaming (pull method)
- timestep = 0
- f_post_stream = self.stream.warp_functional(f_0, index)
- f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream)
+ # fetch the post-collision and post-streaming populations
+ f_post_collision, f_post_stream = self.fetcher_functional(index, f_0, f_1, _missing_mask)
# Compute the momentum transfer
for d in range(self.velocity_set.d):
@@ -156,21 +258,105 @@ def kernel(
m[d] += phi
elif _c[d, _opp_indices[l]] == -1:
m[d] -= phi
-
+ # Atomic sum to get the total force vector
wp.atomic_add(force, 0, m)
- return None, kernel
+ # Construct the warp kernel
+ @wp.kernel
+ def kernel(
+ f_0: wp.array4d(dtype=Any),
+ f_1: wp.array4d(dtype=Any),
+ bc_mask: wp.array4d(dtype=wp.uint8),
+ missing_mask: wp.array4d(dtype=wp.uint8),
+ force: wp.array(dtype=Any),
+ ):
+ # Get the global index
+ i, j, k = wp.tid()
+ index = wp.vec3i(i, j, k)
+
+ # Call the functional to compute the force
+ functional(
+ index,
+ f_0,
+ f_1,
+ bc_mask,
+ missing_mask,
+ force,
+ )
+
+ return functional, kernel
@Operator.register_backend(ComputeBackend.WARP)
def warp_implementation(self, f_0, f_1, bc_mask, missing_mask):
- # Allocate the force vector (the total integral value will be computed)
- _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
- force = wp.zeros((1), dtype=_u_vec)
+ # Ensure the force is initialized to zero
+ self.force *= self.compute_dtype(0.0)
+
+ # Define the warp functionals needed for this operation
+ self.fetcher_functional = self.fetcher.warp_functional
# Launch the warp kernel
wp.launch(
self.warp_kernel,
- inputs=[f_0, f_1, bc_mask, missing_mask, force],
+ inputs=[f_0, f_1, bc_mask, missing_mask, self.force],
dim=f_0.shape[1:],
)
- return force.numpy()[0]
+ return self.force.numpy()[0]
+
+ def _construct_neon(self):
+ import neon
+
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+
+ @neon.Container.factory(name="MomentumTransfer")
+ def container(
+ f_0: Any,
+ f_1: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ force: Any,
+ ):
+ def container_launcher(loader: neon.Loader):
+ loader.set_grid(bc_mask.get_grid())
+ bc_mask_pn = loader.get_write_handle(bc_mask)
+ missing_mask_pn = loader.get_write_handle(missing_mask)
+ f_0_pn = loader.get_write_handle(f_0)
+ f_1_pn = loader.get_write_handle(f_1)
+
+ @wp.func
+ def container_kernel(index: Any):
+ # apply the functional
+ functional(
+ index,
+ f_0_pn,
+ f_1_pn,
+ bc_mask_pn,
+ missing_mask_pn,
+ force,
+ )
+
+ loader.declare_kernel(container_kernel)
+
+ return container_launcher
+
+ return functional, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(
+ self,
+ f_0,
+ f_1,
+ bc_mask,
+ missing_mask,
+ stream=0,
+ ):
+ # Ensure the force is initialized to zero
+ self.force *= self.compute_dtype(0.0)
+
+ # Define the neon functionals needed for this operation
+ self.fetcher_functional = self.fetcher.neon_functional
+
+ # Launch the neon container
+ c = self.neon_container(f_0, f_1, bc_mask, missing_mask, self.force)
+ c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon)
+ return self.force.numpy()[0]
diff --git a/xlb/operator/force/multires_momentum_transfer.py b/xlb/operator/force/multires_momentum_transfer.py
new file mode 100644
index 00000000..f7e04a43
--- /dev/null
+++ b/xlb/operator/force/multires_momentum_transfer.py
@@ -0,0 +1,137 @@
+"""
+Multi-resolution momentum-transfer force operator for the Neon backend.
+"""
+
+from typing import Any
+
+import warp as wp
+
+from xlb.velocity_set.velocity_set import VelocitySet
+from xlb.precision_policy import PrecisionPolicy
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.operator import Operator
+from xlb.operator.force import MomentumTransfer
+from xlb.mres_perf_optimization_type import MresPerfOptimizationType
+
+
+class MultiresMomentumTransfer(MomentumTransfer):
+ """Momentum-transfer force computation on a multi-resolution grid.
+
+ Extends :class:`MomentumTransfer` with Neon-specific container code that
+ iterates over all grid levels. The LBM operation sequence (collide-then-
+ stream vs. stream-then-collide) is inferred from the performance
+ optimization type.
+
+ Parameters
+ ----------
+ no_slip_bc_instance : BoundaryCondition
+ The no-slip BC whose tagged voxels define the force integration
+ surface.
+ mres_perf_opt : MresPerfOptimizationType
+ Multi-resolution performance strategy.
+ velocity_set : VelocitySet, optional
+ precision_policy : PrecisionPolicy, optional
+ compute_backend : ComputeBackend, optional
+ """
+
+ def __init__(
+ self,
+ no_slip_bc_instance,
+ mres_perf_opt=MresPerfOptimizationType.NAIVE_COLLIDE_STREAM,
+ velocity_set: VelocitySet = None,
+ precision_policy: PrecisionPolicy = None,
+ compute_backend: ComputeBackend = None,
+ ):
+ from xlb.operator.force.momentum_transfer import LBMOperationSequence
+
+ if compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]:
+ raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {compute_backend} backend.")
+
+ # Set the sequence of operations based on the performance optimization type
+ if mres_perf_opt == MresPerfOptimizationType.NAIVE_COLLIDE_STREAM:
+ operation_sequence = LBMOperationSequence.COLLIDE_THEN_STREAM
+ elif mres_perf_opt in (
+ MresPerfOptimizationType.FUSION_AT_FINEST,
+ MresPerfOptimizationType.FUSION_AT_FINEST_SFV,
+ MresPerfOptimizationType.FUSION_AT_FINEST_SFV_ALL,
+ ):
+ operation_sequence = LBMOperationSequence.STREAM_THEN_COLLIDE
+ else:
+ raise ValueError(f"Unknown performance optimization type: {mres_perf_opt}")
+
+ # Check if the performance optimization type is compatible with the use of mesh distance
+ if operation_sequence != LBMOperationSequence.STREAM_THEN_COLLIDE:
+ assert not no_slip_bc_instance.needs_mesh_distance, (
+ "Mesh distance is only supported in the MultiresMomentumTransfer operator when the LBM operation sequence is STREAM_THEN_COLLIDE."
+ )
+
+ # Print a warning to the user about the boundary voxels
+ print(
+ "WARNING! make sure boundary voxels are all at the same level and not among the transition regions from one level to another. "
+ "Otherwise, the results of force calculation are not correct!\n"
+ )
+
+ # Call super
+ super().__init__(no_slip_bc_instance, operation_sequence, velocity_set, precision_policy, compute_backend)
+
+ def _construct_neon(self):
+ import neon
+
+ # Use the warp functional for the NEON backend
+ functional, _ = self._construct_warp()
+
+ @neon.Container.factory(name="MomentumTransfer")
+ def container(
+ f_0: Any,
+ f_1: Any,
+ bc_mask: Any,
+ missing_mask: Any,
+ force: Any,
+ level: Any,
+ ):
+ def container_launcher(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask.get_grid(), level)
+ bc_mask_pn = loader.get_mres_write_handle(bc_mask)
+ missing_mask_pn = loader.get_mres_write_handle(missing_mask)
+ f_0_pn = loader.get_mres_write_handle(f_0)
+ f_1_pn = loader.get_mres_write_handle(f_1)
+
+ @wp.func
+ def container_kernel(index: Any):
+ # apply the functional
+ functional(
+ index,
+ f_0_pn,
+ f_1_pn,
+ bc_mask_pn,
+ missing_mask_pn,
+ force,
+ )
+
+ loader.declare_kernel(container_kernel)
+
+ return container_launcher
+
+ return functional, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(
+ self,
+ f_0,
+ f_1,
+ bc_mask,
+ missing_mask,
+ stream=0,
+ ):
+ # Ensure the force is initialized to zero
+ self.force *= self.compute_dtype(0.0)
+
+ # Define the neon functionals needed for this operation
+ self.fetcher_functional = self.fetcher.neon_functional
+
+ grid = bc_mask.get_grid()
+ for level in range(grid.num_levels):
+ # Launch the neon container
+ c = self.neon_container(f_0, f_1, bc_mask, missing_mask, self.force, level)
+ c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon)
+ return self.force.numpy()[0]
diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py
index 75dec9ea..75eacee6 100644
--- a/xlb/operator/macroscopic/__init__.py
+++ b/xlb/operator/macroscopic/__init__.py
@@ -2,3 +2,4 @@
from xlb.operator.macroscopic.second_moment import SecondMoment
from xlb.operator.macroscopic.zero_moment import ZeroMoment
from xlb.operator.macroscopic.first_moment import FirstMoment
+from xlb.operator.macroscopic.multires_macroscopic import MultiresMacroscopic
diff --git a/xlb/operator/macroscopic/first_moment.py b/xlb/operator/macroscopic/first_moment.py
index cb99a9ff..626767fd 100644
--- a/xlb/operator/macroscopic/first_moment.py
+++ b/xlb/operator/macroscopic/first_moment.py
@@ -23,17 +23,31 @@ def _construct_warp(self):
_u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype)
@wp.func
- def functional(
- f: _f_vec,
- rho: Any,
- ):
- u = _u_vec()
+ def neumaier_sum_component(d: int, f: _f_vec):
+ total = self.compute_dtype(0.0)
+ compensation = self.compute_dtype(0.0)
for l in range(self.velocity_set.q):
- for d in range(self.velocity_set.d):
- if _c[d, l] == 1:
- u[d] += f[l]
- elif _c[d, l] == -1:
- u[d] -= f[l]
+ # Get contribution based on the sign of _c[d, l]
+ if _c[d, l] == 1:
+ val = f[l]
+ elif _c[d, l] == -1:
+ val = -f[l]
+ else:
+ val = self.compute_dtype(0.0)
+ t = total + val
+ if wp.abs(total) >= wp.abs(val):
+ compensation = compensation + ((total - t) + val)
+ else:
+ compensation = compensation + ((val - t) + total)
+ total = t
+ return total + compensation
+
+ @wp.func
+ def functional(f: _f_vec, rho: Any):
+ u = _u_vec()
+ # Use Neumaier summation for each spatial component
+ for d in range(self.velocity_set.d):
+ u[d] = neumaier_sum_component(d, f)
u /= rho
return u
@@ -65,3 +79,12 @@ def warp_implementation(self, f, rho, u):
dim=u.shape[1:],
)
return u
+
+ def _construct_neon(self):
+ functional, _ = self._construct_warp()
+ return functional, None
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f, rho):
+ # raise exception as this feature is not implemented yet
+ raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.")
diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py
index ab1193b0..971081c9 100644
--- a/xlb/operator/macroscopic/macroscopic.py
+++ b/xlb/operator/macroscopic/macroscopic.py
@@ -20,20 +20,18 @@ def __init__(self, *args, **kwargs):
@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0), inline=True)
- def jax_implementation(self, f):
+ def jax_implementation(self, f, rho=None, u=None):
rho = self.zero_moment(f)
u = self.first_moment(f, rho)
return rho, u
def _construct_warp(self):
- zero_moment_func = self.zero_moment.warp_functional
- first_moment_func = self.first_moment.warp_functional
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
@wp.func
def functional(f: _f_vec):
- rho = zero_moment_func(f)
- u = first_moment_func(f, rho)
+ rho = self.zero_moment.warp_functional(f)
+ u = self.first_moment.warp_functional(f, rho)
return rho, u
@wp.kernel
@@ -64,3 +62,52 @@ def warp_implementation(self, f, rho, u):
dim=rho.shape[1:],
)
return rho, u
+
+ def _construct_neon(self):
+ import neon
+
+ # Redefine the zero and first moment operators for the neon backend
+ # This is because the neon backend relies on the warp functionals for its operations.
+ self.zero_moment = ZeroMoment(compute_backend=ComputeBackend.WARP)
+ self.first_moment = FirstMoment(compute_backend=ComputeBackend.WARP)
+ functional, _ = self._construct_warp()
+
+ # Set local vectors
+ _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
+
+ @neon.Container.factory("macroscopic")
+ def container(
+ f_field: Any,
+ rho_field: Any,
+ u_fild: Any,
+ ):
+ _d = self.velocity_set.d
+
+ def macroscopic_ll(loader: neon.Loader):
+ loader.set_grid(f_field.get_grid())
+
+ rho = loader.get_read_handle(rho_field)
+ u = loader.get_read_handle(u_fild)
+ f = loader.get_read_handle(f_field)
+
+ @wp.func
+ def macroscopic_cl(gIdx: typing.Any):
+ _f = _f_vec()
+ for l in range(self.velocity_set.q):
+ _f[l] = self.compute_dtype(wp.neon_read(f, gIdx, l))
+ _rho, _u = functional(_f)
+ wp.neon_write(rho, gIdx, 0, self.store_dtype(_rho))
+ for d in range(_d):
+ wp.neon_write(u, gIdx, d, self.store_dtype(_u[d]))
+
+ loader.declare_kernel(macroscopic_cl)
+
+ return macroscopic_ll
+
+ return functional, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f, rho, u):
+ c = self.neon_container(f, rho, u)
+ c.run(0)
+ return rho, u
diff --git a/xlb/operator/macroscopic/multires_macroscopic.py b/xlb/operator/macroscopic/multires_macroscopic.py
new file mode 100644
index 00000000..b7df10c7
--- /dev/null
+++ b/xlb/operator/macroscopic/multires_macroscopic.py
@@ -0,0 +1,89 @@
+"""
+Multi-resolution macroscopic moment computation for the Neon backend.
+"""
+
+from functools import partial
+import jax.numpy as jnp
+from jax import jit
+import warp as wp
+from typing import Any
+
+from xlb.compute_backend import ComputeBackend
+from xlb.operator.operator import Operator
+from xlb.operator.macroscopic import Macroscopic, ZeroMoment, FirstMoment
+from xlb.cell_type import BC_SOLID
+
+
+class MultiresMacroscopic(Macroscopic):
+ """Compute density and velocity on a multi-resolution grid (Neon only).
+
+ Iterates over all grid levels, computing zero-th and first moments of
+ the distribution function. Solid voxels and voxels that have child
+ refinement (halo cells) are set to zero.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]:
+ raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.")
+
+ def _construct_neon(self):
+ import neon
+
+ # Redefine the zero and first moment operators for the neon backend
+ # This is because the neon backend relies on the warp functionals for its operations.
+ self.zero_moment = ZeroMoment(compute_backend=ComputeBackend.WARP)
+ self.first_moment = FirstMoment(compute_backend=ComputeBackend.WARP)
+ _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
+ functional, _ = self._construct_warp()
+
+ @neon.Container.factory("macroscopic")
+ def container(
+ level: int,
+ f_field: Any,
+ bc_mask: Any,
+ rho_field: Any,
+ u_fild: Any,
+ ):
+ _d = self.velocity_set.d
+
+ def macroscopic_ll(loader: neon.Loader):
+ loader.set_mres_grid(f_field.get_grid(), level)
+
+ rho = loader.get_mres_write_handle(rho_field)
+ u = loader.get_mres_write_handle(u_fild)
+ f = loader.get_mres_read_handle(f_field)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask)
+
+ @wp.func
+ def macroscopic_cl(gIdx: typing.Any):
+ _f = _f_vec()
+ _boundary_id = wp.neon_read(bc_mask_pn, gIdx, 0)
+
+ for l in range(self.velocity_set.q):
+ _f[l] = self.compute_dtype(wp.neon_read(f, gIdx, l))
+
+ _rho, _u = functional(_f)
+
+ if _boundary_id == wp.uint8(BC_SOLID) or wp.neon_has_child(f, gIdx):
+ _rho = self.compute_dtype(0.0)
+ for d in range(_d):
+ _u[d] = self.compute_dtype(0.0)
+
+ wp.neon_write(rho, gIdx, 0, self.store_dtype(_rho))
+ for d in range(_d):
+ wp.neon_write(u, gIdx, d, self.store_dtype(_u[d]))
+
+ loader.declare_kernel(macroscopic_cl)
+
+ return macroscopic_ll
+
+ return functional, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f, bc_mask, rho, u, streamId=0):
+ grid = f.get_grid()
+ for level in range(grid.num_levels):
+ c = self.neon_container(level, f, bc_mask, rho, u)
+ c.run(streamId)
+ return rho, u
diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py
index 6c7e70ea..1a0a0f07 100644
--- a/xlb/operator/macroscopic/second_moment.py
+++ b/xlb/operator/macroscopic/second_moment.py
@@ -104,3 +104,12 @@ def warp_implementation(self, f, pi):
# Launch the warp kernel
wp.launch(self.warp_kernel, inputs=[f, pi], dim=pi.shape[1:])
return pi
+
+ def _construct_neon(self):
+ functional, _ = self._construct_warp()
+ return functional, None
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f, rho):
+ # raise exception as this feature is not implemented yet
+ raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.")
diff --git a/xlb/operator/macroscopic/zero_moment.py b/xlb/operator/macroscopic/zero_moment.py
index 8abb4de7..f536f8d7 100644
--- a/xlb/operator/macroscopic/zero_moment.py
+++ b/xlb/operator/macroscopic/zero_moment.py
@@ -20,11 +20,23 @@ def _construct_warp(self):
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
@wp.func
- def functional(f: _f_vec):
- rho = self.compute_dtype(0.0)
+ def neumaier_sum(f: _f_vec):
+ total = self.compute_dtype(0.0)
+ compensation = self.compute_dtype(0.0)
for l in range(self.velocity_set.q):
- rho += f[l]
- return rho
+ x = f[l]
+ t = total + x
+ # Using wp.abs to compute absolute value
+ if wp.abs(total) >= wp.abs(x):
+ compensation = compensation + ((total - t) + x)
+ else:
+ compensation = compensation + ((x - t) + total)
+ total = t
+ return total + compensation
+
+ @wp.func
+ def functional(f: _f_vec):
+ return neumaier_sum(f)
@wp.kernel
def kernel(
@@ -47,3 +59,12 @@ def kernel(
def warp_implementation(self, f, rho):
wp.launch(self.warp_kernel, inputs=[f, rho], dim=rho.shape[1:])
return rho
+
+ def _construct_neon(self):
+ functional, _ = self._construct_warp()
+ return functional, None
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f, rho):
+ # raise exception as this feature is not implemented yet
+ raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.")
diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py
index fcbd07b0..4405708c 100644
--- a/xlb/operator/operator.py
+++ b/xlb/operator/operator.py
@@ -1,6 +1,22 @@
+"""
+Base operator module for XLB.
+
+Every LBM operator (collision, streaming, equilibrium, boundary condition,
+masker, stepper, etc.) inherits from :class:`Operator`. The class provides:
+
+* **Backend dispatch** — ``__call__`` automatically selects the registered
+ implementation for the active compute backend.
+* **Precision management** — ``compute_dtype`` and ``store_dtype`` properties
+ return the correct type for the active backend and precision policy.
+* **Kernel construction hooks** — ``_construct_warp()`` / ``_construct_neon()``
+ are called at init time to compile backend-specific kernels and functionals.
+"""
+
import inspect
import traceback
import jax
+import warp as wp
+from typing import Any
from xlb.compute_backend import ComputeBackend
from xlb import DefaultConfig
@@ -17,6 +33,17 @@ class Operator:
_backends = {}
def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None):
+ """Initialize the operator.
+
+ Parameters
+ ----------
+ velocity_set : VelocitySet, optional
+ Lattice velocity set. Defaults to ``DefaultConfig.velocity_set``.
+ precision_policy : PrecisionPolicy, optional
+ Precision policy. Defaults to ``DefaultConfig.default_precision_policy``.
+ compute_backend : ComputeBackend, optional
+ Compute backend. Defaults to ``DefaultConfig.default_backend``.
+ """
# Set the default values from the global config
self.velocity_set = velocity_set or DefaultConfig.velocity_set
self.precision_policy = precision_policy or DefaultConfig.default_precision_policy
@@ -26,10 +53,18 @@ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=Non
if self.compute_backend not in ComputeBackend:
raise ValueError(f"Compute_backend {compute_backend} is not supported")
+ # Construct read/write functions for the compute backend
+ if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]:
+ self.read_field, self.write_field = self._construct_read_write_functions()
+ self.read_field_neighbor = self._construct_read_field_neighbor()
+
# Construct the kernel based compute_backend functions TODO: Maybe move this to the register or something
if self.compute_backend == ComputeBackend.WARP:
self.warp_functional, self.warp_kernel = self._construct_warp()
+ if self.compute_backend == ComputeBackend.NEON:
+ self.neon_functional, self.neon_container = self._construct_neon()
+
# Updating JAX config in case fp64 is requested
if self.compute_backend == ComputeBackend.JAX and (
precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32
@@ -52,6 +87,20 @@ def decorator(func):
return decorator
def __call__(self, *args, callback=None, **kwargs):
+ """Dispatch to the registered backend implementation.
+
+ Iterates over all registered implementations for this operator class
+ and the active backend, attempts to bind the provided arguments, and
+ executes the first matching signature. An optional *callback* is
+ invoked with the result after successful execution.
+
+ Raises
+ ------
+ NotImplementedError
+ If no implementation is registered for the active backend.
+ Exception
+ If all candidate implementations raise errors.
+ """
method_candidates = [
(key, method) for key, method in self._backends.items() if key[0] == self.__class__.__name__ and key[1] == self.compute_backend
]
@@ -80,7 +129,7 @@ def __call__(self, *args, callback=None, **kwargs):
error = e
traceback_str = traceback.format_exc()
continue # This skips to the next candidate if binding fails
-
+ method_candidates = [(key, method) for key, method in self._backends.items() if key[1] == self.compute_backend]
raise Exception(f"Error captured for backend with key {key} for operator {self.__class__.__name__}: {error}\n {traceback_str}")
@property
@@ -123,6 +172,8 @@ def compute_dtype(self):
return self.precision_policy.compute_precision.jax_dtype
elif self.compute_backend == ComputeBackend.WARP:
return self.precision_policy.compute_precision.wp_dtype
+ elif self.compute_backend == ComputeBackend.NEON:
+ return self.precision_policy.compute_precision.wp_dtype
@property
def store_dtype(self):
@@ -133,6 +184,20 @@ def store_dtype(self):
return self.precision_policy.store_precision.jax_dtype
elif self.compute_backend == ComputeBackend.WARP:
return self.precision_policy.store_precision.wp_dtype
+ elif self.compute_backend == ComputeBackend.NEON:
+ return self.precision_policy.store_precision.wp_dtype
+
+ def get_precision_policy(self):
+ """
+ Returns the precision policy
+ """
+ return self.precision_policy
+
+ def get_grid(self):
+ """
+ Returns the grid object
+ """
+ return self.grid
def _construct_warp(self):
"""
@@ -142,3 +207,110 @@ def _construct_warp(self):
Leave it for now, as it is not clear how the warp compute backend will evolve
"""
return None, None
+
+ def _construct_neon(self):
+ """
+ Construct the Neon functional and Neon container of the operator
+ TODO: Maybe a better way to do this?
+ Maybe add this to the backend decorator?
+ Leave it for now, as it is not clear how the neon backend will evolve
+ """
+ return None, None
+
+ def _construct_read_write_functions(self):
+ """Build backend-specific ``read_field`` / ``write_field`` helpers.
+
+ For the Warp backend these are direct 4-D array accesses. For the
+ Neon backend they wrap ``wp.neon_read`` / ``wp.neon_write``.
+
+ Returns
+ -------
+ tuple of wp.func
+ ``(read_field, write_field)``
+ """
+ if self.compute_backend == ComputeBackend.WARP:
+
+ @wp.func
+ def read_field(
+ field: Any,
+ index: Any,
+ direction: Any,
+ ):
+ # This function reads a field value at a given index and direction.
+ return field[direction, index[0], index[1], index[2]]
+
+ @wp.func
+ def write_field(
+ field: Any,
+ index: Any,
+ direction: Any,
+ value: Any,
+ ):
+ # This function writes a value to a field at a given index and direction.
+ field[direction, index[0], index[1], index[2]] = value
+
+ elif self.compute_backend == ComputeBackend.NEON:
+ import neon
+
+ @wp.func
+ def read_field(
+ field: Any,
+ index: Any,
+ direction: Any,
+ ):
+ # This function reads a field value at a given index and direction.
+ return wp.neon_read(field, index, direction)
+
+ @wp.func
+ def write_field(
+ field: Any,
+ index: Any,
+ direction: Any,
+ value: Any,
+ ):
+ # This function writes a value to a field at a given index and direction.
+ wp.neon_write(field, index, direction, value)
+
+ else:
+ raise ValueError(f"Unsupported compute backend: {self.compute_backend}")
+
+ return read_field, write_field
+
+ def _construct_read_field_neighbor(self):
+ """
+ Construct a function to read a field value at a neighboring index along a given direction.
+ """
+
+ if self.compute_backend == ComputeBackend.WARP:
+
+ @wp.func
+ def read_field_neighbor(
+ field: Any,
+ index: Any,
+ offset: Any,
+ direction: Any,
+ ):
+ # This function reads a field value at a given neighboring index and direction.
+ neighbor = index + offset
+ return field[direction, neighbor[0], neighbor[1], neighbor[2]]
+
+ elif self.compute_backend == ComputeBackend.NEON:
+ import neon
+ # from neon.multires.mPartition import neon_get_type
+
+ @wp.func
+ def read_field_neighbor(
+ field: Any,
+ index: Any,
+ offset: Any,
+ direction: Any,
+ ):
+ # This function reads a field value at a given neighboring index and direction.
+ unused_is_valid = wp.bool(False)
+ # dtype = neon_get_type(field) # This is a placeholder to ensure the dtype is set correctly
+ return wp.neon_read_ngh(field, index, offset, direction, wp.uint8(0.0), unused_is_valid)
+
+ else:
+ raise ValueError(f"Unsupported compute backend: {self.compute_backend}")
+
+ return read_field_neighbor
diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py
index 1eab1668..87c1274f 100644
--- a/xlb/operator/stepper/__init__.py
+++ b/xlb/operator/stepper/__init__.py
@@ -1,3 +1,4 @@
from xlb.operator.stepper.stepper import Stepper
from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper
+from xlb.operator.stepper.nse_multires_stepper import MultiresIncompressibleNavierStokesStepper
from xlb.operator.stepper.ibm_stepper import IBMStepper
diff --git a/xlb/operator/stepper/nse_multires_stepper.py b/xlb/operator/stepper/nse_multires_stepper.py
new file mode 100644
index 00000000..629f1b07
--- /dev/null
+++ b/xlb/operator/stepper/nse_multires_stepper.py
@@ -0,0 +1,1196 @@
+"""
+Multi-Resolution Navier-Stokes Stepper for the NEON Backend
+
+This module implements the multi-resolution LBM stepper using Warp kernels on the
+Neon multi-GPU runtime. It uses several programming patterns specific to Warp's
+compile-time code generation model.
+
+Compile-Time Specialization Pattern
+-----------------------------------
+Warp's @wp.func decorator traces Python code at kernel compilation time, not runtime.
+This means runtime boolean parameters cause Warp to emit branching code for both paths,
+increasing register pressure even when only one path is ever taken.
+
+To generate optimized, branch-free kernels, we use a **factory pattern** that captures
+boolean configuration at function-definition time:
+
+ def make_specialized_func(do_feature: bool):
+ @wp.func
+ def impl(...):
+ if wp.static(do_feature): # Evaluated at compile time
+ # This code is only emitted when do_feature=True
+ ...
+ else:
+ # This code is only emitted when do_feature=False
+ ...
+ return impl
+
+ # Generate specialized variants
+ func_with_feature = make_specialized_func(do_feature=True)
+ func_without_feature = make_specialized_func(do_feature=False)
+
+The `wp.static()` call evaluates its argument during Warp's tracing phase. Since
+`do_feature` is a Python bool captured in the closure, Warp sees a constant and
+eliminates the dead branch entirely.
+
+This pattern is used for:
+- `apply_bc_post_streaming` / `apply_bc_post_collision`: Specialized BC application
+ for streaming vs collision implementation steps
+- `collide_bc_accum` / `collide_simple`: Collision pipeline variants with/without
+ BC application and multi-resolution accumulation
+
+Closure Capture for Self Attributes
+-----------------------------------
+Warp cannot resolve `self.X` in plain assignments inside @wp.func bodies (e.g.,
+`_c = self.velocity_set.c` fails with "Invalid external reference type"). However,
+it can resolve `self.X` in:
+- Function call contexts: `self.stream.neon_functional(...)`
+- Range arguments: `range(self.velocity_set.q)`
+- Type casts: `self.compute_dtype(0)`
+
+For other uses, we pre-capture attributes at the Python level before defining the
+@wp.func, making them available as simple closure variables:
+
+ _c = self.velocity_set.c # Captured in Python scope
+
+ @wp.func
+ def my_kernel(...):
+ # Use _c directly — Warp sees it as a closure variable
+ direction = wp.neon_ngh_idx(wp.int8(_c[0, l]), ...)
+
+Cell Type Constants
+-------------------
+Cell types are defined in `xlb.cell_type`:
+- BC_SFV (254): Simple Fluid Voxel — no BC, no explosion/coalescence
+- BC_SOLID (255): Solid obstacle voxel
+- BC_NONE (0): Regular fluid voxel with potential BCs or multi-res interactions
+"""
+
+import nvtx
+import warp as wp
+from typing import Any
+
+from xlb import DefaultConfig
+from xlb.compute_backend import ComputeBackend
+from xlb.precision_policy import Precision
+from xlb.operator import Operator
+from xlb.operator.stream import Stream
+from xlb.operator.collision import BGK, KBC
+from xlb.operator.equilibrium import MultiresQuadraticEquilibrium
+from xlb.operator.macroscopic import MultiresMacroscopic
+from xlb.operator.stepper import Stepper
+from xlb.operator.boundary_condition.boundary_condition import ImplementationStep
+from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
+from xlb.operator.collision import ForcedCollision
+from xlb.helper import check_bc_overlaps
+from xlb.operator.boundary_masker import (
+ MeshVoxelizationMethod,
+ MultiresMeshMaskerAABB,
+ MultiresMeshMaskerAABBClose,
+ MultiresIndicesBoundaryMasker,
+ MultiresMeshMaskerRay,
+)
+from xlb.operator.boundary_condition.helper_functions_bc import MultiresEncodeAuxiliaryData
+from xlb.cell_type import BC_SFV, BC_SOLID
+
+"""
+SFV = Simple Fluid Voxel: a fluid voxel that is not a BC nor is involved in explosion or coalescence
+CFV = Complex Fluid Voxel: a fluid voxel that is not a SFV
+"""
+
+
+class MultiresIncompressibleNavierStokesStepper(Stepper):
+ """Multi-resolution incompressible Navier-Stokes stepper for the Neon backend.
+
+ Implements the full LBM step (stream, collide, boundary conditions) across
+ a hierarchy of grid levels using Neon containers. Each container is a
+ compile-time specialized Warp kernel wrapped in a Neon execution-graph
+ node.
+
+ The stepper supports several performance optimization strategies (see
+ :class:`MresPerfOptimizationType`):
+
+ * **NAIVE_COLLIDE_STREAM** — separate collide and stream containers at
+ every level.
+ * **FUSION_AT_FINEST** — fused stream+collide at the finest level.
+ * **FUSION_AT_FINEST_SFV** — additionally splits SFV / CFV voxels at
+ the finest level for reduced branching.
+ * **FUSION_AT_FINEST_SFV_ALL** — SFV / CFV splitting at all levels.
+
+ Parameters
+ ----------
+ grid : NeonMultiresGrid
+ The multi-resolution grid.
+ boundary_conditions : list of BoundaryCondition
+ Boundary conditions to apply.
+ collision_type : str
+ Collision operator type: ``"BGK"`` or ``"KBC"``.
+ forcing_scheme : str
+ Forcing scheme name (only used when *force_vector* is given).
+ force_vector : array-like, optional
+ External body force vector.
+ """
+
+ def __init__(
+ self,
+ grid,
+ boundary_conditions=[],
+ collision_type="BGK",
+ forcing_scheme="exact_difference",
+ force_vector=None,
+ ):
+ super().__init__(grid, boundary_conditions)
+
+ # Construct the collision operator
+ if collision_type == "BGK":
+ self.collision = BGK(self.velocity_set, self.precision_policy, self.compute_backend)
+ elif collision_type == "KBC":
+ self.collision = KBC(self.velocity_set, self.precision_policy, self.compute_backend)
+
+ if force_vector is not None:
+ self.collision = ForcedCollision(collision_operator=self.collision, forcing_scheme=forcing_scheme, force_vector=force_vector)
+
+ # Construct the operators
+ self.stream = Stream(self.velocity_set, self.precision_policy, self.compute_backend)
+ self.equilibrium = MultiresQuadraticEquilibrium(self.velocity_set, self.precision_policy, self.compute_backend)
+ self.macroscopic = MultiresMacroscopic(self.velocity_set, self.precision_policy, self.compute_backend)
+
+ def prepare_fields(self, rho, u, initializer=None):
+ import neon
+
+ """Prepare the fields required for the stepper.
+
+ Args:
+ initializer: Optional operator to initialize the distribution functions.
+ If provided, it should be a callable that takes (grid, velocity_set,
+ precision_policy, compute_backend) as arguments and returns initialized f_0.
+ If None, default equilibrium initialization is used with rho=1 and u=0.
+
+ Returns:
+ Tuple of (f_0, f_1, bc_mask, missing_mask):
+ - f_0: Initial distribution functions
+ - f_1: Copy of f_0 for double-buffering
+ - bc_mask: Boundary condition mask indicating which BC applies to each node
+ - missing_mask: Mask indicating which populations are missing at boundary nodes
+ """
+
+ f_0 = self.grid.create_field(
+ cardinality=self.velocity_set.q, dtype=self.precision_policy.store_precision, neon_memory_type=neon.MemoryType.device()
+ )
+
+ f_1 = self.grid.create_field(
+ cardinality=self.velocity_set.q, dtype=self.precision_policy.store_precision, neon_memory_type=neon.MemoryType.device()
+ )
+
+ missing_mask = self.grid.create_field(cardinality=self.velocity_set.q, dtype=Precision.UINT8)
+ bc_mask = self.grid.create_field(cardinality=1, dtype=Precision.UINT8)
+
+ for level in range(self.grid.count_levels):
+ f_1.copy_from_run(level, f_0, 0)
+
+ # Process boundary conditions and update masks
+ f_1, bc_mask, missing_mask = self._process_boundary_conditions(self.boundary_conditions, f_1, bc_mask, missing_mask)
+ # Initialize auxiliary data if needed
+ f_1 = self._initialize_auxiliary_data(self.boundary_conditions, f_1, bc_mask, missing_mask)
+
+ # Initialize distribution functions if initializer is provided
+ if initializer is not None:
+ # Refer to xlb.helper.initializers for available initializers
+ f_0 = initializer(bc_mask, f_0)
+ else:
+ from xlb.helper.initializers import initialize_multires_eq
+
+ f_0 = initialize_multires_eq(f_0, self.grid, self.velocity_set, self.precision_policy, self.compute_backend, rho=rho, u=u)
+
+ return f_0, f_1, bc_mask, missing_mask
+
+ def prepare_coalescence_count(self, coalescence_factor, bc_mask):
+ """Precompute coalescence weighting factors for multi-resolution streaming.
+
+ For each non-halo voxel at every level, this method accumulates
+ the number of finer neighbours that contribute populations via
+ coalescence (child-to-parent transfer), then inverts the count
+ so that the streaming kernel can apply the correct averaging weight.
+
+ Parameters
+ ----------
+ coalescence_factor : field
+ Multi-resolution field to store the per-direction coalescence
+ weights (modified in-place).
+ bc_mask : field
+ Boundary-condition mask used to skip solid voxels.
+ """
+ lattice_central_index = self.velocity_set.center_index
+ num_levels = coalescence_factor.get_grid().num_levels
+
+ @neon.Container.factory(name="sum_kernel_by_level")
+ def sum_kernel_by_level(level):
+ def ll_coalescence_count(loader: neon.Loader):
+ loader.set_mres_grid(coalescence_factor.get_grid(), level)
+
+ coalescence_factor_pn = loader.get_mres_read_handle(coalescence_factor)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask)
+
+ _c = self.velocity_set.c
+ _w = self.velocity_set.w
+
+ @wp.func
+ def cl_collide_coarse(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id == wp.uint8(BC_SOLID):
+ return
+ if not wp.neon_has_child(coalescence_factor_pn, index):
+ for l in range(self.velocity_set.q):
+ if level < num_levels - 1:
+ push_direction = wp.neon_ngh_idx(wp.int8(_c[0, l]), wp.int8(_c[1, l]), wp.int8(_c[2, l]))
+ val = self.store_dtype(1)
+ wp.neon_mres_lbm_store_op(coalescence_factor_pn, index, l, push_direction, val)
+
+ loader.declare_kernel(cl_collide_coarse)
+
+ return ll_coalescence_count
+
+ for level in range(num_levels):
+ sum_kernel = sum_kernel_by_level(level)
+ sum_kernel.run(0)
+
+ @neon.Container.factory(name="sum_kernel_by_level")
+ def invert_count(level):
+ def loading(loader: neon.Loader):
+ loader.set_mres_grid(coalescence_factor.get_grid(), level)
+
+ coalescence_factor_pn = loader.get_mres_read_handle(coalescence_factor)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask)
+
+ _c = self.velocity_set.c
+ _w = self.velocity_set.w
+
+ @wp.func
+ def compute(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id == wp.uint8(BC_SOLID):
+ return
+
+ if wp.neon_has_child(coalescence_factor_pn, index):
+ # we are a halo cell so we just exit
+ return
+
+ for l in range(self.velocity_set.q):
+ if l == lattice_central_index:
+ continue
+
+ pull_direction = wp.neon_ngh_idx(wp.int8(-_c[0, l]), wp.int8(-_c[1, l]), wp.int8(-_c[2, l]))
+
+ has_ngh_at_same_level = wp.bool(False)
+ coalescence_factor = self.compute_dtype(
+ wp.neon_read_ngh(coalescence_factor_pn, index, pull_direction, l, self.store_dtype(0), has_ngh_at_same_level)
+ )
+
+ if not wp.neon_has_finer_ngh(coalescence_factor_pn, index, pull_direction):
+ pass
+ else:
+ # Finer neighbour exists in the pull direction (opposite of l).
+ # Read from the halo sitting on top of that finer neighbour.
+ if has_ngh_at_same_level:
+ # Finer ngh in pull direction: YES
+ # Same-level ngh: YES
+ # Compute coalescence factor
+ if coalescence_factor > self.compute_dtype(0):
+ coalescence_factor = self.compute_dtype(1) / (self.compute_dtype(2) * coalescence_factor)
+ wp.neon_write(coalescence_factor_pn, index, l, self.store_dtype(coalescence_factor))
+
+ loader.declare_kernel(compute)
+
+ return loading
+
+ for level in range(num_levels):
+ sum_kernel = invert_count(level)
+ sum_kernel.run(0)
+ return
+
+ @classmethod
+ def _process_boundary_conditions(cls, boundary_conditions, f_1, bc_mask, missing_mask):
+ """Process boundary conditions and update boundary masks."""
+
+ # Check for boundary condition overlaps
+ # TODO! check_bc_overlaps(boundary_conditions, DefaultConfig.velocity_set.d, DefaultConfig.default_backend)
+
+ # Create boundary maskers
+ indices_masker = MultiresIndicesBoundaryMasker(
+ velocity_set=DefaultConfig.velocity_set,
+ precision_policy=DefaultConfig.default_precision_policy,
+ compute_backend=DefaultConfig.default_backend,
+ )
+
+ # Split boundary conditions by type
+ bc_with_vertices = [bc for bc in boundary_conditions if bc.mesh_vertices is not None]
+ bc_with_indices = [bc for bc in boundary_conditions if bc.indices is not None]
+
+ # Process indices-based boundary conditions
+ if bc_with_indices:
+ bc_mask, missing_mask = indices_masker(bc_with_indices, bc_mask, missing_mask)
+
+ # Process mesh-based boundary conditions for 3D
+ if DefaultConfig.velocity_set.d == 3 and bc_with_vertices:
+ for bc in bc_with_vertices:
+ if bc.voxelization_method.id is MeshVoxelizationMethod("AABB").id:
+ mesh_masker = MultiresMeshMaskerAABB(
+ velocity_set=DefaultConfig.velocity_set,
+ precision_policy=DefaultConfig.default_precision_policy,
+ compute_backend=DefaultConfig.default_backend,
+ )
+ elif bc.voxelization_method.id is MeshVoxelizationMethod("RAY").id:
+ mesh_masker = MultiresMeshMaskerRay(
+ velocity_set=DefaultConfig.velocity_set,
+ precision_policy=DefaultConfig.default_precision_policy,
+ compute_backend=DefaultConfig.default_backend,
+ )
+ elif bc.voxelization_method.id is MeshVoxelizationMethod("AABB_CLOSE").id:
+ mesh_masker = MultiresMeshMaskerAABBClose(
+ velocity_set=DefaultConfig.velocity_set,
+ precision_policy=DefaultConfig.default_precision_policy,
+ compute_backend=DefaultConfig.default_backend,
+ close_voxels=bc.voxelization_method.options.get("close_voxels"),
+ )
+ else:
+ raise ValueError(f"Unsupported voxelization method for multi-res: {bc.voxelization_method}")
+ # Apply the mesh masker to the boundary condition
+ f_1, bc_mask, missing_mask = mesh_masker(bc, f_1, bc_mask, missing_mask)
+
+ return f_1, bc_mask, missing_mask
+
+ @staticmethod
+ def _initialize_auxiliary_data(boundary_conditions, f_1, bc_mask, missing_mask):
+ """Initialize auxiliary data for boundary conditions that require it."""
+ for bc in boundary_conditions:
+ if bc.needs_aux_init and not bc.is_initialized_with_aux_data:
+ # Create the encoder operator for storing the auxiliary data
+ encode_auxiliary_data = MultiresEncodeAuxiliaryData(
+ bc.id,
+ bc.num_of_aux_data,
+ bc.profile,
+ velocity_set=bc.velocity_set,
+ precision_policy=bc.precision_policy,
+ compute_backend=bc.compute_backend,
+ )
+
+ # Encode the auxiliary data in f_1
+ f_1 = encode_auxiliary_data(f_1, bc_mask, missing_mask, stream=0)
+ bc.is_initialized_with_aux_data = True
+ return f_1
+
+ def _construct_neon(self):
+ # Pre-capture self attributes that Warp cannot resolve inside @wp.func bodies.
+ # Warp rejects `self` as an "Invalid external reference type" when it appears
+ # in a plain assignment (e.g. `_c = self.velocity_set.c`). Capturing here
+ # makes these values available as simple closure variables.
+ lattice_central_index = self.velocity_set.center_index
+ _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
+ _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8)
+ _opp_indices = self.velocity_set.opp_indices
+ _c = self.velocity_set.c
+
+ # Read the list of bc_to_id created upon instantiation
+ bc_to_id = boundary_condition_registry.bc_to_id
+
+ # Gather IDs of ExtrapolationOutflowBC boundary conditions
+ extrapolation_outflow_bc_ids = []
+ for bc_name, bc_id in bc_to_id.items():
+ if bc_name.startswith("ExtrapolationOutflowBC"):
+ extrapolation_outflow_bc_ids.append(bc_id)
+
+ # Factory for apply_bc: generates compile-time specialized variants
+ def make_apply_bc(is_post_streaming: bool):
+ @wp.func
+ def apply_bc_impl(
+ index: Any,
+ timestep: Any,
+ _boundary_id: Any,
+ _missing_mask: Any,
+ f_0: Any,
+ f_1: Any,
+ f_pre: Any,
+ f_post: Any,
+ ):
+ f_result = f_post
+
+ for i in range(wp.static(len(self.boundary_conditions))):
+ if wp.static(is_post_streaming):
+ if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.STREAMING):
+ if _boundary_id == wp.static(self.boundary_conditions[i].id):
+ f_result = wp.static(self.boundary_conditions[i].neon_functional)(
+ index, timestep, _missing_mask, f_0, f_1, f_pre, f_post
+ )
+ else:
+ if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.COLLISION):
+ if _boundary_id == wp.static(self.boundary_conditions[i].id):
+ f_result = wp.static(self.boundary_conditions[i].neon_functional)(
+ index, timestep, _missing_mask, f_0, f_1, f_pre, f_post
+ )
+ if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids):
+ if _boundary_id == wp.static(self.boundary_conditions[i].id):
+ f_result = wp.static(self.boundary_conditions[i].assemble_auxiliary_data)(
+ index, timestep, _missing_mask, f_0, f_1, f_pre, f_post
+ )
+ return f_result
+
+ return apply_bc_impl
+
+ # Compile-time specialized BC application variants
+ apply_bc_post_streaming = make_apply_bc(is_post_streaming=True)
+ apply_bc_post_collision = make_apply_bc(is_post_streaming=False)
+
+ @wp.func
+ def neon_get_thread_data(
+ f0_pn: Any,
+ missing_mask_pn: Any,
+ index: Any,
+ ):
+ # Read thread data for populations
+ _f0_thread = _f_vec()
+ _missing_mask = _missing_mask_vec()
+ for l in range(self.velocity_set.q):
+ # q-sized vector of pre-streaming populations
+ _f0_thread[l] = self.compute_dtype(wp.neon_read(f0_pn, index, l))
+ _missing_mask[l] = wp.neon_read(missing_mask_pn, index, l)
+
+ return _f0_thread, _missing_mask
+
+ @wp.func
+ def neon_apply_aux_recovery_bc(
+ index: Any,
+ _boundary_id: Any,
+ _missing_mask: Any,
+ f_0_pn: Any,
+ f_1_pn: Any,
+ ):
+ # Note:
+ # In XLB, the BC auxiliary data (e.g. prescribed values of pressure or normal velocity) are stored in (i) central index of f_1 and/or
+ # (ii) missing directions of f_1. Some BCs may or may not need all these available storage space. This function checks whether
+ # the BC needs recovery of auxiliary data and then recovers the information for the next iteration (due to buffer swapping) by
+ # writting the values of f_1 into f_0.
+
+ # Unroll the loop over boundary conditions
+ for i in range(wp.static(len(self.boundary_conditions))):
+ if wp.static(self.boundary_conditions[i].needs_aux_recovery):
+ if _boundary_id == wp.static(self.boundary_conditions[i].id):
+ for l in range(self.velocity_set.q):
+ # Perform the swapping of data
+ if l == lattice_central_index:
+ # (i) Recover the values stored in the central index of f_1
+ _f1_thread = wp.neon_read(f_1_pn, index, l)
+ wp.neon_write(f_0_pn, index, l, self.store_dtype(_f1_thread))
+ elif _missing_mask[l] == wp.uint8(1):
+ # (ii) Recover the values stored in the missing directions of f_1
+ _f1_thread = wp.neon_read(f_1_pn, index, _opp_indices[l])
+ wp.neon_write(f_0_pn, index, _opp_indices[l], self.store_dtype(_f1_thread))
+
+ # Factory for neon_collide_pipeline: generates compile-time specialized variants
+ def make_collide_pipeline(do_bc: bool, do_accumulation: bool):
+ @wp.func
+ def collide_pipeline_impl(
+ index: Any,
+ timestep: Any,
+ _boundary_id: Any,
+ _missing_mask: Any,
+ f_0_pn: Any,
+ f_1_pn: Any,
+ _f_post_stream: Any,
+ omega: Any,
+ num_levels: int,
+ level: int,
+ accumulation_pn: Any,
+ ):
+ _rho, _u = self.macroscopic.neon_functional(_f_post_stream)
+ _feq = self.equilibrium.neon_functional(_rho, _u)
+ _f_post_collision = self.collision.neon_functional(_f_post_stream, _feq, omega)
+
+ if wp.static(do_bc):
+ _f_post_collision = apply_bc_post_collision(
+ index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_stream, _f_post_collision
+ )
+ neon_apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0_pn, f_1_pn)
+
+ if wp.static(do_accumulation):
+ for l in range(self.velocity_set.q):
+ push_direction = wp.neon_ngh_idx(wp.int8(_c[0, l]), wp.int8(_c[1, l]), wp.int8(_c[2, l]))
+ if level < num_levels - 1:
+ wp.neon_mres_lbm_store_op(accumulation_pn, index, l, push_direction, self.store_dtype(_f_post_collision[l]))
+ wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_collision[l]))
+ else:
+ for l in range(self.velocity_set.q):
+ wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_collision[l]))
+
+ return _f_post_collision
+
+ return collide_pipeline_impl
+
+ # Compile-time specialized collision pipeline variants
+ collide_bc_accum = make_collide_pipeline(do_bc=True, do_accumulation=True)
+ collide_bc_only = make_collide_pipeline(do_bc=True, do_accumulation=False)
+ collide_simple = make_collide_pipeline(do_bc=False, do_accumulation=False)
+
+ @wp.func
+ def neon_stream_explode_coalesce(
+ index: Any,
+ f_0_pn: Any,
+ coalescence_factor_pn: Any,
+ ):
+ _f_post_stream = self.stream.neon_functional(f_0_pn, index)
+
+ for l in range(self.velocity_set.q):
+ if l == lattice_central_index:
+ continue
+
+ pull_direction = wp.neon_ngh_idx(wp.int8(-_c[0, l]), wp.int8(-_c[1, l]), wp.int8(-_c[2, l]))
+
+ has_ngh_at_same_level = wp.bool(False)
+ accumulated = wp.neon_read_ngh(f_0_pn, index, pull_direction, l, self.store_dtype(0), has_ngh_at_same_level)
+
+ if not wp.neon_has_finer_ngh(f_0_pn, index, pull_direction):
+ # No finer ngh in the pull direction (opposite of l)
+ if not has_ngh_at_same_level:
+ # No same-level ngh — could we have a coarser-level ngh?
+ if wp.neon_has_parent(f_0_pn, index):
+ # Halo cell on top of us (parent exists)
+ has_a_coarser_ngh = wp.bool(False)
+ exploded_pop = wp.neon_lbm_read_coarser_ngh(f_0_pn, index, pull_direction, l, self.store_dtype(0), has_a_coarser_ngh)
+ if has_a_coarser_ngh:
+ # No finer ngh in pull direction, no same-level ngh,
+ # but a parent (ghost cell) exists with a coarser ngh
+ # -> Explosion: read the exploded population from the
+ # coarser level's halo.
+ _f_post_stream[l] = self.compute_dtype(exploded_pop)
+ else:
+ # Finer ngh exists in the pull direction (opposite of l).
+ # Read from the halo on top of that finer ngh.
+ if has_ngh_at_same_level:
+ # Finer ngh in pull direction: YES
+ # Same-level ngh: YES
+ # -> Coalescence
+ coalescence_factor = wp.neon_read(coalescence_factor_pn, index, l)
+ accumulated = accumulated * coalescence_factor
+ _f_post_stream[l] = self.compute_dtype(accumulated)
+
+ return _f_post_stream
+
+ @neon.Container.factory(name="collide_coarse")
+ def collide_coarse(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any, timestep: int):
+ num_levels = f_0_fd.get_grid().num_levels
+
+ def ll(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask_fd.get_grid(), level)
+ if level + 1 < f_0_fd.get_grid().num_levels:
+ f_0_pn = loader.get_mres_write_handle(f_0_fd, neon.Loader.Operation.stencil_up)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd, neon.Loader.Operation.stencil_up)
+ else:
+ f_0_pn = loader.get_mres_read_handle(f_0_fd)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd)
+
+ @wp.func
+ def device(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id == wp.uint8(BC_SOLID):
+ return
+ if not wp.neon_has_child(f_0_pn, index):
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ collide_bc_accum(
+ index,
+ timestep,
+ _boundary_id,
+ _missing_mask,
+ f_0_pn,
+ f_1_pn,
+ _f0_thread,
+ omega,
+ num_levels,
+ level,
+ f_1_pn,
+ )
+ else:
+ for l in range(self.velocity_set.q):
+ wp.neon_write(f_1_pn, index, l, self.store_dtype(0))
+
+ loader.declare_kernel(device)
+
+ return ll
+
+ @neon.Container.factory(name="SFV_collide_coarse")
+ def SFV_collide_coarse(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any, timestep: int):
+ """Collision on SFV voxels only — no BCs, no multi-resolution accumulation."""
+
+ def ll(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask_fd.get_grid(), level)
+ f_0_pn = loader.get_mres_read_handle(f_0_fd)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd)
+
+ @wp.func
+ def device(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id != wp.uint8(BC_SFV):
+ return
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ collide_simple(
+ index,
+ 0,
+ _boundary_id,
+ _missing_mask,
+ f_0_pn,
+ f_1_pn,
+ _f0_thread,
+ omega,
+ 0,
+ level,
+ f_1_pn,
+ )
+
+ loader.declare_kernel(device)
+
+ return ll
+
+ @neon.Container.factory(name="CFV_collide_coarse")
+ def CFV_collide_coarse(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any, timestep: int):
+ """Collision on CFV voxels only — skips both solid and SFV."""
+ num_levels = f_0_fd.get_grid().num_levels
+
+ def ll(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask_fd.get_grid(), level)
+ if level + 1 < f_0_fd.get_grid().num_levels:
+ f_0_pn = loader.get_mres_write_handle(f_0_fd, neon.Loader.Operation.stencil_up)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd, neon.Loader.Operation.stencil_up)
+ else:
+ f_0_pn = loader.get_mres_read_handle(f_0_fd)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd)
+
+ @wp.func
+ def device(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id == wp.uint8(BC_SOLID):
+ return
+ if _boundary_id == wp.uint8(BC_SFV):
+ return
+ if not wp.neon_has_child(f_0_pn, index):
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ collide_bc_accum(
+ index,
+ timestep,
+ _boundary_id,
+ _missing_mask,
+ f_0_pn,
+ f_1_pn,
+ _f0_thread,
+ omega,
+ num_levels,
+ level,
+ f_1_pn,
+ )
+ else:
+ for l in range(self.velocity_set.q):
+ wp.neon_write(f_1_pn, index, l, self.store_dtype(0))
+
+ loader.declare_kernel(device)
+
+ return ll
+
+ @neon.Container.factory(name="stream_coarse_step_ABC")
+ def stream_coarse_step_ABC(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any, timestep: int):
+ def ll(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask_fd.get_grid(), level)
+ f_0_pn = loader.get_mres_read_handle(f_0_fd)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd)
+ coalescence_factor_pn = loader.get_mres_read_handle(omega)
+
+ @wp.func
+ def device(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id == wp.uint8(BC_SOLID):
+ return
+ if wp.neon_has_child(f_0_pn, index):
+ return
+
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ _f_post_collision = _f0_thread
+ _f_post_stream = neon_stream_explode_coalesce(index, f_0_pn, coalescence_factor_pn)
+
+ _f_post_stream = apply_bc_post_streaming(
+ index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_collision, _f_post_stream
+ )
+ neon_apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0_pn, f_1_pn)
+
+ for l in range(self.velocity_set.q):
+ wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_stream[l]))
+
+ loader.declare_kernel(device)
+
+ return ll
+
+ @neon.Container.factory(name="SFV_stream_coarse_step_ABC")
+ def SFV_stream_coarse_step_ABC(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any, timestep: int):
+ """Stream on CFV voxels only — skips SFV and solid."""
+
+ def ll(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask_fd.get_grid(), level)
+ f_0_pn = loader.get_mres_read_handle(f_0_fd)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd)
+ coalescence_factor_pn = loader.get_mres_read_handle(omega)
+
+ @wp.func
+ def device(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id == wp.uint8(BC_SFV):
+ return
+ if _boundary_id == wp.uint8(BC_SOLID):
+ return
+ if wp.neon_has_child(f_0_pn, index):
+ return
+
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ _f_post_collision = _f0_thread
+ _f_post_stream = neon_stream_explode_coalesce(index, f_0_pn, coalescence_factor_pn)
+
+ _f_post_stream = apply_bc_post_streaming(
+ index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_collision, _f_post_stream
+ )
+ neon_apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0_pn, f_1_pn)
+
+ for l in range(self.velocity_set.q):
+ wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_stream[l]))
+
+ loader.declare_kernel(device)
+
+ return ll
+
+ @neon.Container.factory(name="SFV_reset_bc_mask")
+ def SFV_reset_bc_mask(
+ level: int,
+ f_0_fd: Any,
+ f_1_fd: Any,
+ bc_mask_fd: Any,
+ missing_mask_fd: Any,
+ ):
+ """
+ Setting the BC type to BC_SFV
+ """
+
+ def ll_stream_coarse(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask_fd.get_grid(), level)
+
+ f_0_pn = loader.get_mres_read_handle(f_0_fd)
+
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd)
+
+ _c = self.velocity_set.c
+
+ @wp.func
+ def cl_stream_coarse(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id == wp.uint8(BC_SOLID):
+ return
+ if _boundary_id != 0:
+ return
+
+ if wp.neon_has_child(f_0_pn, index):
+ # we are a halo cell so we just exit
+ return
+
+ # do stream normally
+ _missing_mask = _missing_mask_vec()
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ _f_post_collision = _f0_thread
+ _f_post_stream = self.stream.neon_functional(f_0_pn, index)
+
+ for l in range(self.velocity_set.q):
+ if l == lattice_central_index:
+ continue
+
+ pull_direction = wp.neon_ngh_idx(wp.int8(-_c[0, l]), wp.int8(-_c[1, l]), wp.int8(-_c[2, l]))
+
+ has_ngh_at_same_level = wp.bool(False)
+ wp.neon_read_ngh(f_0_pn, index, pull_direction, l, self.store_dtype(0), has_ngh_at_same_level)
+
+ if not wp.neon_has_finer_ngh(f_0_pn, index, pull_direction):
+ if not has_ngh_at_same_level:
+ if wp.neon_has_parent(f_0_pn, index):
+ has_a_coarser_ngh = wp.bool(False)
+ wp.neon_lbm_read_coarser_ngh(f_0_pn, index, pull_direction, l, self.store_dtype(0), has_a_coarser_ngh)
+ if has_a_coarser_ngh:
+ # Explosion: not an SFV
+ return
+ else:
+ if has_ngh_at_same_level:
+ # Coalescence: not an SFV
+ return
+
+ # Voxel is a pure fluid cell with no multi-resolution interactions — mark as SFV
+ wp.neon_write(bc_mask_pn, index, 0, wp.uint8(BC_SFV))
+
+ loader.declare_kernel(cl_stream_coarse)
+
+ return ll_stream_coarse
+
+ @neon.Container.factory(name="SFV_stream_coarse_step")
+ def SFV_stream_coarse_step(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any):
+ def ll_stream_coarse(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask_fd.get_grid(), level)
+
+ f_0_pn = loader.get_mres_read_handle(f_0_fd)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd)
+
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd)
+
+ _c = self.velocity_set.c
+
+ @wp.func
+ def cl_stream_coarse(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id != wp.uint8(BC_SFV):
+ return
+ # BC_SFV voxel type:
+ # - They are not BC voxels
+ # - They are not on a resolution jump -> they do not do coalescence or explosion
+ # - They are not mr halo cells
+
+ _missing_mask = _missing_mask_vec()
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ _f_post_collision = _f0_thread
+ _f_post_stream = self.stream.neon_functional(f_0_pn, index)
+
+ for l in range(self.velocity_set.q):
+ wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_stream[l]))
+
+ loader.declare_kernel(cl_stream_coarse)
+
+ return ll_stream_coarse
+
+ @wp.func
+ def neon_stream_finest_with_explosion(
+ index: Any,
+ f_0_pn: Any,
+ explosion_src_pn: Any,
+ ):
+ _f_post_stream = self.stream.neon_functional(f_0_pn, index)
+
+ for l in range(self.velocity_set.q):
+ if l == lattice_central_index:
+ continue
+
+ pull_direction = wp.neon_ngh_idx(wp.int8(-_c[0, l]), wp.int8(-_c[1, l]), wp.int8(-_c[2, l]))
+
+ has_ngh_at_same_level = wp.bool(False)
+ wp.neon_read_ngh(f_0_pn, index, pull_direction, l, self.store_dtype(0), has_ngh_at_same_level)
+
+ if not has_ngh_at_same_level:
+ # No same-level ngh — could we have a coarser-level ngh?
+ if wp.neon_has_parent(f_0_pn, index):
+ # Parent exists — try to read the exploded population from the coarser level
+ has_a_coarser_ngh = wp.bool(False)
+ exploded_pop = wp.neon_lbm_read_coarser_ngh(
+ explosion_src_pn, index, pull_direction, l, self.store_dtype(0), has_a_coarser_ngh
+ )
+ if has_a_coarser_ngh:
+ # No finer ngh in pull direction, no same-level ngh,
+ # but a parent (ghost cell) exists with a coarser ngh
+ # -> Explosion: read the exploded population from the
+ # coarser level's halo.
+ _f_post_stream[l] = self.compute_dtype(exploded_pop)
+
+ return _f_post_stream
+
+ @neon.Container.factory(name="finest_fused_pull")
+ def finest_fused_pull(
+ level: int,
+ f_0_fd: Any,
+ f_1_fd: Any,
+ bc_mask_fd: Any,
+ missing_mask_fd: Any,
+ omega: Any,
+ timestep: Any,
+ is_f1_the_explosion_src_field: bool,
+ ):
+ if level != 0:
+ raise Exception("Only the finest level is supported for now")
+ num_levels = f_0_fd.get_grid().num_levels
+
+ def ll(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask_fd.get_grid(), level)
+ if level + 1 < f_0_fd.get_grid().num_levels:
+ f_0_pn = loader.get_mres_write_handle(f_0_fd, neon.Loader.Operation.stencil_up)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd, neon.Loader.Operation.stencil_up)
+ else:
+ f_0_pn = loader.get_mres_read_handle(f_0_fd)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd)
+ explosion_src_pn = f_1_pn if is_f1_the_explosion_src_field else f_0_pn
+ accumulation_pn = f_1_pn if is_f1_the_explosion_src_field else f_0_pn
+
+ @wp.func
+ def device(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id == wp.uint8(BC_SOLID):
+ return
+ if wp.neon_has_child(f_0_pn, index):
+ return
+
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ _f_post_collision = _f0_thread
+ _f_post_stream = neon_stream_finest_with_explosion(index, f_0_pn, explosion_src_pn)
+
+ _f_post_stream = apply_bc_post_streaming(
+ index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_collision, _f_post_stream
+ )
+
+ collide_bc_accum(
+ index,
+ timestep,
+ _boundary_id,
+ _missing_mask,
+ f_0_pn,
+ f_1_pn,
+ _f_post_stream,
+ omega,
+ num_levels,
+ level,
+ accumulation_pn,
+ )
+
+ loader.declare_kernel(device)
+
+ return ll
+
+ @neon.Container.factory(name="CFV_finest_fused_pull")
+ def CFV_finest_fused_pull(
+ level: int,
+ f_0_fd: Any,
+ f_1_fd: Any,
+ bc_mask_fd: Any,
+ missing_mask_fd: Any,
+ omega: Any,
+ timestep: Any,
+ is_f1_the_explosion_src_field: bool,
+ ):
+ """Fused stream+collide on CFV voxels at the finest level — skips SFV and solid."""
+ if level != 0:
+ raise Exception("Only the finest level is supported for now")
+ num_levels = f_0_fd.get_grid().num_levels
+
+ def ll(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask_fd.get_grid(), level)
+ if level + 1 < f_0_fd.get_grid().num_levels:
+ f_0_pn = loader.get_mres_write_handle(f_0_fd, neon.Loader.Operation.stencil_up)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd, neon.Loader.Operation.stencil_up)
+ else:
+ f_0_pn = loader.get_mres_read_handle(f_0_fd)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd)
+ explosion_src_pn = f_1_pn if is_f1_the_explosion_src_field else f_0_pn
+ accumulation_pn = f_1_pn if is_f1_the_explosion_src_field else f_0_pn
+
+ @wp.func
+ def device(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id == wp.uint8(BC_SOLID):
+ return
+ if _boundary_id == wp.uint8(BC_SFV):
+ return
+ if wp.neon_has_child(f_0_pn, index):
+ return
+
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ _f_post_collision = _f0_thread
+ _f_post_stream = neon_stream_finest_with_explosion(index, f_0_pn, explosion_src_pn)
+
+ _f_post_stream = apply_bc_post_streaming(
+ index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_collision, _f_post_stream
+ )
+
+ collide_bc_accum(
+ index,
+ timestep,
+ _boundary_id,
+ _missing_mask,
+ f_0_pn,
+ f_1_pn,
+ _f_post_stream,
+ omega,
+ num_levels,
+ level,
+ accumulation_pn,
+ )
+
+ loader.declare_kernel(device)
+
+ return ll
+
+ @neon.Container.factory(name="SFV_finest_fused_pull")
+ def SFV_finest_fused_pull(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any):
+ """Fused stream+collide on SFV voxels at the finest level — no BCs, no explosion."""
+ if level != 0:
+ raise Exception("Only the finest level is supported for now")
+
+ def ll(loader: neon.Loader):
+ loader.set_mres_grid(bc_mask_fd.get_grid(), level)
+ f_0_pn = loader.get_mres_read_handle(f_0_fd)
+ f_1_pn = loader.get_mres_write_handle(f_1_fd)
+ bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd)
+
+ @wp.func
+ def device(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id != wp.uint8(BC_SFV):
+ return
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ _f_post_stream = self.stream.neon_functional(f_0_pn, index)
+ collide_simple(
+ index,
+ 0,
+ _boundary_id,
+ _missing_mask,
+ f_0_pn,
+ f_1_pn,
+ _f_post_stream,
+ omega,
+ 0,
+ 0,
+ f_1_pn,
+ )
+
+ loader.declare_kernel(device)
+
+ return ll
+
+ return None, {
+ "collide_coarse": collide_coarse,
+ "stream_coarse_step_ABC": stream_coarse_step_ABC,
+ "finest_fused_pull": finest_fused_pull,
+ "CFV_finest_fused_pull": CFV_finest_fused_pull,
+ "SFV_finest_fused_pull": SFV_finest_fused_pull,
+ "SFV_reset_bc_mask": SFV_reset_bc_mask,
+ "CFV_collide_coarse": CFV_collide_coarse,
+ "SFV_collide_coarse": SFV_collide_coarse,
+ "SFV_stream_coarse_step_ABC": SFV_stream_coarse_step_ABC,
+ "SFV_stream_coarse_step": SFV_stream_coarse_step,
+ }
+
+ def launch_container(self, streamId, op_name, mres_level, f_0, f_1, bc_mask, missing_mask, omega, timestep):
+ """Immediately launch a single Neon container by name.
+
+ Parameters
+ ----------
+ streamId : int
+ CUDA stream index.
+ op_name : str
+ Key into the container dictionary returned by ``_construct_neon``.
+ mres_level : int
+ Grid level to execute on.
+ f_0, f_1 : field
+ Double-buffered distribution-function fields.
+ bc_mask, missing_mask : field
+ Boundary condition and missing-population masks.
+ omega : float
+ Relaxation parameter at this level.
+ timestep : int
+ Current simulation timestep.
+ """
+ self.neon_container[op_name](mres_level, f_0, f_1, bc_mask, missing_mask, omega, timestep).run(0)
+
+ def add_to_app(self, **kwargs):
+ """Append a container invocation to the Neon skeleton application list.
+
+ Required keyword arguments are ``op_name`` (str) and ``app`` (list).
+ All remaining keyword arguments are forwarded to the container
+ factory for the given ``op_name``. Argument validation is performed
+ before the call, and a ``ValueError`` is raised on mismatch.
+ """
+ import inspect
+
+ def validate_kwargs_forward(func, kwargs):
+ """
+ Check whether `func(**kwargs)` would be valid,
+ and return *all* the issues instead of raising on the first one.
+
+ Returns a dict; empty dict means "everything is OK".
+ """
+ sig = inspect.signature(func)
+ params = sig.parameters
+
+ errors = {}
+
+ # --- 1. Positional-only required params (cannot be given via kwargs) ---
+ pos_only_required = [name for name, p in params.items() if p.kind == inspect.Parameter.POSITIONAL_ONLY and p.default is inspect._empty]
+ if pos_only_required:
+ errors["positional_only_required"] = pos_only_required
+
+ # --- 2. Unexpected kwargs (if no **kwargs in target) ---
+ has_var_kw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())
+ if not has_var_kw:
+ allowed_kw = {
+ name
+ for name, p in params.items()
+ if p.kind
+ in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY,
+ )
+ }
+ unexpected = sorted(set(kwargs) - allowed_kw)
+ if unexpected:
+ errors["unexpected_kwargs"] = unexpected
+
+ # --- 3. Missing required keyword-passable params ---
+ missing_required = [
+ name
+ for name, p in params.items()
+ if p.kind
+ in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY,
+ )
+ and p.default is inspect._empty # no default
+ and name not in kwargs # not provided
+ ]
+ if missing_required:
+ errors["missing_required"] = missing_required
+
+ return errors
+
+ container_generator = None
+ try:
+ op_name = kwargs.pop("op_name")
+ app = kwargs.pop("app")
+ except:
+ raise ValueError("op_name and app must be provided as keyword arguments")
+
+ try:
+ container_generator = self.neon_container[op_name]
+ except KeyError:
+ raise ValueError(f"Operator {op_name} not found in neon container. Available operators: {list(self.neon_container.keys())}")
+
+ errors = validate_kwargs_forward(container_generator, kwargs)
+ if errors:
+ raise ValueError(f"Cannot forward kwargs to target: {errors}")
+
+ nvtx.push_range(f"New Container {op_name}", color="yellow")
+ app.append(container_generator(**kwargs))
+ nvtx.pop_range()
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_launch(self, f_0, f_1, bc_mask, missing_mask, omega, timestep):
+ """Execute a single LBM step through the Neon backend (direct launch)."""
+ c = self.neon_container(f_0, f_1, bc_mask, missing_mask, omega, timestep)
+ c.run(0)
+ return f_0, f_1
diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py
index ef0ff412..475958ac 100644
--- a/xlb/operator/stepper/nse_stepper.py
+++ b/xlb/operator/stepper/nse_stepper.py
@@ -1,6 +1,13 @@
-# Base class for all stepper operators
+"""
+Single-resolution incompressible Navier-Stokes stepper.
+
+Implements the full LBM step (stream, collide, apply BCs) for a single-
+resolution grid. Supports pull and push streaming schemes on JAX, a
+pull-only fused kernel on Warp, and a pull-only Neon container.
+"""
from functools import partial
+
from jax import jit
import warp as wp
from typing import Any
@@ -17,12 +24,44 @@
from xlb.operator.boundary_condition.boundary_condition import ImplementationStep
from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry
from xlb.operator.collision import ForcedCollision
-from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker
+from xlb.operator.boundary_masker import (
+ IndicesBoundaryMasker,
+ MeshVoxelizationMethod,
+ MeshMaskerAABB,
+ MeshMaskerRay,
+ MeshMaskerWinding,
+ MeshMaskerAABBClose,
+)
from xlb.helper import check_bc_overlaps
-from xlb.helper.nse_solver import create_nse_fields
+from xlb.helper.nse_fields import create_nse_fields
+from xlb.operator.boundary_condition.helper_functions_bc import EncodeAuxiliaryData
+from xlb.cell_type import BC_SOLID
class IncompressibleNavierStokesStepper(Stepper):
+ """Single-resolution incompressible Navier-Stokes LBM stepper.
+
+ Composes streaming, collision, equilibrium, macroscopic, and boundary-
+ condition operators into a complete timestep.
+
+ Parameters
+ ----------
+ grid : Grid
+ Computational grid.
+ boundary_conditions : list of BoundaryCondition
+ Boundary conditions to apply each step.
+ collision_type : str
+ ``"BGK"``, ``"KBC"``, or ``"SmagorinskyLESBGK"``.
+ streaming_scheme : str
+ ``"pull"`` (default) or ``"push"`` (JAX only).
+ forcing_scheme : str
+ Forcing scheme name (used when *force_vector* is given).
+ force_vector : array-like, optional
+ External body force vector.
+ backend_config : dict
+ Backend-specific options (e.g. Neon OCC configuration).
+ """
+
def __init__(
self,
grid,
@@ -31,8 +70,10 @@ def __init__(
streaming_scheme="pull",
forcing_scheme="exact_difference",
force_vector=None,
+ backend_config={},
):
super().__init__(grid, boundary_conditions)
+ self.backend_config = backend_config
# Construct the collision operator
if collision_type == "BGK":
@@ -76,63 +117,112 @@ def prepare_fields(self, initializer=None):
grid=self.grid, velocity_set=self.velocity_set, compute_backend=self.compute_backend, precision_policy=self.precision_policy
)
- # Initialize distribution functions if initializer is provided
- if initializer is not None:
- f_0 = initializer(self.grid, self.velocity_set, self.precision_policy, self.compute_backend)
- else:
- from xlb.helper.initializers import initialize_eq
-
- f_0 = initialize_eq(f_0, self.grid, self.velocity_set, self.precision_policy, self.compute_backend)
-
# Copy f_0 using backend-specific copy to f_1
if self.compute_backend == ComputeBackend.JAX:
f_1 = f_0.copy()
- else:
+ if self.compute_backend == ComputeBackend.WARP:
wp.copy(f_1, f_0)
+ if self.compute_backend == ComputeBackend.NEON:
+ f_1.copy_from_run(f_0, 0)
+ # Important note: XLB uses f_1 buffer (center index and missing directions) to store auxiliary data for boundary conditions.
# Process boundary conditions and update masks
- bc_mask, missing_mask = self._process_boundary_conditions(self.boundary_conditions, bc_mask, missing_mask)
+ f_1, bc_mask, missing_mask = self._process_boundary_conditions(self.boundary_conditions, f_1, bc_mask, missing_mask)
+
# Initialize auxiliary data if needed
- f_0, f_1 = self._initialize_auxiliary_data(self.boundary_conditions, f_0, f_1, bc_mask, missing_mask)
+ f_1 = self._initialize_auxiliary_data(self.boundary_conditions, f_1, bc_mask, missing_mask)
+ # bc_mask.update_host(0)
+ # missing_mask.update_host(0)
+ wp.synchronize()
+ # bc_mask.export_vti("bc_mask.vti", 'bc_mask')
+ # missing_mask.export_vti("missing_mask.vti", 'missing_mask')
+
+ # Initialize distribution functions if initializer is provided
+ if initializer is not None:
+ f_0 = initializer(bc_mask, f_0)
+ else:
+ from xlb.helper.initializers import initialize_eq
+
+ f_0 = initialize_eq(f_0, self.grid, self.velocity_set, self.precision_policy, self.compute_backend)
return f_0, f_1, bc_mask, missing_mask
- @classmethod
- def _process_boundary_conditions(cls, boundary_conditions, bc_mask, missing_mask):
+ def _process_boundary_conditions(self, boundary_conditions, f_1, bc_mask, missing_mask):
"""Process boundary conditions and update boundary masks."""
+
# Check for boundary condition overlaps
check_bc_overlaps(boundary_conditions, DefaultConfig.velocity_set.d, DefaultConfig.default_backend)
+
# Create boundary maskers
indices_masker = IndicesBoundaryMasker(
velocity_set=DefaultConfig.velocity_set,
precision_policy=DefaultConfig.default_precision_policy,
compute_backend=DefaultConfig.default_backend,
+ grid=self.grid,
)
+
# Split boundary conditions by type
bc_with_vertices = [bc for bc in boundary_conditions if bc.mesh_vertices is not None]
bc_with_indices = [bc for bc in boundary_conditions if bc.indices is not None]
+
# Process indices-based boundary conditions
if bc_with_indices:
bc_mask, missing_mask = indices_masker(bc_with_indices, bc_mask, missing_mask)
+
# Process mesh-based boundary conditions for 3D
if DefaultConfig.velocity_set.d == 3 and bc_with_vertices:
- mesh_masker = MeshBoundaryMasker(
- velocity_set=DefaultConfig.velocity_set,
- precision_policy=DefaultConfig.default_precision_policy,
- compute_backend=DefaultConfig.default_backend,
- )
for bc in bc_with_vertices:
- bc_mask, missing_mask = mesh_masker(bc, bc_mask, missing_mask)
+ if bc.voxelization_method.id is MeshVoxelizationMethod("AABB").id:
+ mesh_masker = MeshMaskerAABB(
+ velocity_set=DefaultConfig.velocity_set,
+ precision_policy=DefaultConfig.default_precision_policy,
+ compute_backend=DefaultConfig.default_backend,
+ )
+ elif bc.voxelization_method.id is MeshVoxelizationMethod("RAY").id:
+ mesh_masker = MeshMaskerRay(
+ velocity_set=DefaultConfig.velocity_set,
+ precision_policy=DefaultConfig.default_precision_policy,
+ compute_backend=DefaultConfig.default_backend,
+ )
+ elif bc.voxelization_method.id is MeshVoxelizationMethod("WINDING").id:
+ mesh_masker = MeshMaskerWinding(
+ velocity_set=DefaultConfig.velocity_set,
+ precision_policy=DefaultConfig.default_precision_policy,
+ compute_backend=DefaultConfig.default_backend,
+ )
+ elif bc.voxelization_method.id is MeshVoxelizationMethod("AABB_CLOSE").id:
+ mesh_masker = MeshMaskerAABBClose(
+ velocity_set=DefaultConfig.velocity_set,
+ precision_policy=DefaultConfig.default_precision_policy,
+ compute_backend=DefaultConfig.default_backend,
+ close_voxels=bc.voxelization_method.options.get("close_voxels"),
+ )
+ else:
+ raise ValueError(f"Unsupported voxelization method: {bc.voxelization_method}")
+ # Apply the mesh masker to the boundary condition
+ f_1, bc_mask, missing_mask = mesh_masker(bc, f_1, bc_mask, missing_mask)
- return bc_mask, missing_mask
+ return f_1, bc_mask, missing_mask
@staticmethod
- def _initialize_auxiliary_data(boundary_conditions, f_0, f_1, bc_mask, missing_mask):
+ def _initialize_auxiliary_data(boundary_conditions, f_1, bc_mask, missing_mask):
"""Initialize auxiliary data for boundary conditions that require it."""
for bc in boundary_conditions:
if bc.needs_aux_init and not bc.is_initialized_with_aux_data:
- f_0, f_1 = bc.aux_data_init(f_0, f_1, bc_mask, missing_mask)
- return f_0, f_1
+ # Create the encoder operator for storing the auxiliary data
+ encode_auxiliary_data = EncodeAuxiliaryData(
+ bc.id,
+ bc.num_of_aux_data,
+ bc.profile,
+ velocity_set=bc.velocity_set,
+ precision_policy=bc.precision_policy,
+ compute_backend=bc.compute_backend,
+ )
+
+ # Encode the auxiliary data in f_1
+ f_1 = encode_auxiliary_data(f_1, bc_mask, missing_mask)
+ bc.is_initialized_with_aux_data = True
+ return f_1
@Operator.register_backend(ComputeBackend.JAX)
@partial(jit, static_argnums=(0,))
@@ -173,11 +263,11 @@ def jax_implementation_pull(self, f_0, f_1, bc_mask, missing_mask, omega, timest
feq = self.equilibrium(rho, u)
# Apply collision
- f_post_collision = self.collision(f_post_stream, feq, rho, u, omega)
+ f_post_collision = self.collision(f_post_stream, feq, omega)
# Apply collision type boundary conditions
for bc in self.boundary_conditions:
- f_post_collision = bc.update_bc_auxilary_data(f_post_stream, f_post_collision, bc_mask, missing_mask)
+ f_post_collision = bc.assemble_auxiliary_data(f_post_stream, f_post_collision, bc_mask, missing_mask)
if bc.implementation_step == ImplementationStep.COLLISION:
f_post_collision = bc(
f_post_stream,
@@ -210,11 +300,11 @@ def jax_implementation_push(self, f_0, f_1, bc_mask, missing_mask, omega, timest
feq = self.equilibrium(rho, u)
# Apply collision
- f_post_collision = self.collision(f_post_stream, feq, rho, u, omega)
+ f_post_collision = self.collision(f_post_stream, feq, omega)
# Apply collision type boundary conditions
for bc in self.boundary_conditions:
- f_post_collision = bc.update_bc_auxilary_data(f_post_stream, f_post_collision, bc_mask, missing_mask)
+ f_post_collision = bc.update_bc_auxiliary_data(f_post_stream, f_post_collision, bc_mask, missing_mask)
if bc.implementation_step == ImplementationStep.COLLISION:
f_post_collision = bc(
f_post_stream,
@@ -247,27 +337,23 @@ def _construct_warp(self):
_f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
_missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8)
_opp_indices = self.velocity_set.opp_indices
+ lattice_central_index = self.velocity_set.center_index
# Read the list of bc_to_id created upon instantiation
bc_to_id = boundary_condition_registry.bc_to_id
- id_to_bc = boundary_condition_registry.id_to_bc
# Gather IDs of ExtrapolationOutflowBC boundary conditions
extrapolation_outflow_bc_ids = []
for bc_name, bc_id in bc_to_id.items():
if bc_name.startswith("ExtrapolationOutflowBC"):
extrapolation_outflow_bc_ids.append(bc_id)
- # Group active boundary conditions
- active_bcs = set(boundary_condition_registry.id_to_bc[bc.id] for bc in self.boundary_conditions)
-
- _opp_indices = self.velocity_set.opp_indices
@wp.func
def apply_bc(
index: Any,
timestep: Any,
_boundary_id: Any,
- missing_mask: Any,
+ _missing_mask: Any,
f_0: Any,
f_1: Any,
f_pre: Any,
@@ -281,39 +367,33 @@ def apply_bc(
if is_post_streaming:
if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.STREAMING):
if _boundary_id == wp.static(self.boundary_conditions[i].id):
- f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)
+ f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, _missing_mask, f_0, f_1, f_pre, f_post)
else:
if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.COLLISION):
if _boundary_id == wp.static(self.boundary_conditions[i].id):
- f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post)
+ f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, _missing_mask, f_0, f_1, f_pre, f_post)
if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids):
if _boundary_id == wp.static(self.boundary_conditions[i].id):
- f_result = wp.static(self.boundary_conditions[i].update_bc_auxilary_data)(
- index, timestep, missing_mask, f_0, f_1, f_pre, f_post
+ f_result = wp.static(self.boundary_conditions[i].assemble_auxiliary_data)(
+ index, timestep, _missing_mask, f_0, f_1, f_pre, f_post
)
return f_result
@wp.func
def get_thread_data(
f0_buffer: wp.array4d(dtype=Any),
- f1_buffer: wp.array4d(dtype=Any),
missing_mask: wp.array4d(dtype=Any),
index: Any,
):
# Read thread data for populations
_f0_thread = _f_vec()
- _f1_thread = _f_vec()
_missing_mask = _missing_mask_vec()
for l in range(self.velocity_set.q):
# q-sized vector of pre-streaming populations
_f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1], index[2]])
- _f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1], index[2]])
- if missing_mask[l, index[0], index[1], index[2]]:
- _missing_mask[l] = wp.uint8(1)
- else:
- _missing_mask[l] = wp.uint8(0)
+ _missing_mask[l] = missing_mask[l, index[0], index[1], index[2]]
- return _f0_thread, _f1_thread, _missing_mask
+ return _f0_thread, _missing_mask
@wp.func
def apply_aux_recovery_bc(
@@ -321,25 +401,28 @@ def apply_aux_recovery_bc(
_boundary_id: Any,
_missing_mask: Any,
f_0: Any,
- _f1_thread: Any,
+ f_1: Any,
):
# Note:
# In XLB, the BC auxiliary data (e.g. prescribed values of pressure or normal velocity) are stored in (i) central index of f_1 and/or
# (ii) missing directions of f_1. Some BCs may or may not need all these available storage space. This function checks whether
# the BC needs recovery of auxiliary data and then recovers the information for the next iteration (due to buffer swapping) by
- # writting the thread values of f_1 (i.e._f1_thread) into f_0.
+ # writting the values of f_1 into f_0.
# Unroll the loop over boundary conditions
for i in range(wp.static(len(self.boundary_conditions))):
if wp.static(self.boundary_conditions[i].needs_aux_recovery):
if _boundary_id == wp.static(self.boundary_conditions[i].id):
- # Perform the swapping of data
- # (i) Recover the values stored in the central index of f_1
- f_0[0, index[0], index[1], index[2]] = self.store_dtype(_f1_thread[0])
- # (ii) Recover the values stored in the missing directions of f_1
- for l in range(1, self.velocity_set.q):
- if _missing_mask[l] == wp.uint8(1):
- f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]])
+ for l in range(self.velocity_set.q):
+ # Perform the swapping of data
+ if l == lattice_central_index:
+ # (i) Recover the values stored in the central index of f_1
+ f_0[l, index[0], index[1], index[2]] = self.store_dtype(f_1[l, index[0], index[1], index[2]])
+ elif _missing_mask[l] == wp.uint8(1):
+ # (ii) Recover the values stored in the missing directions of f_1
+ f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(
+ f_1[_opp_indices[l], index[0], index[1], index[2]]
+ )
@wp.kernel
def kernel(
@@ -354,13 +437,13 @@ def kernel(
index = wp.vec3i(i, j, k)
_boundary_id = bc_mask[0, index[0], index[1], index[2]]
- if _boundary_id == wp.uint8(255):
+ if _boundary_id == wp.uint8(BC_SOLID):
return
# Apply streaming
_f_post_stream = self.stream.warp_functional(f_0, index)
- _f0_thread, _f1_thread, _missing_mask = get_thread_data(f_0, f_1, missing_mask, index)
+ _f0_thread, _missing_mask = get_thread_data(f_0, missing_mask, index)
_f_post_collision = _f0_thread
# Apply post-streaming boundary conditions
@@ -368,13 +451,13 @@ def kernel(
_rho, _u = self.macroscopic.warp_functional(_f_post_stream)
_feq = self.equilibrium.warp_functional(_rho, _u)
- _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u, omega)
+ _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, omega)
# Apply post-collision boundary conditions
_f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False)
# Apply auxiliary recovery for boundary conditions (swapping)
- apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0, _f1_thread)
+ apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0, f_1)
# Store the result in f_1
for l in range(self.velocity_set.q):
@@ -390,3 +473,188 @@ def warp_implementation(self, f_0, f_1, bc_mask, missing_mask, omega, timestep):
dim=f_0.shape[1:],
)
return f_0, f_1
+
+ def _construct_neon(self):
+ import neon
+
+ # Set local constants
+ _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
+ _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8)
+ _opp_indices = self.velocity_set.opp_indices
+ lattice_central_index = self.velocity_set.center_index
+
+ # Read the list of bc_to_id created upon instantiation
+ bc_to_id = boundary_condition_registry.bc_to_id
+
+ # Gather IDs of ExtrapolationOutflowBC boundary conditions
+ extrapolation_outflow_bc_ids = []
+ for bc_name, bc_id in bc_to_id.items():
+ if bc_name.startswith("ExtrapolationOutflowBC"):
+ extrapolation_outflow_bc_ids.append(bc_id)
+
+ @wp.func
+ def apply_bc(
+ index: Any,
+ timestep: Any,
+ _boundary_id: Any,
+ _missing_mask: Any,
+ f_0: Any,
+ f_1: Any,
+ f_pre: Any,
+ f_post: Any,
+ is_post_streaming: bool,
+ ):
+ f_result = f_post
+
+ # Unroll the loop over boundary conditions
+ for i in range(wp.static(len(self.boundary_conditions))):
+ if is_post_streaming:
+ if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.STREAMING):
+ if _boundary_id == wp.static(self.boundary_conditions[i].id):
+ f_result = wp.static(self.boundary_conditions[i].neon_functional)(index, timestep, _missing_mask, f_0, f_1, f_pre, f_post)
+ else:
+ if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.COLLISION):
+ if _boundary_id == wp.static(self.boundary_conditions[i].id):
+ f_result = wp.static(self.boundary_conditions[i].neon_functional)(index, timestep, _missing_mask, f_0, f_1, f_pre, f_post)
+ if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids):
+ if _boundary_id == wp.static(self.boundary_conditions[i].id):
+ f_result = wp.static(self.boundary_conditions[i].assemble_auxiliary_data)(
+ index, timestep, _missing_mask, f_0, f_1, f_pre, f_post
+ )
+ return f_result
+
+ @wp.func
+ def neon_get_thread_data(
+ f0_pn: Any,
+ missing_mask_pn: Any,
+ index: Any,
+ ):
+ # Read thread data for populations
+ _f0_thread = _f_vec()
+ _missing_mask = _missing_mask_vec()
+ for l in range(self.velocity_set.q):
+ # q-sized vector of pre-streaming populations
+ _f0_thread[l] = self.compute_dtype(wp.neon_read(f0_pn, index, l))
+ _missing_mask[l] = wp.neon_read(missing_mask_pn, index, l)
+
+ return _f0_thread, _missing_mask
+
+ @wp.func
+ def neon_apply_aux_recovery_bc(
+ index: Any,
+ _boundary_id: Any,
+ _missing_mask: Any,
+ f_0_pn: Any,
+ f_1_pn: Any,
+ ):
+ # Note:
+ # In XLB, the BC auxiliary data (e.g. prescribed values of pressure or normal velocity) are stored in (i) central index of f_1 and/or
+ # (ii) missing directions of f_1. Some BCs may or may not need all these available storage space. This function checks whether
+ # the BC needs recovery of auxiliary data and then recovers the information for the next iteration (due to buffer swapping) by
+ # writting the values of f_1 into f_0.
+
+ # Unroll the loop over boundary conditions
+ for i in range(wp.static(len(self.boundary_conditions))):
+ if wp.static(self.boundary_conditions[i].needs_aux_recovery):
+ if _boundary_id == wp.static(self.boundary_conditions[i].id):
+ for l in range(self.velocity_set.q):
+ # Perform the swapping of data
+ if l == lattice_central_index:
+ # (i) Recover the values stored in the central index of f_1
+ _f1_thread = wp.neon_read(f_1_pn, index, l)
+ wp.neon_write(f_0_pn, index, l, self.store_dtype(_f1_thread))
+ elif _missing_mask[l] == wp.uint8(1):
+ # (ii) Recover the values stored in the missing directions of f_1
+ _f1_thread = wp.neon_read(f_1_pn, index, _opp_indices[l])
+ wp.neon_write(f_0_pn, index, _opp_indices[l], self.store_dtype(_f1_thread))
+
+ @neon.Container.factory(name="nse_stepper")
+ def container(
+ f_0_fd: Any,
+ f_1_fd: Any,
+ bc_mask_fd: Any,
+ missing_mask_fd: Any,
+ omega: Any,
+ timestep: int,
+ ):
+ def nse_stepper_ll(loader: neon.Loader):
+ loader.set_grid(bc_mask_fd.get_grid())
+
+ f_0_pn = loader.get_read_handle(
+ f_0_fd,
+ operation=neon.Loader.Operation.stencil,
+ discretization=neon.Loader.Discretization.lattice,
+ )
+ bc_mask_pn = loader.get_read_handle(bc_mask_fd)
+ missing_mask_pn = loader.get_read_handle(missing_mask_fd)
+
+ f_1_pn = loader.get_write_handle(f_1_fd)
+
+ @wp.func
+ def nse_stepper_cl(index: Any):
+ _boundary_id = wp.neon_read(bc_mask_pn, index, 0)
+ if _boundary_id == wp.uint8(BC_SOLID):
+ return
+ # Apply streaming
+ _f_post_stream = self.stream.neon_functional(f_0_pn, index)
+
+ _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index)
+ _f_post_collision = _f0_thread
+
+ # Apply post-streaming boundary conditions
+ _f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_collision, _f_post_stream, True)
+
+ _rho, _u = self.macroscopic.neon_functional(_f_post_stream)
+ _feq = self.equilibrium.neon_functional(_rho, _u)
+ _f_post_collision = self.collision.neon_functional(_f_post_stream, _feq, omega)
+
+ # Apply post-collision boundary conditions
+ _f_post_collision = apply_bc(
+ index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_stream, _f_post_collision, False
+ )
+
+ # Apply auxiliary recovery for boundary conditions (swapping)
+ neon_apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0_pn, f_1_pn)
+
+ # Store the result in f_1
+ for l in range(self.velocity_set.q):
+ wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_collision[l]))
+
+ loader.declare_kernel(nse_stepper_cl)
+
+ return nse_stepper_ll
+
+ return None, container
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_launch(self, f_0, f_1, bc_mask, missing_mask, omega, timestep):
+ if timestep == 0:
+ self.prepare_skeleton(f_0, f_1, bc_mask, missing_mask, omega)
+ self.sk[self.sk_iter].run()
+ self.sk_iter = (self.sk_iter + 1) % 2
+ return f_0, f_1
+
+ def prepare_skeleton(self, f_0, f_1, bc_mask, missing_mask, omega):
+ """Build the Neon odd/even skeletons for double-buffered time stepping."""
+ grid = f_0.get_grid()
+ bk = grid.backend
+ self.neon_skeleton = {"odd": {}, "even": {}}
+ self.neon_skeleton["odd"]["container"] = self.neon_container(f_0, f_1, bc_mask, missing_mask, omega, 0)
+ self.neon_skeleton["even"]["container"] = self.neon_container(f_1, f_0, bc_mask, missing_mask, omega, 1)
+ # check if 'occ' is a valid key
+ if "occ" not in self.backend_config:
+ occ = neon.SkeletonConfig.OCC.none()
+ else:
+ occ = self.backend_config["occ"]
+ # check that occ is of type neon.SkeletonConfig.OCC
+ if not isinstance(occ, neon.SkeletonConfig.OCC):
+ print(type(occ))
+ raise ValueError("occ must be of type neon.SkeletonConfig.OCC")
+
+ for key in self.neon_skeleton:
+ self.neon_skeleton[key]["app"] = [self.neon_skeleton[key]["container"]]
+ self.neon_skeleton[key]["skeleton"] = neon.Skeleton(backend=bk)
+ self.neon_skeleton[key]["skeleton"].sequence(name="mres_nse_stepper", containers=self.neon_skeleton[key]["app"], occ=occ)
+
+ self.sk = [self.neon_skeleton["odd"]["skeleton"], self.neon_skeleton["even"]["skeleton"]]
+ self.sk_iter = 0
diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py
index 247fa5a9..7cfe2d19 100644
--- a/xlb/operator/stream/stream.py
+++ b/xlb/operator/stream/stream.py
@@ -1,4 +1,9 @@
-# Base class for all streaming operators
+"""
+Streaming operator for the Lattice Boltzmann Method.
+
+Implements the pull-scheme propagation step: each voxel reads populations
+from its lattice neighbours according to the velocity-set directions.
+"""
from functools import partial
import jax.numpy as jnp
@@ -11,8 +16,14 @@
class Stream(Operator):
- """
- Base class for all streaming operators. This is used for pulling the distribution
+ """Pull-scheme streaming operator.
+
+ Propagates distribution functions by reading each population from the
+ upstream neighbour along the corresponding lattice direction. Periodic
+ boundaries are applied automatically when a pull index falls outside
+ the domain (Warp backend only; JAX uses ``jnp.roll``).
+
+ Supports JAX, Warp, and Neon backends.
"""
@Operator.register_backend(ComputeBackend.JAX)
@@ -112,3 +123,32 @@ def warp_implementation(self, f_0, f_1):
dim=f_0.shape[1:],
)
return f_1
+
+ def _construct_neon(self):
+ # Set local constants TODO: This is a hack and should be fixed with warp update
+ _c = self.velocity_set.c
+ _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype)
+
+ # Construct the funcional to get streamed indices
+ @wp.func
+ def functional(
+ f: Any,
+ index: Any,
+ ):
+ # Pull the distribution function
+ _f = _f_vec()
+ for l in range(self.velocity_set.q):
+ # Get pull offset
+ ngh = wp.neon_ngh_idx(wp.int8(-_c[0, l]), wp.int8(-_c[1, l]), wp.int8(-_c[2, l]))
+ unused_is_valid = wp.bool(False)
+
+ # Read the distribution function from the neighboring cell in the pull direction
+ _f[l] = self.compute_dtype(wp.neon_read_ngh(f, index, ngh, l, self.store_dtype(0), unused_is_valid))
+ return _f
+
+ return functional, None
+
+ @Operator.register_backend(ComputeBackend.NEON)
+ def neon_implementation(self, f_0, f_1):
+ # raise exception as this feature is not implemented yet
+ raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.")
diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py
index 7d31c8a3..32a6d567 100644
--- a/xlb/precision_policy.py
+++ b/xlb/precision_policy.py
@@ -1,11 +1,18 @@
-# Enum for precision policy
+"""
+Precision and precision-policy enumerations for XLB.
+
+:class:`Precision` maps symbolic precisions to Warp and JAX dtypes.
+:class:`PrecisionPolicy` pairs a *compute* precision (used during
+arithmetic) with a *store* precision (used in memory), enabling
+mixed-precision simulations.
+"""
from enum import Enum, auto
-import jax.numpy as jnp
-import warp as wp
class Precision(Enum):
+ """Scalar precision levels with Warp and JAX dtype accessors."""
+
FP64 = auto()
FP32 = auto()
FP16 = auto()
@@ -14,6 +21,8 @@ class Precision(Enum):
@property
def wp_dtype(self):
+ import warp as wp
+
if self == Precision.FP64:
return wp.float64
elif self == Precision.FP32:
@@ -29,6 +38,8 @@ def wp_dtype(self):
@property
def jax_dtype(self):
+ import jax.numpy as jnp
+
if self == Precision.FP64:
return jnp.float64
elif self == Precision.FP32:
@@ -44,6 +55,12 @@ def jax_dtype(self):
class PrecisionPolicy(Enum):
+ """Mixed-precision policy pairing compute and store precisions.
+
+ The naming convention is ````, e.g. ``FP32FP16``
+ computes in FP32 and stores results in FP16.
+ """
+
FP64FP64 = auto()
FP64FP32 = auto()
FP64FP16 = auto()
@@ -81,9 +98,13 @@ def store_precision(self):
raise ValueError("Invalid precision policy")
def cast_to_compute_jax(self, array):
+ import jax.numpy as jnp
+
compute_precision = self.compute_precision
return jnp.array(array, dtype=compute_precision.jax_dtype)
def cast_to_store_jax(self, array):
+ import jax.numpy as jnp
+
store_precision = self.store_precision
return jnp.array(array, dtype=store_precision.jax_dtype)
diff --git a/xlb/utils/__init__.py b/xlb/utils/__init__.py
index 735cfad8..b7bdbf9a 100644
--- a/xlb/utils/__init__.py
+++ b/xlb/utils/__init__.py
@@ -6,9 +6,12 @@
rotate_geometry,
voxelize_stl,
axangle2mat,
+ ToJAX,
+ UnitConvertor,
save_usd_vorticity,
save_usd_q_criterion,
update_usd_lagrangian_parts,
plot_object_placement,
colorize_scalars,
)
+from .mesher import make_cuboid_mesh, MultiresIO
diff --git a/xlb/utils/mesher.py b/xlb/utils/mesher.py
new file mode 100644
index 00000000..8b849b87
--- /dev/null
+++ b/xlb/utils/mesher.py
@@ -0,0 +1,941 @@
+"""
+Multi-resolution mesh utilities.
+
+Provides geometry preparation and I/O for multi-resolution LBM simulations:
+
+* :func:`make_cuboid_mesh` — builds a strongly-balanced cuboid mesh hierarchy
+ from an STL file and a sequence of domain multipliers.
+* :func:`prepare_sparsity_pattern` — converts level data into the sparsity
+ arrays required by :func:`multires_grid_factory`.
+* :class:`MultiresIO` — exports multi-resolution Neon field data to HDF5 /
+ XDMF, 2-D slice images, and 1-D line profiles.
+"""
+
+import numpy as np
+import trimesh
+from typing import Any, Optional
+
+import warp as wp
+from xlb.utils.utils import UnitConvertor
+
+
+def adjust_bbox(cuboid_max, cuboid_min, voxel_size_up):
+ """
+ Adjust the bounding box to the nearest points of one level finer grid that encloses the desired region.
+
+ Args:
+ cuboid_min (np.ndarray): Desired minimum coordinates of the bounding box.
+ cuboid_max (np.ndarray): Desired maximum coordinates of the bounding box.
+ voxel_size_up (float): Voxel size of one level higher (finer) grid.
+
+ Returns:
+ tuple: (adjusted_min, adjusted_max) snapped to grid points of one level higher.
+ """
+ adjusted_min = np.round(cuboid_min / voxel_size_up) * voxel_size_up
+ adjusted_max = np.round(cuboid_max / voxel_size_up) * voxel_size_up
+ return adjusted_min, adjusted_max
+
+
+def prepare_sparsity_pattern(level_data):
+ """
+ Prepare the sparsity pattern for the multiresolution grid based on the level data. "level_data" is expected to be formatted as in
+ the output of "make_cuboid_mesh".
+ """
+ num_levels = len(level_data)
+ level_origins = []
+ sparsity_pattern = []
+ for lvl in range(num_levels):
+ # Get the level mask from the level data
+ level_mask = level_data[lvl][0]
+
+ # Ensure level_0 is contiguous int32
+ level_mask = np.ascontiguousarray(level_mask, dtype=np.int32)
+
+ # Append the padded level mask to the sparsity pattern
+ sparsity_pattern.append(level_mask)
+
+ # Get the origin for this level
+ level_origins.append(level_data[lvl][2])
+
+ return sparsity_pattern, level_origins
+
+
+def make_cuboid_mesh(voxel_size, cuboids, stl_filename):
+ """
+ Create a strongly-balanced multi-level cuboid mesh with a sequence of bounding boxes.
+ Outputs mask arrays that are set to True only in regions not covered by finer levels.
+
+ Args:
+ voxel_size (float): Voxel size of the finest grid .
+ cuboids (list): List of multipliers defining each level's domain.
+ stl_name (str): Path to the STL file.
+
+ Returns:
+ list: Level data with mask arrays, voxel sizes, origins, and levels.
+ """
+ # Load the mesh and get its bounding box
+ mesh = trimesh.load_mesh(stl_filename, process=False)
+ assert not mesh.is_empty, "Loaded mesh is empty or invalid."
+
+ mesh_vertices = mesh.vertices
+ min_bound = mesh_vertices.min(axis=0)
+ max_bound = mesh_vertices.max(axis=0)
+ partSize = max_bound - min_bound
+
+ level_data = []
+ adjusted_bboxes = []
+ max_voxel_size = voxel_size * pow(2, (len(cuboids) - 1))
+ # Step 1: Generate all levels and store their data
+ for level in range(len(cuboids)):
+ # Compute desired bounding box for this level
+ cuboid_min = np.array(
+ [
+ min_bound[0] - cuboids[level][0] * partSize[0],
+ min_bound[1] - cuboids[level][2] * partSize[1],
+ min_bound[2] - cuboids[level][4] * partSize[2],
+ ],
+ dtype=float,
+ )
+
+ cuboid_max = np.array(
+ [
+ max_bound[0] + cuboids[level][1] * partSize[0],
+ max_bound[1] + cuboids[level][3] * partSize[1],
+ max_bound[2] + cuboids[level][5] * partSize[2],
+ ],
+ dtype=float,
+ )
+
+ # Set voxel size for this level
+ voxel_size_level = max_voxel_size / pow(2, level)
+
+ # Adjust bounding box to align with one level up (finer grid)
+ if level > 0:
+ voxel_level_up = max_voxel_size / pow(2, level - 1)
+ else:
+ voxel_level_up = voxel_size_level
+ adjusted_min, adjusted_max = adjust_bbox(cuboid_max, cuboid_min, voxel_level_up)
+
+ xmin, ymin, zmin = adjusted_min
+ xmax, ymax, zmax = adjusted_max
+
+ # Compute number of voxels based on level-specific voxel size
+ nx = int(np.round((xmax - xmin) / voxel_size_level))
+ ny = int(np.round((ymax - ymin) / voxel_size_level))
+ nz = int(np.round((zmax - zmin) / voxel_size_level))
+ print(f"Domain {nx}, {ny}, {nz} Origin {adjusted_min} Voxel Size {voxel_size_level} Voxel Level Up {voxel_level_up}")
+
+ voxel_matrix = np.ones((nx, ny, nz), dtype=bool)
+
+ origin = adjusted_min
+ level_data.append((voxel_matrix, voxel_size_level, origin, level))
+ adjusted_bboxes.append((adjusted_min, adjusted_max))
+
+ # Step 2: Adjust coarser levels to exclude regions covered by finer levels
+ for k in range(len(level_data) - 1): # Exclude the finest level
+ # Current level's data
+ voxel_matrix_k = level_data[k][0]
+ origin_k = level_data[k][2]
+ voxel_size_k = level_data[k][1]
+ nx, ny, nz = voxel_matrix_k.shape
+
+ # Next finer level's bounding box
+ adjusted_min_k1, adjusted_max_k1 = adjusted_bboxes[k + 1]
+
+ # Compute index ranges in level k that overlap with level k+1's bounding box
+ # Use epsilon (1e-10) to handle floating-point precision
+ i_start = max(0, int(np.ceil((adjusted_min_k1[0] - origin_k[0] - 1e-10) / voxel_size_k)))
+ i_end = min(nx, int(np.floor((adjusted_max_k1[0] - origin_k[0] + 1e-10) / voxel_size_k)))
+ j_start = max(0, int(np.ceil((adjusted_min_k1[1] - origin_k[1] - 1e-10) / voxel_size_k)))
+ j_end = min(ny, int(np.floor((adjusted_max_k1[1] - origin_k[1] + 1e-10) / voxel_size_k)))
+ k_start = max(0, int(np.ceil((adjusted_min_k1[2] - origin_k[2] - 1e-10) / voxel_size_k)))
+ k_end = min(nz, int(np.floor((adjusted_max_k1[2] - origin_k[2] + 1e-10) / voxel_size_k)))
+
+ # Set overlapping region to zero
+ voxel_matrix_k[i_start:i_end, j_start:j_end, k_start:k_end] = 0
+
+ # Step 3 Convert to Indices from STL units
+ num_levels = len(level_data)
+ level_data = [(dr, int(v / voxel_size), np.round(dOrigin / v).astype(int), num_levels - 1 - l) for dr, v, dOrigin, l in level_data]
+
+ return list(reversed(level_data))
+
+
+class MultiresIO(object):
+ """I/O helper for multi-resolution Neon field data.
+
+ Converts hierarchical Neon ``mGrid`` fields into merged unstructured
+ hexahedral meshes and exports them as HDF5 + XDMF (for ParaView),
+ 2-D slice PNG images, or 1-D line CSV profiles.
+
+ The constructor precomputes the merged geometry (coordinates,
+ connectivity, centroids) and allocates intermediate Warp fields so
+ that repeated exports only need to transfer data from the Neon fields.
+ """
+
+ def __init__(
+ self,
+ field_name_cardinality_dict,
+ levels_data,
+ unit_convertor: UnitConvertor = None,
+ offset: Optional[tuple] = (0.0, 0.0, 0.0),
+ store_precision=None,
+ ):
+ """
+ Initialize the MultiresIO object.
+
+ Parameters
+ ----------
+ field_name_cardinality_dict : dict
+ A dictionary mapping field names to their cardinalities.
+ Example: {'velocity_x': 1, 'velocity_y': 1, 'velocity': 3, 'density': 1}
+ levels_data : list of tuples
+ Each tuple contains (data, voxel_size, origin, level).
+ unit_convertor : UnitConvertor
+ An instance of the UnitConvertor class for unit conversions.
+ offset : tuple, optional
+ Offset to be applied to the coordinates.
+ store_precision : str, optional
+ The precision policy for storing data.
+ """
+ # Set the unit convertor object
+ self.unit_convertor = unit_convertor
+
+ # Process the multires geometry and extract coordinates and connectivity in the coordinate system of the finest level
+ coordinates, connectivity, level_id_field, total_cells = self.process_geometry(levels_data)
+
+ # Ensure that coordinates and connectivity are not empty
+ assert coordinates.size != 0, "Error: No valid data to process. Check the input levels_data."
+
+ # Merge duplicate points
+ coordinates, connectivity = self._merge_duplicates(coordinates, connectivity, levels_data)
+
+ # Transform coordinates to physical units and apply offset if provided
+ coordinates = self._transform_coordinates(coordinates, offset)
+
+ # Assign to self
+ self.field_name_cardinality_dict = field_name_cardinality_dict
+ self.levels_data = levels_data
+ self.coordinates = coordinates
+ self.connectivity = connectivity
+ self.level_id_field = level_id_field
+ self.total_cells = total_cells
+ self.centroids = np.mean(coordinates[connectivity], axis=1)
+
+ # Set the default precision policy if not provided
+ from xlb import DefaultConfig
+
+ if store_precision is None:
+ self.store_precision = DefaultConfig.default_precision_policy.store_precision
+ self.store_dtype = DefaultConfig.default_precision_policy.store_precision.wp_dtype
+
+ # Prepare and allocate the inputs for the NEON container
+ self.field_warp_dict, self.origin_list = self._prepare_container_inputs()
+
+ # Construct the NEON container for exporting multi-resolution data
+ self.container = self._construct_neon_container()
+
+ def process_geometry(self, levels_data):
+ """Build merged coordinates and connectivity from all levels.
+
+ Returns
+ -------
+ coordinates : np.ndarray, shape (N, 3)
+ Vertex positions (8 per active voxel, before deduplication).
+ connectivity : np.ndarray, shape (M, 8)
+ Hexahedral connectivity (one row per active voxel).
+ level_id_field : np.ndarray, shape (M,)
+ Grid level index for each cell.
+ total_cells : int
+ Total number of active voxels across all levels.
+ """
+ num_voxels_per_level = [np.sum(data) for data, _, _, _ in levels_data]
+ num_points_per_level = [8 * nv for nv in num_voxels_per_level]
+ point_id_offsets = np.cumsum([0] + num_points_per_level[:-1])
+
+ all_corners = []
+ all_connectivity = []
+ level_id_field = []
+ total_cells = 0
+
+ for level_idx, (data, voxel_size, origin, level) in enumerate(levels_data):
+ origin = origin * voxel_size
+ corners_list, conn_list = self._process_level(data, voxel_size, origin, point_id_offsets[level_idx])
+
+ if corners_list:
+ print(f"\tProcessing level {level}: Voxel size {voxel_size}, Origin {origin}, Shape {data.shape}")
+ all_corners.extend(corners_list)
+ all_connectivity.extend(conn_list)
+ num_cells = sum(c.shape[0] for c in conn_list)
+ level_id_field.extend([level] * num_cells)
+ total_cells += num_cells
+ else:
+ print(f"\tSkipping level {level} (no unique data)")
+
+ # Stacking coordinates and connectivity
+ coordinates = np.concatenate(all_corners, axis=0).astype(np.float32)
+ connectivity = np.concatenate(all_connectivity, axis=0).astype(np.int32)
+ level_id_field = np.array(level_id_field, dtype=np.uint8)
+
+ return coordinates, connectivity, level_id_field, total_cells
+
+ def _process_level(self, data, voxel_size, origin, point_id_offset):
+ """
+ Given a voxel grid, returns all corners and connectivity in NumPy for this resolution level.
+ """
+ true_indices = np.argwhere(data)
+ if true_indices.size == 0:
+ return [], []
+
+ max_voxels_per_chunk = 268_435_450
+ chunks = np.array_split(true_indices, max(1, (len(true_indices) + max_voxels_per_chunk - 1) // max_voxels_per_chunk))
+
+ all_corners = []
+ all_connectivity = []
+ pid_offset = point_id_offset
+
+ for chunk in chunks:
+ if chunk.size == 0:
+ continue
+ corners, connectivity = self._process_voxel_chunk(chunk, np.asarray(origin, dtype=np.float32), voxel_size, pid_offset)
+ all_corners.append(corners)
+ all_connectivity.append(connectivity)
+ pid_offset += len(chunk) * 8
+
+ return all_corners, all_connectivity
+
+ def _process_voxel_chunk(self, true_indices, origin, voxel_size, point_id_offset):
+ """
+ Given a set of voxel indices, returns 8 corners and connectivity for each cube using NumPy.
+ """
+ true_indices = np.asarray(true_indices, dtype=np.float32)
+ mins = origin + true_indices * voxel_size
+ offsets = np.array(
+ [
+ [0, 0, 0],
+ [1, 0, 0],
+ [1, 1, 0],
+ [0, 1, 0],
+ [0, 0, 1],
+ [1, 0, 1],
+ [1, 1, 1],
+ [0, 1, 1],
+ ],
+ dtype=np.float32,
+ )
+
+ corners = (mins[:, None, :] + offsets[None, :, :] * voxel_size).reshape(-1, 3).astype(np.float32)
+ base_ids = point_id_offset + np.arange(len(true_indices), dtype=np.int32) * 8
+ connectivity = (base_ids[:, None] + np.arange(8, dtype=np.int32)).astype(np.int32)
+
+ return corners, connectivity
+
+ def save_xdmf(self, h5_filename, xmf_filename, total_cells, num_points, fields={}):
+ """Write an XDMF descriptor that references the companion HDF5 file."""
+ # Generate an XDMF file to accompany the HDF5 file
+ print(f"\tGenerating XDMF file: {xmf_filename}")
+ hdf5_rel_path = h5_filename.split("/")[-1]
+ with open(xmf_filename, "w") as xmf:
+ xmf.write(f'''
+
+
+
+
+
+
+ {hdf5_rel_path}:/Mesh/Connectivity
+
+
+
+
+ {hdf5_rel_path}:/Mesh/Points
+
+
+
+
+ {hdf5_rel_path}:/Mesh/Level
+
+
+ ''')
+ for field_name in fields.keys():
+ xmf.write(f'''
+
+
+ {hdf5_rel_path}:/Fields/{field_name}
+
+
+ ''')
+ xmf.write("""
+
+
+
+ """)
+ print("\tXDMF file written successfully")
+ return
+
+ def save_hdf5_file(self, filename, coordinates, connectivity, level_id_field, fields_data, compression="gzip", compression_opts=0):
+ """Write the processed mesh data to an HDF5 file.
+ Parameters
+ ----------
+ filename : str
+ The name of the output HDF5 file.
+ coordinates : numpy.ndarray
+ An array of all coordinates.
+ connectivity : numpy.ndarray
+ An array of all connectivity data.
+ level_id_field : numpy.ndarray
+ An array of all level data.
+ fields_data : dict
+ A dictionary of all field data.
+ compression : str, optional
+ The compression method to use for the HDF5 file.
+ compression_opts : int, optional
+ The compression options to use for the HDF5 file.
+ """
+ import h5py
+
+ with h5py.File(filename + ".h5", "w") as f:
+ f.create_dataset("/Mesh/Points", data=coordinates, compression=compression, compression_opts=compression_opts, chunks=True)
+ f.create_dataset(
+ "/Mesh/Connectivity",
+ data=connectivity,
+ compression=compression,
+ compression_opts=compression_opts,
+ chunks=True,
+ )
+ f.create_dataset("/Mesh/Level", data=level_id_field, compression=compression, compression_opts=compression_opts)
+ fg = f.create_group("/Fields")
+ for fname, fdata in fields_data.items():
+ fg.create_dataset(fname, data=fdata.astype(np.float32), compression=compression, compression_opts=compression_opts, chunks=True)
+
+ def _merge_duplicates(self, coordinates, connectivity, levels_data):
+ """Deduplicate vertices shared between adjacent voxels.
+
+ Uses spatial hashing (grid-snapped coordinates) processed in
+ chunks to keep memory bounded.
+ """
+ # Merging duplicate points
+ tolerance = 0.01
+ chunk_size = 10_000_000 # Adjust based on GPU memory
+ num_points = coordinates.shape[0]
+ unique_points = []
+ mapping = np.zeros(num_points, dtype=np.int32)
+ unique_idx = 0
+
+ # Get the grid shape of computational box at the finest level from the levels_data
+ num_levels = len(levels_data)
+ grid_shape_finest = np.array(levels_data[-1][0].shape) * 2 ** (num_levels - 1)
+
+ for start in range(0, num_points, chunk_size):
+ end = min(start + chunk_size, num_points)
+ coords_chunk = coordinates[start:end]
+
+ # Simple hashing: grid coordinates as tuple keys
+ grid_coords = np.round(coords_chunk / tolerance).astype(np.int64)
+ hash_keys = grid_coords[:, 0] + grid_coords[:, 1] * grid_shape_finest[0] + grid_coords[:, 2] * grid_shape_finest[0] * grid_shape_finest[1]
+ unique_hash, inverse = np.unique(hash_keys, return_inverse=True)
+ unique_hash, unique_indices, inverse = np.unique(hash_keys, return_index=True, return_inverse=True)
+ unique_chunk = coords_chunk[unique_indices]
+
+ unique_points.append(unique_chunk)
+ mapping[start:end] = inverse + unique_idx
+ unique_idx += len(unique_hash)
+
+ coordinates = np.concatenate(unique_points)
+ connectivity = mapping[connectivity]
+ return coordinates, connectivity
+
+ def _transform_coordinates(self, coordinates, offset):
+ """Convert lattice coordinates to physical units and apply offset."""
+ offset = np.array(offset, dtype=np.float32)
+ if self.unit_convertor is not None:
+ coordinates = self.unit_convertor.length_to_physical(coordinates)
+ return coordinates + offset
+
+ def _prepare_container_inputs(self):
+ """Allocate dense Warp fields used as staging buffers for Neon-to-NumPy transfer."""
+ # load necessary modules
+ from xlb.compute_backend import ComputeBackend
+ from xlb.grid import grid_factory
+
+ # Get the number of levels from the levels_data
+ num_levels = len(self.levels_data)
+
+ # Prepare lists to hold warp fields and origins allocated for each level
+ field_warp_dict = {}
+ origin_list = []
+ for field_name, cardinality in self.field_name_cardinality_dict.items():
+ field_warp_dict[field_name] = []
+ for level in range(num_levels):
+ # get the shape of the grid at this level
+ box_shape = self.levels_data[level][0].shape
+
+ # Use the warp backend to create dense fields to be written in multi-res NEON fields
+ grid_dense = grid_factory(box_shape, compute_backend=ComputeBackend.WARP)
+ field_warp_dict[field_name].append(grid_dense.create_field(cardinality=cardinality, dtype=self.store_precision))
+ origin_list.append(wp.vec3i(*([int(x) for x in self.levels_data[level][2]])))
+
+ return field_warp_dict, origin_list
+
+ def _construct_neon_container(self):
+ """
+ Constructs a NEON container for exporting multi-resolution data to HDF5.
+ This container will be used to transfer multi-resolution NEON fields into stacked warp fields.
+ """
+ import neon
+
+ @neon.Container.factory(name="HDF5MultiresExporter")
+ def container(
+ field_neon: Any,
+ field_warp: Any,
+ origin: Any,
+ level: Any,
+ ):
+ def launcher(loader: neon.Loader):
+ loader.set_mres_grid(field_neon.get_grid(), level)
+ field_neon_hdl = loader.get_mres_read_handle(field_neon)
+ refinement = 2**level
+
+ @wp.func
+ def kernel(index: Any):
+ cIdx = wp.neon_global_idx(field_neon_hdl, index)
+ # Get local indices by dividing the global indices (associated with the finest level) by 2^level
+ # Subtract the origin to get the local indices in the warp field
+ lx = wp.neon_get_x(cIdx) // refinement - origin[0]
+ ly = wp.neon_get_y(cIdx) // refinement - origin[1]
+ lz = wp.neon_get_z(cIdx) // refinement - origin[2]
+
+ # write the values to the warp field
+ cardinality = field_warp.shape[0]
+ for card in range(cardinality):
+ field_warp[card, lx, ly, lz] = self.store_dtype(wp.neon_read(field_neon_hdl, index, card))
+
+ loader.declare_kernel(kernel)
+
+ return launcher
+
+ return container
+
+ def get_fields_data(self, field_neon_dict):
+ """
+ Extracts and prepares the fields data from the NEON fields for export.
+ """
+ # Check if the field_neon_dict is empty
+ if not field_neon_dict:
+ return {}
+
+ # Ensure that this operator is called on multires grids
+ grid_mres = next(iter(field_neon_dict.values())).get_grid()
+ assert grid_mres.name == "mGrid", f"Operation {self.__class__.__name__} is only applicable to multi-resolution cases!"
+
+ for field_name in field_neon_dict.keys():
+ assert field_name in self.field_name_cardinality_dict.keys(), (
+ f"Field {field_name} is not provided in the instantiation of the MultiresIO class!"
+ )
+
+ # number of levels
+ num_levels = grid_mres.num_levels
+ assert num_levels == len(self.levels_data), "Error: Inconsistent number of levels!"
+
+ # Prepare the fields dictionary to be written by transfering multi-res NEON fields into stacked warp fields and then numpy arrays
+ fields_data = {}
+ for field_name, cardinality in self.field_name_cardinality_dict.items():
+ if field_name not in field_neon_dict:
+ continue
+ for card in range(cardinality):
+ fields_data[f"{field_name}_{card}"] = []
+
+ # Iterate over each field and level to fill the dictionary with numpy fields
+ for field_name, cardinality in self.field_name_cardinality_dict.items():
+ if field_name not in field_neon_dict:
+ continue
+ for level in range(num_levels):
+ # Create the container and run it to fill the warp fields
+ c = self.container(field_neon_dict[field_name], self.field_warp_dict[field_name][level], self.origin_list[level], level)
+ c.run(0, container_runtime=neon.Container.ContainerRuntime.neon)
+
+ # Ensure all operations are complete before converting to JAX and Numpy arrays
+ wp.synchronize()
+
+ # Convert the warp fields to numpy arrays and use level's mask to filter the data
+ mask = self.levels_data[level][0]
+ field_np = self.field_warp_dict[field_name][level].numpy()
+ for card in range(cardinality):
+ field_np_card = field_np[card][mask]
+ fields_data[f"{field_name}_{card}"].append(field_np_card)
+
+ # Concatenate all field data
+ for field_name in fields_data.keys():
+ fields_data[field_name] = np.concatenate(fields_data[field_name])
+ assert fields_data[field_name].size == self.total_cells, f"Error: Field {field_name} size mismatch!"
+
+ # Unit conversion if applicable
+ if self.unit_convertor is not None:
+ if "velocity" in field_name.lower():
+ fields_data[field_name] = self.unit_convertor.velocity_to_physical(fields_data[field_name])
+ elif "density" in field_name.lower():
+ fields_data[field_name] = self.unit_convertor.density_to_physical(fields_data[field_name])
+ elif "pressure" in field_name.lower():
+ fields_data[field_name] = self.unit_convertor.pressure_to_physical(fields_data[field_name])
+ # Add more physical quantities as needed
+
+ return fields_data
+
+ def to_hdf5(self, output_filename, field_neon_dict, compression="gzip", compression_opts=0):
+ """
+ Export the multi-resolution mesh data to an HDF5 file.
+ Parameters
+ ----------
+ output_filename : str
+ The name of the output HDF5 file (without extension).
+ field_neon_dict : a dictionary of neon mGrid Fields
+ Eg. The NEON fields containing velocity and density data as { "velocity": velocity_neon, "density": density_neon}
+ compression : str, optional
+ The compression method to use for the HDF5 file.
+ compression_opts : int, optional
+ The compression options to use for the HDF5 file.
+ """
+ import time
+
+ # Get the fields data from the NEON fields
+ fields_data = self.get_fields_data(field_neon_dict)
+
+ # Save XDMF file
+ self.save_xdmf(output_filename + ".h5", output_filename + ".xmf", self.total_cells, len(self.coordinates), fields_data)
+
+ # Writing HDF5 file
+ print("\tWriting HDF5 file")
+ tic_write = time.perf_counter()
+ self.save_hdf5_file(output_filename, self.coordinates, self.connectivity, self.level_id_field, fields_data, compression, compression_opts)
+ toc_write = time.perf_counter()
+ print(f"\tHDF5 file written in {toc_write - tic_write:0.1f} seconds")
+
+ def to_slice_image(
+ self,
+ output_filename,
+ field_neon_dict,
+ plane_point,
+ plane_normal,
+ slice_thickness=1.0,
+ bounds=[0, 1, 0, 1],
+ grid_res=512,
+ cmap=None,
+ component=None,
+ show_axes=False,
+ show_colorbar=False,
+ **kwargs,
+ ):
+ """
+ Export an arbitrary-plane slice from unstructured point data to PNG.
+
+ Parameters
+ ----------
+ output_filename : str
+ Output PNG filename (without extension).
+ field_neon_dict : dict
+ A dictionary of NEON fields containing the data to be plotted.
+ Example: {"velocity": velocity_neon, "density": density_neon}
+ plane_point : array_like
+ A point [x, y, z] on the plane.
+ plane_normal : array_like
+ Plane normal vector [nx, ny, nz].
+ slice_thickness : float
+ How thick (in units of the coordinate system) the slice should be.
+ grid_resolution : tuple
+ Resolution of output image (pixels in plane u, v directions).
+ grid_size : tuple
+ Physical size of slice grid (width, height).
+ cmap : str
+ Matplotlib colormap.
+ """
+ # Get the fields data from the NEON fields
+ assert len(field_neon_dict.keys()) == 1, "Error: This function is designed to plot a single field at a time."
+ fields_data = self.get_fields_data(field_neon_dict)
+
+ # Check if the component is within the valid range
+ if component is None:
+ print("\tCreating slice image of the field magnitude!")
+ cell_data = list(fields_data.values())
+ squared = [comp**2 for comp in cell_data]
+ cell_data = np.sqrt(sum(squared))
+ field_name = list(fields_data.keys())[0].split("_")[0] + "_magnitude"
+ else:
+ assert component < max(self.field_name_cardinality_dict.values()), (
+ f"Error: Component {component} is out of range for the provided fields."
+ )
+ print(f"\tCreating slice image for component {component} of the input field!")
+ field_name = list(fields_data.keys())[component]
+ cell_data = fields_data[field_name]
+
+ # Plot each field in the dictionary
+ self._to_slice_image_single_field(
+ f"{output_filename}_{field_name}",
+ cell_data,
+ plane_point,
+ plane_normal,
+ slice_thickness=slice_thickness,
+ bounds=bounds,
+ grid_res=grid_res,
+ cmap=cmap,
+ show_axes=show_axes,
+ show_colorbar=show_colorbar,
+ **kwargs,
+ )
+ print(f"\tSlice image for field {field_name} saved as {output_filename}.png")
+
+ def _to_slice_image_single_field(
+ self,
+ output_filename,
+ field_data,
+ plane_point,
+ plane_normal,
+ slice_thickness,
+ bounds,
+ grid_res,
+ cmap,
+ show_axes,
+ show_colorbar,
+ **kwargs,
+ ):
+ """
+ Helper function to create a slice image for a single field.
+ """
+ from matplotlib import cm
+ import numpy as np
+ import matplotlib.pyplot as plt
+ from scipy.spatial import cKDTree
+
+ # field data are associated with the cells centers
+ cell_values = field_data
+
+ # get the normalized plane normal
+ plane_normal = np.asarray(np.abs(plane_normal))
+ n = plane_normal / np.linalg.norm(plane_normal)
+
+ # Compute signed distances of each cell center to the plane
+ plane_point *= plane_normal
+ sdf = np.dot(self.centroids - plane_point, n)
+
+ # Filter: cells with centroid near plane
+ mask = np.abs(sdf) <= slice_thickness / 2
+ if not np.any(mask):
+ raise ValueError("No cells intersect the plane within thickness.")
+
+ # Project centroids to plane
+ centroids_slice = self.centroids[mask]
+ sdf_slice = sdf[mask]
+ proj = centroids_slice - np.outer(sdf_slice, n)
+
+ values = cell_values[mask]
+
+ # Build in-plane basis
+ if np.allclose(n, [1, 0, 0]):
+ u1 = np.array([0, 1, 0])
+ else:
+ u1 = np.array([1, 0, 0])
+ u2 = np.abs(np.cross(n, u1))
+
+ local_x = np.dot(proj - plane_point, u1)
+ local_y = np.dot(proj - plane_point, u2)
+
+ # Define extent of the plot
+ xmin, xmax, ymin, ymax = local_x.min(), local_x.max(), local_y.min(), local_y.max()
+ Lx = xmax - xmin
+ Ly = ymax - ymin
+ extent = np.array([xmin + bounds[0] * Lx, xmin + bounds[1] * Lx, ymin + bounds[2] * Ly, ymin + bounds[3] * Ly])
+ mask_bounds = (extent[0] <= local_x) & (local_x <= extent[1]) & (extent[2] <= local_y) & (local_y <= extent[3])
+
+ if cmap is None:
+ cmap = cm.nipy_spectral
+
+ # Adjust vertical resolution based on bounds
+ bounded_x_min = local_x[mask_bounds].min()
+ bounded_x_max = local_x[mask_bounds].max()
+ bounded_y_min = local_y[mask_bounds].min()
+ bounded_y_max = local_y[mask_bounds].max()
+ width_x = bounded_x_max - bounded_x_min
+ height_y = bounded_y_max - bounded_y_min
+ aspect_ratio = height_y / width_x
+ grid_resY = max(1, int(np.round(grid_res * aspect_ratio)))
+
+ # Create grid
+ grid_x = np.linspace(bounded_x_min, bounded_x_max, grid_res)
+ grid_y = np.linspace(bounded_y_min, bounded_y_max, grid_resY)
+ xv, yv = np.meshgrid(grid_x, grid_y, indexing="xy")
+
+ # Fast KDTree-based interpolation
+ points = np.column_stack((local_x[mask_bounds], local_y[mask_bounds]))
+ tree = cKDTree(points)
+
+ # Query points
+ query_points = np.column_stack((xv.ravel(), yv.ravel()))
+
+ # Find k nearest neighbors for smoother interpolation
+ k = min(4, len(points)) # Use 4 neighbors or less if not enough points
+ distances, indices = tree.query(query_points, k=k, workers=-1) # -1 uses all cores
+
+ # Inverse distance weighting
+ epsilon = 1e-10
+ weights = 1.0 / (distances + epsilon)
+ weights /= weights.sum(axis=1, keepdims=True)
+
+ # Interpolate values
+ neighbor_values = values[mask_bounds][indices]
+ grid_field = (neighbor_values * weights).sum(axis=1).reshape(grid_resY, grid_res)
+
+ # Plot
+ if show_colorbar or show_axes:
+ dpi = 300
+ plt.imshow(
+ grid_field,
+ extent=[bounded_x_min, bounded_x_max, bounded_y_min, bounded_y_max],
+ cmap=cmap,
+ origin="lower",
+ aspect="equal",
+ **kwargs,
+ )
+ if show_colorbar:
+ plt.colorbar()
+ if not show_axes:
+ plt.axis("off")
+ plt.savefig(output_filename + ".png", dpi=dpi, bbox_inches="tight", pad_inches=0)
+ plt.close()
+ else:
+ plt.imsave(output_filename + ".png", grid_field, cmap=cmap, origin="lower")
+
+ def to_line(
+ self,
+ output_filename,
+ field_neon_dict,
+ start_point,
+ end_point,
+ resolution,
+ component=None,
+ radius=1.0,
+ **kwargs,
+ ):
+ """
+ Extract field data along a line between start_point and end_point and save to a CSV file.
+
+ This function performs two main steps:
+ 1. Extracts field data from field_neon_dict, handling components or computing magnitude.
+ 2. Interpolates the field values along a line defined by start_point and end_point,
+ then saves the results (coordinates and field values) to a CSV file.
+
+ Parameters
+ ----------
+ output_filename : str
+ The name of the output CSV file (without extension). Example: "velocity_profile".
+ field_neon_dict : dict
+ A dictionary containing the field data to extract, with a single key-value pair.
+ The key is the field name (e.g., "velocity"), and the value is the NEON data object
+ containing the field values. Example: {"velocity": velocity_neon}.
+ start_point : array_like
+ The starting point of the line in 3D space (e.g., [x0, y0, z0]).
+ Units must match the coordinate system used in the class (voxel units if untransformed,
+ or model units if scale/offset are applied).
+ end_point : array_like
+ The ending point of the line in 3D space (e.g., [x1, y1, z1]).
+ Units must match the coordinate system used in the class.
+ resolution : int
+ The number of points along the line where the field will be interpolated.
+ Example: 100 for 100 evenly spaced points.
+ component : int, optional
+ The specific component of the field to extract (e.g., 0 for x-component, 1 for y-component).
+ If None, the magnitude of the field is computed. Default is None.
+ radius : int
+ The specified distance (in units of the coordinate system) to prefilter and query for line plot
+
+ Returns
+ -------
+ None
+ The function writes the output to a CSV file and prints a confirmation message.
+
+ Notes
+ -----
+ - The output CSV file will contain columns: 'x', 'y', 'z', and the value of the field name (e.g., 'velocity_x' or 'velocity_magnitude').
+ """
+
+ # Get the fields data from the NEON fields
+ assert len(field_neon_dict.keys()) == 1, "Error: This function is designed to plot a single field at a time."
+ fields_data = self.get_fields_data(field_neon_dict)
+
+ # Check if the component is within the valid range
+ if component is None:
+ print("\tCreating csv plot of the field magnitude!")
+ cell_data = list(fields_data.values())
+ squared = [comp**2 for comp in cell_data]
+ cell_data = np.sqrt(sum(squared))
+ field_name = list(fields_data.keys())[0].split("_")[0] + "_magnitude"
+
+ else:
+ assert component < max(self.field_name_cardinality_dict.values()), (
+ f"Error: Component {component} is out of range for the provided fields."
+ )
+ print(f"\tCreating csv plot for component {component} of the input field!")
+ field_name = list(fields_data.keys())[component]
+ cell_data = fields_data[field_name]
+
+ # Plot each field in the dictionary
+ self._to_line_field(
+ f"{output_filename}_{field_name}",
+ cell_data,
+ start_point,
+ end_point,
+ resolution,
+ radius=radius,
+ **kwargs,
+ )
+ print(f"\tLine Plot for field {field_name} saved as {output_filename}.csv")
+
+ def _to_line_field(
+ self,
+ output_filename,
+ cell_data,
+ start_point,
+ end_point,
+ resolution,
+ radius,
+ **kwargs,
+ ):
+ """
+ Helper function to create a line plot for a single field.
+ """
+ import numpy as np
+
+ # cell_points = self.coordinates[self.connectivity] # Shape: (M, K, 3), where M is num cells, K is nodes per cell
+ # centroids = np.mean(cell_points, axis=1) # Shape: (M, 3)
+ centroids = self.centroids
+ p0 = np.array(start_point, dtype=np.float32)
+ p1 = np.array(end_point, dtype=np.float32)
+
+ # direction and parameter t for each centroid
+ d = p1 - p0
+ L = np.linalg.norm(d)
+ d_unit = d / L
+ v = centroids - p0
+ t = v.dot(d_unit)
+ closest = p0 + np.outer(t, d_unit)
+ perp_dist = np.linalg.norm(centroids - closest, axis=1)
+
+ # optionally mask to [0,L] or a small perp-radius
+ mask = (t >= 0) & (t <= L) & (perp_dist <= radius)
+ t, data = t[mask], cell_data[mask]
+
+ # sort by t
+ idx = np.argsort(t)
+ t_sorted = t[idx]
+ data_sorted = data[idx]
+
+ # target samples
+ t_line = np.linspace(0, L, resolution)
+
+ # 1D linear interpolation
+ vals_line = np.interp(t_line, t_sorted, data_sorted, left=np.nan, right=np.nan)
+
+ # reconstruct (x,y,z)
+ line_xyz = p0[None, :] + t_line[:, None] * d_unit[None, :]
+
+ # vectorized CSV dump
+ out = np.hstack([line_xyz, vals_line[:, None]])
+ np.savetxt(output_filename + ".csv", out, delimiter=",", header="x,y,z,value", comments="")
diff --git a/xlb/utils/utils.py b/xlb/utils/utils.py
index a42dea98..d707fb5c 100644
--- a/xlb/utils/utils.py
+++ b/xlb/utils/utils.py
@@ -1,3 +1,11 @@
+"""
+General-purpose utilities for XLB.
+
+Includes helpers for field downsampling, VTK/image/USD I/O, geometry
+rotation, STL voxelization, Neon-to-JAX field transfer, and
+physical-to-lattice unit conversion.
+"""
+
import numpy as np
import matplotlib.pylab as plt
from matplotlib import cm
@@ -83,7 +91,7 @@ def save_image(fld, timestep=None, prefix=None, **kwargs):
if len(fld.shape) > 3:
raise ValueError("The input field should be 2D!")
if len(fld.shape) == 3:
- fld = np.sqrt(fld[0, ...] ** 2 + fld[0, ...] ** 2)
+ fld = np.sqrt(fld[0, ...] ** 2 + fld[1, ...] ** 2 + fld[2, ...] ** 2)
plt.clf()
kwargs.pop("cmap", None)
@@ -237,7 +245,7 @@ def rotate_geometry(indices, origin, axis, angle):
return tuple(jnp.rint(indices_rotated).astype("int32").T)
-def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None, pitch=None):
+def voxelize_stl(stl_filename, length_lbm_unit=None, transformation_matrix=None, pitch=None):
"""
Converts an STL file to a voxelized mesh.
@@ -247,7 +255,7 @@ def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None,
The name of the STL file to be voxelized.
length_lbm_unit : float, optional
The unit length in LBM. Either this or 'pitch' must be provided.
- tranformation_matrix : array-like, optional
+ transformation_matrix : array-like, optional
A transformation matrix to be applied to the mesh before voxelization.
pitch : float, optional
The pitch of the voxel grid. Either this or 'length_lbm_unit' must be provided.
@@ -267,8 +275,8 @@ def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None,
raise ValueError("Either 'length_lbm_unit' or 'pitch' must be provided!")
mesh = trimesh.load_mesh(stl_filename, process=False)
length_phys_unit = mesh.extents.max()
- if tranformation_matrix is not None:
- mesh.apply_transform(tranformation_matrix)
+ if transformation_matrix is not None:
+ mesh.apply_transform(transformation_matrix)
if pitch is None:
pitch = length_phys_unit / length_lbm_unit
mesh_voxelized = mesh.voxelized(pitch=pitch)
@@ -319,6 +327,190 @@ def axangle2mat(axis, angle, is_normalized=False):
])
+class ToJAX(object):
+ """Convert a Neon field to a JAX array via an intermediate Warp grid."""
+
+ def __init__(self, field_name, field_cardinality, grid_shape, store_precision=None):
+ """Initialise the Neon-to-JAX converter.
+
+ Parameters
+ ----------
+ field_name : str
+ The name of the field to be converted.
+ field_cardinality : int
+ The cardinality of the field to be converted.
+ grid_shape : tuple
+ The shape of the grid on which the field is defined.
+ store_precision : Precision, optional
+ Storage precision. Defaults to the global config value.
+ """
+ from xlb.compute_backend import ComputeBackend
+ from xlb.grid import grid_factory
+ from xlb import DefaultConfig
+
+ # Assign to self
+ self.field_name = field_name
+ self.field_cardinality = field_cardinality
+ self.grid_shape = grid_shape
+ self.compute_backend = DefaultConfig.default_backend
+ self.velocity_set = DefaultConfig.velocity_set
+ if store_precision is None:
+ self.store_precision = DefaultConfig.default_precision_policy.store_precision
+ self.store_dtype = DefaultConfig.default_precision_policy.store_precision.wp_dtype
+
+ if self.compute_backend == ComputeBackend.NEON:
+ # Allocate warp fields for copying neon fields
+ # Use the warp backend to create dense fields for copying NEON dGrid fields
+ grid_dense = grid_factory(grid_shape, compute_backend=ComputeBackend.WARP)
+ self.warp_field = grid_dense.create_field(cardinality=self.field_cardinality, dtype=self.store_precision)
+
+ def copy_neon_to_warp(self, neon_field):
+ """Convert a dense neon field to a warp field by copying."""
+ import warp as wp
+ import neon
+ from typing import Any
+
+ assert neon_field.get_grid().name == "dGrid", "to_warp only supports dense grids"
+ _d = self.velocity_set.d
+
+ @neon.Container.factory("to_warp")
+ def container(src_field: Any, dst_field: Any, cardinality: wp.int32):
+ def loading_step(loader: neon.Loader):
+ loader.set_grid(src_field.get_grid())
+ src_pn = loader.get_read_handle(src_field)
+
+ @wp.func
+ def cloning(gridIdx: Any):
+ cIdx = wp.neon_global_idx(src_pn, gridIdx)
+ gx = wp.neon_get_x(cIdx)
+ gy = wp.neon_get_y(cIdx)
+ gz = wp.neon_get_z(cIdx)
+
+ # XLB is flattening the z dimension in 3D, while neon uses the y dimension
+ if _d == 2:
+ gy, gz = gz, gy
+
+ for card in range(cardinality):
+ value = wp.neon_read(src_pn, gridIdx, card)
+ dst_field[card, gx, gy, gz] = value
+
+ loader.declare_kernel(cloning)
+
+ return loading_step
+
+ cardinality = neon_field.cardinality
+ c = container(neon_field, self.warp_field, cardinality)
+ c.run(0)
+ wp.synchronize()
+ return self.warp_field
+
+ def __call__(self, field):
+ from xlb.compute_backend import ComputeBackend
+ import warp as wp
+
+ if self.compute_backend == ComputeBackend.JAX:
+ return field
+ elif self.compute_backend == ComputeBackend.WARP:
+ return wp.to_jax(field)
+ elif self.compute_backend == ComputeBackend.NEON:
+ assert field.cardinality == self.field_cardinality, (
+ f"Field cardinality mismatch! Expected {self.field_cardinality}, got {field.cardinality}!"
+ )
+ return wp.to_jax(self.copy_neon_to_warp(field))
+
+ else:
+ raise ValueError("Unsupported compute backend!")
+
+
+class UnitConvertor(object):
+ def __init__(
+ self,
+ velocity_lbm_unit: float,
+ velocity_physical_unit: float,
+ voxel_size_physical_unit: float,
+ density_physical_unit: float = 1.2041,
+ pressure_physical_unit: float = 1.101325e5,
+ ):
+ """
+ Initialize the UnitConvertor object.
+
+ Parameters
+ ----------
+ velocity_lbm_unit : float
+ The reference velocity in lattice Boltzmann units.
+ velocity_physical_unit : float
+ The reference velocity in physical units (e.g., m/s).
+ voxel_size_physical_unit : float
+ The size of a voxel in physical units (e.g., meters).
+ density_physical_unit : float, optional
+ The reference density in physical units (e.g., kg/m^3). Default is 1.2041 (density of air at room temperature).
+ pressure_physical_unit : float, optional
+ The reference pressure in physical units (e.g., Pascals). Default is 1.101325e5 (atmospheric pressure at sea level).
+ """
+
+ self.voxel_size = voxel_size_physical_unit
+ self.velocity_lbm_unit = velocity_lbm_unit
+ self.velocity_phys_unit = velocity_physical_unit
+
+ # Reference density and pressure in physical units
+ self.reference_density = density_physical_unit
+ self.referece_pressure = pressure_physical_unit
+
+ @property
+ def time_step_physical(self):
+ return self.voxel_size * self.velocity_lbm_unit / self.velocity_phys_unit
+
+ @property
+ def reference_length(self):
+ return self.voxel_size
+
+ @property
+ def reference_time(self):
+ return self.time_step_physical
+
+ @property
+ def reference_velocity(self):
+ return self.reference_length / self.reference_time
+
+ def length_to_lbm(self, length_phys):
+ return length_phys / self.reference_length
+
+ def length_to_physical(self, length_lbm):
+ return length_lbm * self.reference_length
+
+ def time_to_lbm(self, time_phys):
+ return time_phys / self.reference_time
+
+ def time_to_physical(self, time_lbm):
+ return time_lbm * self.reference_time
+
+ def density_to_lbm(self, rho_phys):
+ return rho_phys / self.reference_density
+
+ def density_to_physical(self, rho_lbm):
+ return rho_lbm * self.reference_density
+
+ def velocity_to_lbm(self, velocity_phys):
+ return velocity_phys / self.reference_velocity
+
+ def velocity_to_physical(self, velocity_lbm):
+ return velocity_lbm * self.reference_velocity
+
+ def viscosity_to_lbm(self, viscosity_phys):
+ return viscosity_phys * (self.reference_time / (self.reference_length**2))
+
+ def viscosity_to_physical(self, viscosity_lbm):
+ return viscosity_lbm * (self.reference_length**2 / self.reference_time)
+
+ def pressure_to_lbm(self, pressure_phys):
+ pressure_perturbation = pressure_phys - self.reference_pressure
+ return pressure_perturbation / self.reference_density / self.reference_velocity**2
+
+ def pressure_to_physical(self, pressure_lbm):
+ pressure_perturbation = pressure_lbm - 1.0 / 3.0
+ return self.referece_pressure + pressure_perturbation * self.reference_density * (self.reference_velocity**2)
+
+
@wp.kernel
def get_color(
low: float,
diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py
index 8b8d3213..6f9634b2 100644
--- a/xlb/velocity_set/velocity_set.py
+++ b/xlb/velocity_set/velocity_set.py
@@ -1,4 +1,10 @@
-# Base Velocity Set class
+"""
+Base velocity-set class for the Lattice Boltzmann Method.
+
+Defines lattice directions, weights, and derived properties (opposite
+indices, moments, etc.) for any DdQq stencil. Backend-specific constants
+(Warp vectors, JAX arrays, Neon lattice objects) are initialised lazily.
+"""
import math
import numpy as np
@@ -44,6 +50,8 @@ def __init__(self, d, q, c, w, precision_policy, compute_backend):
# Convert properties to backend-specific format
if self.compute_backend == ComputeBackend.WARP:
self._init_warp_properties()
+ elif self.compute_backend == ComputeBackend.NEON:
+ self._init_neon_properties()
elif self.compute_backend == ComputeBackend.JAX:
self._init_jax_properties()
else:
@@ -72,6 +80,7 @@ def _init_numpy_properties(self, c, w):
self.main_indices = self._construct_main_indices()
self.right_indices = self._construct_right_indices()
self.left_indices = self._construct_left_indices()
+ self.center_index = self._get_center_index()
def _init_warp_properties(self):
"""
@@ -85,6 +94,12 @@ def _init_warp_properties(self):
self.c_float = wp.constant(wp.mat((self.d, self.q), dtype=dtype)(self._c_float))
self.qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=dtype)(self._qi))
+ def _init_neon_properties(self):
+ """
+ Convert NumPy properties to Neon-specific properties which are identical to Warp.
+ """
+ self._init_warp_properties()
+
def _init_jax_properties(self):
"""
Convert NumPy properties to JAX-specific properties.
@@ -220,6 +235,23 @@ def _construct_left_indices(self):
"""
return np.nonzero(self._c.T[:, 0] == -1)[0]
+ def _get_center_index(self):
+ """
+ This function returns the index of the center point in the lattice associated with (0,0,0)
+
+ Returns
+ -------
+ numpy.ndarray
+ The index of the zero lattice velocity.
+ """
+ arr = self._c.T
+ if self.d == 2:
+ target = np.array([0, 0])
+ else:
+ target = np.array([0, 0, 0])
+ match = np.all(arr == target, axis=1)
+ return int(np.nonzero(match)[0][0])
+
def __str__(self):
"""
This function returns the name of the lattice in the format of DxQy.