diff --git a/.gitignore b/.gitignore index 44c7f5a5..2d4f7cd5 100644 --- a/.gitignore +++ b/.gitignore @@ -155,6 +155,14 @@ checkpoints/* dist/ build/ *.egg-info/ +*.dot + +# Ignore h5 and xmf formats +*.h5 +*.xmf + +# Ignore CSV files +*.csv # USD files *.usd @@ -162,4 +170,4 @@ build/ *.usdc *.usd.gz *.usd.zip -*.usd.bz2 \ No newline at end of file +*.usd.bz2 diff --git a/README.md b/README.md index 97b2082a..65f312cc 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ # XLB: A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning -XLB is a fully differentiable 2D/3D Lattice Boltzmann Method (LBM) library that leverages hardware acceleration. It supports [JAX](https://github.com/google/jax) and [NVIDIA Warp](https://github.com/NVIDIA/warp) backends, and is specifically designed to solve fluid dynamics problems in a computationally efficient and differentiable manner. Its unique combination of features positions it as an exceptionally suitable tool for applications in physics-based machine learning. With the new Warp backend, XLB now offers state-of-the-art performance for even faster simulations. +XLB is a fully differentiable 2D/3D Lattice Boltzmann Method (LBM) library that leverages hardware acceleration. It supports [JAX](https://github.com/google/jax), [NVIDIA Warp](https://github.com/NVIDIA/warp), and [Neon](https://github.com/Autodesk/Neon) backends, and is specifically designed to solve fluid dynamics problems in a computationally efficient and differentiable manner. Its unique combination of features positions it as an exceptionally suitable tool for applications in physics-based machine learning. With the Warp backend, XLB offers state-of-the-art single-GPU performance, and with the new Neon backend it extends to multi-GPU (single-resolution). More importantly, the Neon backend provides grid refinement capabilities for multi-resolution simulations. ## Getting Started To get started with XLB, you can install it using pip. There are different installation options depending on your hardware and needs: @@ -28,6 +28,16 @@ This installation is for the JAX backend with TPU support: pip install "xlb[tpu]" ``` +### Installation with Neon support +Neon backend enables multi-GPU dense and single-GPU multi-resolution representations. Neon depends on a custom fork of warp-lang, so any existing warp installation must be removed before installing Neon. The Python interface for Neon can be installed from a pre-built wheel hosted on GitHub. Note that the wheel currently requires GLIBC >= 2.38 (e.g., Ubuntu 24.04 or later). + +```bash +pip uninstall warp-lang +pip install https://github.com/Autodesk/Neon/releases/download/v0.5.2a1/neon_gpu-0.5.2a1-cp312-cp312-linux_x86_64.whl +``` + + + ### Notes: - For Mac users: Use the basic CPU installation command as JAX's GPU support is not available on MacOS - The NVIDIA Warp backend is included in all installation options and supports CUDA automatically when available @@ -63,11 +73,36 @@ If you use XLB in your research, please cite the following paper: } ``` +If you use the grid refinement capabilities in your work, please also cite: + +``` +@inproceedings{mahmoud2024optimized, + title={Optimized {GPU} implementation of grid refinement in lattice {Boltzmann} method}, + author={Mahmoud, Ahmed H and Salehipour, Hesam and Meneghin, Massimiliano}, + booktitle={2024 IEEE International Parallel and Distributed Processing Symposium (IPDPS)}, + pages={398--407}, + year={2024}, + organization={IEEE} +} + +@inproceedings{meneghin2022neon, + title={Neon: A Multi-{GPU} Programming Model for Grid-based Computations}, + author={Meneghin, Massimiliano and Mahmoud, Ahmed H. and Jayaraman, Pradeep Kumar and Morris, Nigel J. W.}, + booktitle={Proceedings of the 36th IEEE International Parallel and Distributed Processing Symposium}, + pages={817--827}, + year={2022}, + month={june}, + doi={10.1109/IPDPS53621.2022.00084}, + url={https://escholarship.org/uc/item/9fz7k633} +} +``` + ## Key Features -- **Multiple Backend Support:** XLB now includes support for multiple backends including JAX and NVIDIA Warp, providing *state-of-the-art* performance for lattice Boltzmann simulations. Currently, only single GPU is supported for the Warp backend. +- **Multiple Backend Support:** XLB includes support for JAX, NVIDIA Warp, and Neon backends, providing *state-of-the-art* performance for lattice Boltzmann simulations. The Warp backend targets single-GPU runs, while the Neon backend enables multi-GPU single-resolution and single-GPU multi-resolution simulations. +- **Multi-Resolution Grid Refinement:** Mesh refinement with nested cuboid grids and multiple kernel-fusion strategies for optimal performance on the Neon backend. - **Integration with JAX Ecosystem:** The library can be easily integrated with JAX's robust ecosystem of machine learning libraries such as [Flax](https://github.com/google/flax), [Haiku](https://github.com/deepmind/dm-haiku), [Optax](https://github.com/deepmind/optax), and many more. - **Differentiable LBM Kernels:** XLB provides differentiable LBM kernels that can be used in differentiable physics and deep learning applications. -- **Scalability:** XLB is capable of scaling on distributed multi-GPU systems using the JAX backend, enabling the execution of large-scale simulations on hundreds of GPUs with billions of cells. +- **Scalability:** XLB is capable of scaling on distributed multi-GPU systems using the JAX backend or the Neon backend, enabling the execution of large-scale simulations on hundreds of GPUs with billions of cells. - **Support for Various LBM Boundary Conditions and Kernels:** XLB supports several LBM boundary conditions and collision kernels. - **User-Friendly Interface:** Written entirely in Python, XLB emphasizes a highly accessible interface that allows users to extend the library with ease and quickly set up and run new simulations. - **Leverages JAX Array and Shardmap:** The library incorporates the new JAX array unified array type and JAX shardmap, providing users with a numpy-like interface. This allows users to focus solely on the semantics, leaving performance optimizations to the compiler. @@ -103,7 +138,7 @@ If you use XLB in your research, please cite the following paper:

- Airflow in to, out of, and within a building (~400 million cells) + Airflow into, out of, and within a building (~400 million cells)

@@ -128,6 +163,7 @@ The stages of a fluid density field from an initial state to the emergence of th - BGK collision model (Standard LBM collision model) - KBC collision model (unconditionally stable for flows with high Reynolds number) +- Smagorinsky LES sub-grid model for turbulence modelling ### Machine Learning @@ -143,21 +179,25 @@ The stages of a fluid density field from an initial state to the emergence of th ### Compute Capabilities - Single GPU support for the Warp backend with state-of-the-art performance +- Multi-GPU support using the Neon backend with single-resolution grids +- Grid refinement support on single-GPU using the Neon backend - Distributed Multi-GPU support using the JAX backend - Mixed-Precision support (store vs compute) +- Multiple kernel-fusion performance strategies for multi-resolution simulations - Out-of-core support (coming soon) ### Output - Binary and ASCII VTK output (based on PyVista library) +- HDF5/XDMF output for multi-resolution data (with gzip compression) - In-situ rendering using [PhantomGaze](https://github.com/loliverhennigh/PhantomGaze) library - [Orbax](https://github.com/google/orbax)-based distributed asynchronous checkpointing -- Image Output +- Image Output (including multi-resolution slice images) - 3D mesh voxelizer using trimesh ### Boundary conditions -- **Equilibrium BC:** In this boundary condition, the fluid populations are assumed to be in at equilibrium. Can be used to set prescribed velocity or pressure. +- **Equilibrium BC:** In this boundary condition, the fluid populations are assumed to be at equilibrium. Can be used to set prescribed velocity or pressure. - **Full-Way Bounceback BC:** In this boundary condition, the velocity of the fluid populations is reflected back to the fluid side of the boundary, resulting in zero fluid velocity at the boundary. @@ -171,17 +211,22 @@ The stages of a fluid density field from an initial state to the emergence of th - **Interpolated Bounceback BC:** Interpolated bounce-back boundary condition for representing curved boundaries. +- **Hybrid BC:** Combines regularized and bounce-back methods with optional wall-distance interpolation for improved accuracy on curved geometries. + +- **Grad's Approximation BC:** Boundary condition based on Grad's approximation of the non-equilibrium distribution. + ## Roadmap -### Work in Progress (WIP) -*Note: Some of the work-in-progress features can be found in the branches of the XLB repository. For contributions to these features, please reach out.* +### Recently Completed - - 🌐 **Grid Refinement:** Implementing adaptive mesh refinement techniques for enhanced simulation accuracy. + - ✅ **Grid Refinement:** Multi-resolution LBM with nested cuboid grids and multiple kernel-fusion strategies via the Neon backend. - - 💾 **Out-of-Core Computations:** Enabling simulations that exceed available GPU memory, suitable for CPU+GPU coherent memory models such as NVIDIA's Grace Superchips (coming soon). + - ✅ **Multi-GPU Acceleration using [Neon](https://github.com/Autodesk/Neon) + Warp:** Multi-GPU support through Neon's data structures with Warp-based kernels for single-resolution settings. +### Work in Progress (WIP) +*Note: Some of the work-in-progress features can be found in the branches of the XLB repository. For contributions to these features, please reach out.* -- ⚡ **Multi-GPU Acceleration using [Neon](https://github.com/Autodesk/Neon) + Warp:** Using Neon's data structure for improved scaling. + - 💾 **Out-of-Core Computations:** Enabling simulations that exceed available GPU memory, suitable for CPU+GPU coherent memory models such as NVIDIA's Grace Superchips (coming soon). - 🗜️ **GPU Accelerated Lossless Compression and Decompression**: Implementing high-performance lossless compression and decompression techniques for larger-scale simulations and improved performance. diff --git a/examples/cfd/data/ahmed.json b/examples/cfd/data/ahmed.json new file mode 100644 index 00000000..59ee02e7 --- /dev/null +++ b/examples/cfd/data/ahmed.json @@ -0,0 +1,22 @@ +{ + "_comment": "Ahmed Car Model, slant - angle = 25 degree. Profiles on symmetry plane (y=0) covering entire field. Origin of coordinate system: x=0: end of the car, y=0: symmetry plane, z=0: ground plane S.Becker/H. Lienhart/C Stoots, Institute of Fluid Mechanics, University Erlangen-Nuremberg, Erlangen, Germany, Coordinates in meters need to convert to voxels, Velocity data in m/s", + "data": { + "-1.162" : { "x-velocity" : [26.995,29.825,29.182,28.488,27.703,26.988,26.456,26.163,26.190,26.523,27.083,28.033,29.131,30.429,31.747,33.036,34.268,35.354,36.312,37.083,37.770,38.484,39.033,39.447,39.839,40.086,40.268,40.380,40.451], "height" : [0.028,0.048,0.068,0.088,0.108,0.128,0.148,0.168,0.188,0.208,0.228,0.248,0.268,0.288,0.308,0.328,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.7388]}, + "-1.062" : { "x-velocity" : [30.307,28.962,25.812,21.232,15.848,10.812,7.459,6.080,5.845,6.196,7.428,10.456,15.718,22.129,28.090,32.707,35.888,37.891,39.071,39.840,40.261,40.604,40.767,40.820,40.870,40.890,40.907,40.871,40.853], "height" : [0.028,0.048,0.068,0.088,0.108,0.128,0.148,0.168,0.188,0.208,0.228,0.248,0.268,0.288,0.308,0.328,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "-0.962" : { "x-velocity" : [52.216,51.303,50.196,48.833,47.728,46.790,45.514,44.222,43.379,42.829,42.322,42.056,41.876,41.706,41.584], "height" : [0.363,0.368,0.378,0.388,0.398,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "-0.862" : { "x-velocity" : [46.589,46.538,46.228,46.033,45.810,45.554,45.056,44.369,43.789,43.275,42.789,42.344,42.148,41.913,41.720], "height" : [0.363,0.368,0.378,0.388,0.398,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "-0.562" : { "x-velocity" : [43.237,43.262,43.248,43.225,43.183,43.145,43.083,43.030,42.904,42.776,42.685,42.434,42.358,42.197,42.042], "height" : [0.363,0.368,0.378,0.388,0.398,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "-0.362" : { "x-velocity" : [44.493,44.491,44.443,44.379,44.297,44.215,44.067,43.867,43.577,43.306,43.061,42.689,42.527,42.293,42.105], "height" : [0.363,0.368,0.378,0.388,0.398,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "-0.212" : { "x-velocity" : [49.202,48.429,47.805,46.697,45.883,44.913,44.195,43.650,43.130,42.677,42.432,42.154,41.961], "height" : [0.368,0.378,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "-0.162" : { "x-velocity" : [50.511,49.784,48.894,48.103,47.468,46.322,45.563,44.581,43.933,43.383,42.905,42.505,42.293,42.042,41.863], "height" : [0.348,0.358,0.368,0.378,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "-0.112" : { "x-velocity" : [27.615,35.449,41.526,46.068,46.277,46.038,45.774,45.505,45.237,44.701,44.326,43.765,43.284,42.890,42.529,42.247,42.082,41.880,41.732], "height" : [0.318,0.323,0.328,0.338,0.348,0.358,0.368,0.378,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "-0.062" : { "x-velocity" : [22.891,27.789,32.292,36.568,39.533,41.426,42.371,42.971,43.030,43.081,43.074,43.065,43.039,42.996,42.908,42.665,42.456,42.294,42.105,41.929,41.827,41.660,41.546], "height" : [0.298,0.303,0.308,0.313,0.318,0.323,0.328,0.338,0.348,0.358,0.368,0.378,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "-0.012" : { "x-velocity" : [23.304,26.317,29.429,32.341,34.923,37.106,38.673,39.841,40.447,40.780,40.973,41.085,41.193,41.282,41.359,41.442,41.522,41.699,41.737,41.749,41.724,41.714,41.642,41.574,41.518,41.431,41.366], "height" : [0.278,0.283,0.288,0.293,0.298,0.303,0.308,0.313,0.318,0.323,0.328,0.338,0.348,0.358,0.368,0.378,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "0.038" : { "x-velocity" : [42.752,37.392,15.320,-4.501,-8.079,-8.892,-8.420,-7.027,-5.143,-2.903,-0.936,0.927,2.200,3.099,3.622,4.026,4.280,4.520,5.620,8.938,13.913,17.872,21.148,24.814,29.075,33.188,36.424,38.490,39.388,39.675,39.794,39.911,40.007,40.219,40.425,40.643,40.757,40.896,40.994,41.058,41.124,41.127,41.143,41.106,41.080], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "0.088" : { "x-velocity" : [41.859,35.830,22.660,7.745,-5.808,-12.650,-14.748,-13.756,-10.659,-6.484,-2.121,1.303,3.672,5.441,7.066,9.157,11.613,14.620,17.662,20.639,23.565,26.437,29.484,32.441,35.024,36.938,37.938,38.377,38.595,38.728,38.856,38.976,39.133,39.438,39.749,39.975,40.129,40.344,40.499,40.649,40.783,40.853,40.927,40.945,40.960], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "0.138" : { "x-velocity" : [36.223,32.501,24.752,14.281,2.799,-6.218,-10.908,-11.892,-9.708,-5.258,-0.140,4.331,7.882,10.995,13.961,16.699,19.477,22.063,24.651,27.081,29.524,31.950,34.043,35.594,36.506,37.053,37.386,37.614,37.832,38.032,38.214,38.397,38.575,38.940,39.298,39.533,39.749,40.028,40.206,40.404,40.580,40.691,40.803,40.858,40.921], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "0.188" : { "x-velocity" : [29.417,27.755,23.967,18.261,11.662,5.405,0.676,-0.652,0.937,4.261,7.958,11.427,14.366,17.138,19.735,22.151,24.577,26.883,29.165,31.111,32.781,34.072,34.893,35.524,35.974,36.329,36.604,36.872,37.138,37.402,37.673,37.900,38.112,38.518,38.829,39.088,39.326,39.639,39.871,40.096,40.275,40.423,40.523,40.603,40.687], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "0.238" : { "x-velocity" : [24.405,24.168,22.782,20.196,16.970,13.937,12.137,11.757,12.851,14.649,16.780,18.995,21.070,23.335,25.280,27.468,29.262,30.832,32.133,33.102,33.856,34.473,34.922,35.340,35.698,36.039,36.336,36.629,36.906,37.193,37.454,37.691,37.929,38.329,38.611,38.875,39.126,39.414,39.677,39.917,40.097,40.259,40.380,40.478,40.568], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]}, + "0.288" : { "x-velocity" : [21.489,22.225,22.127,21.456,20.404,19.743,19.541,19.909,21.002,22.381,24.018,25.670,27.421,28.998,30.371,31.523,32.406,33.111,33.670,34.155,34.532,34.893,35.240,35.567,35.875,36.158,36.437,36.708,36.974,37.230,37.473,37.709,37.932,38.266,38.515,38.773,39.008,39.270,39.562,39.782,39.962,40.148,40.266,40.369,40.475], "height" : [0.028,0.038,0.048,0.058,0.068,0.078,0.088,0.098,0.108,0.118,0.128,0.138,0.148,0.158,0.168,0.178,0.188,0.198,0.208,0.218,0.228,0.238,0.248,0.258,0.268,0.278,0.288,0.298,0.308,0.318,0.328,0.338,0.348,0.368,0.388,0.408,0.428,0.458,0.488,0.518,0.558,0.598,0.638,0.688,0.738]} + } +} \ No newline at end of file diff --git a/examples/cfd/flow_past_sphere_3d.py b/examples/cfd/flow_past_sphere_3d.py index 2872d7fe..1616d32a 100644 --- a/examples/cfd/flow_past_sphere_3d.py +++ b/examples/cfd/flow_past_sphere_3d.py @@ -20,7 +20,7 @@ omega = 1.6 grid_shape = (512 // 2, 128 // 2, 128 // 2) -compute_backend = ComputeBackend.WARP +compute_backend = ComputeBackend.JAX precision_policy = PrecisionPolicy.FP32FP32 velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend) u_max = 0.04 @@ -51,7 +51,7 @@ z = np.arange(grid_shape[2]) X, Y, Z = np.meshgrid(x, y, z, indexing="ij") indices = np.where((X - grid_shape[0] // 6) ** 2 + (Y - grid_shape[1] // 2) ** 2 + (Z - grid_shape[2] // 2) ** 2 < sphere_radius**2) -sphere = [tuple(indices[i]) for i in range(velocity_set.d)] +sphere = [tuple(indices[i].tolist()) for i in range(velocity_set.d)] # Define Boundary Conditions @@ -80,21 +80,25 @@ def bc_profile_jax(): return bc_profile_jax - elif compute_backend == ComputeBackend.WARP: + else: + wp_dtype = precision_policy.compute_precision.wp_dtype + H_y = wp_dtype(grid_shape[1] - 1) # Height in y direction + H_z = wp_dtype(grid_shape[2] - 1) # Height in z direction + two = wp_dtype(2.0) @wp.func def bc_profile_warp(index: wp.vec3i): # Poiseuille flow profile: parabolic velocity distribution - y = wp.float32(index[1]) - z = wp.float32(index[2]) + y = wp_dtype(index[1]) + z = wp_dtype(index[2]) # Calculate normalized distance from center - y_center = y - (H_y / 2.0) - z_center = z - (H_z / 2.0) - r_squared = (2.0 * y_center / H_y) ** 2.0 + (2.0 * z_center / H_z) ** 2.0 + y_center = y - (H_y / two) + z_center = z - (H_z / two) + r_squared = (two * y_center / H_y) ** two + (two * z_center / H_z) ** two # Parabolic profile: u = u_max * (1 - r²) - return wp.vec(u_max * wp.max(0.0, 1.0 - r_squared), length=1) + return wp.vec(wp_dtype(u_max) * wp.max(wp_dtype(0.0), wp_dtype(1.0) - r_squared), length=1) return bc_profile_warp @@ -122,15 +126,33 @@ def bc_profile_warp(index: wp.vec3i): precision_policy=precision_policy, velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX), ) +to_jax = xlb.utils.ToJAX("populations", velocity_set.q, grid_shape) + +# Setup Momentum Transfer for Force Calculation +from xlb.operator.force.momentum_transfer import MomentumTransfer + +momentum_transfer = MomentumTransfer(bc_sphere, compute_backend=compute_backend) +sphere_cross_section = np.pi * sphere_radius**2 # Post-Processing Function -def post_process(step, f_current): +def post_process(step, f_0, f_1): + wp.synchronize() + + # Compute lift and drag + boundary_force = momentum_transfer(f_0, f_1, bc_mask, missing_mask) + drag = boundary_force[0] # x-direction + lift = boundary_force[2] + cd = 2.0 * drag / (u_max**2 * sphere_cross_section) + cl = 2.0 * lift / (u_max**2 * sphere_cross_section) + print(f"CD={cd}, CL={cl}") + # Convert to JAX array if necessary - if not isinstance(f_current, jnp.ndarray): - f_current = wp.to_jax(f_current) + if not isinstance(f_0, jnp.ndarray): + f_0 = to_jax(f_0) + wp.synchronize() - rho, u = macro(f_current) + rho, u = macro(f_0) # Remove boundary cells u = u[:, 1:-1, 1:-1, 1:-1] @@ -158,9 +180,7 @@ def post_process(step, f_current): f_0, f_1 = f_1, f_0 # Swap the buffers if step % post_process_interval == 0 or step == num_steps - 1: - if compute_backend == ComputeBackend.WARP: - wp.synchronize() - post_process(step, f_0) + post_process(step, f_0, f_1) end_time = time.time() elapsed = end_time - start_time print(f"Completed step {step}. Time elapsed for {post_process_interval} steps: {elapsed:.6f} seconds.") diff --git a/examples/cfd/multires_flow_past_sphere_3d.py b/examples/cfd/multires_flow_past_sphere_3d.py new file mode 100644 index 00000000..ea510a29 --- /dev/null +++ b/examples/cfd/multires_flow_past_sphere_3d.py @@ -0,0 +1,290 @@ +""" +3D fow past a sphere with multi-resolution LBM. + +Demonstrates the multi-resolution Neon backend for a 3-D Poiseuille- +inlet flow past a sphere inside a nested cuboid multi-resolution domain. +Uses AABB-Close voxelization with halfway bounce-back on the sphere surface and +computes lift/drag via momentum transfer. +""" + +import neon +import warp as wp +import numpy as np +import time + +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.grid import multires_grid_factory +from xlb.operator.boundary_condition import ( + FullwayBounceBackBC, + HalfwayBounceBackBC, + RegularizedBC, + ExtrapolationOutflowBC, + DoNothingBC, + ZouHeBC, + HybridBC, +) +from xlb.operator.boundary_masker import MeshVoxelizationMethod +from xlb.utils.mesher import make_cuboid_mesh, prepare_sparsity_pattern +from xlb.operator.force import MultiresMomentumTransfer + + +def generate_cuboid_mesh(stl_filename, num_finest_voxels_across_part): + """ + Generate a cuboid mesh based on the provided voxel size and domain multipliers. + """ + import trimesh + import os + + # Domain multipliers for each refinement level + # First entry should be full domain size + # Domain multipliers + domainMultiplier = [ + [7, 22, 7, 7, 7, 7], # -x, x, -y, y, -z, z (sphere at 1/4 domain from inlet) + [3, 12, 5, 5, 5, 5], # -x, x, -y, y, -z, z (wake-biased) + [2, 8, 4, 4, 4, 4], + [1, 5, 2, 2, 2, 2], + # [1, 2, 1, 1, 1, 1], + # [0.4, 1, 0.4, 0.4, 0.4, 0.4], + # [0.2, 0.4, 0.2, 0.2, 0.2, 0.2], + ] + + # Load the mesh + mesh = trimesh.load_mesh(stl_filename, process=False) + assert not mesh.is_empty, ValueError("Loaded mesh is empty or invalid.") + + # Compute original bounds + # Find voxel size and sphere radius + min_bound = mesh.vertices.min(axis=0) + max_bound = mesh.vertices.max(axis=0) + partSize = max_bound - min_bound + + # smallest voxel size + voxel_size = min(partSize) / num_finest_voxels_across_part + + # Compute translation to put mesh into first octant of that domain— + shift = np.array( + [ + domainMultiplier[0][0] * partSize[0] - min_bound[0], + domainMultiplier[0][2] * partSize[1] - min_bound[1], + domainMultiplier[0][4] * partSize[2] - min_bound[2], + ], + dtype=float, + ) + + # Apply translation and save out temp stl + mesh.apply_translation(shift) + _ = mesh.vertex_normals + mesh_vertices = np.asarray(mesh.vertices) / voxel_size + mesh.export("temp.stl") + + # Mesh based on temp stl + level_data = make_cuboid_mesh( + voxel_size, + domainMultiplier, + "temp.stl", + ) + grid_shape_finest = tuple([i * 2 ** (len(level_data) - 1) for i in level_data[-1][0].shape]) + print(f"Full shape based on finest voxels size is {grid_shape_finest}") + os.remove("temp.stl") + return level_data, mesh_vertices, tuple([int(a) for a in grid_shape_finest]) + + +# -------------------------- Simulation Setup -------------------------- + +# The following parameters define the resolution of the voxelized grid +sphere_radius = 5 +num_finest_voxels_across_part = 2 * sphere_radius + +# Other setup parameters +Re = 5000.0 +compute_backend = ComputeBackend.NEON +precision_policy = PrecisionPolicy.FP32FP32 +velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend) +u_max = 0.04 +num_steps = 10000 +post_process_interval = 1000 + +# Initialize XLB +xlb.init( + velocity_set=velocity_set, + default_backend=compute_backend, + default_precision_policy=precision_policy, +) + +# Generate the cuboid mesh and sphere vertices +stl_filename = "../stl-files/sphere.stl" +level_data, sphere, grid_shape_finest = generate_cuboid_mesh(stl_filename, num_finest_voxels_across_part) + + +# Define exporter object for hdf5 output +from xlb.utils import MultiresIO + +# Define an exporter for the multiresolution data +exporter = MultiresIO({"velocity": 3, "density": 1}, level_data) + +# Prepare the sparsity pattern and origins from the level data +sparsity_pattern, level_origins = prepare_sparsity_pattern(level_data) + +# get the number of levels +num_levels = len(level_data) + +# Create the multires grid +grid = multires_grid_factory( + grid_shape_finest, + velocity_set=velocity_set, + sparsity_pattern_list=sparsity_pattern, + sparsity_pattern_origins=[neon.Index_3d(*box_origin) for box_origin in level_origins], +) + +# Define Boundary Indices +coarsest_level = grid.count_levels - 1 +box = grid.bounding_box_indices(shape=grid.level_to_shape(coarsest_level)) +box_no_edge = grid.bounding_box_indices(shape=grid.level_to_shape(coarsest_level), remove_edges=True) +inlet = box_no_edge["left"] +outlet = box_no_edge["right"] +walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(velocity_set.d)] +walls = np.unique(np.array(walls), axis=-1).tolist() + + +# Define Boundary Conditions +def bc_profile(): + """Build a Warp function for a Poiseuille parabolic inlet velocity profile.""" + assert compute_backend == ComputeBackend.NEON + # IMPORTANT NOTE: the user defined functional must be defined in terms of the indices at the finest level + _, ny, nz = grid_shape_finest + dtype = precision_policy.compute_precision.wp_dtype + H_y = dtype(ny) # Length in y direction (finest level) + H_z = dtype(nz) # Length in z direction (finest level) + two = dtype(2.0) + one = dtype(1.0) + zero = dtype(0.0) + u_max_wp = dtype(u_max) + _u_vec = wp.vec(velocity_set.d, dtype=dtype) + + @wp.func + def bc_profile_warp(index: wp.vec3i): + # Poiseuille flow profile: parabolic velocity distribution + y = dtype(index[1]) + z = dtype(index[2]) + + # Calculate normalized distance from center + y_center = y - (H_y / two) + z_center = z - (H_z / two) + r_squared = (two * y_center / H_y) ** two + (two * z_center / H_z) ** two + + # Parabolic profile: u = u_max * (1 - r²) + # Note that unlike RegularizedBC and ZouHeBC which only accept normal velocity, hybridBC accepts the full velocity vector + + # For hybridBC + # return _u_vec(u_max_wp * wp.max(zero, one - r_squared), zero, zero) + + # For Regularized and ZouHe + return wp.vec(u_max_wp * wp.max(zero, one - r_squared), length=1) + + return bc_profile_warp + + +# Convert bc indices to a list of list (first entry corresponds to the finest level) +inlet = [[] for _ in range(num_levels - 1)] + [inlet] +outlet = [[] for _ in range(num_levels - 1)] + [outlet] +walls = [[] for _ in range(num_levels - 1)] + [walls] + +# Initialize Boundary Conditions +bc_left = RegularizedBC("velocity", profile=bc_profile(), indices=inlet) +# Alternatives: +# bc_left = HybridBC(bc_method="bounceback_regularized", profile=bc_profile(), indices=inlet) +# bc_left = RegularizedBC("velocity", prescribed_value=(u_max, 0.0, 0.0), indices=inlet) +bc_walls = FullwayBounceBackBC(indices=walls) +bc_outlet = DoNothingBC(indices=outlet) +# bc_outlet = ExtrapolationOutflowBC(indices=outlet) +bc_sphere = HybridBC( + bc_method="nonequilibrium_regularized", mesh_vertices=sphere, voxelization_method=MeshVoxelizationMethod("AABB"), use_mesh_distance=True +) +# bc_sphere = HalfwayBounceBackBC(mesh_vertices=sphere, voxelization_method=MeshVoxelizationMethod('AABB')) + +boundary_conditions = [bc_walls, bc_left, bc_outlet, bc_sphere] + +# Configure the simulation relaxation time +visc = u_max * num_finest_voxels_across_part / Re +omega_finest = 1.0 / (3.0 * visc + 0.5) + +# Make initializer operator +from xlb.helper.initializers import CustomMultiresInitializer + +initializer = CustomMultiresInitializer( + bc_id=bc_outlet.id, + constant_velocity_vector=(u_max, 0.0, 0.0), + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, +) + +# Define a multi-resolution simulation manager +sim = xlb.helper.MultiresSimulationManager( + omega_finest=omega_finest, + grid=grid, + boundary_conditions=boundary_conditions, + collision_type="KBC", + initializer=initializer, + mres_perf_opt=xlb.mres_perf_optimization_type.MresPerfOptimizationType.FUSION_AT_FINEST, +) + +# Setup Momentum Transfer for Force Calculation +bc_sphere = boundary_conditions[-1] +momentum_transfer = MultiresMomentumTransfer(bc_sphere, mres_perf_opt=sim.mres_perf_opt, compute_backend=compute_backend) + + +def print_lift_drag(sim): + """Compute and print drag and lift coefficients from the simulation state.""" + boundary_force = momentum_transfer(sim.f_0, sim.f_1, sim.bc_mask, sim.missing_mask) + drag = boundary_force[0] # x-direction + lift = boundary_force[2] + sphere_cross_section = np.pi * sphere_radius**2 + u_avg = 0.5 * u_max + cd = 2.0 * drag / (u_avg**2 * sphere_cross_section) + cl = 2.0 * lift / (u_avg**2 * sphere_cross_section) + print(f"\tCD={cd}, CL={cl}") + + +# -------------------------- Simulation Loop -------------------------- + +wp.synchronize() +start_time = time.time() +for step in range(num_steps): + sim.step() + + if step % post_process_interval == 0 or step == num_steps - 1: + # # Export VTK for comparison + # tic_write = time.perf_counter() + # sim.export_macroscopic("multires_flow_over_sphere_3d_") + # toc_write = time.perf_counter() + # print(f"\tVTK file written in {toc_write - tic_write:0.1f} seconds") + + # Call the Macroscopic operator to compute macroscopic fields + wp.synchronize() + sim.macro(sim.f_0, sim.bc_mask, sim.rho, sim.u, streamId=0) + + # Call the exporter to save the current state + nx, ny, nz = grid_shape_finest + filename = f"multires_flow_past_sphere_3d_{step:04d}" + wp.synchronize() + exporter.to_hdf5(filename, {"velocity": sim.u, "density": sim.rho}, compression="gzip", compression_opts=2) + exporter.to_slice_image( + filename, + {"velocity": sim.u}, + plane_point=(nx // 2, ny // 2, nz // 2), + plane_normal=(0, 0, 1), + grid_res=256, + slice_thickness=2 ** (num_levels - 1), + bounds=(0.1, 0.6, 0.3, 0.7), + ) + + # Print lift and drag coefficients + print_lift_drag(sim) + wp.synchronize() + end_time = time.time() + elapsed = end_time - start_time + print(f"\tCompleted step {step}. Time elapsed for {post_process_interval} steps: {elapsed:.6f} seconds.") + start_time = time.time() diff --git a/examples/cfd/multires_windtunnel_3d.py b/examples/cfd/multires_windtunnel_3d.py new file mode 100644 index 00000000..d94d4357 --- /dev/null +++ b/examples/cfd/multires_windtunnel_3d.py @@ -0,0 +1,575 @@ +""" +Ahmed body aerodynamics with multi-resolution LBM. + +Simulates turbulent flow around the Ahmed body (25-degree slant angle) +using the XLB multi-resolution Neon backend. Computes drag and lift +coefficients via momentum transfer and exports HDF5/XDMF data for +post-processing. +""" + +import neon +import warp as wp +import numpy as np +import time +import os +import matplotlib.pyplot as plt +import trimesh +import shutil + +import xlb +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.grid import multires_grid_factory +from xlb.operator.boundary_condition import ( + DoNothingBC, + HybridBC, + RegularizedBC, +) +from xlb.operator.boundary_masker import MeshVoxelizationMethod +from xlb.utils.mesher import prepare_sparsity_pattern, make_cuboid_mesh, MultiresIO +from xlb.utils import UnitConvertor +from xlb.operator.force import MultiresMomentumTransfer +from xlb.helper.initializers import CustomMultiresInitializer + +wp.clear_kernel_cache() +wp.config.quiet = True + +# User Configuration +# ================= +# Physical and simulation parameters +wind_speed_lbm = 0.05 # Lattice velocity +wind_speed_mps = 38.0 # Physical inlet velocity in m/s (user input) +flow_passes = 2 # Domain flow passes +kinematic_viscosity = 1.508e-5 # Kinematic viscosity of air in m^2/s 1.508e-5 +voxel_size = 0.005 # Finest voxel size in meters + +# STL filename +stl_filename = "../stl-files/Ahmed_25_NoLegs.stl" +script_name = "Ahmed" + +# I/O settings +print_interval_percentage = 1 # Print every 1% of iterations +file_output_crossover_percentage = 10 # Crossover at 50% of iterations +num_file_outputs_pre_crossover = 20 # Outputs before crossover +num_file_outputs_post_crossover = 5 # Outputs after crossover + +# Other setup parameters +compute_backend = ComputeBackend.NEON +precision_policy = PrecisionPolicy.FP32FP32 +velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend) + + +def generate_cuboid_mesh(stl_filename, voxel_size): + """ + Alternative cuboid mesh generation based on Apolo's method with domain multipliers per level. + """ + # Domain multipliers for each refinement level + domain_multiplier = [ + [3.0, 4.0, 2.5, 2.5, 0.0, 4.0], # -x, x, -y, y, -z, z + [1.2, 1.25, 1.75, 1.75, 0.0, 1.5], + [0.8, 1.0, 1.25, 1.25, 0.0, 1.2], + [0.5, 0.65, 0.6, 0.60, 0.0, 0.6], + [0.25, 0.25, 0.25, 0.25, 0.0, 0.25], + ] + + # Load the mesh + mesh = trimesh.load_mesh(stl_filename, process=False) + if mesh.is_empty: + raise ValueError("Loaded mesh is empty or invalid.") + + # Compute original bounds + min_bound = mesh.vertices.min(axis=0) + max_bound = mesh.vertices.max(axis=0) + partSize = max_bound - min_bound + x0 = max_bound[0] # End of car for Ahmed + + # Compute translation to put mesh into first octant of the domain + stl_shift = np.array( + [ + domain_multiplier[0][0] * partSize[0] - min_bound[0], + domain_multiplier[0][2] * partSize[1] - min_bound[1], + domain_multiplier[0][4] * partSize[2] - min_bound[2], + ], + dtype=float, + ) + + # Apply translation and save out temp STL + mesh.apply_translation(stl_shift) + _ = mesh.vertex_normals + mesh_vertices = np.asarray(mesh.vertices) + mesh.export("temp.stl") + + # Generate mesh using make_cuboid_mesh + level_data = make_cuboid_mesh( + voxel_size, + domain_multiplier, + "temp.stl", + ) + + num_levels = len(level_data) + grid_shape_finest = tuple([int(i * 2 ** (num_levels - 1)) for i in level_data[-1][0].shape]) + print(f"Full shape based on finest voxel size is {grid_shape_finest}") + os.remove("temp.stl") + + return ( + level_data, + mesh_vertices, + tuple([int(a) for a in grid_shape_finest]), + stl_shift, + x0, + ) + + +# Boundary Conditions Setup +# ========================= +def setup_boundary_conditions(grid, level_data, body_vertices, wind_speed_mps): + """ + Set up boundary conditions for the simulation. + """ + # Convert wind speed to lattice units + wind_speed_lbm = unit_convertor.velocity_to_lbm(wind_speed_mps) + + left_indices = grid.boundary_indices_across_levels(level_data, box_side="left", remove_edges=True) + right_indices = grid.boundary_indices_across_levels(level_data, box_side="right", remove_edges=True) + top_indices = grid.boundary_indices_across_levels(level_data, box_side="top", remove_edges=False) + bottom_indices = grid.boundary_indices_across_levels(level_data, box_side="bottom", remove_edges=False) + front_indices = grid.boundary_indices_across_levels(level_data, box_side="front", remove_edges=False) + back_indices = grid.boundary_indices_across_levels(level_data, box_side="back", remove_edges=False) + + # Initialize boundary conditions + bc_inlet = RegularizedBC("velocity", prescribed_value=(wind_speed_lbm, 0.0, 0.0), indices=left_indices) + bc_outlet = DoNothingBC(indices=right_indices) + bc_top = HybridBC(bc_method="nonequilibrium_regularized", indices=top_indices) + bc_bottom = HybridBC(bc_method="nonequilibrium_regularized", indices=bottom_indices) + bc_front = HybridBC(bc_method="nonequilibrium_regularized", indices=front_indices) + bc_back = HybridBC(bc_method="nonequilibrium_regularized", indices=back_indices) + bc_body = HybridBC( + bc_method="nonequilibrium_regularized", + mesh_vertices=unit_convertor.length_to_lbm(body_vertices), + voxelization_method=MeshVoxelizationMethod("AABB_CLOSE", close_voxels=4), + use_mesh_distance=True, + ) + + return [bc_top, bc_bottom, bc_front, bc_back, bc_inlet, bc_outlet, bc_body] + + +# Simulation Initialization +# ========================= +def initialize_simulation( + grid, boundary_conditions, omega_finest, initializer, collision_type="KBC", mres_perf_opt=xlb.MresPerfOptimizationType.FUSION_AT_FINEST +): + """ + Initialize the multiresolution simulation manager. + """ + sim = xlb.helper.MultiresSimulationManager( + omega_finest=omega_finest, + grid=grid, + boundary_conditions=boundary_conditions, + collision_type=collision_type, + initializer=initializer, + mres_perf_opt=mres_perf_opt, + ) + return sim + + +# Utility Functions +# ================= +def compute_force_coefficients(sim, step, momentum_transfer, wind_speed_lbm, reference_area): + """ + Calculate and print lift and drag coefficients. + """ + boundary_force = momentum_transfer(sim.f_0, sim.f_1, sim.bc_mask, sim.missing_mask) + drag = boundary_force[0] + lift = boundary_force[2] + cd = 2.0 * drag / (wind_speed_lbm**2 * reference_area) + cl = 2.0 * lift / (wind_speed_lbm**2 * reference_area) + if np.isnan(cd) or np.isnan(cl): + print(f"NaN detected in coefficients at step {step}") + raise ValueError(f"NaN detected in coefficients at step {step}: Cd={cd}, Cl={cl}") + drag_values.append([cd, cl]) + return cd, cl, drag + + +def plot_force_coefficients(drag_values, output_dir, print_interval, script_name, percentile_range=(15, 85), use_log_scale=False): + """ + Plot CD and CL over time and save the plot to the output directory. + """ + drag_values_array = np.array(drag_values) + steps = np.arange(0, len(drag_values) * print_interval, print_interval) + cd_values = drag_values_array[:, 0] + cl_values = drag_values_array[:, 1] + y_min = min(np.percentile(cd_values, percentile_range[0]), np.percentile(cl_values, percentile_range[0])) + y_max = max(np.percentile(cd_values, percentile_range[1]), np.percentile(cl_values, percentile_range[1])) + padding = (y_max - y_min) * 0.1 + y_min, y_max = y_min - padding, y_max + padding + if use_log_scale: + y_min = max(y_min, 1e-6) + plt.figure(figsize=(10, 6)) + plt.plot(steps, cd_values, label="Drag Coefficient (Cd)", color="blue") + plt.plot(steps, cl_values, label="Lift Coefficient (Cl)", color="red") + plt.xlabel("Simulation Step") + plt.ylabel("Coefficient") + plt.title(f"{script_name}: Drag and Lift Coefficients Over Time") + plt.legend() + plt.grid(True) + plt.ylim(y_min, y_max) + if use_log_scale: + plt.yscale("log") + plt.savefig(os.path.join(output_dir, "drag_lift_plot.png")) + plt.close() + + +def compute_voxel_statistics(sim, bc_mask_exporter, sparsity_pattern, boundary_conditions, unit_convertor): + """ + Compute active/solid voxels, totals, lattice updates, and reference area based on simulation data. + """ + fields_data = bc_mask_exporter.get_fields_data({"bc_mask": sim.bc_mask}) + bc_mask_data = fields_data["bc_mask_0"] + level_id_field = bc_mask_exporter.level_id_field + + # Compute solid voxels per level (assuming 255 is the solid marker) + solid_voxels = [] + for lvl in range(num_levels): + level_mask = level_id_field == lvl + solid_voxels.append(np.sum(bc_mask_data[level_mask] == 255)) + + # Compute active voxels (total non-zero in sparsity minus solids) + active_voxels = [np.count_nonzero(mask) for mask in sparsity_pattern] + active_voxels = [max(0, active_voxels[lvl] - solid_voxels[lvl]) for lvl in range(num_levels)] + + # Totals + total_voxels = sum(active_voxels) + total_lattice_updates_per_step = sum(active_voxels[lvl] * (2 ** (num_levels - 1 - lvl)) for lvl in range(num_levels)) + + # Compute reference area (projected on YZ plane at finest level) + finest_level = 0 + mask_finest = level_id_field == finest_level + bc_mask_finest = bc_mask_data[mask_finest] + active_indices_finest = np.argwhere(sparsity_pattern[0]) + bc_body_id = boundary_conditions[-1].id # Assuming last BC is bc_body + solid_voxels_indices = active_indices_finest[bc_mask_finest == bc_body_id] + unique_jk = np.unique(solid_voxels_indices[:, 1:3], axis=0) + reference_area = unique_jk.shape[0] + reference_area_physical = reference_area * unit_convertor.reference_length**2 + + return { + "active_voxels": active_voxels, + "solid_voxels": solid_voxels, + "total_voxels": total_voxels, + "total_lattice_updates_per_step": total_lattice_updates_per_step, + "reference_area": reference_area, + "reference_area_physical": reference_area_physical, + } + + +def plot_data(x0, output_dir, delta_x_coarse, sim, IOexporter, prefix="Ahmed"): + """ + Ahmed Car Model, slant - angle = 25 degree + Profiles on symmetry plane (y=0) covering entire field + Origin of coordinate system: + x=0: end of the car, y=0: symmetry plane, z=0: ground plane + + S.Becker/H. Lienhart/C.Stoots + Insitute of Fluid Mechanics + University Erlangen-Nuremberg + Erlangen, Germany + Coordaintes in meters need to convert to voxels + Velocity data in m/s + """ + + def _load_sim_line(csv_path): + """ + Read a CSV exported by IOexporter.to_line without pandas. + Returns (z, Ux). + """ + # Read with header as column names + data = np.genfromtxt( + csv_path, + delimiter=",", + names=True, + autostrip=True, + dtype=None, + encoding="utf-8", + ) + if data.size == 0: + raise ValueError(f"No data in {csv_path}") + + z = np.asarray(data["z"], dtype=float) + ux = np.asarray(data["value"], dtype=float) + return z, ux + + # Load reference data + import json + + ref_data_path = "examples/cfd/data/ahmed.json" + with open(ref_data_path, "r") as file: + data = json.load(file) + + for x_str in data["data"].keys(): + # Extract reference horizontal velocity in m/s and its corresponding height in m + refX = np.array(data["data"][x_str]["x-velocity"]) + refY = np.array(data["data"][x_str]["height"]) + + # From reference x0 (rear of body) find x1 for plot + x_pos = float(x_str) + x1 = x0 + x_pos + + print(f" x1 is {x1}") + sim.macro(sim.f_0, sim.bc_mask, sim.rho, sim.u, streamId=0) + filename = os.path.join(output_dir, f"{prefix}_{x_str}") + wp.synchronize() + IOexporter.to_line( + filename, + {"velocity": sim.u}, + start_point=(x1, 0, 0), + end_point=(x1, 0, 0.8), + resolution=250, + component=0, + radius=delta_x_coarse, # needed with model units + ) + # read the CSV written by the exporter + csv_path = filename + "_velocity_0.csv" + print(f"CSV path is {csv_path}") + + try: + sim_z, sim_ux = _load_sim_line(csv_path) + except Exception as e: + print(f"Failed to read {csv_path}: {e}") + continue + + # plot reference vs simulation + plt.figure(figsize=(4.5, 6)) + plt.plot(refX, refY, "o", mfc="none", label="Experimental)") + plt.plot(sim_ux, sim_z, "-", lw=2, label="Simulation") + plt.xlim(np.min(refX) * 0.9, np.max(refX) * 1.1) + plt.ylim(np.min(refY), np.max(refY)) + plt.xlabel("Ux [m/s]") + plt.ylabel("z [m]") + plt.title(f"Velocity Plot at {x_pos:+.3f}") + plt.grid(True, alpha=0.3) + plt.legend() + plt.tight_layout() + plt.savefig(filename + ".png", dpi=150) + plt.close() + + +# Main Script +# =========== +# Initialize XLB + +xlb.init( + velocity_set=velocity_set, + default_backend=compute_backend, + default_precision_policy=precision_policy, +) + +# Generate mesh +level_data, body_vertices, grid_shape_zip, stl_shift, x0 = generate_cuboid_mesh(stl_filename, voxel_size) + +# Prepare the sparsity pattern and origins from the level data +sparsity_pattern, level_origins = prepare_sparsity_pattern(level_data) + +# Define a unit convertor +unit_convertor = UnitConvertor( + velocity_lbm_unit=wind_speed_lbm, + velocity_physical_unit=wind_speed_mps, + voxel_size_physical_unit=voxel_size, +) + +# Calculate lattice parameters +num_levels = len(level_data) +delta_x_coarse = voxel_size * 2 ** (num_levels - 1) +nu_lattice = unit_convertor.viscosity_to_lbm(kinematic_viscosity) +omega_finest = 1.0 / (3.0 * nu_lattice + 0.5) + +# Create output directory +current_dir = os.path.join(os.path.dirname(__file__)) +output_dir = os.path.join(current_dir, script_name) +if os.path.exists(output_dir): + shutil.rmtree(output_dir) +os.makedirs(output_dir) + +# Define exporter objects +field_name_cardinality_dict = {"velocity": 3, "density": 1} +h5exporter = MultiresIO( + field_name_cardinality_dict, + level_data, + offset=-stl_shift, + unit_convertor=unit_convertor, +) +bc_mask_exporter = MultiresIO( + {"bc_mask": 1}, + level_data, + offset=-stl_shift, + unit_convertor=unit_convertor, +) + +# Create grid +grid = multires_grid_factory( + grid_shape_zip, + velocity_set=velocity_set, + sparsity_pattern_list=sparsity_pattern, + sparsity_pattern_origins=[neon.Index_3d(*box_origin) for box_origin in level_origins], +) + +# Calculate num_steps +coarsest_level = grid.count_levels - 1 +grid_shape_x_coarsest = grid.level_to_shape(coarsest_level)[0] +num_steps = int(flow_passes * (grid_shape_x_coarsest / wind_speed_lbm)) + +# Calculate print and file output intervals +print_interval = max(1, int(num_steps * (print_interval_percentage / 100.0))) +crossover_step = int(num_steps * (file_output_crossover_percentage / 100.0)) +file_output_interval_pre_crossover = ( + max(1, int(crossover_step / num_file_outputs_pre_crossover)) if num_file_outputs_pre_crossover > 0 else num_steps + 1 +) +file_output_interval_post_crossover = ( + max(1, int((num_steps - crossover_step) / num_file_outputs_post_crossover)) if num_file_outputs_post_crossover > 0 else num_steps + 1 +) + +# Setup boundary conditions +boundary_conditions = setup_boundary_conditions(grid, level_data, body_vertices, wind_speed_mps) + +# Create initializer +wind_speed_lbm = unit_convertor.velocity_to_lbm(wind_speed_mps) +initializer = CustomMultiresInitializer( + bc_id=boundary_conditions[-2].id, # bc_outlet + constant_velocity_vector=(wind_speed_lbm, 0.0, 0.0), + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, +) + +# Initialize simulation +sim = initialize_simulation(grid, boundary_conditions, omega_finest, initializer) + +# Compute voxel statistics and reference area +stats = compute_voxel_statistics(sim, bc_mask_exporter, sparsity_pattern, boundary_conditions, unit_convertor) +active_voxels = stats["active_voxels"] +solid_voxels = stats["solid_voxels"] +total_voxels = stats["total_voxels"] +total_lattice_updates_per_step = stats["total_lattice_updates_per_step"] +reference_area = stats["reference_area"] +reference_area_physical = stats["reference_area_physical"] + +# Save initial bc_mask +filename = os.path.join(output_dir, f"{script_name}_initial_bc_mask") +try: + bc_mask_exporter.to_hdf5(filename, {"bc_mask": sim.bc_mask}, compression="gzip", compression_opts=0) + xmf_filename = f"{filename}.xmf" + hdf5_basename = f"{script_name}_initial_bc_mask.h5" +except Exception as e: + print(f"Error during initial bc_mask output: {e}") +wp.synchronize() + + +# Setup momentum transfer +momentum_transfer = MultiresMomentumTransfer( + boundary_conditions[-1], + mres_perf_opt=xlb.MresPerfOptimizationType.FUSION_AT_FINEST, + compute_backend=compute_backend, +) + +# Print simulation info +print("\n" + "=" * 50 + "\n") +print(f"Number of flow passes: {flow_passes}") +print(f"Calculated iterations: {num_steps:,}") +print(f"Finest voxel size: {voxel_size} meters") +print(f"Coarsest voxel size: {delta_x_coarse} meters") +print(f"Total voxels: {sum(np.count_nonzero(mask) for mask in sparsity_pattern):,}") +print(f"Total active voxels: {total_voxels:,}") +print(f"Active voxels per level: {[int(v) for v in active_voxels]}") +print(f"Solid voxels per level: {[int(v) for v in solid_voxels]}") +print(f"Total lattice updates per global step: {total_lattice_updates_per_step:,}") +print(f"Number of refinement levels: {num_levels}") +print(f"Physical inlet velocity: {wind_speed_mps:.4f} m/s") +print(f"Lattice velocity (ulb): {wind_speed_lbm}") +print(f"Computed reference area (bc_mask): {reference_area} lattice units") +print(f"Physical reference area (bc_mask): {reference_area_physical:.6f} m^2") +print("\n" + "=" * 50 + "\n") + +# -------------------------- Simulation Loop -------------------------- +wp.synchronize() +start_time = time.time() +compute_time = 0.0 +steps_since_last_print = 0 +drag_values = [] + +for step in range(num_steps): + step_start = time.time() + sim.step() + wp.synchronize() + compute_time += time.time() - step_start + steps_since_last_print += 1 + if step % print_interval == 0 or step == num_steps - 1: + sim.macro(sim.f_0, sim.bc_mask, sim.rho, sim.u, streamId=0) + wp.synchronize() + cd, cl, drag = compute_force_coefficients(sim, step, momentum_transfer, wind_speed_lbm, reference_area) + filename = os.path.join(output_dir, f"{script_name}_{step:04d}") + h5exporter.to_hdf5(filename, {"velocity": sim.u, "density": sim.rho}, compression="gzip", compression_opts=0) + h5exporter.to_slice_image( + filename, + {"velocity": sim.u}, + plane_point=(1, 0, 0), + plane_normal=(0, 1, 0), + grid_res=2000, + bounds=(0.25, 0.75, 0, 0.5), + show_axes=False, + show_colorbar=False, + slice_thickness=delta_x_coarse, # needed when using model units + ) + end_time = time.time() + elapsed = end_time - start_time + total_lattice_updates = total_lattice_updates_per_step * steps_since_last_print + MLUPS = total_lattice_updates / compute_time / 1e6 if compute_time > 0 else 0.0 + current_flow_passes = step * wind_speed_lbm / grid_shape_x_coarsest + remaining_steps = num_steps - step - 1 + time_remaining = 0.0 if MLUPS == 0 else (total_lattice_updates_per_step * remaining_steps) / (MLUPS * 1e6) + hours, rem = divmod(time_remaining, 3600) + minutes, seconds = divmod(rem, 60) + time_remaining_str = f"{int(hours):02d}h {int(minutes):02d}m {int(seconds):02d}s" + percent_complete = (step + 1) / num_steps * 100 + print(f"Completed step {step}/{num_steps} ({percent_complete:.2f}% complete)") + print(f" Flow Passes: {current_flow_passes:.2f}") + print(f" Time elapsed: {elapsed:.1f}s, Compute time: {compute_time:.1f}s, ETA: {time_remaining_str}") + print(f" MLUPS: {MLUPS:.1f}") + print(f" Cd={cd:.3f}, Cl={cl:.3f}, Drag Force (lattice units)={drag:.3f}") + start_time = time.time() + compute_time = 0.0 + steps_since_last_print = 0 + file_output_interval = file_output_interval_pre_crossover if step < crossover_step else file_output_interval_post_crossover + if step % file_output_interval == 0 or step == num_steps - 1: + sim.macro(sim.f_0, sim.bc_mask, sim.rho, sim.u, streamId=0) + filename = os.path.join(output_dir, f"{script_name}_{step:04d}") + try: + h5exporter.to_hdf5(filename, {"velocity": sim.u, "density": sim.rho}, compression="gzip", compression_opts=0) + xmf_filename = f"{filename}.xmf" + hdf5_basename = f"{script_name}_{step:04d}.h5" + except Exception as e: + print(f"Error during file output at step {step}: {e}") + wp.synchronize() + if step == num_steps - 1: + plot_data(x0, output_dir, delta_x_coarse, sim, h5exporter, prefix="Ahmed") + +# Save drag and lift data to CSV +if len(drag_values) > 0: + with open(os.path.join(output_dir, "drag_lift.csv"), "w") as fd: + fd.write("Step,Cd,Cl\n") + for i, (cd, cl) in enumerate(drag_values): + fd.write(f"{i * print_interval},{cd},{cl}\n") + plot_force_coefficients(drag_values, output_dir, print_interval, script_name) + +# Calculate and print average Cd and Cl for the last 50% +drag_values_array = np.array(drag_values) +if len(drag_values) > 0: + start_index = len(drag_values) // 2 + last_half = drag_values_array[start_index:, :] + avg_cd = np.mean(last_half[:, 0]) + avg_cl = np.mean(last_half[:, 1]) + print(f"Average Drag Coefficient (Cd) for last 50%: {avg_cd:.6f}") + print(f"Average Lift Coefficient (Cl) for last 50%: {avg_cl:.6f}") + print(f"Experimental Drag Coefficient (Cd): {0.3088}") + print(f"Error Drag Coefficient (Cd): {((avg_cd - 0.3088) / 0.3088) * 100:.2f}%") + +else: + print("No drag or lift data collected.") diff --git a/examples/cfd/rotating_sphere_3d.py b/examples/cfd/rotating_sphere_3d.py new file mode 100644 index 00000000..7a0b10a4 --- /dev/null +++ b/examples/cfd/rotating_sphere_3d.py @@ -0,0 +1,327 @@ +""" +Rotating sphere 3-D example (single-resolution). + +Simulates flow past a sphere rotating about the y-axis using the +halfway bounce-back BC with a prescribed rotational-velocity profile. +Computes drag and lift coefficients over time and saves VTK snapshots. +""" + +import xlb +import trimesh +import time +import warp as wp +import numpy as np +import jax.numpy as jnp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.grid import grid_factory +from xlb.operator.stepper import IncompressibleNavierStokesStepper +from xlb.operator.boundary_condition import ( + HalfwayBounceBackBC, + FullwayBounceBackBC, + RegularizedBC, + DoNothingBC, + HybridBC, +) +from xlb.operator.force.momentum_transfer import MomentumTransfer +from xlb.operator.macroscopic import Macroscopic +from xlb.utils import save_fields_vtk, save_image +import matplotlib.pyplot as plt +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator import Operator +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.operator.boundary_masker import MeshVoxelizationMethod + +# -------------------------- Simulation Setup -------------------------- + +# Grid parameters +wp.clear_kernel_cache() +diam = 32 +grid_size_x, grid_size_y, grid_size_z = 10 * diam, 7 * diam, 7 * diam +grid_shape = (grid_size_x, grid_size_y, grid_size_z) + +# Simulation Configuration +compute_backend = ComputeBackend.WARP +precision_policy = PrecisionPolicy.FP32FP32 + +velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend) +wind_speed = 0.04 +num_steps = 100000 +print_interval = 1000 +post_process_interval = 1000 + +# Physical Parameters +Re = 200.0 +visc = wind_speed * diam / Re +omega = 1.0 / (3.0 * visc + 0.5) + +# Rotational speed parameters (see [1] which discusses the problem in terms of 2 non-dimensional parameters: Re and Omega) +# [1] J. Fluid Mech. (2016), vol. 807, pp. 62–86. c© Cambridge University Press 2016 doi:10.1017/jfm.2016.596 +# \Omega = \omega * D / (2 U_\infty) where Omega is non-dimensional and omega is dimensional. +rot_rate_nondim = -0.2 +rot_rate = 2.0 * wind_speed * rot_rate_nondim / diam + +# Print simulation info +print("\n" + "=" * 50 + "\n") +print("Simulation Configuration:") +print(f"Grid size: {grid_size_x} x {grid_size_y} x {grid_size_z}") +print(f"Backend: {compute_backend}") +print(f"Velocity set: {velocity_set}") +print(f"Precision policy: {precision_policy}") +print(f"Prescribed velocity: {wind_speed}") +print(f"Reynolds number: {Re}") +print(f"Max iterations: {num_steps}") +print("\n" + "=" * 50 + "\n") + +# Initialize XLB +xlb.init( + velocity_set=velocity_set, + default_backend=compute_backend, + default_precision_policy=precision_policy, +) + +# Create Grid +grid = grid_factory(grid_shape, compute_backend=compute_backend) + +# Bounding box indices +box = grid.bounding_box_indices() +box_no_edge = grid.bounding_box_indices(remove_edges=True) +inlet = box_no_edge["left"] +outlet = box["right"] +walls = [box["bottom"][i] + box["top"][i] + box["front"][i] + box["back"][i] for i in range(velocity_set.d)] +walls = np.unique(np.array(walls), axis=-1).tolist() + +# Load the mesh (replace with your own mesh) +stl_filename = "../stl-files/sphere.stl" +mesh = trimesh.load_mesh(stl_filename, process=False) +mesh_vertices = mesh.vertices + +# Transform the mesh points to be located in the right position in the wind tunnel +mesh_vertices -= mesh_vertices.min(axis=0) +mesh_extents = mesh_vertices.max(axis=0) +length_phys_unit = mesh_extents.max() +length_lbm_unit = grid_shape[1] / 7 +dx = length_phys_unit / length_lbm_unit +mesh_vertices = mesh_vertices / dx +shift = np.array([grid_shape[0] / 3, (grid_shape[1] - mesh_extents[1] / dx) / 2, (grid_shape[2] - mesh_extents[2] / dx) / 2]) +sphere = mesh_vertices + shift +diam = np.max(sphere.max(axis=0) - sphere.min(axis=0)) +sphere_cross_section = np.pi * diam**2 / 4.0 + + +# Define rotating boundary profile +def bc_profile(): + """Build a Warp function returning the rotational wall velocity at a voxel.""" + dtype = precision_policy.compute_precision.wp_dtype + _u_vec = wp.vec(velocity_set.d, dtype=dtype) + angular_velocity = _u_vec(0.0, rot_rate, 0.0) + origin_np = shift + diam / 2 + origin_wp = _u_vec(origin_np[0], origin_np[1], origin_np[2]) + + @wp.func + def bc_profile_warp(index: wp.vec3i): + x = dtype(index[0]) + y = dtype(index[1]) + z = dtype(index[2]) + surface_coord = _u_vec(x, y, z) - origin_wp + return wp.cross(angular_velocity, surface_coord) + + return bc_profile_warp + + +# Define boundary conditions +bc_left = RegularizedBC("velocity", prescribed_value=(wind_speed, 0.0, 0.0), indices=inlet) +bc_do_nothing = DoNothingBC(indices=outlet) +# bc_sphere = HalfwayBounceBackBC(mesh_vertices=sphere, voxelization_method="ray", profile=bc_profile()) +bc_sphere = HybridBC( + bc_method="nonequilibrium_regularized", + mesh_vertices=sphere, + use_mesh_distance=True, + voxelization_method=MeshVoxelizationMethod("RAY"), + profile=bc_profile(), +) +# Not assining BC for walls makes them periodic. +boundary_conditions = [bc_left, bc_do_nothing, bc_sphere] + + +# Setup Stepper +stepper = IncompressibleNavierStokesStepper( + grid=grid, + boundary_conditions=boundary_conditions, + collision_type="KBC", +) + +# Make initializer operator +from xlb.helper.initializers import CustomInitializer + +initializer = CustomInitializer( + bc_id=bc_do_nothing.id, + constant_velocity_vector=(wind_speed, 0.0, 0.0), + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, +) + +# Prepare Fields +f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields(initializer=initializer) + + +# -------------------------- Helper Functions -------------------------- + + +def plot_coefficient(time_steps, coefficients, prefix="drag"): + """ + Plot the drag coefficient with various moving averages. + + Args: + time_steps (list): List of time steps. + coefficients (list): List of force coefficients. + """ + # Convert lists to numpy arrays for processing + time_steps_np = np.array(time_steps) + coefficients_np = np.array(coefficients) + + # Define moving average windows + windows = [10, 100, 1000, 10000, 100000] + labels = ["MA 10", "MA 100", "MA 1,000", "MA 10,000", "MA 100,000"] + + plt.figure(figsize=(12, 8)) + plt.plot(time_steps_np, coefficients_np, label="Raw", alpha=0.5) + + for window, label in zip(windows, labels): + if len(coefficients_np) >= window: + ma = np.convolve(coefficients_np, np.ones(window) / window, mode="valid") + plt.plot(time_steps_np[window - 1 :], ma, label=label) + + plt.ylim(-1.0, 1.0) + plt.legend() + plt.xlabel("Time step") + plt.ylabel("Drag coefficient") + plt.title("Drag Coefficient Over Time with Moving Averages") + plt.savefig(prefix + "_ma.png") + plt.close() + + +def post_process( + step, + f_0, + f_1, + grid_shape, + macro, + momentum_transfer, + missing_mask, + bc_mask, + wind_speed, + car_cross_section, + drag_coefficients, + lift_coefficients, + time_steps, +): + """Compute macroscopic fields, force coefficients, and save VTK output.""" + """ + Post-process simulation data: save fields, compute forces, and plot drag coefficient. + + Args: + step (int): Current time step. + f_current: Current distribution function. + grid_shape (tuple): Shape of the grid. + macro: Macroscopic operator object. + momentum_transfer: MomentumTransfer operator object. + missing_mask: Missing mask from stepper. + bc_mask: Boundary condition mask from stepper. + wind_speed (float): Prescribed wind speed. + car_cross_section (float): Cross-sectional area of the car. + drag_coefficients (list): List to store drag coefficients. + lift_coefficients (list): List to store lift coefficients. + time_steps (list): List to store time steps. + """ + wp.synchronize() + # Convert to JAX array if necessary + if not isinstance(f_0, jnp.ndarray): + f_0_jax = wp.to_jax(f_0) + else: + f_0_jax = f_0 + + # Compute macroscopic quantities + rho, u = macro(f_0_jax) + + # Remove boundary cells + u = u[:, 1:-1, 1:-1, 1:-1] + u_magnitude = jnp.sqrt(u[0] ** 2 + u[1] ** 2 + u[2] ** 2) + + fields = {"ux": u[0], "uy": u[1], "uz": u[2], "u_magnitude": u_magnitude} + + # Save fields in VTK format + # save_fields_vtk(fields, timestep=step) + + # Save the u_magnitude slice at the mid y-plane + mid_y = grid_shape[1] // 2 + save_image(fields["u_magnitude"][:, mid_y, :], timestep=step) + + # Compute lift and drag + boundary_force = momentum_transfer(f_0, f_1, bc_mask, missing_mask) + drag = boundary_force[0] # x-direction + lift = boundary_force[2] + cd = 2.0 * drag / (wind_speed**2 * car_cross_section) + cl = 2.0 * lift / (wind_speed**2 * car_cross_section) + print(f"CD={cd}, CL={cl}") + drag_coefficients.append(cd) + lift_coefficients.append(cl) + time_steps.append(step) + + # Plot drag coefficient + plot_coefficient(time_steps, drag_coefficients, prefix="drag") + plot_coefficient(time_steps, lift_coefficients, prefix="lift") + + +# Setup Momentum Transfer for Force Calculation +bc_car = boundary_conditions[-1] +momentum_transfer = MomentumTransfer(bc_car, compute_backend=compute_backend) + +# Define Macroscopic Calculation +macro = Macroscopic( + compute_backend=ComputeBackend.JAX, + precision_policy=precision_policy, + velocity_set=xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX), +) + +# Initialize Lists to Store Coefficients and Time Steps +time_steps = [] +drag_coefficients = [] +lift_coefficients = [] + +# -------------------------- Simulation Loop -------------------------- + +start_time = time.time() +for step in range(num_steps): + # Perform simulation step + f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, step) + f_0, f_1 = f_1, f_0 # Swap the buffers + + # Print progress at intervals + if step % print_interval == 0: + elapsed_time = time.time() - start_time + print(f"Iteration: {step}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") + start_time = time.time() + + # Post-process at intervals and final step + if (step % post_process_interval == 0) or (step == num_steps - 1): + post_process( + step, + f_0, + f_1, + grid_shape, + macro, + momentum_transfer, + missing_mask, + bc_mask, + wind_speed, + sphere_cross_section, + drag_coefficients, + lift_coefficients, + time_steps, + ) + +print("Simulation completed successfully.") diff --git a/examples/cfd/windtunnel_3d.py b/examples/cfd/windtunnel_3d.py index 074570b8..9909132c 100644 --- a/examples/cfd/windtunnel_3d.py +++ b/examples/cfd/windtunnel_3d.py @@ -10,6 +10,7 @@ FullwayBounceBackBC, RegularizedBC, ExtrapolationOutflowBC, + HybridBC, ) from xlb.operator.force.momentum_transfer import MomentumTransfer from xlb.operator.macroscopic import Macroscopic @@ -18,6 +19,7 @@ import numpy as np import jax.numpy as jnp import matplotlib.pyplot as plt +from xlb.operator.boundary_masker import MeshVoxelizationMethod # -------------------------- Simulation Setup -------------------------- @@ -74,6 +76,7 @@ # Load the mesh (replace with your own mesh) stl_filename = "../stl-files/DrivAer-Notchback.stl" +voxelization_method = MeshVoxelizationMethod("RAY") mesh = trimesh.load_mesh(stl_filename, process=False) mesh_vertices = mesh.vertices @@ -84,7 +87,15 @@ length_lbm_unit = grid_shape[0] / 4 dx = length_phys_unit / length_lbm_unit mesh_vertices = mesh_vertices / dx -shift = np.array([grid_shape[0] / 4, (grid_shape[1] - mesh_extents[1] / dx) / 2, 0.0]) + +# Depending on the voxelization method, shift_z ensures the bottom ground does not intersect with the voxelized mesh +# Any smaller shift value would lead to large lift computations due to the initial equilibrium distributions. Bigger +# values would be fine but leave a gap between surfaces that are supposed to touch. +if voxelization_method in (MeshVoxelizationMethod("RAY"), MeshVoxelizationMethod("WINDING")): + shift_z = 2 +elif voxelization_method in (MeshVoxelizationMethod("AABB"), MeshVoxelizationMethod("AABB_CLOSE", close_voxels=3)): + shift_z = 3 +shift = np.array([grid_shape[0] / 4, (grid_shape[1] - mesh_extents[1] / dx) / 2, shift_z]) car_vertices = mesh_vertices + shift car_cross_section = np.prod(mesh_extents[1:]) / dx**2 @@ -92,15 +103,22 @@ bc_left = RegularizedBC("velocity", prescribed_value=(wind_speed, 0.0, 0.0), indices=inlet) bc_walls = FullwayBounceBackBC(indices=walls) bc_do_nothing = ExtrapolationOutflowBC(indices=outlet) -bc_car = HalfwayBounceBackBC(mesh_vertices=car_vertices) +bc_car = HalfwayBounceBackBC(mesh_vertices=car_vertices, voxelization_method=voxelization_method) +# bc_car = HybridBC(bc_method="nonequilibrium_regularized", mesh_vertices=car_vertices, +# voxelization_method=voxelization_method, use_mesh_distance=True) boundary_conditions = [bc_walls, bc_left, bc_do_nothing, bc_car] +# Configure backend options: +# backend_config = {"occ": neon.SkeletonConfig.OCC.from_string("standard"), "device_list": [0, 1]} if compute_backend == ComputeBackend.NEON else {} +backend_config = {} + # Setup Stepper stepper = IncompressibleNavierStokesStepper( grid=grid, boundary_conditions=boundary_conditions, collision_type="KBC", + backend_config=backend_config, ) # Prepare Fields @@ -110,28 +128,28 @@ # -------------------------- Helper Functions -------------------------- -def plot_drag_coefficient(time_steps, drag_coefficients): +def plot_coefficient(time_steps, coefficients, prefix="drag"): """ Plot the drag coefficient with various moving averages. Args: time_steps (list): List of time steps. - drag_coefficients (list): List of drag coefficients. + coefficients (list): List of force coefficients. """ # Convert lists to numpy arrays for processing time_steps_np = np.array(time_steps) - drag_coefficients_np = np.array(drag_coefficients) + coefficients_np = np.array(coefficients) # Define moving average windows windows = [10, 100, 1000, 10000, 100000] labels = ["MA 10", "MA 100", "MA 1,000", "MA 10,000", "MA 100,000"] plt.figure(figsize=(12, 8)) - plt.plot(time_steps_np, drag_coefficients_np, label="Raw", alpha=0.5) + plt.plot(time_steps_np, coefficients_np, label="Raw", alpha=0.5) for window, label in zip(windows, labels): - if len(drag_coefficients_np) >= window: - ma = np.convolve(drag_coefficients_np, np.ones(window) / window, mode="valid") + if len(coefficients_np) >= window: + ma = np.convolve(coefficients_np, np.ones(window) / window, mode="valid") plt.plot(time_steps_np[window - 1 :], ma, label=label) plt.ylim(-1.0, 1.0) @@ -139,7 +157,7 @@ def plot_drag_coefficient(time_steps, drag_coefficients): plt.xlabel("Time step") plt.ylabel("Drag coefficient") plt.title("Drag Coefficient Over Time with Moving Averages") - plt.savefig("drag_coefficient_ma.png") + plt.savefig(prefix + "_ma.png") plt.close() @@ -177,7 +195,7 @@ def post_process( """ # Convert to JAX array if necessary if not isinstance(f_0, jnp.ndarray): - f_0_jax = wp.to_jax(f_0) + f_0_jax = to_jax(f_0) else: f_0_jax = f_0 @@ -203,12 +221,14 @@ def post_process( lift = boundary_force[2] cd = 2.0 * drag / (wind_speed**2 * car_cross_section) cl = 2.0 * lift / (wind_speed**2 * car_cross_section) + print(f"CD={cd}, CL={cl}") drag_coefficients.append(cd) lift_coefficients.append(cl) time_steps.append(step) # Plot drag coefficient - plot_drag_coefficient(time_steps, drag_coefficients) + plot_coefficient(time_steps, drag_coefficients, prefix="drag") + plot_coefficient(time_steps, lift_coefficients, prefix="lift") # Setup Momentum Transfer for Force Calculation @@ -221,6 +241,7 @@ def post_process( precision_policy=precision_policy, velocity_set=xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=ComputeBackend.JAX), ) +to_jax = xlb.utils.ToJAX("populations", velocity_set.q, grid_shape) # Initialize Lists to Store Coefficients and Time Steps time_steps = [] @@ -237,7 +258,7 @@ def post_process( # Print progress at intervals if step % print_interval == 0: - if compute_backend == ComputeBackend.WARP: + if compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]: wp.synchronize() elapsed_time = time.time() - start_time print(f"Iteration: {step}/{num_steps} | Time elapsed: {elapsed_time:.2f}s") diff --git a/examples/performance/mlups_3d.py b/examples/performance/mlups_3d.py index 409a8d59..66945e89 100644 --- a/examples/performance/mlups_3d.py +++ b/examples/performance/mlups_3d.py @@ -9,21 +9,109 @@ from xlb.operator.stepper import IncompressibleNavierStokesStepper from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC from xlb.distribute import distribute +from xlb.operator.macroscopic import Macroscopic + # -------------------------- Simulation Setup -------------------------- def parse_arguments(): - parser = argparse.ArgumentParser(description="MLUPS for 3D Lattice Boltzmann Method Simulation (BGK)") - parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid") - parser.add_argument("num_steps", type=int, help="Number of timesteps for the simulation") - parser.add_argument("compute_backend", type=str, help="Backend for the simulation (jax or warp)") - parser.add_argument("precision", type=str, help="Precision for the simulation (e.g., fp32/fp32)") - return parser.parse_args() + # Define valid options for consistency + COMPUTE_BACKENDS = ["neon", "warp", "jax"] + PRECISION_OPTIONS = ["fp32/fp32", "fp64/fp64", "fp64/fp32", "fp32/fp16"] + VELOCITY_SETS = ["D3Q19", "D3Q27"] + COLLISION_MODELS = ["BGK", "KBC"] + OCC_OPTIONS = ["standard", "none"] + + parser = argparse.ArgumentParser( + description="MLUPS Benchmark for 3D Lattice Boltzmann Method Simulation", + epilog=f""" +Examples: + %(prog)s 100 1000 neon fp32/fp32 + %(prog)s 200 500 neon fp64/fp64 --collision_model KBC --velocity_set D3Q27 + %(prog)s 150 2000 neon fp32/fp32 --gpu_devices=[0,1,2] --measure_scalability --report + %(prog)s 100 1000 neon fp32/fp32 --repetitions 5 --export_final_velocity + """, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # Positional arguments + parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid (e.g., 100)") + parser.add_argument("num_steps", type=int, help="Number of timesteps for the simulation (e.g., 1000)") + parser.add_argument("compute_backend", type=str, choices=COMPUTE_BACKENDS, help=f"Backend for the simulation ({', '.join(COMPUTE_BACKENDS)})") + parser.add_argument("precision", type=str, choices=PRECISION_OPTIONS, help=f"Precision for the simulation ({', '.join(PRECISION_OPTIONS)})") + + # Optional arguments + parser.add_argument("--gpu_devices", type=str, default=None, help="CUDA devices to use for Neon backend (e.g., [0,1,2] or [0])") + parser.add_argument( + "--velocity_set", + type=str, + default="D3Q19", + choices=VELOCITY_SETS, + help=f"Lattice velocity set (default: D3Q19, choices: {', '.join(VELOCITY_SETS)})", + ) + parser.add_argument( + "--collision_model", + type=str, + default="BGK", + choices=COLLISION_MODELS, + help=f"Collision model (default: BGK, choices: {', '.join(COLLISION_MODELS)}, KBC requires D3Q27)", + ) + parser.add_argument( + "--occ", + type=str, + default="standard", + choices=OCC_OPTIONS, + help=f"Overlapping Communication and Computation strategy (default: standard, choices: {', '.join(OCC_OPTIONS)})", + ) + parser.add_argument("--report", action="store_true", help="Generate Neon performance report") + parser.add_argument("--export_final_velocity", action="store_true", help="Export final velocity field to VTI file") + parser.add_argument("--measure_scalability", action="store_true", help="Measure performance across different GPU counts") + parser.add_argument( + "--repetitions", type=int, default=1, metavar="N", help="Number of simulation repetitions for statistical analysis (default: 1)" + ) + args = parser.parse_args() -def setup_simulation(args): - compute_backend = ComputeBackend.JAX if args.compute_backend == "jax" else ComputeBackend.WARP + # Parse gpu_devices string to list + if args.gpu_devices is not None: + try: + import ast + + args.gpu_devices = ast.literal_eval(args.gpu_devices) + if not isinstance(args.gpu_devices, list): + args.gpu_devices = [args.gpu_devices] # Handle single integer case + except (ValueError, SyntaxError): + raise ValueError("Invalid gpu_devices format. Use format like [0,1,2] or [0]") + + # Validate and convert compute backend + compute_backend_map = { + "jax": ComputeBackend.JAX, + "warp": ComputeBackend.WARP, + "neon": ComputeBackend.NEON, + } + compute_backend = compute_backend_map.get(args.compute_backend) + if compute_backend is None: + raise ValueError(f"Invalid compute backend '{args.compute_backend}'. Use: {', '.join(COMPUTE_BACKENDS)}") + args.compute_backend = compute_backend + + # Handle GPU devices for Neon backend + if args.compute_backend == ComputeBackend.NEON: + if args.gpu_devices is None: + print("[INFO] No GPU devices specified. Using default device 0.") + args.gpu_devices = [0] + + import neon + + occ_enum = neon.SkeletonConfig.OCC.from_string(args.occ) + args.occ_enum = occ_enum # Store the enum for Neon + args.occ_display = args.occ # Store the original string for display + else: + if args.gpu_devices is not None: + raise ValueError(f"--gpu_devices can only be used with Neon backend, not {args.compute_backend.name}") + args.gpu_devices = [0] # Default for non-Neon backends + + # Checking precision policy precision_policy_map = { "fp32/fp32": PrecisionPolicy.FP32FP32, "fp64/fp64": PrecisionPolicy.FP64FP64, @@ -32,18 +120,77 @@ def setup_simulation(args): } precision_policy = precision_policy_map.get(args.precision) if precision_policy is None: - raise ValueError("Invalid precision specified.") + raise ValueError(f"Invalid precision '{args.precision}'. Use: {', '.join(PRECISION_OPTIONS)}") + args.precision_policy = precision_policy + + # Validate collision model and velocity set compatibility + if args.collision_model == "KBC" and args.velocity_set != "D3Q27": + raise ValueError("KBC collision model requires D3Q27 velocity set. Use --velocity_set D3Q27") + + if args.velocity_set == "D3Q19": + velocity_set = xlb.velocity_set.D3Q19(precision_policy=args.precision_policy, compute_backend=compute_backend) + elif args.velocity_set == "D3Q27": + velocity_set = xlb.velocity_set.D3Q27(precision_policy=args.precision_policy, compute_backend=compute_backend) + args.velocity_set = velocity_set + + print_args(args) + + return args + + +def print_args(args): + """Print simulation configuration in a clean, organized format""" + print("\n" + "=" * 70) + print(" SIMULATION CONFIGURATION") + print("=" * 70) + + # Grid and simulation parameters + print("GRID & SIMULATION:") + print(f" Grid Size: {args.cube_edge}³ ({args.cube_edge:,} × {args.cube_edge:,} × {args.cube_edge:,})") + print(f" Total Lattice Points: {args.cube_edge**3:,}") + print(f" Time Steps: {args.num_steps:,}") + print(f" Repetitions: {args.repetitions}") + # Computational settings + print("\nCOMPUTATIONAL SETTINGS:") + print(f" Compute Backend: {args.compute_backend.name}") + print(f" Precision Policy: {args.precision}") + print(f" Velocity Set: {args.velocity_set.__class__.__name__}") + print(f" Collision Model: {args.collision_model}") + + # Backend-specific settings + if args.compute_backend.name == "NEON": + print("\nNEON BACKEND SETTINGS:") + print(f" GPU Devices: {args.gpu_devices}") + print(f" OCC Strategy: {args.occ_display}") + + # Output options + print("\nOUTPUT OPTIONS:") + print(f" Generate Report: {'Yes' if args.report else 'No'}") + print(f" Measure Scalability: {'Yes' if args.measure_scalability else 'No'}") + print(f" Export Velocity: {'Yes' if args.export_final_velocity else 'No'}") + + print("=" * 70) + print("Starting simulation...\n") + + +def init_xlb(args): xlb.init( - velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend), - default_backend=compute_backend, - default_precision_policy=precision_policy, + velocity_set=args.velocity_set, + default_backend=args.compute_backend, + default_precision_policy=args.precision_policy, ) - return compute_backend, precision_policy + options = None + if args.compute_backend == ComputeBackend.NEON: + neon_options = {"occ": args.occ_enum, "device_list": args.gpu_devices} + options = neon_options + return args.compute_backend, args.precision_policy, options -def run_simulation(compute_backend, precision_policy, grid_shape, num_steps): - grid = grid_factory(grid_shape) +def run_simulation( + compute_backend, precision_policy, grid_shape, num_steps, options, export_final_velocity, repetitions, num_devices, collision_model +): + grid = grid_factory(grid_shape, backend_config=options) box = grid.bounding_box_indices() box_no_edge = grid.bounding_box_indices(remove_edges=True) @@ -59,7 +206,8 @@ def run_simulation(compute_backend, precision_policy, grid_shape, num_steps): stepper = IncompressibleNavierStokesStepper( grid=grid, boundary_conditions=boundary_conditions, - collision_type="BGK", + collision_type=collision_model, + backend_config=options, ) # Distribute if using JAX @@ -74,14 +222,44 @@ def run_simulation(compute_backend, precision_policy, grid_shape, num_steps): omega = 1.0 f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields() - start_time = time.time() - for i in range(num_steps): + warmup_iterations = 10 + # Warp-up iterations + for i in range(warmup_iterations): f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, i) f_0, f_1 = f_1, f_0 wp.synchronize() - elapsed_time = time.time() - start_time + export_num_steps = warmup_iterations + + elapsed_time_list = [] + for i in range(repetitions): + start_time = time.time() + for i in range(num_steps): + f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, i) + f_0, f_1 = f_1, f_0 + wp.synchronize() + elapsed_time = time.time() - start_time + elapsed_time_list.append(elapsed_time) + export_num_steps += num_steps + + # Define Macroscopic Calculation + macro = Macroscopic( + compute_backend=compute_backend, + precision_policy=precision_policy, + velocity_set=xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend), + ) + + if compute_backend == ComputeBackend.NEON: + if export_final_velocity: + rho = grid.create_field(cardinality=1, dtype=precision_policy.store_precision) + u = grid.create_field(cardinality=3, dtype=precision_policy.store_precision) - return elapsed_time + macro(f_0, rho, u) + wp.synchronize() + u.update_host(0) + wp.synchronize() + u.export_vti(f"mlups_3d_size_{grid_shape[0]}_dev_{num_devices}_step_{export_num_steps}.vti", "u") + + return elapsed_time_list def calculate_mlups(cube_edge, num_steps, elapsed_time): @@ -90,15 +268,293 @@ def calculate_mlups(cube_edge, num_steps, elapsed_time): return mlups +def print_summary_with_stats(args, stats): + """Print comprehensive simulation summary with statistics from multiple repetitions""" + total_lattice_points = args.cube_edge**3 + total_lattice_updates = total_lattice_points * args.num_steps + + mean_mlups = stats["mean_mlups"] + std_mlups = stats["std_dev_mlups"] + mean_elapsed_time = stats["mean_elapsed_time"] + std_elapsed_time = stats["std_dev_elapsed_time"] + + print("\n\n\n" + "=" * 70) + print(" SIMULATION SUMMARY") + print("=" * 70) + + # Simulation Parameters + print("SIMULATION PARAMETERS:") + print("-" * 25) + print(f" Grid Size: {args.cube_edge}³ ({args.cube_edge:,} × {args.cube_edge:,} × {args.cube_edge:,})") + print(f" Total Lattice Points: {total_lattice_points:,}") + print(f" Time Steps: {args.num_steps:,}") + print(f" Total Lattice Updates: {total_lattice_updates:,}") + print(f" Repetitions: {args.repetitions}") + print(f" Compute Backend: {args.compute_backend.name}") + print(f" Precision Policy: {args.precision}") + print(f" Velocity Set: {args.velocity_set.__class__.__name__}") + print(f" Collision Model: {args.collision_model}") + print(f" Generate Report: {'Yes' if args.report else 'No'}") + print(f" Measure Scalability: {'Yes' if args.measure_scalability else 'No'}") + + if args.compute_backend.name == "NEON": + print(f" GPU Devices: {args.gpu_devices}") + occ_display = args.occ_display + print(f" OCC Strategy: {occ_display}") + + print() + + # Raw Data (if multiple repetitions) + if args.repetitions > 1: + print("RAW MEASUREMENT DATA:") + print("-" * 21) + print(f"{'Run':<6} {'Elapsed Time (s)':<18} {'MLUPs':<12} {'Time/Step (ms)':<15}") + print("-" * 53) + + raw_elapsed_times = stats["raw_elapsed_times"] + raw_mlups = stats["raw_mlups"] + + for i, (elapsed_time, mlups) in enumerate(zip(raw_elapsed_times, raw_mlups)): + time_per_step = elapsed_time / args.num_steps * 1000 + print(f"{i + 1:<6} {elapsed_time:<18.3f} {mlups:<12.2f} {time_per_step:<15.3f}") + + print("-" * 53) + print() + + # Performance Results (Statistical Summary) + print("PERFORMANCE RESULTS:") + print("-" * 20) + if args.repetitions > 1: + print(f" Time in main loop: {mean_elapsed_time:.3f} ± {std_elapsed_time:.3f} seconds") + print(f" MLUPs: {mean_mlups:.2f} ± {std_mlups:.2f}") + print(f" Time per LBM step: {mean_elapsed_time / args.num_steps * 1000:.3f} ± {std_elapsed_time / args.num_steps * 1000:.3f} ms") + else: + print(f" Time in main loop: {mean_elapsed_time:.3f} seconds") + print(f" MLUPs: {mean_mlups:.2f}") + print(f" Time per LBM step: {mean_elapsed_time / args.num_steps * 1000:.3f} ms") + + if args.compute_backend.name == "NEON" and len(args.gpu_devices) > 1: + mlups_per_gpu = mean_mlups / len(args.gpu_devices) + if args.repetitions > 1: + mlups_per_gpu_std = std_mlups / len(args.gpu_devices) + print(f" MLUPs per GPU: {mlups_per_gpu:.2f} ± {mlups_per_gpu_std:.2f}") + else: + print(f" MLUPs per GPU: {mlups_per_gpu:.2f}") + + print("=" * 70) + + +def print_scalability_summary(args, stats_list): + """Print comprehensive scalability summary with MLUPs statistics for different GPU counts""" + total_lattice_points = args.cube_edge**3 + total_lattice_updates = total_lattice_points * args.num_steps + + print("\n\n\n" + "=" * 95) + print(" SCALABILITY ANALYSIS") + print("=" * 95) + + # Simulation Parameters + print("SIMULATION PARAMETERS:") + print("-" * 25) + print(f" Grid Size: {args.cube_edge}³ ({args.cube_edge:,} × {args.cube_edge:,} × {args.cube_edge:,})") + print(f" Total Lattice Points: {total_lattice_points:,}") + print(f" Time Steps: {args.num_steps:,}") + print(f" Total Lattice Updates: {total_lattice_updates:,}") + print(f" Repetitions: {args.repetitions}") + print(f" Compute Backend: {args.compute_backend.name}") + print(f" Precision Policy: {args.precision}") + print(f" Velocity Set: {args.velocity_set.__class__.__name__}") + print(f" Collision Model: {args.collision_model}") + + if args.compute_backend.name == "NEON": + occ_display = args.occ_display + print(f" OCC Strategy: {occ_display}") + print(f" Available GPU Devices: {args.gpu_devices}") + + print() + + # Extract mean MLUPs for calculations + mlups_means = [stats["mean_mlups"] for stats in stats_list] + baseline_mlups = mlups_means[0] if mlups_means else 0 + + # Scalability Results + print("SCALABILITY RESULTS:") + print("-" * 20) + print(f"{'GPUs':<6} {'MLUPs (mean±std)':<18} {'Speedup':<10} {'Efficiency':<12} {'MLUPs/GPU':<12}") + print("-" * 68) + + for i, stats in enumerate(stats_list): + num_gpus = i + 1 + mean_mlups = stats["mean_mlups"] + std_mlups = stats["std_dev_mlups"] + speedup = mean_mlups / baseline_mlups if baseline_mlups > 0 else 0 + efficiency = (speedup / num_gpus) if num_gpus > 0 else 0 + mlups_per_gpu = mean_mlups / num_gpus if num_gpus > 0 else 0 + + # Format MLUPs with standard deviation + if args.repetitions > 1: + mlups_str = f"{mean_mlups:.2f}±{std_mlups:.2f}" + else: + mlups_str = f"{mean_mlups:.2f}" + + print(f"{num_gpus:<6} {mlups_str:<18} {speedup:<10.2f} {efficiency:<11.3f} {mlups_per_gpu:<12.2f}") + + print("-" * 68) + + # Summary Statistics + if len(stats_list) > 1: + max_mlups = max(mlups_means) + max_mlups_idx = mlups_means.index(max_mlups) + max_speedup = max_mlups / baseline_mlups if baseline_mlups > 0 else 0 + best_efficiency_idx = 0 + best_efficiency = 0.0 + + for i, mean_mlups in enumerate(mlups_means): + num_gpus = i + 1 + speedup = mean_mlups / baseline_mlups if baseline_mlups > 0 else 0 + efficiency = (speedup / num_gpus) if num_gpus > 0 else 0 + if efficiency > best_efficiency: + best_efficiency = efficiency + best_efficiency_idx = i + + print() + print("SUMMARY STATISTICS:") + print("-" * 19) + print(f" Best Performance: {max_mlups:.2f} MLUPs ({max_mlups_idx + 1} GPUs)") + if args.repetitions > 1: + max_std = stats_list[max_mlups_idx]["std_dev_mlups"] + print(f" Performance Std Dev: ±{max_std:.2f} MLUPs") + print(f" Maximum Speedup: {max_speedup:.2f}x") + print(f" Best Efficiency: {best_efficiency:.3f} ({best_efficiency_idx + 1} GPUs)") + print(f" Scalability Range: 1-{len(stats_list)} GPUs") + + print("=" * 95) + + +def report(args, stats): + import neon + import sys + + report = neon.Report("LBM MLUPS LDC") + + # Save the full command line + command_line = " ".join(sys.argv) + report.add_member("command_line", command_line) + + report.add_member("velocity_set", args.velocity_set.__class__.__name__) + report.add_member("compute_backend", args.compute_backend.name) + report.add_member("precision_policy", args.precision) + report.add_member("collision_model", args.collision_model) + report.add_member("grid_size", args.cube_edge) + report.add_member("num_steps", args.num_steps) + report.add_member("repetitions", args.repetitions) + + # Statistical measures + report.add_member("mean_elapsed_time", stats["mean_elapsed_time"]) + report.add_member("mean_mlups", stats["mean_mlups"]) + report.add_member("std_dev_elapsed_time", stats["std_dev_elapsed_time"]) + report.add_member("std_dev_mlups", stats["std_dev_mlups"]) + + # Raw data vectors (if multiple repetitions) + if args.repetitions > 1: + report.add_member_vector("raw_elapsed_times", stats["raw_elapsed_times"]) + report.add_member_vector("raw_mlups", stats["raw_mlups"]) + + # Legacy fields for backwards compatibility + report.add_member("elapsed_time", stats["mean_elapsed_time"]) + report.add_member("mlups", stats["mean_mlups"]) + + report.add_member("occ", args.occ_display) + report.add_member_vector("gpu_devices", args.gpu_devices) + report.add_member("num_devices", len(args.gpu_devices)) + report.add_member("measure_scalability", args.measure_scalability) + + # Generate report name following the convention: script_name + parameters + report_name = "mlups_3d" + report_name += f"_velocity_set_{args.velocity_set.__class__.__name__}" + report_name += f"_compute_backend_{args.compute_backend.name}" + report_name += f"_precision_policy_{args.precision.replace('/', '_')}" + report_name += f"_collision_model_{args.collision_model}" + report_name += f"_grid_size_{args.cube_edge}" + report_name += f"_num_steps_{args.num_steps}" + + if args.compute_backend.name == "NEON": + report_name += f"_occ_{args.occ_display}" + report_name += f"_num_devices_{len(args.gpu_devices)}" + + if args.repetitions > 1: + report_name += f"_repetitions_{args.repetitions}" + + report.write(report_name, True) + + # -------------------------- Simulation Loop -------------------------- -args = parse_arguments() -compute_backend, precision_policy = setup_simulation(args) -grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge) -elapsed_time = run_simulation(compute_backend=compute_backend, precision_policy=precision_policy, grid_shape=grid_shape, num_steps=args.num_steps) +def benchmark(args): + compute_backend, precision_policy, options = init_xlb(args) + grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge) + + elapsed_time_list = [] + mlups_list = [] + elapsed_time_list = run_simulation( + compute_backend=compute_backend, + precision_policy=precision_policy, + grid_shape=grid_shape, + num_steps=args.num_steps, + options=options, + export_final_velocity=args.export_final_velocity, + repetitions=args.repetitions, + num_devices=len(args.gpu_devices), + collision_model=args.collision_model, + ) + + for elapsed_time in elapsed_time_list: + mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) + mlups_list.append(mlups) + + mean_mlups = np.mean(mlups_list) + std_dev_mlups = np.std(mlups_list) + mean_elapsed_time = np.mean(elapsed_time_list) + std_dev_elapsed_time = np.std(elapsed_time_list) + + stats = { + "mean_mlups": mean_mlups, + "std_dev_mlups": std_dev_mlups, + "mean_elapsed_time": mean_elapsed_time, + "std_dev_elapsed_time": std_dev_elapsed_time, + "num_devices": len(args.gpu_devices), + "raw_mlups": mlups_list, + "raw_elapsed_times": elapsed_time_list, + } + # Generate report if requested + if args.report: + report(args, stats) + print("Report generated successfully.") + + return stats + + +def main(): + args = parse_arguments() + if not args.measure_scalability: + stats = benchmark(args) + # For single run, print_summary expects individual values with additional stats + print_summary_with_stats(args, stats) + return + + stats_list = [] + for num_devices in range(1, len(args.gpu_devices) + 1): + import copy + + args_copy = copy.deepcopy(args) + args_copy.gpu_devices = args_copy.gpu_devices[:num_devices] + stats = benchmark(args_copy) + stats_list.append(stats) + + # Print comprehensive scalability analysis + print_scalability_summary(args, stats_list) -mlups = calculate_mlups(args.cube_edge, args.num_steps, elapsed_time) -print(f"Simulation completed in {elapsed_time:.2f} seconds") -print(f"MLUPs: {mlups:.2f}") +if __name__ == "__main__": + main() diff --git a/examples/performance/mlups_3d_multires.py b/examples/performance/mlups_3d_multires.py new file mode 100644 index 00000000..ba71985b --- /dev/null +++ b/examples/performance/mlups_3d_multires.py @@ -0,0 +1,402 @@ +""" +MLUPS benchmark for the multi-resolution LBM solver. + +Runs a lid-driven cavity simulation on a multi-resolution Neon grid and +reports the Equivalent Million Lattice Updates Per Second (EMLUPS). + +Usage:: + + python mlups_3d_multires.py neon \\ + [options] + +Example:: + + python mlups_3d_multires.py 100 1000 neon fp32/fp32 2 NAIVE_COLLIDE_STREAM +""" + +import xlb +import argparse +import time +import warp as wp +import numpy as np +import neon + +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import PrecisionPolicy +from xlb.grid import multires_grid_factory +from xlb.operator.stepper import MultiresIncompressibleNavierStokesStepper +from xlb.operator.boundary_condition import FullwayBounceBackBC, EquilibriumBC +from xlb.mres_perf_optimization_type import MresPerfOptimizationType + + +def parse_arguments(): + """Parse and validate command-line arguments.""" + parser = argparse.ArgumentParser( + description="MLUPS for 3D Lattice Boltzmann Method Simulation with Multi-resolution Grid", + epilog=""" +Examples: + %(prog)s 100 1000 neon fp32/fp32 2 NAIVE_COLLIDE_STREAM + %(prog)s 200 500 neon fp64/fp64 3 FUSION_AT_FINEST --report + %(prog)s 50 2000 neon fp32/fp16 2 NAIVE_COLLIDE_STREAM --export_final_velocity + +Valid values: + compute_backend: neon + precision: fp32/fp32, fp64/fp64, fp64/fp32, fp32/fp16 + mres_perf_opt: NAIVE_COLLIDE_STREAM, FUSION_AT_FINEST + velocity_set: D3Q19, D3Q27 + collision_model: BGK, KBC + """, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # Positional arguments + parser.add_argument("cube_edge", type=int, help="Length of the edge of the cubic grid (e.g., 100)") + parser.add_argument("num_steps", type=int, help="Number of timesteps for the simulation (e.g., 1000)") + parser.add_argument("compute_backend", type=str, help="Backend for the simulation (neon)") + parser.add_argument("precision", type=str, help="Precision for the simulation (fp32/fp32, fp64/fp64, fp64/fp32, fp32/fp16)") + parser.add_argument("num_levels", type=int, help="Number of levels for the multiresolution grid (e.g., 2)") + parser.add_argument( + "mres_perf_opt", + type=MresPerfOptimizationType.from_string, + help="Multi-resolution performance optimization strategy (NAIVE_COLLIDE_STREAM, FUSION_AT_FINEST)", + ) + + # Optional arguments + parser.add_argument("--num_devices", type=int, default=0, help="Number of devices for the simulation (default: 0)") + parser.add_argument("--velocity_set", type=str, default="D3Q19", help="Lattice type: D3Q19 or D3Q27 (default: D3Q19)") + parser.add_argument("--collision_model", type=str, default="BGK", help="Collision model: BGK or KBC (default: BGK)") + + parser.add_argument("--report", action="store_true", help="Generate a neon report file (default: disabled)") + parser.add_argument("--export_final_velocity", action="store_true", help="Export the final velocity field to a vti file (default: disabled)") + + try: + args = parser.parse_args() + except SystemExit: + # Re-raise with custom message + print("\n" + "=" * 60) + print("USAGE EXAMPLES:") + print("=" * 60) + print("python mlups_3d_multires.py 100 1000 neon fp32/fp32 2 NAIVE_COLLIDE_STREAM") + print("python mlups_3d_multires.py 200 500 neon fp64/fp64 3 FUSION_AT_FINEST --report") + print("\nVALID VALUES:") + print(" compute_backend: neon") + print(" precision: fp32/fp32, fp64/fp64, fp64/fp32, fp32/fp16") + print(" mres_perf_opt: NAIVE_COLLIDE_STREAM, FUSION_AT_FINEST") + print(" velocity_set: D3Q19, D3Q27") + print(" collision_model: BGK, KBC") + print("=" * 60) + raise + + print_args(args) + + if args.compute_backend != "neon": + raise ValueError("Invalid compute backend specified. Use 'neon' which supports multi-resolution!") + + if args.collision_model not in ["BGK", "KBC"]: + raise ValueError("Invalid collision model specified. Use 'BGK' or 'KBC'.") + + return args + + +def print_args(args): + """Print the simulation configuration to stdout.""" + # Print simulation configuration + print("=" * 60) + print(" 3D LATTICE BOLTZMANN SIMULATION CONFIG") + print("=" * 60) + print(f"Grid Size: {args.cube_edge}³ ({args.cube_edge:,} × {args.cube_edge:,} × {args.cube_edge:,})") + print(f"Total Lattice Points: {args.cube_edge**3:,}") + print(f"Time Steps: {args.num_steps:,}") + print(f"Number Levels: {args.num_levels}") + print(f"Compute Backend: {args.compute_backend}") + print(f"Precision Policy: {args.precision}") + print(f"Velocity Set: {args.velocity_set}") + print(f"Collision Model: {args.collision_model}") + print(f"Mres Perf Opt: {args.mres_perf_opt}") + print(f"Generate Report: {'Yes' if args.report else 'No'}") + print(f"Export Velocity: {'Yes' if args.export_final_velocity else 'No'}") + + print("=" * 60) + print("Starting simulation...") + print() + + +def setup_simulation(args): + """Initialize XLB globals (velocity set, backend, precision) from CLI args. + + Returns + ------- + VelocitySet + The configured lattice velocity set. + """ + compute_backend = None + if args.compute_backend == "neon": + compute_backend = ComputeBackend.NEON + else: + raise ValueError("Invalid compute backend specified. Use 'neon' which supports multi-resolution!") + + precision_policy_map = { + "fp32/fp32": PrecisionPolicy.FP32FP32, + "fp64/fp64": PrecisionPolicy.FP64FP64, + "fp64/fp32": PrecisionPolicy.FP64FP32, + "fp32/fp16": PrecisionPolicy.FP32FP16, + } + precision_policy = precision_policy_map.get(args.precision) + if precision_policy is None: + raise ValueError("Invalid precision") + + velocity_set = None + if args.velocity_set == "D3Q19": + velocity_set = xlb.velocity_set.D3Q19(precision_policy=precision_policy, compute_backend=compute_backend) + elif args.velocity_set == "D3Q27": + velocity_set = xlb.velocity_set.D3Q27(precision_policy=precision_policy, compute_backend=compute_backend) + if velocity_set is None: + raise ValueError("Invalid velocity set") + + xlb.init( + velocity_set=velocity_set, + default_backend=compute_backend, + default_precision_policy=precision_policy, + ) + + return velocity_set + + +def ldc_multires_setup(grid_shape, velocity_set, num_levels): + """Lid-driven cavity with refinement peeling inward from the boundary. + + Each finer level covers only the outermost shell of its parent, + concentrating resolution near the walls. + + Parameters + ---------- + grid_shape : tuple of int + Domain size at the finest level. + velocity_set : VelocitySet + Lattice velocity set. + num_levels : int + Number of refinement levels. + + Returns + ------- + grid : NeonMultiresGrid + lid : list of index arrays (per level) + walls : list of index arrays (per level) + """ + + def peel(dim, idx, peel_level, outwards): + if outwards: + xIn = idx.x <= peel_level or idx.x >= dim.x - 1 - peel_level + yIn = idx.y <= peel_level or idx.y >= dim.y - 1 - peel_level + zIn = idx.z <= peel_level or idx.z >= dim.z - 1 - peel_level + return xIn or yIn or zIn + else: + xIn = idx.x >= peel_level and idx.x <= dim.x - 1 - peel_level + yIn = idx.y >= peel_level and idx.y <= dim.y - 1 - peel_level + zIn = idx.z >= peel_level and idx.z <= dim.z - 1 - peel_level + return xIn and yIn and zIn + + dim = neon.Index_3d(grid_shape[0], grid_shape[1], grid_shape[2]) + + def get_peeled_np(level, width): + divider = 2**level + m = neon.Index_3d(dim.x // divider, dim.y // divider, dim.z // divider) + if level == 0: + m = dim + + mask = np.zeros((m.x, m.y, m.z), dtype=int) + mask = np.ascontiguousarray(mask, dtype=np.int32) + # loop over all the elements in mask and set to one any that have x=0 or y=0 or z=0 + for i in range(m.x): + for j in range(m.y): + for k in range(m.z): + idx = neon.Index_3d(i, j, k) + val = 0 + if peel(m, idx, width, True): + val = 1 + mask[i, j, k] = val + return mask + + def get_levels(num_levels): + levels = [] + for i in range(num_levels - 1): + l = get_peeled_np(i, 8) + levels.append(l) + lastLevel = num_levels - 1 + divider = 2**lastLevel + m = neon.Index_3d(dim.x // divider + 1, dim.y // divider + 1, dim.z // divider + 1) + lastLevel = np.ones((m.x, m.y, m.z), dtype=int) + lastLevel = np.ascontiguousarray(lastLevel, dtype=np.int32) + levels.append(lastLevel) + return levels + + levels = get_levels(num_levels) + + grid = multires_grid_factory( + grid_shape, + velocity_set=velocity_set, + sparsity_pattern_list=levels, + sparsity_pattern_origins=[neon.Index_3d(0, 0, 0)] * len(levels), + ) + + box = grid.bounding_box_indices() + box_no_edge = grid.bounding_box_indices(remove_edges=True) + lid = box_no_edge["top"] + walls = [box["bottom"][i] + box["left"][i] + box["right"][i] + box["front"][i] + box["back"][i] for i in range(len(grid.shape))] + walls = np.unique(np.array(walls), axis=-1).tolist() + # convert bc indices to a list of list, where the first entry of the list corresponds to the finest level + lid = [lid] + [[] for _ in range(num_levels - 1)] + walls = [walls] + [[] for _ in range(num_levels - 1)] + return grid, lid, walls + + +def run( + velocity_set, + grid_shape, + num_steps, + num_levels, + collision_model, + export_final_velocity, + mres_perf_opt, +): + """Set up and execute the benchmark simulation. + + Returns + ------- + dict + ``{"time": elapsed_seconds, "num_levels": int}`` + """ + # Create grid and setup boundary conditions + grid, lid, walls = ldc_multires_setup(grid_shape, velocity_set, num_levels) + + prescribed_vel = 0.1 + boundary_conditions = [ + EquilibriumBC(rho=1.0, u=(prescribed_vel, 0.0, 0.0), indices=lid), + FullwayBounceBackBC(indices=walls), + ] + + # Problem parameters + Re = 5000.0 + clength = grid_shape[0] - 1 + visc = prescribed_vel * clength / Re + omega_finest = 1.0 / (3.0 * visc + 0.5) + + # Define a multi-resolution simulation manager + sim = xlb.helper.MultiresSimulationManager( + omega_finest=omega_finest, + grid=grid, + boundary_conditions=boundary_conditions, + collision_type=collision_model, + mres_perf_opt=mres_perf_opt, + ) + + # sim.export_macroscopic("Initial_") + # sim.step() + + print("start timing") + wp.synchronize() + start_time = time.time() + + if num_levels == 1: + num_steps = num_steps // 2 + + for i in range(num_steps): + sim.step() + # if i % 1000 == 0: + # print(f"step {i}") + # sim.export_macroscopic("u_lid_driven_cavity_") + wp.synchronize() + t = time.time() - start_time + print(f"Timing {t}") + + if export_final_velocity: + sim.export_macroscopic("u_lid_driven_cavity_") + + # sim.export_macroscopic("u_lid_driven_cavity_") + num_levels = grid.count_levels + return {"time": t, "num_levels": num_levels} + + +def calculate_mlups(cube_edge, num_steps, elapsed_time, num_levels): + """Compute the Equivalent Million Lattice Updates Per Second (EMLUPS). + + The metric accounts for the fact that finer levels are stepped + 2^(num_levels-1) times per coarsest-level step. + + Returns + ------- + dict + ``{"EMLUPS": float, "finer_steps": int}`` + """ + num_step_finer = num_steps * 2 ** (num_levels - 1) + total_lattice_updates = cube_edge**3 * num_step_finer + mlups = (total_lattice_updates / elapsed_time) / 1e6 + return {"EMLUPS": mlups, "finer_steps": num_step_finer} + + # # remove boundary cells + # rho = rho[:, 1:-1, 1:-1, 1:-1] + # u = u[:, 1:-1, 1:-1, 1:-1] + # u_magnitude = (u[0] ** 2 + u[1] ** 2) ** 0.5 + # + # fields = {"rho": rho[0], "u_x": u[0], "u_y": u[1], "u_magnitude": u_magnitude} + # + # # save_fields_vtk(fields, timestep=i, prefix="lid_driven_cavity") + # ny=fields["u_magnitude"].shape[1] + # from xlb.utils import save_image + # save_image(fields["u_magnitude"][:, ny//2, :], timestep=i, prefix="lid_driven_cavity") + + +def generate_report(args, stats, mlups_stats): + """Generate a neon report file with simulation parameters and results""" + import neon + import sys + + report = neon.Report("LBM MLUPS Multiresolution LDC") + + # Save the full command line + command_line = " ".join(sys.argv) + report.add_member("command_line", command_line) + + report.add_member("velocity_set", args.velocity_set) + report.add_member("compute_backend", args.compute_backend) + report.add_member("precision_policy", args.precision) + report.add_member("collision_model", args.collision_model) + report.add_member("grid_size", args.cube_edge) + report.add_member("num_steps", args.num_steps) + report.add_member("num_levels", stats["num_levels"]) + report.add_member("finer_steps", mlups_stats["finer_steps"]) + + # Performance metrics + report.add_member("elapsed_time", stats["time"]) + report.add_member("emlups", mlups_stats["EMLUPS"]) + + report_name = f"mlups_3d_multires_size_{args.cube_edge}_levels_{stats['num_levels']}" + report.write(report_name, True) + print("Report generated successfully.") + + +def main(): + args = parse_arguments() + velocity_set = setup_simulation(args) + grid_shape = (args.cube_edge, args.cube_edge, args.cube_edge) + stats = run( + velocity_set, grid_shape, args.num_steps, args.num_levels, args.collision_model, args.export_final_velocity, mres_perf_opt=args.mres_perf_opt + ) + mlups_stats = calculate_mlups(args.cube_edge, args.num_steps, stats["time"], stats["num_levels"]) + + print(f"Simulation completed in {stats['time']:.2f} seconds") + print(f"Number of levels {stats['num_levels']}") + print(f"Cube edge {args.cube_edge}") + print(f"Coarse Iterations {args.num_steps}") + finer_steps = mlups_stats["finer_steps"] + print(f"Fine Iterations {finer_steps}") + EMLUPS = mlups_stats["EMLUPS"] + print(f"EMLUPs: {EMLUPS:.2f}") + + # Generate report if requested + if args.report: + generate_report(args, stats, mlups_stats) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 0693dd31..a3e9bbcb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,8 @@ trimesh warp-lang numpy-stl pydantic +nvtx pytest ruff -usd-core \ No newline at end of file +usd-core +h5py \ No newline at end of file diff --git a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py index 07e68cf4..711c34bf 100644 --- a/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py +++ b/tests/boundary_conditions/bc_equilibrium/test_bc_equilibrium_warp.py @@ -30,7 +30,7 @@ def test_bc_equilibrium_warp(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.UINT8) bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) diff --git a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py index 3cc15cb3..4f5d6757 100644 --- a/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py +++ b/tests/boundary_conditions/bc_fullway_bounce_back/test_bc_fullway_bounce_back_warp.py @@ -33,7 +33,7 @@ def test_fullway_bounce_back_warp(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.UINT8) bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) diff --git a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py index 4ec0639e..cb012cee 100644 --- a/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py +++ b/tests/boundary_conditions/mask/test_bc_indices_masker_warp.py @@ -32,7 +32,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): my_grid = grid_factory(grid_shape) velocity_set = DefaultConfig.velocity_set - missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.BOOL) + missing_mask = my_grid.create_field(cardinality=velocity_set.q, dtype=xlb.Precision.UINT8) bc_mask = my_grid.create_field(cardinality=1, dtype=xlb.Precision.UINT8) @@ -61,7 +61,7 @@ def test_indices_masker_warp(dim, velocity_set, grid_shape): bc_mask, missing_mask, ) - assert missing_mask.dtype == xlb.Precision.BOOL.wp_dtype + assert missing_mask.dtype == xlb.Precision.UINT8.wp_dtype assert bc_mask.dtype == xlb.Precision.UINT8.wp_dtype diff --git a/tests/kernels/collision/test_bgk_collision_jax.py b/tests/kernels/collision/test_bgk_collision_jax.py index f3f4308f..3d6ea901 100644 --- a/tests/kernels/collision/test_bgk_collision_jax.py +++ b/tests/kernels/collision/test_bgk_collision_jax.py @@ -28,7 +28,7 @@ def init_xlb_env(velocity_set): (3, xlb.velocity_set.D3Q27, (50, 50, 50), 1.0), ], ) -def test_bgk_ollision(dim, velocity_set, grid_shape, omega): +def test_bgk_collision(dim, velocity_set, grid_shape, omega): init_xlb_env(velocity_set) my_grid = grid_factory(grid_shape) @@ -45,7 +45,7 @@ def test_bgk_ollision(dim, velocity_set, grid_shape, omega): f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) - f_out = compute_collision(f_orig, f_eq, rho, u, omega) + f_out = compute_collision(f_orig, f_eq, omega) assert jnp.allclose(f_out, f_orig - omega * (f_orig - f_eq)) diff --git a/tests/kernels/collision/test_bgk_collision_warp.py b/tests/kernels/collision/test_bgk_collision_warp.py index aa51ea1d..fa6884b2 100644 --- a/tests/kernels/collision/test_bgk_collision_warp.py +++ b/tests/kernels/collision/test_bgk_collision_warp.py @@ -44,7 +44,7 @@ def test_bgk_collision_warp(dim, velocity_set, grid_shape, omega): f_orig = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) f_out = my_grid.create_field(cardinality=DefaultConfig.velocity_set.q) - f_out = compute_collision(f_orig, f_eq, f_out, rho, u, omega) + f_out = compute_collision(f_orig, f_eq, f_out, omega) f_eq = f_eq.numpy() f_out = f_out.numpy() diff --git a/xlb/__init__.py b/xlb/__init__.py index d03a368a..02b7a994 100644 --- a/xlb/__init__.py +++ b/xlb/__init__.py @@ -9,6 +9,7 @@ from xlb.compute_backend import ComputeBackend as ComputeBackend from xlb.precision_policy import PrecisionPolicy as PrecisionPolicy, Precision as Precision from xlb.physics_type import PhysicsType as PhysicsType +from xlb.mres_perf_optimization_type import MresPerfOptimizationType as MresPerfOptimizationType # Config from .default_config import init as init, DefaultConfig as DefaultConfig diff --git a/xlb/cell_type.py b/xlb/cell_type.py new file mode 100644 index 00000000..ea082e10 --- /dev/null +++ b/xlb/cell_type.py @@ -0,0 +1,11 @@ +# Boundary-mask constants for the bc_mask field. +# Each voxel in the domain carries a uint8 tag in bc_mask that encodes its role: +# BC_NONE — regular fluid voxel (no boundary condition) +# BC_SFV — Simple Fluid Voxel: fluid cell not involved in any BC, +# explosion, or coalescence (used for fast-path kernels) +# BC_SOLID — solid / obstacle voxel (skipped by all LBM operators) +# Registered boundary conditions receive IDs in the range [1, 253]. + +BC_NONE = 0 +BC_SFV = 254 +BC_SOLID = 255 diff --git a/xlb/compute_backend.py b/xlb/compute_backend.py index 60da2912..d53ff8a4 100644 --- a/xlb/compute_backend.py +++ b/xlb/compute_backend.py @@ -1,8 +1,18 @@ -# Enum used to keep track of the compute backends +""" +Compute-backend enumeration for XLB. +""" from enum import Enum, auto class ComputeBackend(Enum): + """Available compute backends. + + ``JAX`` — single-res, multi-GPU/TPU via JAX. + ``WARP`` — single-res, single-GPU CUDA via NVIDIA Warp. + ``NEON`` — single-res and multi-res, single-GPU CUDA via Neon (uses Warp kernels internally). + """ + JAX = auto() WARP = auto() + NEON = auto() diff --git a/xlb/default_config.py b/xlb/default_config.py index 3823353c..078128b8 100644 --- a/xlb/default_config.py +++ b/xlb/default_config.py @@ -1,4 +1,11 @@ -import jax +""" +Global configuration for XLB. + +Call :func:`init` once at the start of every script to select the velocity +set, compute backend, and precision policy. All operators read their +defaults from :class:`DefaultConfig` when explicit arguments are omitted. +""" + from xlb.compute_backend import ComputeBackend from dataclasses import dataclass from xlb.precision_policy import PrecisionPolicy @@ -6,12 +13,40 @@ @dataclass class DefaultConfig: + """Singleton holding the active global configuration. + + Attributes are set by :func:`init` and read by operators, grids, and + helpers throughout XLB. + + Attributes + ---------- + default_precision_policy : PrecisionPolicy or None + Active precision policy (compute / store dtype pair). + velocity_set : VelocitySet or None + Active lattice velocity set. + default_backend : ComputeBackend or None + Active compute backend. + """ + default_precision_policy = None velocity_set = None default_backend = None def init(velocity_set, default_backend, default_precision_policy): + """Initialize the global XLB configuration. + + Must be called before creating any grid, operator, or field. + + Parameters + ---------- + velocity_set : VelocitySet + Lattice velocity set (e.g. ``D3Q19``). + default_backend : ComputeBackend + Compute backend to use (JAX, WARP, or NEON). + default_precision_policy : PrecisionPolicy + Precision policy for compute and storage dtypes. + """ DefaultConfig.velocity_set = velocity_set DefaultConfig.default_backend = default_backend DefaultConfig.default_precision_policy = default_precision_policy @@ -20,6 +55,23 @@ def init(velocity_set, default_backend, default_precision_policy): import warp as wp wp.init() # TODO: Must be removed in the future versions of WARP + elif default_backend == ComputeBackend.NEON: + import warp as wp + import neon + + # wp.config.mode = "release" + # wp.config.llvm_cuda = False + # wp.config.verbose = True + # wp.verbose_warnings = True + + wp.init() + + # It's a good idea to always clear the kernel cache when developing new native or codegen features + wp.build.clear_kernel_cache() + + # !!! DO THIS BEFORE DEFINING/USING ANY KERNELS WITH CUSTOM TYPES + neon.init() + elif default_backend == ComputeBackend.JAX: check_backend_support() else: @@ -27,10 +79,14 @@ def init(velocity_set, default_backend, default_precision_policy): def default_backend() -> ComputeBackend: + """Return the currently configured compute backend.""" return DefaultConfig.default_backend def check_backend_support(): + """Print a summary of available JAX hardware accelerators.""" + import jax + if jax.devices()[0].platform == "gpu": gpus = jax.devices("gpu") if len(gpus) > 1: diff --git a/xlb/grid/__init__.py b/xlb/grid/__init__.py index f242c387..c4a80a8b 100644 --- a/xlb/grid/__init__.py +++ b/xlb/grid/__init__.py @@ -1,4 +1,5 @@ from xlb.grid.grid import grid_factory as grid_factory +from xlb.grid.grid import multires_grid_factory as multires_grid_factory from xlb.grid.warp_grid import WarpGrid from xlb.grid.jax_grid import JaxGrid diff --git a/xlb/grid/grid.py b/xlb/grid/grid.py index 53139fc1..7ae9236f 100644 --- a/xlb/grid/grid.py +++ b/xlb/grid/grid.py @@ -1,17 +1,55 @@ +""" +Grid abstraction and factory functions for XLB. + +Defines the :class:`Grid` abstract base class that every backend-specific +grid must implement, plus two factory helpers: + +* :func:`grid_factory` — creates a single-resolution grid for any backend. +* :func:`multires_grid_factory` — creates a multi-resolution grid (Neon only). +""" + from abc import ABC, abstractmethod -from typing import Tuple +from typing import Tuple, List import numpy as np from xlb import DefaultConfig from xlb.compute_backend import ComputeBackend -def grid_factory(shape: Tuple[int, ...], compute_backend: ComputeBackend = None): +def grid_factory( + shape: Tuple[int, ...], + compute_backend: ComputeBackend = None, + velocity_set=None, + backend_config=None, +): + """Create a single-resolution grid for the specified backend. + + Parameters + ---------- + shape : tuple of int + Domain dimensions, e.g. ``(nx, ny, nz)``. + compute_backend : ComputeBackend, optional + Backend to use. Defaults to ``DefaultConfig.default_backend``. + velocity_set : VelocitySet, optional + Required for the Neon backend. + backend_config : dict, optional + Backend-specific configuration (Neon only). + + Returns + ------- + Grid + A backend-specific grid instance. + """ compute_backend = compute_backend or DefaultConfig.default_backend + velocity_set = velocity_set or DefaultConfig.velocity_set if compute_backend == ComputeBackend.WARP: from xlb.grid.warp_grid import WarpGrid return WarpGrid(shape) + elif compute_backend == ComputeBackend.NEON: + from xlb.grid.neon_grid import NeonGrid + + return NeonGrid(shape=shape, velocity_set=velocity_set, backend_config=backend_config) elif compute_backend == ComputeBackend.JAX: from xlb.grid.jax_grid import JaxGrid @@ -20,8 +58,67 @@ def grid_factory(shape: Tuple[int, ...], compute_backend: ComputeBackend = None) raise ValueError(f"Compute backend {compute_backend} is not supported") +def multires_grid_factory( + shape: Tuple[int, ...], + compute_backend: ComputeBackend = None, + velocity_set=None, + sparsity_pattern_list: List[np.ndarray] = [], + sparsity_pattern_origins=[], +): + import neon + + """Create a multi-resolution grid (Neon backend only). + + Parameters + ---------- + shape : tuple of int + Bounding-box dimensions at the finest level. + compute_backend : ComputeBackend, optional + Must be ``ComputeBackend.NEON``. + velocity_set : VelocitySet, optional + Lattice velocity set. + sparsity_pattern_list : list of np.ndarray + Active-voxel masks, one per level (finest first). + sparsity_pattern_origins : list of neon.Index_3d + Origin of each level's pattern in finest-level coordinates. + + Returns + ------- + NeonMultiresGrid + A multi-resolution Neon grid. + """ + compute_backend = compute_backend or DefaultConfig.default_backend + velocity_set = velocity_set or DefaultConfig.velocity_set + if compute_backend == ComputeBackend.NEON: + from xlb.grid.multires_grid import NeonMultiresGrid + + return NeonMultiresGrid( + shape=shape, velocity_set=velocity_set, sparsity_pattern_list=sparsity_pattern_list, sparsity_pattern_origins=sparsity_pattern_origins + ) + else: + raise ValueError(f"Compute backend {compute_backend} is not supported for multires grid") + + class Grid(ABC): - def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend): + """Abstract base class for all XLB computational grids. + + Subclasses must implement :meth:`_initialize_backend` to set up the + backend-specific data structures and :meth:`create_field` (not + enforced by ABC but expected by all operators). + + Parameters + ---------- + shape : tuple of int + Domain dimensions. + compute_backend : ComputeBackend + The compute backend this grid is associated with. + """ + + def __init__( + self, + shape: Tuple[int, ...], + compute_backend: ComputeBackend, + ): self.shape = shape self.dim = len(shape) self.compute_backend = compute_backend @@ -31,7 +128,11 @@ def __init__(self, shape: Tuple[int, ...], compute_backend: ComputeBackend): def _initialize_backend(self): pass - def bounding_box_indices(self, remove_edges=False): + def get_compute_backend(self): + """Return the compute backend associated with this grid.""" + return self.compute_backend + + def bounding_box_indices(self, shape=None, remove_edges=False): """ This function calculates the indices of the bounding box of a 2D or 3D grid. The bounding box is defined as the set of grid points on the outer edge of the grid. @@ -49,9 +150,13 @@ def bounding_box_indices(self, remove_edges=False): are numpy arrays of indices corresponding to each face. """ + # If shape is not give, use self.shape + if shape is None: + shape = self.shape + # Get the shape of the grid origin = np.array([0, 0, 0]) - bounds = np.array(self.shape) + bounds = np.array(shape) if remove_edges: origin += 1 bounds -= 1 @@ -60,11 +165,11 @@ def bounding_box_indices(self, remove_edges=False): dim = len(bounds) # Generate bounding box indices for each face - grid = np.indices(self.shape) + grid = np.indices(shape) boundingBoxIndices = {} if dim == 2: - nx, ny = self.shape + nx, ny = shape boundingBoxIndices = { "bottom": grid[:, slice_x, 0], "top": grid[:, slice_x, ny - 1], @@ -72,7 +177,7 @@ def bounding_box_indices(self, remove_edges=False): "right": grid[:, nx - 1, slice_y], } elif dim == 3: - nx, ny, nz = self.shape + nx, ny, nz = shape slice_z = slice(origin[2], bounds[2]) boundingBoxIndices = { "bottom": grid[:, slice_x, slice_y, 0].reshape(3, -1), diff --git a/xlb/grid/multires_grid.py b/xlb/grid/multires_grid.py new file mode 100644 index 00000000..582dd27e --- /dev/null +++ b/xlb/grid/multires_grid.py @@ -0,0 +1,225 @@ +""" +Multi-resolution sparse grid backed by the Neon ``mGrid`` runtime. + +This module wraps ``neon.multires.mGrid`` and exposes it through the +:class:`Grid` interface. The grid is hierarchical: level 0 is the finest +and level *N-1* is the coarsest. Each coarser level has half the +resolution of the level below it (refinement factor 2). +""" + +import numpy as np +import warp as wp +import neon +from .grid import Grid +from xlb.precision_policy import Precision +from xlb.compute_backend import ComputeBackend +from typing import Literal, List +from xlb import DefaultConfig + + +class NeonMultiresGrid(Grid): + """Hierarchical multi-resolution grid on the Neon backend. + + Wraps ``neon.multires.mGrid``. Each level is described by a boolean + sparsity pattern (active-voxel mask) and an integer origin that + places it within the finest-level coordinate system. + + Parameters + ---------- + shape : tuple of int + Bounding-box dimensions at the **finest** level ``(nx, ny, nz)``. + velocity_set : VelocitySet + Lattice velocity set defining neighbour connectivity. + sparsity_pattern_list : list of np.ndarray + One boolean/int array per level indicating which voxels are active. + Index 0 = finest level, index *N-1* = coarsest. + sparsity_pattern_origins : list of neon.Index_3d + Origin offset for each level's pattern in the finest-level + coordinate system. + """ + + def __init__( + self, + shape, + velocity_set, + sparsity_pattern_list: List[np.ndarray], + sparsity_pattern_origins: List[neon.Index_3d], + ): + self.bk = None + self.dim = None + self.grid = None + self.velocity_set = velocity_set + self.sparsity_pattern_list = sparsity_pattern_list + self.sparsity_pattern_origins = sparsity_pattern_origins + self.count_levels = len(sparsity_pattern_list) + self.refinement_factor = 2 + + super().__init__(shape, ComputeBackend.NEON) + + def _get_velocity_set(self): + return self.velocity_set + + def _initialize_backend(self): + # FIXME@max: for now we hardcode the number of devices to 0 + num_devs = 1 + dev_idx_list = list(range(num_devs)) + + if len(self.shape) == 2: + import py_neon + + self.dim = py_neon.Index_3d(self.shape[0], 1, self.shape[1]) + self.neon_stencil = [] + for q in range(self.velocity_set.q): + xval, yval = self.velocity_set._c[:, q] + self.neon_stencil.append([xval, 0, yval]) + + else: + self.dim = neon.Index_3d(self.shape[0], self.shape[1], self.shape[2]) + + self.neon_stencil = [] + for q in range(self.velocity_set.q): + xval, yval, zval = self.velocity_set._c[:, q] + self.neon_stencil.append([xval, yval, zval]) + + self.bk = neon.Backend(runtime=neon.Backend.Runtime.stream, dev_idx_list=dev_idx_list) + + self.grid = neon.multires.mGrid( + backend=self.bk, + dim=self.dim, + sparsity_pattern_list=self.sparsity_pattern_list, + sparsity_pattern_origins=self.sparsity_pattern_origins, + stencil=self.neon_stencil, + ) + # Print grid stats about voxel distribution between levels. + self.grid.print_info() + pass + + def create_field( + self, + cardinality: int, + dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16] = None, + fill_value=None, + neon_memory_type: neon.MemoryType = neon.MemoryType.host_device(), + ): + """Allocate a new multi-resolution Neon field. + + The field spans all grid levels. Each level is either zero-filled + or filled with *fill_value*. + + Parameters + ---------- + cardinality : int + Number of components per voxel. + dtype : Precision, optional + Element precision. Defaults to the store precision from the + global config. + fill_value : float, optional + Value to fill every element with. ``None`` means zero. + neon_memory_type : neon.MemoryType + Memory residency (host, device, or both). + + Returns + ------- + neon.multires.mField + The newly allocated multi-resolution field. + """ + dtype = dtype.wp_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.wp_dtype + field = self.grid.new_field( + cardinality=cardinality, + dtype=dtype, + memory_type=neon_memory_type, + ) + for l in range(self.count_levels): + if fill_value is None: + field.zero_run(l, stream_idx=0) + else: + field.fill_run(level=l, value=fill_value, stream_idx=0) + return field + + def get_neon_backend(self): + """Return the underlying ``neon.Backend`` instance.""" + return self.bk + + def level_to_shape(self, level): + """Return the bounding-box shape at the given grid level. + + Level 0 is the finest and has shape ``self.shape``. Each subsequent + level halves each dimension. + """ + # level = 0 corresponds to the finest level + return tuple(x // self.refinement_factor**level for x in self.shape) + + def boundary_indices_across_levels(self, level_data, box_side: str = "front", remove_edges: bool = False): + """ + Get indices for creating a boundary condition on the specified box side that crosses multiples levels of a multiresolution grid. + The indices are returned as a list of lists, where each sublist corresponds to a level + + Parameters + ---------- + - level_data: Level data containing the origins and sparsity patterns for each level as prepared by mesher/make_cuboid_mesh function! + - box_side: The side of the bounding box to get indices for (default is "front"). + returns: + - A list of lists, where each sublist contains the indices for the boundary condition at that level. + """ + num_levels = len(level_data) + bc_indices_list = [] + d = self.velocity_set.d # Dimensionality (2 or 3) + + # Define side configurations (adjust if your conventions differ) + if d == 3: + side_config = { + "left": {"dim": 0, "value": 0}, + "right": {"dim": 0, "value": lambda s: s[0] - 1}, + "front": {"dim": 1, "value": 0}, + "back": {"dim": 1, "value": lambda s: s[1] - 1}, + "bottom": {"dim": 2, "value": 0}, + "top": {"dim": 2, "value": lambda s: s[2] - 1}, + } + elif d == 2: + side_config = { + "left": {"dim": 0, "value": 0}, + "right": {"dim": 0, "value": lambda s: s[0] - 1}, + "bottom": {"dim": 1, "value": 0}, + "top": {"dim": 1, "value": lambda s: s[1] - 1}, + } + else: + raise ValueError(f"Unsupported dimensionality: {d}") + + if box_side not in side_config: + raise ValueError(f"Unsupported box_side: {box_side}") + + for level in range(num_levels): + mask = level_data[level][0] + origin = level_data[level][2] # Assume np.array of shape (d,) + grid_shape = self.level_to_shape(level) # tuple of length d + + conf = side_config[box_side] + dim_idx = conf["dim"] + grid_bounds = conf["value"](grid_shape) if callable(conf["value"]) else conf["value"] + + # Get local indices of active voxels + local_coords = np.nonzero(mask) # Tuple of d arrays, each of length num_active + if not local_coords[0].size: + bc_indices_list.append([]) + continue + + # Compute global coords (list of d arrays) + global_coords = [local_coords[i] + origin[i] for i in range(d)] + + # Filter: must match grid_bounds along the dimension associated with the selected box_side + cond = global_coords[dim_idx] == grid_bounds + + # If remove_edges, exclude perimeter of the face + if remove_edges: + for i in range(d): + if i != dim_idx: + cond &= (global_coords[i] > 0) & (global_coords[i] < grid_shape[i] - 1) + + # Collect filtered indices + if np.any(cond): + active_bc = [gc[cond] for gc in global_coords] + bc_indices_list.append([arr.tolist() for arr in active_bc]) + else: + bc_indices_list.append([]) + + return bc_indices_list diff --git a/xlb/grid/neon_grid.py b/xlb/grid/neon_grid.py new file mode 100644 index 00000000..e92eb7ff --- /dev/null +++ b/xlb/grid/neon_grid.py @@ -0,0 +1,137 @@ +""" +Single-resolution dense grid backed by the Neon multi-GPU runtime. + +This module wraps ``neon.dense.dGrid`` and exposes it through the +:class:`Grid` interface so that XLB operators can allocate and operate on +fields transparently. +""" + +import neon +from .grid import Grid +from xlb.precision_policy import Precision +from xlb.compute_backend import ComputeBackend +from typing import Literal +from xlb import DefaultConfig + + +class NeonGrid(Grid): + """Dense single-resolution grid on the Neon backend. + + Wraps a ``neon.dense.dGrid``. The grid is initialized with the LBM + stencil derived from the provided *velocity_set* so that Neon can + set up the correct halo exchanges for neighbour communication. + + Parameters + ---------- + shape : tuple of int + Bounding-box dimensions of the domain ``(nx, ny, nz)`` (or + ``(nx, ny)`` for 2-D). + velocity_set : VelocitySet + Lattice velocity set whose stencil defines neighbour connectivity. + backend_config : dict, optional + Neon backend configuration. Must contain ``"device_list"`` (list + of GPU device indices). Defaults to ``{"device_list": [0]}``. + """ + + def __init__( + self, + shape, + velocity_set, + backend_config=None, + ): + from .warp_grid import WarpGrid + + if backend_config is None: + backend_config = { + "device_list": [0], + "skeleton_config": neon.SkeletonConfig.OCC.none(), + } + + # check that the config dictionary has the required keys + required_keys = ["device_list"] + for key in required_keys: + if key not in backend_config: + raise ValueError(f"backend_config must contain a '{key}' key") + + # check that the device list is a list of integers + if not isinstance(backend_config["device_list"], list): + raise ValueError("backend_config['device_list'] must be a list of integers") + for device in backend_config["device_list"]: + if not isinstance(device, int): + raise ValueError("backend_config['device_list'] must be a list of integers") + + self.config = backend_config + self.bk = None + self.dim = None + self.grid = None + self.velocity_set = velocity_set + + super().__init__(shape, ComputeBackend.NEON) + + def _get_velocity_set(self): + return self.velocity_set + + def _initialize_backend(self): + dev_idx_list = self.config["device_list"] + + if len(self.shape) == 2: + import py_neon + + self.dim = py_neon.Index_3d(self.shape[0], 1, self.shape[1]) + self.neon_stencil = [] + for q in range(self.velocity_set.q): + xval, yval = self.velocity_set._c[:, q] + self.neon_stencil.append([xval, 0, yval]) + + else: + self.dim = neon.Index_3d(self.shape[0], self.shape[1], self.shape[2]) + + self.neon_stencil = [] + for q in range(self.velocity_set.q): + xval, yval, zval = self.velocity_set._c[:, q] + self.neon_stencil.append([xval, yval, zval]) + + self.bk = neon.Backend(runtime=neon.Backend.Runtime.stream, dev_idx_list=dev_idx_list) + self.bk.info_print() + self.grid = neon.dense.dGrid(backend=self.bk, dim=self.dim, sparsity=None, stencil=self.neon_stencil) + pass + + def create_field( + self, + cardinality: int, + dtype: Literal[Precision.FP32, Precision.FP64, Precision.FP16] = None, + fill_value=None, + ): + """Allocate a new Neon field on this grid. + + Parameters + ---------- + cardinality : int + Number of components per voxel (e.g. ``q`` for populations). + dtype : Precision, optional + Element precision. Defaults to the store precision from the + global config. + fill_value : float, optional + If provided every element is set to this value; otherwise the + field is zero-initialized. + + Returns + ------- + neon.dense.dField + The newly allocated field. + """ + dtype = dtype.wp_dtype if dtype else DefaultConfig.default_precision_policy.store_precision.wp_dtype + field = self.grid.new_field( + cardinality=cardinality, + dtype=dtype, + ) + + if fill_value is None: + field.zero_run(stream_idx=0) + else: + field.fill_run(value=fill_value, stream_idx=0) + return field + + def get_neon_backend(self): + """Return the underlying ``neon.Backend`` instance.""" + return self.bk diff --git a/xlb/helper/__init__.py b/xlb/helper/__init__.py index aa6dc961..452cfbce 100644 --- a/xlb/helper/__init__.py +++ b/xlb/helper/__init__.py @@ -1,6 +1,7 @@ -from xlb.helper.nse_solver import create_nse_fields -from xlb.helper.initializers import initialize_eq +from xlb.helper.nse_fields import create_nse_fields +from xlb.helper.initializers import initialize_eq, initialize_multires_eq, CustomInitializer, CustomMultiresInitializer from xlb.helper.check_boundary_overlaps import check_bc_overlaps +from xlb.helper.simulation_manager import MultiresSimulationManager from xlb.helper.ibm_helper import ( reconstruct_mesh_from_vertices_and_faces, transform_mesh, diff --git a/xlb/helper/initializers.py b/xlb/helper/initializers.py index 487d2cfa..bda54e11 100644 --- a/xlb/helper/initializers.py +++ b/xlb/helper/initializers.py @@ -1,8 +1,57 @@ +""" +Initializers for distribution function fields. + +Provides helper functions and Operator subclasses that populate +distribution-function fields with equilibrium values. Two usage patterns +are supported: + +* **Functional helpers** (`initialize_eq`, `initialize_multires_eq`) — + one-shot initialization used during simulation setup. +* **Operator classes** (`CustomInitializer`, `CustomMultiresInitializer`) — + reusable operators that can target the whole domain or a single boundary + condition region, with support for JAX, Warp, and Neon backends. +""" + +import warp as wp +from typing import Any +from xlb import DefaultConfig +from xlb.operator import Operator +from xlb.velocity_set import VelocitySet from xlb.compute_backend import ComputeBackend from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.equilibrium import MultiresQuadraticEquilibrium def initialize_eq(f, grid, velocity_set, precision_policy, compute_backend, rho=None, u=None): + """Initialize a distribution-function field to equilibrium. + + Computes the quadratic equilibrium for the given density and velocity + fields and writes it into *f*. When *rho* or *u* are ``None`` the + defaults are uniform density 1 and zero velocity. + + Parameters + ---------- + f : field + Distribution-function field to populate (modified in-place for + Warp / Neon backends; replaced for JAX). + grid : Grid + Computational grid used to allocate temporary fields. + velocity_set : VelocitySet + Lattice velocity set (e.g. D3Q19). + precision_policy : PrecisionPolicy + Precision policy for compute / store dtypes. + compute_backend : ComputeBackend + Active compute backend (JAX, WARP, or NEON). + rho : field, optional + Density field. Defaults to uniform 1.0. + u : field, optional + Velocity field. Defaults to uniform 0.0. + + Returns + ------- + field + The initialized distribution-function field. + """ if rho is None: rho = grid.create_field(cardinality=1, fill_value=1.0, dtype=precision_policy.compute_precision) if u is None: @@ -11,10 +60,250 @@ def initialize_eq(f, grid, velocity_set, precision_policy, compute_backend, rho= if compute_backend == ComputeBackend.JAX: f = equilibrium(rho, u) - elif compute_backend == ComputeBackend.WARP: f = equilibrium(rho, u, f) + elif compute_backend == ComputeBackend.NEON: + f = equilibrium(rho, u, f) + else: + raise NotImplementedError(f"Backend {compute_backend} not implemented") del rho, u return f + + +def initialize_multires_eq(f, grid, velocity_set, precision_policy, backend, rho, u): + """Initialize a multi-resolution distribution-function field to equilibrium. + + Parameters + ---------- + f : field + Multi-resolution distribution-function field to populate. + grid : NeonMultiresGrid + Multi-resolution grid. + velocity_set : VelocitySet + Lattice velocity set. + precision_policy : PrecisionPolicy + Precision policy. + backend : ComputeBackend + Compute backend (expected to be NEON). + rho : field + Density field across all grid levels. + u : field + Velocity field across all grid levels. + + Returns + ------- + field + The initialized multi-resolution distribution-function field. + """ + equilibrium = MultiresQuadraticEquilibrium() + return equilibrium(rho, u, f, stream=0) + + +class CustomInitializer(Operator): + """Operator that initializes distribution functions to equilibrium. + + When ``bc_id == -1`` (default) the entire domain is initialized with the + given constant velocity and density. Otherwise only voxels whose + ``bc_mask`` matches *bc_id* are set while the rest receive the + weight-only equilibrium (zero velocity, unit density). + + Supports JAX, Warp, and Neon backends. + + Parameters + ---------- + constant_velocity_vector : list of float + Macroscopic velocity [ux, uy, uz] used for initialization. + constant_density : float + Macroscopic density used for initialization. + bc_id : int + Boundary-condition ID to target. ``-1`` means the whole domain. + initialization_operator : Operator, optional + Equilibrium operator to use. Defaults to ``QuadraticEquilibrium``. + velocity_set : VelocitySet, optional + precision_policy : PrecisionPolicy, optional + compute_backend : ComputeBackend, optional + """ + + def __init__( + self, + constant_velocity_vector=[0.0, 0.0, 0.0], + constant_density: float = 1.0, + bc_id: int = -1, + initialization_operator=None, + velocity_set: VelocitySet = None, + precision_policy=None, + compute_backend=None, + ): + self.bc_id = bc_id + self.constant_velocity_vector = constant_velocity_vector + self.constant_density = constant_density + if initialization_operator is None: + compute_backend = compute_backend or DefaultConfig.default_backend + self.initialization_operator = QuadraticEquilibrium( + velocity_set=velocity_set or DefaultConfig.velocity_set, + precision_policy=precision_policy or DefaultConfig.precision_policy, + compute_backend=compute_backend if compute_backend == ComputeBackend.JAX else ComputeBackend.WARP, + ) + super().__init__(velocity_set, precision_policy, compute_backend) + + @Operator.register_backend(ComputeBackend.JAX) + def jax_implementation(self, bc_mask, f_field): + from xlb.grid import grid_factory + import jax.numpy as jnp + + grid_shape = f_field.shape[1:] + grid = grid_factory(grid_shape) + rho_init = grid.create_field(cardinality=1, fill_value=self.constant_density, dtype=self.precision_policy.compute_precision) + u_init = grid.create_field(cardinality=self.velocity_set.d, fill_value=0.0, dtype=self.precision_policy.compute_precision) + _vel = jnp.array(self.constant_velocity_vector)[(...,) + (None,) * self.velocity_set.d] + if self.bc_id == -1: + u_init += _vel + else: + u_init = jnp.where(bc_mask[0] == self.bc_id, u_init + _vel, u_init) + return self.initialization_operator(rho_init, u_init) + + def _construct_warp(self): + _q = self.velocity_set.q + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + _u = _u_vec(self.constant_velocity_vector[0], self.constant_velocity_vector[1], self.constant_velocity_vector[2]) + _rho = self.compute_dtype(self.constant_density) + _w = self.velocity_set.w + bc_id = self.bc_id + + @wp.func + def functional_local(index: Any, bc_mask: Any, f_field: Any): + # Check if the index corresponds to the outlet + if self.read_field(bc_mask, index, 0) == bc_id: + _f_init = self.initialization_operator.warp_functional(_rho, _u) + for l in range(_q): + self.write_field(f_field, index, l, self.store_dtype(_f_init[l])) + else: + # In the rest of the domain, we assume zero velocity and equilibrium distribution. + for l in range(_q): + self.write_field(f_field, index, l, self.store_dtype(_w[l])) + + @wp.func + def functional_domain(index: Any, bc_mask: Any, f_field: Any): + # If bc_id is -1, initialize the entire domain according to the custom initialization operator for the given velocity + _f_init = self.initialization_operator.warp_functional(_rho, _u) + for l in range(_q): + self.write_field(f_field, index, l, self.store_dtype(_f_init[l])) + + # Set the functional based on whether we are initializing a specific BC or the entire domain + functional = functional_local if self.bc_id != -1 else functional_domain + + # Construct the warp kernel + @wp.kernel + def kernel( + bc_mask: wp.array4d(dtype=wp.uint8), + f_field: wp.array4d(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Set the velocity at the outlet (i.e. where i = nx-1) + functional(index, bc_mask, f_field) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, bc_mask, f_field): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[bc_mask, f_field], + dim=f_field.shape[1:], + ) + return f_field + + def _construct_neon(self): + import neon + + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + + @neon.Container.factory(name="CustomInitializer") + def container( + bc_mask: Any, + f_field: Any, + ): + def launcher(loader: neon.Loader): + loader.set_grid(f_field.get_grid()) + f_field_pn = loader.get_write_handle(f_field) + bc_mask_pn = loader.get_read_handle(bc_mask) + + @wp.func + def kernel(index: Any): + # apply the functional + functional(index, bc_mask_pn, f_field_pn) + + loader.declare_kernel(kernel) + + return launcher + + return _, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, bc_mask, f_field, stream=0): + # Launch the neon container + c = self.neon_container(bc_mask, f_field) + c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon) + return f_field + + +class CustomMultiresInitializer(CustomInitializer): + """Multi-resolution variant of :class:`CustomInitializer`. + + Iterates over all grid levels and initializes distribution functions + using the Neon multi-resolution container API. + """ + + def __init__( + self, + constant_velocity_vector=[0.0, 0.0, 0.0], + constant_density: float = 1.0, + bc_id: int = -1, + initialization_operator=None, + velocity_set: VelocitySet = None, + precision_policy=None, + compute_backend=None, + ): + super().__init__(constant_velocity_vector, constant_density, bc_id, initialization_operator, velocity_set, precision_policy, compute_backend) + + def _construct_neon(self): + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + + @neon.Container.factory(name="CustomMultiresInitializer") + def container( + bc_mask: Any, + f_field: Any, + level: Any, + ): + def launcher(loader: neon.Loader): + loader.set_mres_grid(f_field.get_grid(), level) + f_field_pn = loader.get_mres_write_handle(f_field) + bc_mask_pn = loader.get_mres_read_handle(bc_mask) + + @wp.func + def kernel(index: Any): + # apply the functional + functional(index, bc_mask_pn, f_field_pn) + + loader.declare_kernel(kernel) + + return launcher + + return _, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, bc_mask, f_field, stream=0): + grid = bc_mask.get_grid() + for level in range(grid.num_levels): + # Launch the neon container + c = self.neon_container(bc_mask, f_field, level) + c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon) + return f_field diff --git a/xlb/helper/nse_solver.py b/xlb/helper/nse_fields.py similarity index 70% rename from xlb/helper/nse_solver.py rename to xlb/helper/nse_fields.py index 361dc6e3..81e01006 100644 --- a/xlb/helper/nse_solver.py +++ b/xlb/helper/nse_fields.py @@ -1,6 +1,15 @@ +""" +Factory function for creating the standard Navier-Stokes field arrays. + +Returns the distribution-function pair (*f_0*, *f_1*), the boundary- +condition mask, and the missing-population mask, all allocated on the +given grid and backend. +""" + from xlb import DefaultConfig from xlb.grid import grid_factory from xlb.precision_policy import Precision +from xlb.compute_backend import ComputeBackend from typing import Tuple @@ -30,12 +39,17 @@ def create_nse_fields( if grid is None: if grid_shape is None: raise ValueError("grid_shape must be provided when grid is None") - grid = grid_factory(grid_shape, compute_backend=compute_backend) + grid = grid_factory(grid_shape, compute_backend=compute_backend, velocity_set=velocity_set) # Create fields f_0 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) f_1 = grid.create_field(cardinality=velocity_set.q, dtype=precision_policy.store_precision) - missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.BOOL) bc_mask = grid.create_field(cardinality=1, dtype=Precision.UINT8) + if compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]: + # For WARP and NEON, we use UINT8 for missing mask + missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.UINT8) + else: + # For JAX, we use bool for missing mask + missing_mask = grid.create_field(cardinality=velocity_set.q, dtype=Precision.BOOL) return grid, f_0, f_1, missing_mask, bc_mask diff --git a/xlb/helper/simulation_manager.py b/xlb/helper/simulation_manager.py new file mode 100644 index 00000000..ad60b939 --- /dev/null +++ b/xlb/helper/simulation_manager.py @@ -0,0 +1,244 @@ +""" +High-level simulation manager for multi-resolution LBM on the Neon backend. + +:class:`MultiresSimulationManager` orchestrates the complete simulation +lifecycle: field allocation, boundary-condition setup, coalescence-factor +precomputation, and the recursive time-stepping skeleton that correctly +interleaves coarse and fine grid updates. +""" + +import warp as wp +from xlb.operator.stepper import MultiresIncompressibleNavierStokesStepper +from xlb.operator.macroscopic import MultiresMacroscopic +from xlb.mres_perf_optimization_type import MresPerfOptimizationType + + +class MultiresSimulationManager(MultiresIncompressibleNavierStokesStepper): + """Orchestrates multi-resolution LBM simulations on the Neon backend. + + Inherits from :class:`MultiresIncompressibleNavierStokesStepper` and + adds field management, omega computation across levels, and the + recursive skeleton builder that encodes the multi-resolution + time-stepping order. + + Parameters + ---------- + omega_finest : float + Relaxation parameter at the finest grid level. + grid : NeonMultiresGrid + Multi-resolution grid. + boundary_conditions : list of BoundaryCondition + Boundary conditions to apply. + collision_type : str + ``"BGK"`` or ``"KBC"``. + forcing_scheme : str + Forcing scheme (used only when *force_vector* is given). + force_vector : array-like, optional + External body force. + initializer : Operator, optional + Custom initializer for distribution functions. If ``None`` + the default equilibrium initialization is used. + mres_perf_opt : MresPerfOptimizationType + Performance optimization strategy. + """ + + def __init__( + self, + omega_finest, + grid, + boundary_conditions=[], + collision_type="BGK", + forcing_scheme="exact_difference", + force_vector=None, + initializer=None, + mres_perf_opt: MresPerfOptimizationType = MresPerfOptimizationType.NAIVE_COLLIDE_STREAM, + ): + super().__init__(grid, boundary_conditions, collision_type, forcing_scheme, force_vector) + + self.initializer = initializer + self.count_levels = grid.count_levels + self.omega_list = [self.compute_omega(omega_finest, level) for level in range(self.count_levels)] + self.mres_perf_opt = mres_perf_opt + # Create fields + self.rho = grid.create_field(cardinality=1, dtype=self.precision_policy.store_precision) + self.u = grid.create_field(cardinality=3, dtype=self.precision_policy.store_precision) + self.coalescence_factor = grid.create_field(cardinality=self.velocity_set.q, dtype=self.precision_policy.store_precision) + + for level in range(self.count_levels): + self.u.fill_run(level, 0.0, 0) + self.rho.fill_run(level, 1.0, 0) + self.coalescence_factor.fill_run(level, 0.0, 0) + + # Prepare fields + self.f_0, self.f_1, self.bc_mask, self.missing_mask = self.prepare_fields(self.rho, self.u, self.initializer) + self.prepare_coalescence_count(coalescence_factor=self.coalescence_factor, bc_mask=self.bc_mask) + + self.iteration_idx = -1 + self.macro = MultiresMacroscopic( + compute_backend=self.compute_backend, + precision_policy=self.precision_policy, + velocity_set=self.velocity_set, + ) + + # Construct the stepper skeleton + self._construct_stepper_skeleton() + + def compute_omega(self, omega_finest, level): + """ + Compute the relaxation parameter omega at a given grid level based on the finest level omega. + We select a refinement ratio of 2 where a coarse cell at level L is uniformly divided into 2^d cells + where d is the dimension. to arrive at level L - 1, or in other words ∆x_{L-1} = ∆x_L/2. + For neighboring cells that interface two grid levels, a maximum jump in grid level of ∆L = 1 is + allowed. Due to acoustic scaling which requires the speed of sound cs to remain constant across various grid levels, + ∆tL ∝ ∆xL and hence ∆t_{L-1} = ∆t_{L}/2. In addition, the fluid viscosity \nu must also remain constant on each + grid level which leads to the following relationship for the relaxation parameter omega at grid level L base + on the finest grid level omega_finest. + + Args: + omega_finest: Relaxation parameter at the finest grid level. + level: Current grid level (0-indexed, with 0 being the finest level). + + Returns: + Relaxation parameter omega at the specified grid level. + """ + omega0 = omega_finest + return 2 ** (level + 1) * omega0 / ((2**level - 1.0) * omega0 + 2.0) + + def export_macroscopic(self, fname_prefix): + """Compute macroscopic fields and export velocity to a VTI file. + + Parameters + ---------- + fname_prefix : str + Output filename prefix. The iteration index is appended + automatically (e.g. ``"u_"`` → ``"u_42.vti"``). + """ + print(f"exporting macroscopic: #levels {self.count_levels}") + self.macro(self.f_0, self.bc_mask, self.rho, self.u, streamId=0) + + wp.synchronize() + self.u.update_host(0) + wp.synchronize() + self.u.export_vti(f"{fname_prefix}{self.iteration_idx}.vti", "u") + print("DONE exporting macroscopic") + + return + + def step(self): + """Advance the simulation by one coarsest-level timestep. + + Internally this executes the pre-compiled Neon skeleton which + performs the correct number of sub-steps at each finer level + according to the acoustic-scaling time refinement ratio. + """ + self.iteration_idx = self.iteration_idx + 1 + self.sk.run() + + def _build_recursion(self, level, app, config): + """Unified multi-resolution recursion builder. + + config keys: + finest_ops: list of (op_name, swap_fields, extra_kwargs) for level 0, + or None to treat level 0 like any coarse level. + coarse_collide_ops: list of op_names for coarse collision. + coarse_stream_ops: list of (op_name, extra_kwargs) for coarse streaming. + fuse_finest: if True, recurse once (not twice) when child is at level 0. + """ + if level < 0: + return + + omega = self.omega_list[level] + fields = dict(f_0_fd=self.f_0, f_1_fd=self.f_1, bc_mask_fd=self.bc_mask, missing_mask_fd=self.missing_mask) + fields_swapped = dict(f_0_fd=self.f_1, f_1_fd=self.f_0, bc_mask_fd=self.bc_mask, missing_mask_fd=self.missing_mask) + + if level == 0 and config["finest_ops"] is not None: + for op_name, swap, extra in config["finest_ops"]: + base = fields_swapped if swap else fields + self.add_to_app(app=app, op_name=op_name, level=level, **base, omega=omega, **extra) + return + + for op_name in config["coarse_collide_ops"]: + self.add_to_app(app=app, op_name=op_name, level=level, **fields, omega=omega, timestep=0) + + if config["fuse_finest"] and level - 1 == 0: + self._build_recursion(level - 1, app, config) + else: + self._build_recursion(level - 1, app, config) + self._build_recursion(level - 1, app, config) + + for op_name, extra in config["coarse_stream_ops"]: + self.add_to_app(app=app, op_name=op_name, level=level, **fields_swapped, **extra) + + def _construct_stepper_skeleton(self): + import neon + + """Build the Neon skeleton that encodes the recursive time-stepping order. + + The skeleton is a list of Neon container invocations that, when + executed in sequence, perform one coarsest-level timestep with the + correct sub-cycling at finer levels. The structure depends on + ``self.mres_perf_opt``. + """ + self.app = [] + + stream_abc = {"omega": self.coalescence_factor, "timestep": 0} + + # Finest-level op descriptors: (op_name, swap_f0_f1, extra_kwargs) + fused_pull_finest = [ + ("finest_fused_pull", False, {"timestep": 0, "is_f1_the_explosion_src_field": True}), + ("finest_fused_pull", True, {"timestep": 0, "is_f1_the_explosion_src_field": False}), + ] + sfv_fused_pull_finest = [ + ("CFV_finest_fused_pull", False, {"timestep": 0, "is_f1_the_explosion_src_field": True}), + ("SFV_finest_fused_pull", False, {}), + ("CFV_finest_fused_pull", True, {"timestep": 0, "is_f1_the_explosion_src_field": False}), + ("SFV_finest_fused_pull", True, {}), + ] + + configs = { + MresPerfOptimizationType.NAIVE_COLLIDE_STREAM: { + "finest_ops": None, + "coarse_collide_ops": ["collide_coarse"], + "coarse_stream_ops": [("stream_coarse_step_ABC", stream_abc)], + "fuse_finest": False, + }, + MresPerfOptimizationType.FUSION_AT_FINEST: { + "finest_ops": fused_pull_finest, + "coarse_collide_ops": ["collide_coarse"], + "coarse_stream_ops": [("stream_coarse_step_ABC", stream_abc)], + "fuse_finest": True, + }, + MresPerfOptimizationType.FUSION_AT_FINEST_SFV: { + "finest_ops": sfv_fused_pull_finest, + "coarse_collide_ops": ["collide_coarse"], + "coarse_stream_ops": [("stream_coarse_step_ABC", stream_abc)], + "fuse_finest": True, + }, + MresPerfOptimizationType.FUSION_AT_FINEST_SFV_ALL: { + "finest_ops": sfv_fused_pull_finest, + "coarse_collide_ops": ["CFV_collide_coarse", "SFV_collide_coarse"], + "coarse_stream_ops": [("SFV_stream_coarse_step_ABC", stream_abc), ("SFV_stream_coarse_step", {})], + "fuse_finest": True, + }, + } + + config = configs.get(self.mres_perf_opt) + if config is None: + raise ValueError(f"Unknown optimization level: {self.mres_perf_opt}") + + # Pre-recursion SFV mask setup + if self.mres_perf_opt == MresPerfOptimizationType.FUSION_AT_FINEST_SFV: + wp.synchronize() + self.neon_container["SFV_reset_bc_mask"](0, self.f_0, self.f_1, self.bc_mask, self.bc_mask).run(0) + wp.synchronize() + elif self.mres_perf_opt == MresPerfOptimizationType.FUSION_AT_FINEST_SFV_ALL: + wp.synchronize() + for l in range(self.f_0.get_grid().num_levels): + self.neon_container["SFV_reset_bc_mask"](l, self.f_0, self.f_1, self.bc_mask, self.bc_mask).run(0) + wp.synchronize() + + self._build_recursion(self.count_levels - 1, self.app, config) + + bk = self.grid.get_neon_backend() + self.sk = neon.Skeleton(backend=bk) + self.sk.sequence("mres_nse_stepper", self.app) diff --git a/xlb/mres_perf_optimization_type.py b/xlb/mres_perf_optimization_type.py new file mode 100644 index 00000000..797699f5 --- /dev/null +++ b/xlb/mres_perf_optimization_type.py @@ -0,0 +1,78 @@ +""" +Multi-resolution performance-optimization strategies. + +Defines the kernel-fusion levels available for the multi-resolution LBM +stepper and provides CLI argument parsing helpers. +""" + +import argparse +from enum import Enum, auto + + +class MresPerfOptimizationType(Enum): + """ + Enumeration of available optimization strategies for the LBM solver. + + Supports parsing from either the enum member name (case-insensitive) + or its integer value, and provides a method to build the CLI parser. + """ + + NAIVE_COLLIDE_STREAM = auto() + FUSION_AT_FINEST = auto() + FUSION_AT_FINEST_SFV = auto() + FUSION_AT_FINEST_SFV_ALL = auto() + + @staticmethod + def from_string(value: str) -> "MresPerfOptimizationType": + """ + Parse a string to an OptimizationType. + + Accepts either the enum member name (case-insensitive) or its integer value. + + Args: + value: The enum name (e.g. 'naive_collide_stream') or integer value (e.g. '0'). + + Returns: + An OptimizationType member. + + Raises: + argparse.ArgumentTypeError: If the input is invalid. + """ + # Attempt to parse by name (case-insensitive) + key = value.strip().upper() + if key in MresPerfOptimizationType.__members__: + return MresPerfOptimizationType[key] + + # Attempt to parse by integer value + try: + int_value = int(value) + return MresPerfOptimizationType(int_value) + except (ValueError, KeyError): + valid_options = ", ".join(f"{member.name}({member.value})" for member in MresPerfOptimizationType) + raise argparse.ArgumentTypeError(f"Invalid OptimizationType {value!r}. Choose from: {valid_options}.") + + def __str__(self) -> str: + """ + Return a human-readable string for the enum member. + """ + return self.name + + @staticmethod + def build_arg_parser() -> argparse.ArgumentParser: + """ + Create and configure the argument parser with optimization option. + + Returns: + A configured ArgumentParser instance. + """ + parser = argparse.ArgumentParser(description="Run the LBM multiresolution simulation with specified optimizations.") + # Dynamically generate help text from enum members + valid_options = ", ".join(f"{member.name}({member.value})" for member in MresPerfOptimizationType) + parser.add_argument( + "-o", + "--optimization", + type=MresPerfOptimizationType.from_string, + default=MresPerfOptimizationType.NAIVE_COLLIDE_STREAM, + help=f"Select optimization strategy: {valid_options}", + ) + return parser diff --git a/xlb/operator/boundary_condition/__init__.py b/xlb/operator/boundary_condition/__init__.py index 7c87f58c..8be2f226 100644 --- a/xlb/operator/boundary_condition/__init__.py +++ b/xlb/operator/boundary_condition/__init__.py @@ -1,4 +1,4 @@ -from xlb.operator.boundary_condition.helper_functions_bc import HelperFunctionsBC +from xlb.operator.boundary_condition.helper_functions_bc import HelperFunctionsBC, EncodeAuxiliaryData, MultiresEncodeAuxiliaryData from xlb.operator.boundary_condition.boundary_condition import BoundaryCondition from xlb.operator.boundary_condition.boundary_condition_registry import BoundaryConditionRegistry from xlb.operator.boundary_condition.bc_equilibrium import EquilibriumBC @@ -8,4 +8,4 @@ from xlb.operator.boundary_condition.bc_zouhe import ZouHeBC from xlb.operator.boundary_condition.bc_regularized import RegularizedBC from xlb.operator.boundary_condition.bc_extrapolation_outflow import ExtrapolationOutflowBC -from xlb.operator.boundary_condition.bc_grads_approximation import GradsApproximationBC +from xlb.operator.boundary_condition.bc_hybrid import HybridBC diff --git a/xlb/operator/boundary_condition/bc_do_nothing.py b/xlb/operator/boundary_condition/bc_do_nothing.py index aeefd788..7c5ae0b5 100644 --- a/xlb/operator/boundary_condition/bc_do_nothing.py +++ b/xlb/operator/boundary_condition/bc_do_nothing.py @@ -1,5 +1,8 @@ """ -Base class for boundary conditions in a LBM simulation. +Do-nothing boundary condition. + +Skips the streaming step at tagged boundary voxels, leaving the +populations unchanged. """ import jax.numpy as jnp @@ -16,6 +19,7 @@ ImplementationStep, BoundaryCondition, ) +from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod class DoNothingBC(BoundaryCondition): @@ -31,6 +35,7 @@ def __init__( compute_backend: ComputeBackend = None, indices=None, mesh_vertices=None, + voxelization_method: MeshVoxelizationMethod = None, ): super().__init__( ImplementationStep.STREAMING, @@ -39,6 +44,7 @@ def __init__( compute_backend, indices, mesh_vertices, + voxelization_method, ) @Operator.register_backend(ComputeBackend.JAX) @@ -74,3 +80,12 @@ def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): dim=f_pre.shape[1:], ) return f_post + + def _construct_neon(self): + functional, _ = self._construct_warp() + return functional, None + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # raise exception as this feature is not implemented yet + raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.") diff --git a/xlb/operator/boundary_condition/bc_equilibrium.py b/xlb/operator/boundary_condition/bc_equilibrium.py index 85cfd653..85ebe92b 100644 --- a/xlb/operator/boundary_condition/bc_equilibrium.py +++ b/xlb/operator/boundary_condition/bc_equilibrium.py @@ -12,18 +12,30 @@ from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend -from xlb.operator.equilibrium.equilibrium import Equilibrium -from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.equilibrium import Equilibrium, QuadraticEquilibrium from xlb.operator.operator import Operator from xlb.operator.boundary_condition.boundary_condition import ( ImplementationStep, BoundaryCondition, ) +from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod class EquilibriumBC(BoundaryCondition): - """ - Full Bounce-back boundary condition for a lattice Boltzmann method simulation. + """Equilibrium boundary condition. + + Sets populations at tagged voxels to the equilibrium distribution + computed from the prescribed macroscopic density *rho* and velocity + *u*. Commonly used as an inlet or outlet condition. + + Parameters + ---------- + rho : float + Prescribed macroscopic density. + u : tuple of float + Prescribed macroscopic velocity ``(ux, uy, uz)``. + equilibrium_operator : Operator, optional + Equilibrium operator. Defaults to ``QuadraticEquilibrium``. """ def __init__( @@ -36,6 +48,7 @@ def __init__( compute_backend: ComputeBackend = None, indices=None, mesh_vertices=None, + voxelization_method: MeshVoxelizationMethod = None, ): # Store the equilibrium information self.rho = rho @@ -53,6 +66,7 @@ def __init__( compute_backend, indices, mesh_vertices, + voxelization_method, ) @Operator.register_backend(ComputeBackend.JAX) @@ -90,8 +104,17 @@ def functional( return functional, kernel + def _construct_neon(self): + # Redefine the equilibrium operators for the neon backend + # This is because the neon backend relies on the warp functionals for its operations. + self.equilibrium_operator = QuadraticEquilibrium(compute_backend=ComputeBackend.WARP) + + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + return functional, None + @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): + def warp_launch(self, f_pre, f_post, bc_mask, missing_mask): # Launch the warp kernel wp.launch( self.warp_kernel, diff --git a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py index 884e691e..a1a26e0b 100644 --- a/xlb/operator/boundary_condition/bc_extrapolation_outflow.py +++ b/xlb/operator/boundary_condition/bc_extrapolation_outflow.py @@ -1,5 +1,14 @@ """ -Base class for boundary conditions in a LBM simulation. +Extrapolation outflow boundary condition. + +Uses first-order extrapolation from the interior to set the unknown +populations at outflow boundaries, avoiding strong wave reflections. + +Reference +--------- +Geier, M. et al. (2015). "The cumulant lattice Boltzmann equation in +three dimensions: Theory and validation." *Computers & Mathematics +with Applications*, 70(4), 507-547. """ import jax.numpy as jnp @@ -19,6 +28,7 @@ ImplementationStep, BoundaryCondition, ) +from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod class ExtrapolationOutflowBC(BoundaryCondition): @@ -42,6 +52,7 @@ def __init__( compute_backend: ComputeBackend = None, indices=None, mesh_vertices=None, + voxelization_method: MeshVoxelizationMethod = None, ): # Call the parent constructor super().__init__( @@ -51,16 +62,20 @@ def __init__( compute_backend, indices, mesh_vertices, + voxelization_method, ) # find and store the normal vector using indices - self._get_normal_vec(indices) + if self.compute_backend == ComputeBackend.JAX: + self._get_normal_vectors(indices) # Unpack the two warp functionals needed for this BC! if self.compute_backend == ComputeBackend.WARP: - self.warp_functional, self.update_bc_auxilary_data = self.warp_functional + self.warp_functional, self.assemble_auxiliary_data = self.warp_functional + elif self.compute_backend == ComputeBackend.NEON: + self.neon_functional, self.assemble_auxiliary_data = self.neon_functional - def _get_normal_vec(self, indices): + def _get_normal_vectors(self, indices): # Get the frequency count and most common element directly freq_counts = [Counter(coord).most_common(1)[0] for coord in indices] @@ -89,9 +104,10 @@ def _roll(self, fld, vec): return jnp.roll(fld, (vec[0], vec[1], vec[2]), axis=(1, 2, 3)) @partial(jit, static_argnums=(0,), inline=True) - def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): + def assemble_auxiliary_data(self, f_pre, f_post, bc_mask, missing_mask): """ - Update the auxilary distribution functions for the boundary condition. + Prepare time-dependent dynamic data for imposing the boundary condition in the next iteration after streaming. + We use directions that leave the domain for storing this prepared data. Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision """ sound_speed = 1.0 / jnp.sqrt(3.0) @@ -102,7 +118,7 @@ def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): # Roll boundary mask in the opposite of the normal vector to mask its next immediate neighbour neighbour = self._roll(boundary, -self.normal) - # gather post-streaming values associated with previous time-step to construct the auxilary data for BC + # gather post-streaming values associated with previous time-step to construct the required data for BC fpop = jnp.where(boundary, f_pre, f_post) fpop_neighbour = jnp.where(neighbour, f_pre, f_post) @@ -168,7 +184,7 @@ def functional( return _f @wp.func - def update_bc_auxilary_data( + def assemble_auxiliary_data_warp( index: Any, timestep: Any, missing_mask: Any, @@ -177,9 +193,9 @@ def update_bc_auxilary_data( _f_pre: Any, _f_post: Any, ): - # Update the auxilary data for this BC using the neighbour's populations stored in f_aux and - # f_pre (post-streaming values of the current voxel). We use directions that leave the domain - # for storing this prepared data. + # Prepare time-dependent dynamic data for imposing the boundary condition in the next iteration after streaming. + # We use directions that leave the domain for storing this prepared data. + # Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision _f = _f_post nv = get_normal_vectors(missing_mask) for l in range(self.velocity_set.q): @@ -194,9 +210,42 @@ def update_bc_auxilary_data( _f[_opp_indices[l]] = (self.compute_dtype(1.0) - sound_speed) * _f_pre[l] + sound_speed * f_aux return _f + @wp.func + def assemble_auxiliary_data_neon( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, + _f_pre: Any, + _f_post: Any, + level: Any = 0, + ): + # Prepare time-dependent dynamic data for imposing the boundary condition in the next iteration after streaming. + # We use directions that leave the domain for storing this prepared data. + # Since this function is called post-collisiotn: f_pre = f_post_stream and f_post = f_post_collision + _f = _f_post + nv = get_normal_vectors(missing_mask) + for lattice_dir in range(self.velocity_set.q): + if missing_mask[lattice_dir] == wp.uint8(1): + # f_0 is the post-collision values of the current time-step + # Get pull index associated with the "neighbours" pull_index + offset = wp.vec3i(-_c[0, lattice_dir], -_c[1, lattice_dir], -_c[2, lattice_dir]) + for d in range(self.velocity_set.d): + offset[d] = offset[d] - nv[d] + offset_pull_index = wp.neon_ngh_idx(wp.int8(offset[0]), wp.int8(offset[1]), wp.int8(offset[2])) + + # The following is the post-streaming values of the neighbor cell + # This function reads a field value at a given neighboring index and direction. + unused_is_valid = wp.bool(False) + f_aux = self.compute_dtype(wp.neon_read_ngh(f_0, index, offset_pull_index, lattice_dir, self.store_dtype(0.0), unused_is_valid)) + _f[_opp_indices[lattice_dir]] = (self.compute_dtype(1.0) - sound_speed) * _f_pre[lattice_dir] + sound_speed * f_aux + return _f + kernel = self._construct_kernel(functional) + assemble_auxiliary_data = assemble_auxiliary_data_warp if self.compute_backend == ComputeBackend.WARP else assemble_auxiliary_data_neon - return (functional, update_bc_auxilary_data), kernel + return (functional, assemble_auxiliary_data), kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, _f_pre, _f_post, bc_mask, missing_mask): @@ -207,3 +256,12 @@ def warp_implementation(self, _f_pre, _f_post, bc_mask, missing_mask): dim=_f_pre.shape[1:], ) return _f_post + + def _construct_neon(self): + functional, _ = self._construct_warp() + return functional, None + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # raise exception as this feature is not implemented yet + raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.") diff --git a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py index 995e2ff9..4b7f8f0f 100644 --- a/xlb/operator/boundary_condition/bc_fullway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_fullway_bounce_back.py @@ -1,5 +1,8 @@ """ -Base class for boundary conditions in a LBM simulation. +Full-way bounce-back boundary condition. + +Reverses every population at tagged solid voxels, effectively +imposing a no-slip wall located *on* the grid node. """ import jax.numpy as jnp @@ -17,6 +20,7 @@ BoundaryCondition, ImplementationStep, ) +from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod class FullwayBounceBackBC(BoundaryCondition): @@ -31,6 +35,7 @@ def __init__( compute_backend: ComputeBackend = None, indices=None, mesh_vertices=None, + voxelization_method: MeshVoxelizationMethod = None, ): super().__init__( ImplementationStep.COLLISION, @@ -39,6 +44,7 @@ def __init__( compute_backend, indices, mesh_vertices, + voxelization_method, ) @Operator.register_backend(ComputeBackend.JAX) @@ -84,3 +90,7 @@ def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): dim=f_pre.shape[1:], ) return f_post + + def _construct_neon(self): + functional, _ = self._construct_warp() + return functional, None diff --git a/xlb/operator/boundary_condition/bc_grads_approximation.py b/xlb/operator/boundary_condition/bc_grads_approximation.py deleted file mode 100644 index 22fbb4ec..00000000 --- a/xlb/operator/boundary_condition/bc_grads_approximation.py +++ /dev/null @@ -1,321 +0,0 @@ -""" -Base class for boundary conditions in a LBM simulation. -""" - -import jax.numpy as jnp -from jax import jit -import jax.lax as lax -from functools import partial -import warp as wp -from typing import Any -from collections import Counter -import numpy as np - -from xlb.velocity_set.velocity_set import VelocitySet -from xlb.precision_policy import PrecisionPolicy -from xlb.compute_backend import ComputeBackend -from xlb.operator.operator import Operator -from xlb.operator.macroscopic import Macroscopic -from xlb.operator.macroscopic.zero_moment import ZeroMoment -from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux -from xlb.operator.equilibrium import QuadraticEquilibrium -from xlb.operator.boundary_condition.boundary_condition import ( - ImplementationStep, - BoundaryCondition, -) - - -class GradsApproximationBC(BoundaryCondition): - """ - Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and - Dirichlet BCs [2] - [1] S. Chikatamarla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann - simulations", Europhys. Lett. 74, 215 (2006). - [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and - stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. - - """ - - def __init__( - self, - velocity_set: VelocitySet = None, - precision_policy: PrecisionPolicy = None, - compute_backend: ComputeBackend = None, - indices=None, - mesh_vertices=None, - ): - # TODO: the input velocity must be suitably stored elesewhere when mesh is moving. - self.u = (0, 0, 0) - - # Call the parent constructor - super().__init__( - ImplementationStep.STREAMING, - velocity_set, - precision_policy, - compute_backend, - indices, - mesh_vertices, - ) - - # Instantiate the operator for computing macroscopic values - self.macroscopic = Macroscopic() - self.zero_moment = ZeroMoment() - self.equilibrium = QuadraticEquilibrium() - self.momentum_flux = MomentumFlux() - - # This BC needs implicit distance to the mesh - self.needs_mesh_distance = True - - # If this BC is defined using indices, it would need padding in order to find missing directions - # when imposed on a geometry that is in the domain interior - if self.mesh_vertices is None: - assert self.indices is not None - self.needs_padding = True - - # Raise error if used for 2d examples: - if self.velocity_set.d == 2: - raise NotImplementedError("This BC is not implemented in 2D!") - - # if indices is not None: - # # this BC would be limited to stationary boundaries - # # assert mesh_vertices is None - # if mesh_vertices is not None: - # # this BC would be applicable for stationary and moving boundaries - # assert indices is None - # if mesh_velocity_function is not None: - # # mesh is moving and/or deforming - - assert self.compute_backend == ComputeBackend.WARP, "This BC is currently only implemented with the Warp backend!" - - @Operator.register_backend(ComputeBackend.JAX) - @partial(jit, static_argnums=(0)) - def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): - # TODO - raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") - return - - def _construct_warp(self): - # Set local variables and constants - _c = self.velocity_set.c - _q = self.velocity_set.q - _d = self.velocity_set.d - _w = self.velocity_set.w - _qi = self.velocity_set.qi - _opp_indices = self.velocity_set.opp_indices - _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) - _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _u_wall = _u_vec(self.u[0], self.u[1], self.u[2]) if _d == 3 else _u_vec(self.u[0], self.u[1]) - # diagonal = wp.vec3i(0, 3, 5) if _d == 3 else wp.vec2i(0, 2) - - @wp.func - def regularize_fpop( - missing_mask: Any, - rho: Any, - u: Any, - fpop: Any, - ): - """ - Regularizes the distribution functions by adding non-equilibrium contributions based on second moments of fpop. - """ - # Compute momentum flux of off-equilibrium populations for regularization: Pi^1 = Pi^{neq} - feq = self.equilibrium.warp_functional(rho, u) - f_neq = fpop - feq - PiNeq = self.momentum_flux.warp_functional(f_neq) - - # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) - nt = _d * (_d + 1) // 2 - for l in range(_q): - QiPi1 = self.compute_dtype(0.0) - for t in range(nt): - QiPi1 += _qi[l, t] * PiNeq[t] - - # assign all populations based on eq 45 of Latt et al (2008) - # fneq ~ f^1 - fpop1 = self.compute_dtype(4.5) * _w[l] * QiPi1 - fpop[l] = feq[l] + fpop1 - return fpop - - @wp.func - def grads_approximate_fpop( - missing_mask: Any, - rho: Any, - u: Any, - f_post: Any, - ): - # Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and - # Dirichlet BCs [2] - # [1] S. Chikatax`marla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann - # simulations", Europhys. Lett. 74, 215 (2006). - # [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and - # stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. - - # Note: See also self.regularize_fpop function which is somewhat similar. - - # Compute pressure tensor Pi using all f_post-streaming values - Pi = self.momentum_flux.warp_functional(f_post) - - # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) - nt = _d * (_d + 1) // 2 - for l in range(_q): - # if missing_mask[l] == wp.uint8(1): - QiPi = self.compute_dtype(0.0) - for t in range(nt): - if t == 0 or t == 3 or t == 5: - QiPi += _qi[l, t] * (Pi[t] - rho / self.compute_dtype(3.0)) - else: - QiPi += _qi[l, t] * Pi[t] - - # Compute c.u - cu = self.compute_dtype(0.0) - for d in range(self.velocity_set.d): - if _c[d, l] == 1: - cu += u[d] - elif _c[d, l] == -1: - cu -= u[d] - cu *= self.compute_dtype(3.0) - - # change f_post using the Grad's approximation - f_post[l] = rho * _w[l] * (self.compute_dtype(1.0) + cu) + _w[l] * self.compute_dtype(4.5) * QiPi - - return f_post - - # Construct the functionals for this BC - @wp.func - def functional_method1( - index: Any, - timestep: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - # NOTE: this BC has been reformulated to become entirely local and so has differences compared to the original paper. - # Here we use the current time-step populations (f_pre = f_post_collision and f_post = f_post_streaming). - one = self.compute_dtype(1.0) - for l in range(_q): - # If the mask is missing then take the opposite index - if missing_mask[l] == wp.uint8(1): - # The implicit distance to the boundary or "weights" have been stored in known directions of f_1 - # weight = f_1[_opp_indices[l], index[0], index[1], index[2]] - weight = self.compute_dtype(0.5) - - # Use differentiable interpolated BB to find f_missing: - f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) - - # # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC - # cu = self.compute_dtype(0.0) - # for d in range(_d): - # if _c[d, l] == 1: - # cu += _u_wall[d] - # elif _c[d, l] == -1: - # cu -= _u_wall[d] - # cu *= self.compute_dtype(-6.0) * _w[l] - # f_post[l] += cu - - # Compute density, velocity using all f_post-streaming values - rho, u = self.macroscopic.warp_functional(f_post) - - # Compute Grad's appriximation using full equation as in Eq (10) of Dorschner et al. - f_post = regularize_fpop(missing_mask, rho, u, f_post) - # f_post = grads_approximate_fpop(missing_mask, rho, u, f_post) - return f_post - - # Construct the functionals for this BC - @wp.func - def functional_method2( - index: Any, - timestep: Any, - missing_mask: Any, - f_0: Any, - f_1: Any, - f_pre: Any, - f_post: Any, - ): - # NOTE: this BC has been reformulated to become entirely local and so has differences compared to the original paper. - # Here we use the current time-step populations (f_pre = f_post_collision and f_post = f_post_streaming). - # NOTE: f_aux should contain populations at "x_f" (see their fig 1) in the missign direction of the BC which amounts - # to post-collision values being pulled from appropriate cells like ExtrapolationBC - # - # here I need to compute all terms in Eq (10) - # Strategy: - # 1) "weights" should have been stored somewhere to be used here. - # 2) Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14) - # NOTE: in the original paper "u_target" is associated with the previous time step not current time. - # 3) Given "weights" use differentiable interpolated BB to find f_missing as I had before: - # fmissing = ((1. - weights) * f_poststreaming_iknown + weights * (f_postcollision_imissing + f_postcollision_iknown)) / (1.0 + weights) - # 4) Add contribution due to u_w to f_missing as is usual in regular Bouzidi BC (ie. -6.0 * self.lattice.w * jnp.dot(self.vel, c) - # 5) Compute rho_target = \sum(f_ibb) based on these values - # 6) Compute feq using feq = self.equilibrium(rho_target, u_target) - # 7) Compute Pi_neq and Pi_eq using all f_post-streaming values as per: - # Pi_neq = self.momentum_flux(fneq) and Pi_eq = self.momentum_flux(feq) - # 8) Compute Grad's appriximation using full equation as in Eq (10) - # NOTE: this is very similar to the regularization procedure. - - _f_nbr = _f_vec() - u_target = _u_vec(0.0, 0.0, 0.0) if _d == 3 else _u_vec(0.0, 0.0) - num_missing = 0 - one = self.compute_dtype(1.0) - for l in range(_q): - # If the mask is missing then take the opposite index - if missing_mask[l] == wp.uint8(1): - # Find the neighbour and its velocity value - for ll in range(_q): - # f_0 is the post-collision values of the current time-step - # Get index associated with the fluid neighbours - fluid_nbr_index = type(index)() - for d in range(_d): - fluid_nbr_index[d] = index[d] + _c[d, l] - # The following is the post-collision values of the fluid neighbor cell - _f_nbr[ll] = self.compute_dtype(f_0[ll, fluid_nbr_index[0], fluid_nbr_index[1], fluid_nbr_index[2]]) - - # Compute the velocity vector at the fluid neighbouring cells - _, u_f = self.macroscopic.warp_functional(_f_nbr) - - # Record the number of missing directions - num_missing += 1 - - # The implicit distance to the boundary or "weights" have been stored in known directions of f_1 - weight = f_1[_opp_indices[l], index[0], index[1], index[2]] - - # Given "weights", "u_w" (input to the BC) and "u_f" (computed from f_aux), compute "u_target" as per Eq (14) - for d in range(_d): - u_target[d] += (weight * u_f[d] + _u_wall[d]) / (one + weight) - - # Use differentiable interpolated BB to find f_missing: - f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) - - # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC - cu = self.compute_dtype(0.0) - for d in range(_d): - if _c[d, l] == 1: - cu += _u_wall[d] - elif _c[d, l] == -1: - cu -= _u_wall[d] - cu *= self.compute_dtype(-6.0) * _w[l] - f_post[l] += cu - - # Compute rho_target = \sum(f_ibb) based on these values - rho_target = self.zero_moment.warp_functional(f_post) - for d in range(_d): - u_target[d] /= num_missing - - # Compute Grad's appriximation using full equation as in Eq (10) of Dorschner et al. - f_post = grads_approximate_fpop(missing_mask, rho_target, u_target, f_post) - return f_post - - functional = functional_method1 - - kernel = self._construct_kernel(functional) - - return functional, kernel - - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): - # Launch the warp kernel - wp.launch( - self.warp_kernel, - inputs=[f_pre, f_post, bc_mask, missing_mask], - dim=f_pre.shape[1:], - ) - return f_post diff --git a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py index 8ede0c8b..3ee31584 100644 --- a/xlb/operator/boundary_condition/bc_halfway_bounce_back.py +++ b/xlb/operator/boundary_condition/bc_halfway_bounce_back.py @@ -1,5 +1,10 @@ """ -Base class for boundary conditions in a LBM simulation. +Halfway bounce-back boundary condition. + +Implements the standard halfway bounce-back scheme where the no-slip +wall is located halfway between a solid node and a fluid node. +Optionally supports prescribed wall velocity (moving walls) and +interpolated variants that use wall-distance data. """ import jax.numpy as jnp @@ -7,7 +12,8 @@ import jax.lax as lax from functools import partial import warp as wp -from typing import Any +from typing import Any, Union, Tuple, Callable +import numpy as np from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -16,7 +22,9 @@ from xlb.operator.boundary_condition.boundary_condition import ( ImplementationStep, BoundaryCondition, + HelperFunctionsBC, ) +from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod class HalfwayBounceBackBC(BoundaryCondition): @@ -33,6 +41,9 @@ def __init__( compute_backend: ComputeBackend = None, indices=None, mesh_vertices=None, + voxelization_method: MeshVoxelizationMethod = None, + profile: Callable = None, + prescribed_value: Union[float, Tuple[float, ...], np.ndarray] = None, ): # Call the parent constructor super().__init__( @@ -42,24 +53,86 @@ def __init__( compute_backend, indices, mesh_vertices, + voxelization_method, ) # This BC needs padding for finding missing directions when imposed on a geometry that is in the domain interior self.needs_padding = True + # This BC class accepts both constant prescribed values of velocity with keyword "prescribed_value" or + # velocity profiles given by keyword "profile" which must be a callable function. + self.profile = profile + + # A flag to enable moving wall treatment when either "prescribed_value" or "profile" are provided. + self.needs_moving_wall_treatment = False + + if (profile is not None) or (prescribed_value is not None): + self.needs_moving_wall_treatment = True + + # Handle no-slip BCs if neither prescribed_value or profile are provided. + if prescribed_value is None and profile is None: + print(f"WARNING! Assuming no-slip condition for BC type = {self.__class__.__name__}!") + prescribed_value = [0] * self.velocity_set.d + + # Handle prescribed value if provided + if prescribed_value is not None: + if profile is not None: + raise ValueError("Cannot specify both profile and prescribed_value") + + # Ensure prescribed_value is a NumPy array of floats + if isinstance(prescribed_value, (tuple, list, np.ndarray)): + prescribed_value = np.asarray(prescribed_value, dtype=np.float64) + else: + raise ValueError("Velocity prescribed_value must be a tuple, list, or array") + + # Create a constant prescribed profile function + if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]: + if self.velocity_set.d == 2: + prescribed_value = np.array([prescribed_value[0], prescribed_value[1], 0.0], dtype=np.float64) + prescribed_value = wp.vec(3, dtype=self.compute_dtype)(prescribed_value) + self.profile = self._create_constant_prescribed_profile(prescribed_value) + + def _create_constant_prescribed_profile(self, prescribed_value): + _u_vec = wp.vec(3, dtype=self.compute_dtype) + + @wp.func + def prescribed_profile_warp(index: Any, time: Any): + return _u_vec(prescribed_value[0], prescribed_value[1], prescribed_value[2]) + + def prescribed_profile_jax(): + return jnp.array(prescribed_value, dtype=self.precision_policy.store_precision.jax_dtype).reshape(-1, 1) + + if self.compute_backend == ComputeBackend.JAX: + return prescribed_profile_jax + elif self.compute_backend == ComputeBackend.WARP: + return prescribed_profile_warp + elif self.compute_backend == ComputeBackend.NEON: + return prescribed_profile_warp + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): boundary = bc_mask == self.id new_shape = (self.velocity_set.q,) + boundary.shape[1:] boundary = lax.broadcast_in_dim(boundary, new_shape, tuple(range(self.velocity_set.d + 1))) - return jnp.where( - jnp.logical_and(missing_mask, boundary), - f_pre[self.velocity_set.opp_indices], - f_post, - ) + + # Add contribution due to moving_wall to f_missing + moving_wall_component = 0.0 + if self.needs_moving_wall_treatment: + u_wall = self.profile() + cu = self.velocity_set.w[:, None] * jnp.tensordot(self.velocity_set.c, u_wall, axes=(0, 0)) + cu = cu.reshape((-1,) + (1,) * (len(f_post[1:].shape) - 1)) + moving_wall_component = 6.0 * cu + + # Apply the halfway bounce-back condition + f_post = jnp.where(jnp.logical_and(missing_mask, boundary), f_pre[self.velocity_set.opp_indices] + moving_wall_component, f_post) + + return f_post def _construct_warp(self): + # load helper functions. Explicitly using the WARP backend for helper functions as it may also be called by the Neon backend. + bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=ComputeBackend.WARP) + # Set local constants _opp_indices = self.velocity_set.opp_indices @@ -74,6 +147,9 @@ def functional( f_pre: Any, f_post: Any, ): + # Get wall velocity + u_wall = self.profile(index, timestep) + # Post-streaming values are only modified at missing direction _f = f_post for l in range(self.velocity_set.q): @@ -82,6 +158,10 @@ def functional( # Get the pre-streaming distribution function in oppisite direction _f[l] = f_pre[_opp_indices[l]] + # Add contribution due to moving_wall to f_missing + if wp.static(self.needs_moving_wall_treatment): + _f[l] += bc_helper.moving_wall_fpop_correction(u_wall, l) + return _f kernel = self._construct_kernel(functional) @@ -97,3 +177,12 @@ def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): dim=f_pre.shape[1:], ) return f_post + + def _construct_neon(self): + functional, _ = self._construct_warp() + return functional, None + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # raise exception as this feature is not implemented yet + raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.") diff --git a/xlb/operator/boundary_condition/bc_hybrid.py b/xlb/operator/boundary_condition/bc_hybrid.py new file mode 100644 index 00000000..62584160 --- /dev/null +++ b/xlb/operator/boundary_condition/bc_hybrid.py @@ -0,0 +1,391 @@ +""" +Hybrid boundary condition combining interpolated bounce-back with regularization. + +Provides three wall-treatment strategies, selectable via *bc_method*: + +* ``"bounceback_regularized"`` — interpolated bounce-back + Latt regularization. +* ``"bounceback_grads"`` — interpolated bounce-back + Grad's approximation. +* ``"nonequilibrium_regularized"`` — Tao non-equilibrium bounce-back + Latt + regularization. + +All variants optionally support: + +* Moving walls (via *prescribed_value* or *profile*). +* Curved boundaries with fractional distance to the mesh surface (via + *use_mesh_distance*). +""" + +import inspect +from jax import jit +from functools import partial +import warp as wp +from typing import Any, Union, Tuple, Callable +import numpy as np + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.boundary_condition.boundary_condition import ( + ImplementationStep, + BoundaryCondition, + HelperFunctionsBC, +) +from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod + + +class HybridBC(BoundaryCondition): + """ + The hybrid BC methods in this boundary condition have been originally developed by H. Salehipour and are inspired from + various previous publications, in particular [1]. The reformulations are aimed to provide local formulations that are + computationally efficient and numerically stable at high Reynolds numbers. + + [1] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and + stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. + """ + + def __init__( + self, + bc_method, + profile: Callable = None, + prescribed_value: Union[float, Tuple[float, ...], np.ndarray] = None, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + indices=None, + mesh_vertices=None, + voxelization_method: MeshVoxelizationMethod = None, + use_mesh_distance=False, + ): + """ + Parameters + ---------- + bc_method : str + Wall-treatment strategy. One of ``"bounceback_regularized"``, + ``"bounceback_grads"``, or ``"nonequilibrium_regularized"``. + profile : callable, optional + Warp function ``(index) -> u_vec`` or ``(index, timestep) -> u_vec`` + defining the wall velocity. Mutually exclusive with *prescribed_value*. + prescribed_value : float or array-like, optional + Constant wall velocity vector. Mutually exclusive with *profile*. + If neither is given, a no-slip wall is assumed. + velocity_set : VelocitySet, optional + precision_policy : PrecisionPolicy, optional + compute_backend : ComputeBackend, optional + indices : list of array-like, optional + Boundary voxel indices (use this **or** *mesh_vertices*, not both). + mesh_vertices : np.ndarray, optional + Mesh triangle vertices for mesh-based voxelization. + voxelization_method : MeshVoxelizationMethod, optional + Voxelization strategy (AABB, RAY, AABB_CLOSE, etc.). + use_mesh_distance : bool + If ``True``, fractional distances to the mesh surface are + computed and stored for interpolated boundary schemes. + """ + assert bc_method in [ + "bounceback_regularized", + "bounceback_grads", + "nonequilibrium_regularized", + ], f"type = {bc_method} not supported! Use 'bounceback_regularized', 'bounceback_grads' or 'nonequilibrium_regularized'." + self.bc_method = bc_method + + # Call the parent constructor + super().__init__( + ImplementationStep.STREAMING, + velocity_set, + precision_policy, + compute_backend, + indices, + mesh_vertices, + voxelization_method, + ) + + # Raise error if used for 2d examples: + if self.velocity_set.d == 2: + raise NotImplementedError("This BC is not implemented in 2D!") + + # Check if the compute backend is Warp + assert self.compute_backend in (ComputeBackend.WARP, ComputeBackend.NEON), "This BC is currently not supported by JAX backend!" + + # Instantiate the operator for computing macroscopic values + # Explicitly using the WARP backend for these operators as they may also be called by the Neon backend. + self.macroscopic = Macroscopic(compute_backend=ComputeBackend.WARP) + self.equilibrium = QuadraticEquilibrium(compute_backend=ComputeBackend.WARP) + + # Define BC helper functions. Explicitly using the WARP backend for helper functions as it may also be called by the Neon backend. + self.bc_helper = HelperFunctionsBC( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=ComputeBackend.WARP, + distance_decoder_function=self._construct_distance_decoder_function(), + ) + + # A flag to enable moving wall treatment when either "prescribed_value" or "profile" are provided. + self.needs_moving_wall_treatment = False + + if (profile is not None) or (prescribed_value is not None): + self.needs_moving_wall_treatment = True + + # Handle no-slip BCs if neither prescribed_value or profile are provided. + if prescribed_value is None and profile is None: + print(f"WARNING! Assuming no-slip condition for BC type = {self.__class__.__name__}_{self.bc_method}!") + prescribed_value = [0, 0, 0] + + # Handle prescribed value if provided + if prescribed_value is not None: + assert profile is None, "Cannot specify both profile and prescribed_value" + + # Ensure prescribed_value is a NumPy array of floats + if isinstance(prescribed_value, (tuple, list, np.ndarray)): + prescribed_value = np.asarray(prescribed_value, dtype=np.float64) + else: + raise ValueError("Velocity prescribed_value must be a tuple, list, or array") + + # Handle 2D velocity sets + if self.velocity_set.d == 2: + assert len(prescribed_value) == 2, "For 2D velocity set, prescribed_value must be a tuple or array of length 2!" + prescribed_value = np.array([prescribed_value[0], prescribed_value[1], 0.0], dtype=np.float64) + + # create a constant prescribed profile + _u_vec = wp.vec(3, dtype=self.compute_dtype) + prescribed_value = _u_vec(prescribed_value) + + @wp.func + def prescribed_profile_warp(index: Any): + return _u_vec(prescribed_value[0], prescribed_value[1], prescribed_value[2]) + + profile = prescribed_profile_warp + + # Inspect the function signature and add time parameter if needed + self.is_time_dependent = False + sig = inspect.signature(profile) + if len(sig.parameters) > 1: + # We assume the profile function takes only the index as input and is hence time-independent. + # In case it is defined with more than 1 input, we assume the second input is time and create + # a wrapper function that also accepts time as a parameter. + self.is_time_dependent = True + + # This BC class accepts both constant prescribed values of velocity with keyword "prescribed_value" or + # velocity profiles given by keyword "profile" which must be a callable function. + self.profile = profile + + # Set whether this BC needs mesh distance + self.needs_mesh_distance = use_mesh_distance + + # This BC needs normalized distance to the mesh + if self.needs_mesh_distance: + # This BC needs auxiliary data recovery after streaming + self.needs_aux_recovery = True + + # If this BC is defined using indices, it would need padding in order to find missing directions + # when imposed on a geometry that is in the domain interior + if self.mesh_vertices is None: + assert self.indices is not None + assert self.needs_mesh_distance is False, 'To use mesh distance, please provide the mesh vertices using keyword "mesh_vertices"!' + assert self.voxelization_method is None, "Voxelization method is only applicable when using mesh vertices!" + self.needs_padding = True + else: + assert self.indices is None, "Cannot use indices with mesh vertices! Please provide mesh vertices only." + + # Define the profile functional + self.profile_functional = self._construct_profile_functional() + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): + raise NotImplementedError(f"Operation {self.__class__.__name__} not implemented in JAX!") + + def _construct_distance_decoder_function(self): + """ + Constructs the distance decoder function for this BC. + """ + # Get the opposite indices for the velocity set + _opp_indices = self.velocity_set.opp_indices + + # Define the distance decoder function for this BC + @wp.func + def distance_decoder_function(f_1: Any, index: Any, direction: Any): + return self.read_field(f_1, index, _opp_indices[direction]) + + return distance_decoder_function + + def _construct_profile_functional(self): + """ + Get the profile functional for this BC. + TODO@Hesam: + Right now, we can impose a profile on a boundary which requires mesh-distance only if that boundary lives on the finest level. + In order to extract "level" from the "neon_field_hdl" we can use the function wp.neon_level(neon_field_hdl). This will allow us + to do the following and get rid of the above limitation. + cIdx = wp.neon_global_idx(field_neon_hdl, index) + gx = wp.neon_get_x(cIdx) // 2 ** level + gy = wp.neon_get_y(cIdx) // 2 ** level + gz = wp.neon_get_z(cIdx) // 2 ** level + """ + + @wp.func + def profile_functional_neon(f_1: Any, index: Any, timestep: Any): + # Convert neon index to warp index + warp_index = self.bc_helper.neon_index_to_warp(f_1, index) + if wp.static(self.is_time_dependent): + return self.profile(warp_index, timestep) + else: + return self.profile(warp_index) + + @wp.func + def profile_functional_warp(f_1: Any, index: Any, timestep: Any): + if wp.static(self.is_time_dependent): + return self.profile(index, timestep) + else: + return self.profile(index) + + return profile_functional_warp if self.compute_backend == ComputeBackend.WARP else profile_functional_neon + + def _construct_warp(self): + # Construct the functionals for this BC + @wp.func + def hybrid_bounceback_regularized( + index: Any, + timestep: Any, + _missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, + ): + # Using regularization technique [1] to represent fpop using macroscopic values derived from interpolated bounceback scheme of [2]. + # missing data in lattice Boltzmann. + # [1] Latt, J., Chopard, B., Malaspinas, O., Deville, M., Michler, A., 2008. Straight velocity + # boundaries in the lattice Boltzmann method. Physical Review E 77, 056703. + # [2] Yu, D., Mei, R., Shyy, W., 2003. A unified boundary treatment in lattice boltzmann method, + # in: 41st aerospace sciences meeting and exhibit, p. 953. + + # Apply interpolated bounceback first to find missing populations at the boundary + u_wall = self.profile_functional(f_1, index, timestep) + f_post = self.bc_helper.interpolated_bounceback( + index, + _missing_mask, + f_0, + f_1, + f_pre, + f_post, + u_wall, + wp.static(self.needs_moving_wall_treatment), + wp.static(self.needs_mesh_distance), + ) + + # Compute density, velocity using all f_post-streaming values + rho, u = self.macroscopic.warp_functional(f_post) + + # Regularize the resulting populations + feq = self.equilibrium.warp_functional(rho, u) + f_post = self.bc_helper.regularize_fpop(f_post, feq) + return f_post + + @wp.func + def hybrid_bounceback_grads( + index: Any, + timestep: Any, + _missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, + ): + # Using Grad's approximation [1] to represent fpop using macroscopic values derived from interpolated bounceback scheme of [2]. + # missing data in lattice Boltzmann. + # [1] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and + # stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. + # [2] Yu, D., Mei, R., Shyy, W., 2003. A unified boundary treatment in lattice boltzmann method, + # in: 41st aerospace sciences meeting and exhibit, p. 953. + + # Apply interpolated bounceback first to find missing populations at the boundary + u_wall = self.profile_functional(f_1, index, timestep) + f_post = self.bc_helper.interpolated_bounceback( + index, + _missing_mask, + f_0, + f_1, + f_pre, + f_post, + u_wall, + wp.static(self.needs_moving_wall_treatment), + wp.static(self.needs_mesh_distance), + ) + + # Compute density, velocity using all f_post-streaming values + rho, u = self.macroscopic.warp_functional(f_post) + + # Compute Grad's approximation using full equation as in Eq (10) of Dorschner et al. + f_post = self.bc_helper.grads_approximate_fpop(_missing_mask, rho, u, f_post) + return f_post + + @wp.func + def hybrid_nonequilibrium_regularized( + index: Any, + timestep: Any, + _missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, + ): + # This boundary condition uses the method of Tao et al (2018) [1] to get unknown populations on curved boundaries (denoted here by + # interpolated_nonequilibrium_bounceback method). To further stabilize this BC, we add regularization technique of [2]. + # [1] Tao, Shi, et al. "One-point second-order curved boundary condition for lattice Boltzmann simulation of suspended particles." + # Computers & Mathematics with Applications 76.7 (2018): 1593-1607. + # [2] Latt, J., Chopard, B., Malaspinas, O., Deville, M., Michler, A., 2008. Straight velocity + # boundaries in the lattice Boltzmann method. Physical Review E 77, 056703. + + # Apply interpolated bounceback first to find missing populations at the boundary + u_wall = self.profile_functional(f_1, index, timestep) + f_post = self.bc_helper.interpolated_nonequilibrium_bounceback( + index, + _missing_mask, + f_0, + f_1, + f_pre, + f_post, + u_wall, + wp.static(self.needs_moving_wall_treatment), + wp.static(self.needs_mesh_distance), + ) + + # Compute density, velocity using all f_post-streaming values + rho, u = self.macroscopic.warp_functional(f_post) + + # Regularize the resulting populations + feq = self.equilibrium.warp_functional(rho, u) + f_post = self.bc_helper.regularize_fpop(f_post, feq) + return f_post + + if self.bc_method == "bounceback_regularized": + functional = hybrid_bounceback_regularized + elif self.bc_method == "bounceback_grads": + functional = hybrid_bounceback_grads + elif self.bc_method == "nonequilibrium_regularized": + functional = hybrid_nonequilibrium_regularized + + kernel = self._construct_kernel(functional) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_pre, f_post, bc_mask, _missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_pre, f_post, bc_mask, _missing_mask], + dim=f_pre.shape[1:], + ) + return f_post + + def _construct_neon(self): + functional, _ = self._construct_warp() + return functional, None + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # raise exception as this feature is not implemented yet + raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.") diff --git a/xlb/operator/boundary_condition/bc_regularized.py b/xlb/operator/boundary_condition/bc_regularized.py index 8d741798..1ea1f84a 100644 --- a/xlb/operator/boundary_condition/bc_regularized.py +++ b/xlb/operator/boundary_condition/bc_regularized.py @@ -1,5 +1,13 @@ """ -Base class for boundary conditions in a LBM simulation. +Regularized boundary condition. + +A non-equilibrium bounce-back scheme with additional regularization of the +distribution function. Applicable as velocity or pressure boundary conditions. + +Reference +--------- +Latt, J. et al. (2008). "Straight velocity boundaries in the lattice +Boltzmann method." *Physical Review E*, 77(5), 056703. """ import jax.numpy as jnp @@ -7,7 +15,7 @@ import jax.lax as lax from functools import partial import warp as wp -from typing import Any, Union, Tuple +from typing import Any, Union, Tuple, Callable import numpy as np from xlb.velocity_set.velocity_set import VelocitySet @@ -16,6 +24,7 @@ from xlb.operator.operator import Operator from xlb.operator.boundary_condition import ZouHeBC, HelperFunctionsBC from xlb.operator.macroscopic import SecondMoment as MomentumFlux +from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod class RegularizedBC(ZouHeBC): @@ -43,13 +52,14 @@ class RegularizedBC(ZouHeBC): def __init__( self, bc_type, - profile=None, + profile: Callable = None, prescribed_value: Union[float, Tuple[float, ...], np.ndarray] = None, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, mesh_vertices=None, + voxelization_method: MeshVoxelizationMethod = None, ): # Call the parent constructor super().__init__( @@ -61,6 +71,7 @@ def __init__( compute_backend, indices, mesh_vertices, + voxelization_method, ) self.momentum_flux = MomentumFlux() @@ -124,18 +135,17 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): return f_post def _construct_warp(self): - # load helper functions - bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend) + # load helper functions. Always use warp backend for helper functions as it may also be called by the Neon backend. + bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=ComputeBackend.WARP) # Set local constants _d = self.velocity_set.d - _q = self.velocity_set.q - _opp_indices = self.velocity_set.opp_indices + lattice_central_index = self.velocity_set.center_index @wp.func def functional_velocity( index: Any, timestep: Any, - missing_mask: Any, + _missing_mask: Any, f_0: Any, f_1: Any, f_pre: Any, @@ -145,16 +155,16 @@ def functional_velocity( _f = f_post # Find normal vector - normals = bc_helper.get_normal_vectors(missing_mask) + normals = bc_helper.get_normal_vectors(_missing_mask) # Find the value of u from the missing directions # Since we are only considering normal velocity, we only need to find one value (stored at the center of f_1) # Create velocity vector by multiplying the prescribed value with the normal vector - prescribed_value = self.compute_dtype(f_1[0, index[0], index[1], index[2]]) + prescribed_value = self.decoder_functional(f_1, index, _missing_mask)[0] _u = -prescribed_value * normals # calculate rho - fsum = bc_helper.get_bc_fsum(_f, missing_mask) + fsum = bc_helper.get_bc_fsum(_f, _missing_mask) unormal = self.compute_dtype(0.0) for d in range(_d): unormal += _u[d] * normals[d] @@ -162,7 +172,7 @@ def functional_velocity( # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask) + _f = bc_helper.bounceback_nonequilibrium(_f, feq, _missing_mask) # Regularize the boundary fpop _f = bc_helper.regularize_fpop(_f, feq) @@ -172,7 +182,7 @@ def functional_velocity( def functional_pressure( index: Any, timestep: Any, - missing_mask: Any, + _missing_mask: Any, f_0: Any, f_1: Any, f_pre: Any, @@ -182,20 +192,20 @@ def functional_pressure( _f = f_post # Find normal vector - normals = bc_helper.get_normal_vectors(missing_mask) + normals = bc_helper.get_normal_vectors(_missing_mask) # Find the value of rho from the missing directions # Since we need only one scalar value, we only need to find one value (stored at the center of f_1) - _rho = self.compute_dtype(f_1[0, index[0], index[1], index[2]]) + _rho = self.decoder_functional(f_1, index, _missing_mask)[0] # calculate velocity - fsum = bc_helper.get_bc_fsum(_f, missing_mask) + fsum = bc_helper.get_bc_fsum(_f, _missing_mask) unormal = -self.compute_dtype(1.0) + fsum / _rho _u = unormal * normals # impose non-equilibrium bounceback feq = self.equilibrium_operator.warp_functional(_rho, _u) - _f = bc_helper.bounceback_nonequilibrium(_f, feq, missing_mask) + _f = bc_helper.bounceback_nonequilibrium(_f, feq, _missing_mask) # Regularize the boundary fpop _f = bc_helper.regularize_fpop(_f, feq) diff --git a/xlb/operator/boundary_condition/bc_zouhe.py b/xlb/operator/boundary_condition/bc_zouhe.py index 5cad5048..9865585c 100644 --- a/xlb/operator/boundary_condition/bc_zouhe.py +++ b/xlb/operator/boundary_condition/bc_zouhe.py @@ -1,5 +1,14 @@ """ -Base class for boundary conditions in a LBM simulation. +Zou-He boundary condition. + +Sets unknown populations at velocity or pressure boundaries using +mass and momentum conservation combined with non-equilibrium +bounce-back. Commonly used for inlets and outlets. + +Reference +--------- +Zou, Q. & He, X. (1997). "On pressure and velocity boundary conditions +for the lattice Boltzmann BGK model." *Physics of Fluids*, 9(6), 1591. """ import jax.numpy as jnp @@ -7,7 +16,7 @@ import jax.lax as lax from functools import partial import warp as wp -from typing import Any, Union, Tuple +from typing import Any, Union, Tuple, Callable import numpy as np from xlb.velocity_set.velocity_set import VelocitySet @@ -20,6 +29,8 @@ ) from xlb.operator.boundary_condition import HelperFunctionsBC from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod +from xlb.operator.boundary_condition.helper_functions_bc import EncodeAuxiliaryData class ZouHeBC(BoundaryCondition): @@ -37,20 +48,20 @@ class ZouHeBC(BoundaryCondition): def __init__( self, bc_type, - profile=None, + profile: Callable = None, prescribed_value: Union[float, Tuple[float, ...], np.ndarray] = None, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, indices=None, mesh_vertices=None, + voxelization_method: MeshVoxelizationMethod = None, ): # Important Note: it is critical to add id inside __init__ for this BC because different instantiations of this BC # may have different types (velocity or pressure). assert bc_type in ["velocity", "pressure"], f"type = {bc_type} not supported! Use 'pressure' or 'velocity'." self.bc_type = bc_type self.equilibrium_operator = QuadraticEquilibrium() - self.profile = profile # Call the parent constructor super().__init__( @@ -60,28 +71,29 @@ def __init__( compute_backend, indices, mesh_vertices, + voxelization_method, ) + # This BC class accepts both constant prescribed values of velocity with keyword "prescribed_value" or + # velocity profiles given by keyword "profile" which must be a callable function. + self.profile = profile + # Handle prescribed value if provided if prescribed_value is not None: if profile is not None: raise ValueError("Cannot specify both profile and prescribed_value") - # Convert input to numpy array for validation - if isinstance(prescribed_value, (tuple, list)): - prescribed_value = np.array(prescribed_value, dtype=np.float64) - elif isinstance(prescribed_value, (int, float)): - if bc_type == "pressure": + # Ensure prescribed_value is a NumPy array of floats + if bc_type == "velocity": + if isinstance(prescribed_value, (tuple, list, np.ndarray)): + prescribed_value = np.asarray(prescribed_value, dtype=np.float64) + else: + raise ValueError("Velocity prescribed_value must be a tuple, list, or array-like") + elif bc_type == "pressure": + if isinstance(prescribed_value, (int, float)): prescribed_value = float(prescribed_value) else: - raise ValueError("Velocity prescribed_value must be a tuple or array") - elif isinstance(prescribed_value, np.ndarray): - prescribed_value = prescribed_value.astype(np.float64) - - # Validate prescribed value - if bc_type == "velocity": - if not isinstance(prescribed_value, np.ndarray): - raise ValueError("Velocity prescribed_value must be an array-like") + raise ValueError("Pressure prescribed_value must be a scalar (int or float)") # Check for non-zero elements - only one element should be non-zero non_zero_count = np.count_nonzero(prescribed_value) @@ -93,21 +105,38 @@ def __init__( # a single non-zero number associated with pressure BC OR # a vector of zeros associated with no-slip BC. # Accounting for all scenarios here. - if self.compute_backend is ComputeBackend.WARP: + if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]: idx = np.nonzero(prescribed_value)[0] prescribed_value = prescribed_value[idx][0] if idx.size else 0.0 prescribed_value = self.precision_policy.store_precision.wp_dtype(prescribed_value) self.prescribed_value = prescribed_value self.profile = self._create_constant_prescribed_profile() - # This BC needs auxilary data initialization before streaming - self.needs_aux_init = True + if self.compute_backend == ComputeBackend.JAX: + self.prescribed_values = self.profile() + else: + # This BC needs auxiliary data initialization before streaming + self.needs_aux_init = True + + # This BC needs auxiliary data recovery after streaming + self.needs_aux_recovery = True - # This BC needs auxilary data recovery after streaming - self.needs_aux_recovery = True + # This BC needs one auxiliary data for the density or normal velocity + self.num_of_aux_data = 1 - # This BC needs one auxilary data for the density or normal velocity - self.num_of_aux_data = 1 + # Create the encoder operator for storing the auxiliary data + encode_auxiliary_data = EncodeAuxiliaryData( + self.id, + self.num_of_aux_data, + self.profile, + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + + # get decoder functional + functional_dict, _ = encode_auxiliary_data._construct_warp() + self.decoder_functional = functional_dict["decoder"] # This BC needs padding for finding missing directions when imposed on a geometry that is in the domain interior self.needs_padding = True @@ -126,6 +155,8 @@ def prescribed_profile_jax(): return prescribed_profile_jax elif self.compute_backend == ComputeBackend.WARP: return prescribed_profile_warp + elif self.compute_backend == ComputeBackend.NEON: + return prescribed_profile_warp @partial(jit, static_argnums=(0,), inline=True) def _get_known_middle_mask(self, missing_mask): @@ -269,12 +300,11 @@ def jax_implementation(self, f_pre, f_post, bc_mask, missing_mask): return f_post def _construct_warp(self): - # load helper functions - bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend) + # load helper functions. Always use warp backend for helper functions as it may also be called by the Neon backend. + bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=ComputeBackend.WARP) + # Set local constants _d = self.velocity_set.d - _q = self.velocity_set.q - _opp_indices = self.velocity_set.opp_indices @wp.func def functional_velocity( @@ -299,7 +329,7 @@ def functional_velocity( # Find the value of u from the missing directions # Since we are only considering normal velocity, we only need to find one value (stored at the center of f_1) # Create velocity vector by multiplying the prescribed value with the normal vector - prescribed_value = f_1[0, index[0], index[1], index[2]] + prescribed_value = self.decoder_functional(f_1, index, _missing_mask)[0] _u = -prescribed_value * normals for d in range(_d): @@ -330,7 +360,7 @@ def functional_pressure( # Find the value of rho from the missing directions # Since we need only one scalar value, we only need to find one value (stored at the center of f_1) - _rho = f_1[0, index[0], index[1], index[2]] + _rho = self.decoder_functional(f_1, index, _missing_mask)[0] # calculate velocity fsum = bc_helper.get_bc_fsum(_f, _missing_mask) @@ -360,3 +390,15 @@ def warp_implementation(self, f_pre, f_post, bc_mask, missing_mask): dim=f_pre.shape[1:], ) return f_post + + def _construct_neon(self): + # Redefine the quadratic eq operator for the neon backend + # This is because the neon backend relies on the warp functionals for its operations. + self.equilibrium_operator = QuadraticEquilibrium(compute_backend=ComputeBackend.WARP) + functional, _ = self._construct_warp() + return functional, None + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f_pre, f_post, bc_mask, missing_mask): + # raise exception as this feature is not implemented yet + raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.") diff --git a/xlb/operator/boundary_condition/boundary_condition.py b/xlb/operator/boundary_condition/boundary_condition.py index 0b8a93c6..4ac2a96a 100644 --- a/xlb/operator/boundary_condition/boundary_condition.py +++ b/xlb/operator/boundary_condition/boundary_condition.py @@ -1,5 +1,9 @@ """ -Base class for boundary conditions in a LBM simulation. +Base class for boundary conditions in a Lattice Boltzmann simulation. + +Every concrete BC inherits from :class:`BoundaryCondition`, which provides +a registration mechanism, helper-function access, and the boilerplate +needed to encode auxiliary data into the ``f_1`` buffer. """ from enum import Enum, auto @@ -7,8 +11,7 @@ from typing import Any from jax import jit from functools import partial -import jax -import jax.numpy as jnp +import numpy as np from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -17,17 +20,39 @@ from xlb import DefaultConfig from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry from xlb.operator.boundary_condition import HelperFunctionsBC +from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod -# Enum for implementation step class ImplementationStep(Enum): + """At which algorithmic stage the boundary condition is applied.""" + COLLISION = auto() STREAMING = auto() class BoundaryCondition(Operator): - """ - Base class for boundary conditions in a LBM simulation. + """Abstract base class for all LBM boundary conditions. + + Each BC is registered with a unique numeric *id* and annotated with: + + * ``implementation_step`` - whether it executes after streaming or after + collision. + * ``needs_aux_recovery`` / ``needs_aux_init`` - whether the BC stores + auxiliary data in the ``f_1`` distribution buffer. + + Parameters + ---------- + implementation_step : ImplementationStep + Phase in the LBM algorithm where this BC is applied. + velocity_set : VelocitySet, optional + precision_policy : PrecisionPolicy, optional + compute_backend : ComputeBackend, optional + indices : array-like, optional + Explicit voxel indices for this BC. + mesh_vertices : array-like, optional + Mesh vertices for geometry-based BCs. + voxelization_method : MeshVoxelizationMethod, optional + Voxelization strategy when *mesh_vertices* is provided. """ def __init__( @@ -38,6 +63,7 @@ def __init__( compute_backend: ComputeBackend = None, indices=None, mesh_vertices=None, + voxelization_method: MeshVoxelizationMethod = None, ): self.id = boundary_condition_registry.register_boundary_condition(self.__class__.__name__ + "_" + str(hash(self))) velocity_set = velocity_set or DefaultConfig.velocity_set @@ -54,28 +80,65 @@ def __init__( self.implementation_step = implementation_step # A flag to indicate whether bc indices need to be padded in both normal directions to identify missing directions - # when inside/outside of the geoemtry is not known + # when inside/outside of the geometry is not known self.needs_padding = False - # A flag for BCs that need implicit boundary distance between the grid and a mesh (to be set to True if applicable inside each BC) + # A flag for BCs that need normalized distance between the grid and a mesh (to be set to True if applicable inside each BC) self.needs_mesh_distance = False - # A flag for BCs that need auxilary data initialization before stepper + # A flag for BCs that need auxiliary data initialization before stepper self.needs_aux_init = False - # A flag to track if the BC is initialized with auxilary data + # A flag to track if the BC is initialized with auxiliary data self.is_initialized_with_aux_data = False - # Number of auxilary data needed for the BC (for prescribed values) + # Number of auxiliary data needed for the BC (for prescribed values) self.num_of_aux_data = 0 - # A flag for BCs that need auxilary data recovery after streaming + # A flag for BCs that need auxiliary data recovery after streaming self.needs_aux_recovery = False + # Voxelization method. For BC's specified on a mesh, the user can specify the voxelization scheme. + # Currently we support three methods based on (a) aabb method (b) ray casting and (c) winding number. + self.voxelization_method = voxelization_method + + # Construct a default warp functional for assembling auxiliary data if needed + if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]: + + @wp.func + def assemble_auxiliary_data( + index: Any, + timestep: Any, + missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, + level: Any = 0, + ): + return f_post + + self.assemble_auxiliary_data = assemble_auxiliary_data + + def pad_indices(self): + """ + This method pads the indices to ensure that the boundary condition can be applied correctly. + It is used to find missing directions in indices_boundary_masker when the BC is imposed on a + geometry that is in the domain interior. + """ + _d = self.velocity_set.d + bc_indices = np.array(self.indices) + lattice_velocity_np = self.velocity_set._c + if self.needs_padding: + bc_indices_padded = bc_indices[:, :, None] + lattice_velocity_np[:, None, :] + return np.unique(bc_indices_padded.reshape(_d, -1), axis=1) + else: + return bc_indices + @partial(jit, static_argnums=(0,), inline=True) - def update_bc_auxilary_data(self, f_pre, f_post, bc_mask, missing_mask): + def assemble_auxiliary_data(self, f_pre, f_post, bc_mask, missing_mask): """ - A placeholder function for prepare the auxilary distribution functions for the boundary condition. + A placeholder function for prepare the auxiliary distribution functions for the boundary condition. currently being called after collision only. """ return f_post @@ -94,14 +157,14 @@ def kernel( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), + missing_mask: wp.array4d(dtype=wp.uint8), ): # Get the global index i, j, k = wp.tid() index = wp.vec3i(i, j, k) # read tid data - _f_pre, _f_post, _boundary_id, _missing_mask = bc_helper.get_thread_data(f_pre, f_post, bc_mask, missing_mask, index) + _f_pre, _f_post, _boundary_id, _missing_mask = bc_helper.get_bc_thread_data(f_pre, f_post, bc_mask, missing_mask, index) # Apply the boundary condition if _boundary_id == _id: @@ -115,61 +178,3 @@ def kernel( f_post[l, index[0], index[1], index[2]] = self.store_dtype(_f[l]) return kernel - - def _construct_aux_data_init_kernel(self, functional): - """ - Constructs the warp kernel for the auxilary data recovery. - """ - bc_helper = HelperFunctionsBC(velocity_set=self.velocity_set, precision_policy=self.precision_policy, compute_backend=self.compute_backend) - - _id = wp.uint8(self.id) - _opp_indices = self.velocity_set.opp_indices - _num_of_aux_data = self.num_of_aux_data - - # Construct the warp kernel - @wp.kernel - def aux_data_init_kernel( - f_0: wp.array4d(dtype=Any), - f_1: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - - # read tid data - _f_0, _f_1, _boundary_id, _missing_mask = bc_helper.get_thread_data(f_0, f_1, bc_mask, missing_mask, index) - - # Apply the functional - if _boundary_id == _id: - # prescribed_values is a q-sized vector of type wp.vec - prescribed_values = functional(index) - # Write the result for all q directions, but only store up to num_of_aux_data - # TODO: Somehow raise an error if the number of prescribed values does not match the number of missing directions - - # The first BC auxiliary data is stored in the zero'th index of f_1 associated with its center. - f_1[0, index[0], index[1], index[2]] = self.store_dtype(prescribed_values[0]) - counter = wp.int32(1) - - # The other remaining BC auxiliary data are stored in missing directions of f_1. - for l in range(1, self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1) and counter < _num_of_aux_data: - f_1[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(prescribed_values[counter]) - counter += 1 - - return aux_data_init_kernel - - def aux_data_init(self, f_0, f_1, bc_mask, missing_mask): - if self.compute_backend == ComputeBackend.WARP: - # Launch the warp kernel - wp.launch( - self._construct_aux_data_init_kernel(self.profile), - inputs=[f_0, f_1, bc_mask, missing_mask], - dim=f_0.shape[1:], - ) - elif self.compute_backend == ComputeBackend.JAX: - # We don't use boundary aux encoding/decoding in JAX - self.prescribed_values = self.profile() - self.is_initialized_with_aux_data = True - return f_0, f_1 diff --git a/xlb/operator/boundary_condition/helper_functions_bc.py b/xlb/operator/boundary_condition/helper_functions_bc.py index 6f8e768b..c613573b 100644 --- a/xlb/operator/boundary_condition/helper_functions_bc.py +++ b/xlb/operator/boundary_condition/helper_functions_bc.py @@ -1,11 +1,44 @@ -from xlb import DefaultConfig, ComputeBackend -from xlb.operator.macroscopic.second_moment import SecondMoment as MomentumFlux +""" +Warp/Neon helper functions shared by multiple boundary conditions. + +:class:`HelperFunctionsBC` exposes ``@wp.func`` helpers for bounce-back, +regularization, Grad's approximation, moving-wall corrections, +interpolated BCs, and BC thread-data loading. These are used as building +blocks by the concrete BC classes. + +Also contains :class:`EncodeAuxiliaryData` and +:class:`MultiresEncodeAuxiliaryData` operators for writing user-prescribed +BC profiles into the ``f_1`` buffer during initialization. +""" + +import inspect +from typing import Any, Callable + import warp as wp -from typing import Any + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb import DefaultConfig, ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.macroscopic import SecondMoment as MomentumFlux +from xlb.operator.macroscopic import Macroscopic +from xlb.operator.equilibrium import QuadraticEquilibrium class HelperFunctionsBC(object): - def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None): + """Collection of Warp/Neon ``@wp.func`` helpers for boundary conditions. + + Parameters + ---------- + velocity_set : VelocitySet, optional + precision_policy : PrecisionPolicy, optional + compute_backend : ComputeBackend, optional + Must be ``WARP`` or ``NEON`` (JAX not supported). + distance_decoder_function : callable, optional + Function to decode wall-distance data for interpolated BCs. + """ + + def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None, distance_decoder_function=None): if compute_backend == ComputeBackend.JAX: raise ValueError("This helper class contains helper functions only for the WARP implementation of some BCs not JAX!") @@ -13,6 +46,7 @@ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=Non self.velocity_set = velocity_set or DefaultConfig.velocity_set self.precision_policy = precision_policy or DefaultConfig.default_precision_policy self.compute_backend = compute_backend or DefaultConfig.default_backend + self.distance_decoder_function = distance_decoder_function # Set the compute and Store dtypes compute_dtype = self.precision_policy.compute_precision.wp_dtype @@ -30,15 +64,21 @@ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=Non _f_vec = wp.vec(_q, dtype=compute_dtype) _missing_mask_vec = wp.vec(_q, dtype=wp.uint8) # TODO fix vec bool + # Define the operator needed for computing equilibrium + equilibrium = QuadraticEquilibrium(velocity_set, precision_policy, compute_backend) + + # Define the operator needed for computing macroscopic variables + macroscopic = Macroscopic(velocity_set, precision_policy, compute_backend) + # Define the operator needed for computing the momentum flux momentum_flux = MomentumFlux(velocity_set, precision_policy, compute_backend) @wp.func - def get_thread_data( + def get_bc_thread_data( f_pre: wp.array4d(dtype=Any), f_post: wp.array4d(dtype=Any), bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), + missing_mask: wp.array4d(dtype=wp.uint8), index: wp.vec3i, ): # Get the boundary id and missing mask @@ -58,41 +98,62 @@ def get_thread_data( _missing_mask[l] = wp.uint8(0) return _f_pre, _f_post, _boundary_id, _missing_mask + @wp.func + def neon_get_bc_thread_data( + f_pre_pn: Any, + f_post_pn: Any, + bc_mask_pn: Any, + missing_mask_pn: Any, + index: Any, + ): + # Get the boundary id and missing mask + _f_pre = _f_vec() + _f_post = _f_vec() + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + _missing_mask = _missing_mask_vec() + for l in range(_q): + # q-sized vector of populations + _f_pre[l] = compute_dtype(wp.neon_read(f_pre_pn, index, l)) + _f_post[l] = compute_dtype(wp.neon_read(f_post_pn, index, l)) + _missing_mask[l] = wp.neon_read(missing_mask_pn, index, l) + + return _f_pre, _f_post, _boundary_id, _missing_mask + @wp.func def get_bc_fsum( fpop: Any, - missing_mask: Any, + _missing_mask: Any, ): fsum_known = compute_dtype(0.0) fsum_middle = compute_dtype(0.0) for l in range(_q): - if missing_mask[_opp_indices[l]] == wp.uint8(1): + if _missing_mask[_opp_indices[l]] == wp.uint8(1): fsum_known += compute_dtype(2.0) * fpop[l] - elif missing_mask[l] != wp.uint8(1): + elif _missing_mask[l] != wp.uint8(1): fsum_middle += fpop[l] return fsum_known + fsum_middle @wp.func def get_normal_vectors( - missing_mask: Any, + _missing_mask: Any, ): if wp.static(_d == 3): for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: + if _missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) + wp.abs(_c[2, l]) == 1: return -_u_vec(_c_float[0, l], _c_float[1, l], _c_float[2, l]) else: for l in range(_q): - if missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: + if _missing_mask[l] == wp.uint8(1) and wp.abs(_c[0, l]) + wp.abs(_c[1, l]) == 1: return -_u_vec(_c_float[0, l], _c_float[1, l]) @wp.func def bounceback_nonequilibrium( fpop: Any, feq: Any, - missing_mask: Any, + _missing_mask: Any, ): for l in range(_q): - if missing_mask[l] == wp.uint8(1): + if _missing_mask[l] == wp.uint8(1): fpop[l] = fpop[_opp_indices[l]] + feq[l] - feq[_opp_indices[l]] return fpop @@ -121,8 +182,461 @@ def regularize_fpop( fpop[l] = feq[l] + fpop1 return fpop - self.get_thread_data = get_thread_data + @wp.func + def grads_approximate_fpop( + _missing_mask: Any, + rho: Any, + u: Any, + f_post: Any, + ): + # Purpose: Using Grad's approximation to represent fpop based on macroscopic inputs used for outflow [1] and + # Dirichlet BCs [2] + # [1] S. Chikatax`marla, S. Ansumali, and I. Karlin, "Grad's approximation for missing data in lattice Boltzmann + # simulations", Europhys. Lett. 74, 215 (2006). + # [2] Dorschner, B., Chikatamarla, S. S., Bösch, F., & Karlin, I. V. (2015). Grad's approximation for moving and + # stationary walls in entropic lattice Boltzmann simulations. Journal of Computational Physics, 295, 340-354. + + # Note: See also self.regularize_fpop function which is somewhat similar. + + # Compute pressure tensor Pi using all f_post-streaming values + Pi = momentum_flux.warp_functional(f_post) + + # Compute double dot product Qi:Pi1 (where Pi1 = PiNeq) + nt = _d * (_d + 1) // 2 + for l in range(_q): + if _missing_mask[l] == wp.uint8(1): + # compute dot product of qi and Pi + QiPi = compute_dtype(0.0) + for t in range(nt): + if t == 0 or t == 3 or t == 5: + QiPi += _qi[l, t] * (Pi[t] - rho / compute_dtype(3.0)) + else: + QiPi += _qi[l, t] * Pi[t] + + # Compute c.u + cu = compute_dtype(0.0) + for d in range(_d): + if _c[d, l] == 1: + cu += u[d] + elif _c[d, l] == -1: + cu -= u[d] + cu *= compute_dtype(3.0) + + # change f_post using the Grad's approximation + f_post[l] = rho * _w[l] * (compute_dtype(1.0) + cu) + _w[l] * compute_dtype(4.5) * QiPi + + return f_post + + @wp.func + def moving_wall_fpop_correction( + u_wall: Any, + lattice_direction: Any, + ): + # Add forcing term necessary to account for the local density changes caused by the mass displacement + # as the object moves with velocity u_wall. + # [1] L.-S. Luo, Unified theory of lattice Boltzmann models for nonideal gases, Phys. Rev. Lett. 81 (1998) 1618-1621. + # [2] L.-S. Luo, Theory of the lattice Boltzmann method: Lattice Boltzmann models for nonideal gases, Phys. Rev. E 62 (2000) 4982-4996. + # + # Note: this function must be called within a for-loop over all lattice directions and the populations to be modified must + # be only those in the missing direction (the check for missing direction must be outside of this function). + cu = compute_dtype(0.0) + l = lattice_direction + for d in range(_d): + if _c[d, l] == 1: + cu += u_wall[d] + elif _c[d, l] == -1: + cu -= u_wall[d] + cu *= compute_dtype(6.0) * _w[l] + return cu + + @wp.func + def interpolated_bounceback( + index: Any, + _missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, + u_wall: Any, + needs_moving_wall_treatment: bool, + needs_mesh_distance: bool, + ): + # A local single-node version of the interpolated bounce-back boundary condition due to Bouzidi for a lattice + # Boltzmann method simulation. + # Ref: + # [1] Yu, D., Mei, R., Shyy, W., 2003. A unified boundary treatment in lattice boltzmann method, + # in: 41st aerospace sciences meeting and exhibit, p. 953. + + one = compute_dtype(1.0) + for l in range(_q): + # If the mask is missing then take the opposite index + if _missing_mask[l] == wp.uint8(1): + # The normalized distance to the mesh or "weights" have been stored in known directions of f_1 + if needs_mesh_distance: + # use weights associated with curved boundaries that are properly stored in f_1. + weight = compute_dtype(self.distance_decoder_function(f_1, index, l)) + + # Use differentiable interpolated BB to find f_missing: + f_post[l] = ((one - weight) * f_post[_opp_indices[l]] + weight * (f_pre[l] + f_pre[_opp_indices[l]])) / (one + weight) + else: + # Use regular halfway bounceback + f_post[l] = f_pre[_opp_indices[l]] + + if _missing_mask[_opp_indices[l]] == wp.uint8(1): + # These are cases where the boundary is sandwiched between 2 solid cells and so both opposite directions are missing. + f_post[l] = f_pre[_opp_indices[l]] + + # Add contribution due to moving_wall to f_missing as is usual in regular Bouzidi BC + if needs_moving_wall_treatment: + f_post[l] += moving_wall_fpop_correction(u_wall, l) + return f_post + + @wp.func + def interpolated_nonequilibrium_bounceback( + index: Any, + _missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, + u_wall: Any, + needs_moving_wall_treatment: bool, + needs_mesh_distance: bool, + ): + # Compute density, velocity using all f_post-collision values + rho, u = macroscopic.warp_functional(f_pre) + feq = equilibrium.warp_functional(rho, u) + + # Compute equilibrium distribution at the wall + if needs_moving_wall_treatment: + feq_wall = equilibrium.warp_functional(rho, u_wall) + else: + feq_wall = _f_vec() + + # Apply method in Tao et al (2018) [1] to find missing populations at the boundary + one = compute_dtype(1.0) + for l in range(_q): + # If the mask is missing then take the opposite index + if _missing_mask[l] == wp.uint8(1): + # The normalized distance to the mesh or "weights" have been stored in known directions of f_1 + if needs_mesh_distance: + # use weights associated with curved boundaries that are properly stored in f_1. + weight = compute_dtype(self.distance_decoder_function(f_1, index, l)) + else: + weight = compute_dtype(0.5) + + # Use non-equilibrium bounceback to find f_missing: + fneq = f_pre[_opp_indices[l]] - feq[_opp_indices[l]] + + # Compute equilibrium distribution at the wall + # Same quadratic equilibrium but accounting for zero velocity (no-slip) + if not needs_moving_wall_treatment: + feq_wall[l] = _w[l] * rho + + # Assemble wall population for doing interpolation at the boundary + f_wall = feq_wall[l] + fneq + f_post[l] = (f_wall + weight * f_pre[l]) / (one + weight) + + return f_post + + @wp.func + def neon_index_to_warp(neon_field_hdl: Any, index: Any): + # Unpack the global index in Neon at the finest level and convert it to a warp vector + cIdx = wp.neon_global_idx(neon_field_hdl, index) + gx = wp.neon_get_x(cIdx) + gy = wp.neon_get_y(cIdx) + gz = wp.neon_get_z(cIdx) + + # XLB is flattening the z dimension in 3D, while neon uses the y dimension + if _d == 2: + gy, gz = gz, gy + + # Get warp indices + index_wp = wp.vec3i(gx, gy, gz) + return index_wp + + self.get_bc_thread_data = get_bc_thread_data self.get_bc_fsum = get_bc_fsum self.get_normal_vectors = get_normal_vectors self.bounceback_nonequilibrium = bounceback_nonequilibrium self.regularize_fpop = regularize_fpop + self.grads_approximate_fpop = grads_approximate_fpop + self.moving_wall_fpop_correction = moving_wall_fpop_correction + self.interpolated_bounceback = interpolated_bounceback + self.interpolated_nonequilibrium_bounceback = interpolated_nonequilibrium_bounceback + self.neon_get_bc_thread_data = neon_get_bc_thread_data + self.neon_index_to_warp = neon_index_to_warp + + +class EncodeAuxiliaryData(Operator): + """ + Operator for encoding boundary auxiliary data during initialization. + """ + + def __init__( + self, + boundary_id: int, + num_of_aux_data: int, + user_defined_functional: Callable, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + self.user_defined_functional = user_defined_functional + self.boundary_id = wp.uint8(boundary_id) + self.num_of_aux_data = num_of_aux_data + + super().__init__(velocity_set, precision_policy, compute_backend) + + # Inspect the signature of the user-defined functional. + # We assume the profile function takes only the index as input and is hence time-independent. + sig = inspect.signature(user_defined_functional) + assert self.compute_backend != ComputeBackend.JAX, "Encoding/decoding of auxiliary data are not required for boundary conditions in JAX" + assert len(sig.parameters) == 1, f"User-defined functional must take exactly one argument (the index), it received {len(sig.parameters)}." + + # Define a HelperFunctionsBC instance + self.bc_helper = HelperFunctionsBC( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + + # TODO: Somehow raise an error if the number of prescribed values does not match the number of missing directions + + def _construct_warp(self): + """ + Constructs the warp kernel for the auxiliary data recovery. + """ + # Find velocity index for (0, 0, 0) + lattice_central_index = self.velocity_set.center_index + _opp_indices = self.velocity_set.opp_indices + _id = self.boundary_id + _num_of_aux_data = self.num_of_aux_data + _aux_vec = wp.vec(_num_of_aux_data, dtype=self.compute_dtype) + + @wp.func + def encoder_functional( + index: Any, + _missing_mask: Any, + field_storage: Any, + prescribed_values: Any, + ): + if len(prescribed_values) != _num_of_aux_data: + wp.printf("Error: User-defined profile must return a vector of size %d\n", _num_of_aux_data) + return + + # Write the result for all q directions, but only store up to _num_of_aux_data + counter = wp.int32(0) + for l in range(self.velocity_set.q): + # Only store up to _num_of_aux_data + if counter == _num_of_aux_data: + return + + if l == lattice_central_index: + # The first BC auxiliary data is stored in the zero'th index of f_1 associated with its center. + self.write_field(field_storage, index, l, self.store_dtype(prescribed_values[l])) + counter += 1 + elif _missing_mask[l] == wp.uint8(1): + # The other remaining BC auxiliary data are stored in missing directions of f_1. + self.write_field(field_storage, index, _opp_indices[l], self.store_dtype(prescribed_values[l])) + counter += 1 + + @wp.func + def decoder_functional( + field_storage: Any, + index: Any, + _missing_mask: Any, + ): + """ + Decode the encoded values needed for the boundary condition treatment from the center location in field_storage. + """ + + # Define a vector to hold prescribed_values + prescribed_values = _aux_vec() + + # Read all q directions, but only retrieve up to _num_of_aux_data + counter = wp.int32(0) + for l in range(self.velocity_set.q): + # Only retrieve up to _num_of_aux_data + if counter == _num_of_aux_data: + return prescribed_values + + if l == lattice_central_index: + # The first BC auxiliary data is stored in the zero'th index of f_1 associated with its center. + value = self.read_field(field_storage, index, l) + prescribed_values[counter] = self.compute_dtype(value) + counter += 1 + elif _missing_mask[l] == wp.uint8(1): + # The other remaining BC auxiliary data are stored in missing directions of f_1. + value = self.read_field(field_storage, index, _opp_indices[l]) + prescribed_values[counter] = self.compute_dtype(value) + counter += 1 + + # Construct the warp kernel + @wp.kernel + def kernel( + f_1: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.uint8), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # read tid data + _, _, _boundary_id, _missing_mask = self.bc_helper.get_bc_thread_data(f_1, f_1, bc_mask, missing_mask, index) + + # Apply the functional + # change this to use central location + if _boundary_id == _id: + # prescribed_values is a q-sized vector of type wp.vec + prescribed_values = self.user_defined_functional(index) + + # call the functional + encoder_functional(index, _missing_mask, f_1, prescribed_values) + + functional_dict = {"encoder": encoder_functional, "decoder": decoder_functional} + return functional_dict, kernel + + def _construct_neon(self): + import neon + + """ + Constructs the Neon container for encoding auxiliary data recovery. + """ + # Use the warp functional for the Neon backend + functional_dict, _ = self._construct_warp() + encoder_functional = functional_dict["encoder"] + _id = self.boundary_id + + # Construct the Neon container + @neon.Container.factory(name="EncodingAuxData_" + str(_id)) + def aux_data_init_container( + f_1: Any, + bc_mask: Any, + missing_mask: Any, + ): + def aux_data_init_ll(loader: neon.Loader): + loader.set_grid(f_1.get_grid()) + + f_1_pn = loader.get_write_handle(f_1) + bc_mask_pn = loader.get_read_handle(bc_mask) + missing_mask_pn = loader.get_read_handle(missing_mask) + + @wp.func + def aux_data_init_cl(index: Any): + # read tid data + _, _, _boundary_id, _missing_mask = self.bc_helper.neon_get_bc_thread_data(f_1_pn, f_1_pn, bc_mask_pn, missing_mask_pn, index) + + # Apply the functional + if _boundary_id == _id: + warp_index = self.bc_helper.neon_index_to_warp(f_1_pn, index) + prescribed_values = self.user_defined_functional(warp_index) + + # Call the functional + encoder_functional(index, _missing_mask, f_1_pn, prescribed_values) + + # Declare the kernel in the Neon loader + loader.declare_kernel(aux_data_init_cl) + + return aux_data_init_ll + + return functional_dict, aux_data_init_container + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, f_1, bc_mask, missing_mask): + # Launch the warp kernel + wp.launch( + self.warp_kernel, + inputs=[f_1, bc_mask, missing_mask], + dim=f_1.shape[1:], + ) + return f_1 + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f_1, bc_mask, missing_mask): + c = self.neon_container(f_1, bc_mask, missing_mask) + c.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + return f_1 + + +class MultiresEncodeAuxiliaryData(EncodeAuxiliaryData): + """ + Operator for encoding boundary auxiliary data during initialization. + """ + + def __init__( + self, + boundary_id: int, + num_of_aux_data: int, + user_defined_functional: Callable, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + super().__init__( + boundary_id=boundary_id, + num_of_aux_data=num_of_aux_data, + user_defined_functional=user_defined_functional, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) + + assert self.compute_backend == ComputeBackend.NEON, f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend." + + def _construct_neon(self): + """ + Constructs the Neon container for encoding auxiliary data recovery. + """ + + # Borrow the functional from the warp implementation + functional_dict, _ = self._construct_warp() + encoder_functional = functional_dict["encoder"] + _id = self.boundary_id + + # Construct the Neon container + @neon.Container.factory(name="MultiresEncodingAuxData_" + str(_id)) + def aux_data_init_container( + f_1: Any, + bc_mask: Any, + missing_mask: Any, + level: Any, + ): + def aux_data_init_ll(loader: neon.Loader): + loader.set_mres_grid(f_1.get_grid(), level) + + f_1_pn = loader.get_mres_write_handle(f_1) + bc_mask_pn = loader.get_mres_read_handle(bc_mask) + missing_mask_pn = loader.get_mres_read_handle(missing_mask) + + @wp.func + def aux_data_init_cl(index: Any): + # read tid data + _, _, _boundary_id, _missing_mask = self.bc_helper.neon_get_bc_thread_data(f_1_pn, f_1_pn, bc_mask_pn, missing_mask_pn, index) + + # Apply the functional + if _boundary_id == _id: + # IMPORTANT NOTE: + # It is assumed in XLB that the user_defined_functional in multi-res simulations is defined in terms of the indices at the finest level. + # This assumption enables handling of BCs whose indices span multiple levels + warp_index = self.bc_helper.neon_index_to_warp(f_1_pn, index) + prescribed_values = self.user_defined_functional(warp_index) + + # Call the functional + encoder_functional(index, _missing_mask, f_1_pn, prescribed_values) + + # Declare the kernel in the Neon loader + loader.declare_kernel(aux_data_init_cl) + + return aux_data_init_ll + + return functional_dict, aux_data_init_container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f_1, bc_mask, missing_mask, stream): + grid = bc_mask.get_grid() + for level in range(grid.num_levels): + c = self.neon_container(f_1, bc_mask, missing_mask, level) + c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon) + return f_1 diff --git a/xlb/operator/boundary_masker/__init__.py b/xlb/operator/boundary_masker/__init__.py index 3417c3c8..5a1ceb75 100644 --- a/xlb/operator/boundary_masker/__init__.py +++ b/xlb/operator/boundary_masker/__init__.py @@ -1,2 +1,12 @@ +from xlb.operator.boundary_masker.helper_functions_masker import HelperFunctionsMasker from xlb.operator.boundary_masker.indices_boundary_masker import IndicesBoundaryMasker from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker +from xlb.operator.boundary_masker.aabb import MeshMaskerAABB +from xlb.operator.boundary_masker.ray import MeshMaskerRay +from xlb.operator.boundary_masker.winding import MeshMaskerWinding +from xlb.operator.boundary_masker.aabb_close import MeshMaskerAABBClose +from xlb.operator.boundary_masker.mesh_voxelization_method import MeshVoxelizationMethod +from xlb.operator.boundary_masker.multires_aabb import MultiresMeshMaskerAABB +from xlb.operator.boundary_masker.multires_aabb_close import MultiresMeshMaskerAABBClose +from xlb.operator.boundary_masker.multires_indices_boundary_masker import MultiresIndicesBoundaryMasker +from xlb.operator.boundary_masker.multires_ray import MultiresMeshMaskerRay diff --git a/xlb/operator/boundary_masker/aabb.py b/xlb/operator/boundary_masker/aabb.py new file mode 100644 index 00000000..94def41b --- /dev/null +++ b/xlb/operator/boundary_masker/aabb.py @@ -0,0 +1,198 @@ +""" +AABB mesh-based boundary masker. + +Voxelizes an STL mesh using ``warp.mesh_query_aabb`` for approximate +one-voxel-thick surface detection around the geometry. +""" + +import warp as wp +from typing import Any +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker +from xlb.operator.operator import Operator +from xlb.cell_type import BC_SOLID + + +class MeshMaskerAABB(MeshBoundaryMasker): + """ + Operator for creating boundary missing_mask from mesh using Axis-Aligned Bounding Box (AABB) voxelization. + + This implementation uses warp.mesh_query_aabb for efficient mesh-voxel intersection testing, + providing approximate 1-voxel thick surface detection around the mesh geometry. + Suitable for scenarios where fast, approximate boundary detection is sufficient. + """ + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.c + _q = self.velocity_set.q + _opp_indices = self.velocity_set.opp_indices + + # Set local constants + lattice_central_index = self.velocity_set.center_index + + @wp.func + def functional( + index: Any, + mesh_id: Any, + id_number: Any, + distances: Any, + bc_mask: Any, + missing_mask: Any, + needs_mesh_distance: Any, + ): + # position of the point + cell_center_pos = self.helper_masker.index_to_position(bc_mask, index) + HALF_VOXEL = wp.vec3(0.5, 0.5, 0.5) + + if self.read_field(bc_mask, index, 0) == wp.uint8(BC_SOLID) or self.mesh_voxel_intersect( + mesh_id=mesh_id, low=cell_center_pos - HALF_VOXEL + ): + # Make solid voxel + self.write_field(bc_mask, index, 0, wp.uint8(BC_SOLID)) + else: + # Find the boundary voxels and their missing directions + for direction_idx in range(_q): + if direction_idx == lattice_central_index: + # Skip the central index as it is not relevant for boundary masking + continue + + # Get the lattice direction vector + direction_vec = wp.vec3f(wp.float32(_c[0, direction_idx]), wp.float32(_c[1, direction_idx]), wp.float32(_c[2, direction_idx])) + + # Check to see if this neighbor is solid + if self.mesh_voxel_intersect(mesh_id=mesh_id, low=cell_center_pos + direction_vec - HALF_VOXEL): + # We know we have a solid neighbor + # Set the boundary id and missing_mask + self.write_field(bc_mask, index, 0, wp.uint8(id_number)) + self.write_field(missing_mask, index, _opp_indices[direction_idx], wp.uint8(True)) + + # If we don't need the mesh distance, we can return early + if not needs_mesh_distance: + continue + + # Find the fractional distance to the mesh in each direction + # We increase max_length to find intersections in neighboring cells + max_length = wp.length(direction_vec) + query = wp.mesh_query_ray(mesh_id, cell_center_pos, direction_vec / max_length, 1.5 * max_length) + if query.result: + # get position of the mesh triangle that intersects with the ray + pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) + # We reduce the distance to give some wall thickness + dist = wp.length(pos_mesh - cell_center_pos) - 0.5 * max_length + weight = dist / max_length + self.write_field(distances, index, direction_idx, self.store_dtype(weight)) + else: + # Expected an intersection in this direction but none was found. + # Assume the solid extends one lattice unit beyond the BC voxel leading to a distance fraction of 1. + self.write_field(distances, index, direction_idx, self.store_dtype(1.0)) + + @wp.kernel + def kernel( + mesh_id: wp.uint64, + id_number: wp.int32, + distances: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.uint8), + needs_mesh_distance: bool, + ): + # get index + i, j, k = wp.tid() + + # Get local indices + index = wp.vec3i(i, j, k) + + # apply the functional + functional( + index, + mesh_id, + id_number, + distances, + bc_mask, + missing_mask, + needs_mesh_distance, + ) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, + bc, + distances, + bc_mask, + missing_mask, + ): + return self.warp_implementation_base( + bc, + distances, + bc_mask, + missing_mask, + ) + + def _construct_neon(self): + import neon + + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + + @neon.Container.factory(name="MeshMaskerAABB") + def container( + mesh_id: Any, + id_number: Any, + distances: Any, + bc_mask: Any, + missing_mask: Any, + needs_mesh_distance: Any, + ): + def aabb_launcher(loader: neon.Loader): + loader.set_grid(bc_mask.get_grid()) + bc_mask_pn = loader.get_write_handle(bc_mask) + missing_mask_pn = loader.get_write_handle(missing_mask) + distances_pn = loader.get_write_handle(distances) + + @wp.func + def aabb_kernel(index: Any): + # apply the functional + functional( + index, + mesh_id, + id_number, + distances_pn, + bc_mask_pn, + missing_mask_pn, + needs_mesh_distance, + ) + + loader.declare_kernel(aabb_kernel) + + return aabb_launcher + + return functional, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation( + self, + bc, + distances, + bc_mask, + missing_mask, + ): + # Prepare inputs + mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask) + + # Launch the appropriate neon container + c = self.neon_container(mesh_id, bc_id, distances, bc_mask, missing_mask, wp.static(bc.needs_mesh_distance)) + c.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + return distances, bc_mask, missing_mask diff --git a/xlb/operator/boundary_masker/aabb_close.py b/xlb/operator/boundary_masker/aabb_close.py new file mode 100644 index 00000000..21ee31f9 --- /dev/null +++ b/xlb/operator/boundary_masker/aabb_close.py @@ -0,0 +1,365 @@ +""" +AABB-Close boundary masker with morphological close operation. + +Identifies solid voxels via axis-aligned bounding-box (AABB) intersection, +then applies a morphological *close* (dilate followed by erode) to fill +thin gaps and small cavities in the mesh surface. The resulting solid +mask is used to determine boundary voxels and their missing population +directions. + +Supports both Warp (single-resolution) and Neon (multi-resolution) +backends. +""" + +import numpy as np +import warp as wp +import jax +from typing import Any +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker +from xlb.cell_type import BC_SOLID + + +class MeshMaskerAABBClose(MeshBoundaryMasker): + """Boundary masker using AABB voxelization with morphological close. + + The *close* operation (dilate then erode) thickens the raw solid mask + by ``close_voxels`` layers before shrinking it back, sealing small + holes and thin slits in the mesh surface. + + Parameters + ---------- + velocity_set : VelocitySet, optional + precision_policy : PrecisionPolicy, optional + compute_backend : ComputeBackend, optional + close_voxels : int + Half-width of the morphological structuring element. Must be + provided explicitly. + """ + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + close_voxels: int = None, + ): + assert close_voxels is not None, ( + "Please provide the number of close voxels using the 'close_voxels' argument! e.g., MeshVoxelizationMethod('AABB_CLOSE', close_voxels=3)" + ) + self.close_voxels = close_voxels + # Call super + self.tile_half = close_voxels + self.tile_size = self.tile_half * 2 + 1 + super().__init__(velocity_set, precision_policy, compute_backend) + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.c + _q = self.velocity_set.q + _opp_indices = self.velocity_set.opp_indices + TILE_SIZE = wp.constant(self.tile_size) + TILE_HALF = wp.constant(self.tile_half) + lattice_central_index = self.velocity_set.center_index + + # Erode the solid mask in mask_field, removing a layer of outer solid voxels, storing output in mask_field_out + @wp.kernel + def erode_tile(mask_field: wp.array3d(dtype=Any), mask_field_out: wp.array3d(dtype=Any)): + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + if not self.helper_masker.is_in_bounds(index, wp.vec3i(mask_field.shape[0], mask_field.shape[1], mask_field.shape[2]), TILE_HALF): + mask_field_out[i, j, k] = mask_field[i, j, k] + return + t = wp.tile_load(mask_field, shape=(TILE_SIZE, TILE_SIZE, TILE_SIZE), offset=(i - TILE_HALF, j - TILE_HALF, k - TILE_HALF)) + min_val = wp.tile_min(t) + mask_field_out[i, j, k] = min_val[0] + + # Dilate the solid mask in mask_field, adding a layer of outer solid voxels, storing output in mask_field_out + @wp.kernel + def dilate_tile(mask_field: wp.array3d(dtype=Any), mask_field_out: wp.array3d(dtype=Any)): + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + if not self.helper_masker.is_in_bounds(index, wp.vec3i(mask_field.shape[0], mask_field.shape[1], mask_field.shape[2]), TILE_HALF): + mask_field_out[i, j, k] = mask_field[i, j, k] + return + t = wp.tile_load(mask_field, shape=(TILE_SIZE, TILE_SIZE, TILE_SIZE), offset=(i - TILE_HALF, j - TILE_HALF, k - TILE_HALF)) + max_val = wp.tile_max(t) + mask_field_out[i, j, k] = max_val[0] + + # Erode the solid mask in mask_field, removing a layer of outer solid voxels, storing output in mask_field_out + @wp.func + def functional_erode(index: Any, mask_field: Any, mask_field_out: Any): + min_val = wp.uint8(BC_SOLID) + for l in range(_q): + if l == lattice_central_index: + continue + is_valid = wp.bool(False) + ngh = wp.neon_ngh_idx(wp.int8(_c[0, l]), wp.int8(_c[1, l]), wp.int8(_c[2, l])) + ngh_val = wp.neon_read_ngh(mask_field, index, ngh, 0, wp.uint8(0), is_valid) + if is_valid: + # Take the min value of all neighbors in bounds + min_val = wp.min(min_val, ngh_val) + self.write_field(mask_field_out, index, 0, min_val) + + # Dilate the solid mask in mask_field, adding a layer of outer solid voxels, storing output in mask_field_out + @wp.func + def functional_dilate(index: Any, mask_field: Any, mask_field_out: Any): + max_val = wp.uint8(0) + for l in range(_q): + if l == lattice_central_index: + continue + is_valid = wp.bool(False) + ngh = wp.neon_ngh_idx(wp.int8(_c[0, l]), wp.int8(_c[1, l]), wp.int8(_c[2, l])) + ngh_val = wp.neon_read_ngh(mask_field, index, ngh, 0, wp.uint8(0), is_valid) + if is_valid: + max_val = wp.max(max_val, ngh_val) + self.write_field(mask_field_out, index, 0, max_val) + + # Construct the warp kernel + # Find solid voxels that intersect the mesh + @wp.func + def functional_solid(index: Any, mesh_id: Any, solid_mask: Any, offset: Any): + # position of the point + cell_center_pos = self.helper_masker.index_to_position(solid_mask, index) + offset + half = wp.vec3(0.5, 0.5, 0.5) + + if self.mesh_voxel_intersect(mesh_id=mesh_id, low=cell_center_pos - half): + # Make solid voxel + self.write_field(solid_mask, index, 0, wp.uint8(BC_SOLID)) + + @wp.kernel + def kernel_solid( + mesh_id: wp.uint64, + solid_mask: wp.array3d(dtype=wp.int32), + offset: wp.vec3f, + ): + # get index + i, j, k = wp.tid() + + # Get local indices + index = wp.vec3i(i, j, k) + + functional_solid(index, mesh_id, solid_mask, offset) + + return + + @wp.func + def functional_aabb( + index: Any, + mesh_id: wp.uint64, + id_number: wp.int32, + distances: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.uint8), + solid_mask: wp.array3d(dtype=wp.uint8), + needs_mesh_distance: bool, + ): + # position of the point + cell_center_pos = self.helper_masker.index_to_position(bc_mask, index) + HALF_VOXEL = wp.vec3(0.5, 0.5, 0.5) + + if self.read_field(solid_mask, index, 0) == wp.uint8(BC_SOLID) or self.read_field(bc_mask, index, 0) == wp.uint8(BC_SOLID): + # Make solid voxel + self.write_field(bc_mask, index, 0, wp.uint8(BC_SOLID)) + else: + # Find the boundary voxels and their missing directions + for direction_idx in range(_q): + if direction_idx == lattice_central_index: + # Skip the central index as it is not relevant for boundary masking + continue + + # Get the lattice direction vector + direction_vec = wp.vec3f(wp.float32(_c[0, direction_idx]), wp.float32(_c[1, direction_idx]), wp.float32(_c[2, direction_idx])) + + # Check to see if this neighbor is solid + if self.helper_masker.is_in_bounds(index, wp.vec3i(solid_mask.shape[0], solid_mask.shape[1], solid_mask.shape[2]), 1): + if self.read_field(solid_mask, index + direction_idx, 0) == wp.uint8(BC_SOLID): + # We know we have a solid neighbor + # Set the boundary id and missing_mask + self.write_field(bc_mask, index, 0, wp.uint8(id_number)) + self.write_field(missing_mask, index, _opp_indices[direction_idx], wp.uint8(True)) + + # If we don't need the mesh distance, we can return early + if not needs_mesh_distance: + continue + + # Find the fractional distance to the mesh in each direction + # We increase max_length to find intersections in neighboring cells + max_length = wp.length(direction_vec) + query = wp.mesh_query_ray(mesh_id, cell_center_pos, direction_vec / max_length, 1.5 * max_length) + if query.result: + # get position of the mesh triangle that intersects with the ray + pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) + # We reduce the distance to give some wall thickness + dist = wp.length(pos_mesh - cell_center_pos) - 0.5 * max_length + weight = dist / max_length + self.write_field(distances, index, direction_idx, self.store_dtype(weight)) + else: + # Expected an intersection in this direction but none was found. + # Assume the solid extends one lattice unit beyond the BC voxel leading to a distance fraction of 1. + self.write_field(distances, index, direction_idx, self.store_dtype(1.0)) + + # Assign the bc_mask and distances based on the solid_mask we already computed + @wp.kernel + def kernel( + mesh_id: wp.uint64, + id_number: wp.int32, + distances: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.uint8), + solid_mask: wp.array3d(dtype=wp.uint8), + needs_mesh_distance: bool, + ): + # get index + i, j, k = wp.tid() + + # Get local indices + index = wp.vec3i(i, j, k) + + # position of the point + cell_center_pos = self.helper_masker.index_to_position(bc_mask, index) + + if solid_mask[i, j, k] == wp.uint8(BC_SOLID) or bc_mask[0, index[0], index[1], index[2]] == wp.uint8(BC_SOLID): + # Make solid voxel + bc_mask[0, index[0], index[1], index[2]] = wp.uint8(BC_SOLID) + else: + # Find the boundary voxels and their missing directions + for direction_idx in range(_q): + if direction_idx == lattice_central_index: + # Skip the central index as it is not relevant for boundary masking + continue + direction_vec = wp.vec3f(wp.float32(_c[0, direction_idx]), wp.float32(_c[1, direction_idx]), wp.float32(_c[2, direction_idx])) + + # Check to see if this neighbor is solid - this is super inefficient TODO: make it way better + # if solid_mask[i,j,k] == wp.uint8(BC_SOLID): + if solid_mask[i + _c[0, direction_idx], j + _c[1, direction_idx], k + _c[2, direction_idx]] == wp.uint8(BC_SOLID): + # We know we have a solid neighbor + # Set the boundary id and missing_mask + bc_mask[0, index[0], index[1], index[2]] = wp.uint8(id_number) + missing_mask[_opp_indices[direction_idx], index[0], index[1], index[2]] = wp.uint8(True) + + # If we don't need the mesh distance, we can return early + if not needs_mesh_distance: + continue + + # Find the fractional distance to the mesh in each direction + # We increase max_length to find intersections in neighboring cells + max_length = wp.length(direction_vec) + query = wp.mesh_query_ray(mesh_id, cell_center_pos, direction_vec / max_length, 1.5 * max_length) + if query.result: + # get position of the mesh triangle that intersects with the ray + pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) + # We reduce the distance to give some wall thickness + dist = wp.length(pos_mesh - cell_center_pos) - 0.5 * max_length + weight = self.store_dtype(dist / max_length) + distances[direction_idx, index[0], index[1], index[2]] = weight + else: + # We didn't have an intersection in the given direction but we know we should so we assume the solid is slightly thicker + # and one lattice direction away from the BC voxel + distances[direction_idx, index[0], index[1], index[2]] = self.store_dtype(1.0) + + functional_dict = { + "functional_erode": functional_erode, + "functional_dilate": functional_dilate, + "functional_solid": functional_solid, + "functional_aabb": functional_aabb, + } + kernel_dict = { + "kernel": kernel, + "kernel_solid": kernel_solid, + "erode_tile": erode_tile, + "dilate_tile": dilate_tile, + } + return functional_dict, kernel_dict + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, + bc, + distances, + bc_mask, + missing_mask, + ): + assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' + assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!" + assert bc.mesh_vertices.shape[1] == self.velocity_set.d, ( + "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + ) + + domain_shape = bc_mask.shape[1:] # (nx, ny, nz) + mesh_vertices = bc.mesh_vertices + mesh_min = np.min(mesh_vertices, axis=0) + mesh_max = np.max(mesh_vertices, axis=0) + + if any(mesh_min < 0) or any(mesh_max >= domain_shape): + raise ValueError( + f"Mesh extents ({mesh_min}, {mesh_max}) exceed domain dimensions {domain_shape}. The mesh must be fully contained within the domain." + ) + + # We are done with bc.mesh_vertices. Remove them from BC objects + bc.__dict__.pop("mesh_vertices", None) + + mesh_indices = np.arange(mesh_vertices.shape[0]) + mesh = wp.Mesh( + points=wp.array(mesh_vertices, dtype=wp.vec3), + indices=wp.array(mesh_indices, dtype=wp.int32), + ) + mesh_id = wp.uint64(mesh.id) + bc_id = bc.id + + # Create a padded mask for the solid voxels to account for the tile size + # It needs to be padded by twice the tile size on each side since we run two tile operations + tile_length = 2 * self.tile_half + offset = wp.vec3f(-tile_length, -tile_length, -tile_length) + pad = 2 * tile_length + nx, ny, nz = domain_shape + solid_mask = wp.zeros((nx + pad, ny + pad, nz + pad), dtype=wp.int32) + solid_mask_out = wp.zeros((nx + pad, ny + pad, nz + pad), dtype=wp.int32) + + # Prepare the warp kernel dictionary + kernel_dict = self.warp_kernel + + # Launch all required kernels for creating the solid mask + wp.launch( + kernel=kernel_dict["kernel_solid"], + inputs=[ + mesh_id, + solid_mask, + offset, + ], + dim=solid_mask.shape, + ) + wp.launch_tiled( + kernel=kernel_dict["dilate_tile"], + dim=solid_mask.shape, + block_dim=32, + inputs=[solid_mask, solid_mask_out], + ) + wp.launch_tiled( + kernel=kernel_dict["erode_tile"], + dim=solid_mask.shape, + block_dim=32, + inputs=[solid_mask_out, solid_mask], + ) + solid_mask_cropped = wp.array( + solid_mask[tile_length:-tile_length, tile_length:-tile_length, tile_length:-tile_length], + dtype=wp.uint8, + ) + + # Launch the main kernel for boundary masker + wp.launch( + kernel_dict["kernel"], + inputs=[mesh_id, bc_id, distances, bc_mask, missing_mask, solid_mask_cropped, wp.static(bc.needs_mesh_distance)], + dim=bc_mask.shape[1:], + ) + + # Resolve out of bound indices + wp.launch( + self.resolve_out_of_bound_kernel, + inputs=[bc_id, bc_mask, missing_mask], + dim=bc_mask.shape[1:], + ) + return distances, bc_mask, missing_mask diff --git a/xlb/operator/boundary_masker/helper_functions_masker.py b/xlb/operator/boundary_masker/helper_functions_masker.py new file mode 100644 index 00000000..565455fe --- /dev/null +++ b/xlb/operator/boundary_masker/helper_functions_masker.py @@ -0,0 +1,135 @@ +""" +Warp/Neon helper functions shared by boundary masker operators. +""" + +import warp as wp +from typing import Any +from xlb import DefaultConfig, ComputeBackend + + +class HelperFunctionsMasker(object): + """Warp ``@wp.func`` helpers for boundary masker operators. + + Provides coordinate-conversion, bounds-checking, pull-index + computation, and BC-index membership tests used by the mesh and + indices boundary maskers on both Warp and Neon backends. + """ + + def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None): + if compute_backend == ComputeBackend.JAX: + raise ValueError("This helper class contains helper functions only for the WARP implementation of some BCs not JAX!") + + # Set the default values from the global config + self.velocity_set = velocity_set or DefaultConfig.velocity_set + self.precision_policy = precision_policy or DefaultConfig.default_precision_policy + self.compute_backend = compute_backend or DefaultConfig.default_backend + + # Set local constants + _d = self.velocity_set.d + _c = self.velocity_set.c + + @wp.func + def neon_index_to_warp(neon_field_hdl: Any, index: Any): + # Unpack the global index in Neon at the finest level and convert it to a warp vector + cIdx = wp.neon_global_idx(neon_field_hdl, index) + gx = wp.neon_get_x(cIdx) + gy = wp.neon_get_y(cIdx) + gz = wp.neon_get_z(cIdx) + + # XLB is flattening the z dimension in 3D, while neon uses the y dimension + if _d == 2: + gy, gz = gz, gy + + # Get warp indices + index_wp = wp.vec3i(gx, gy, gz) + return index_wp + + @wp.func + def index_to_position_warp(field: Any, index: wp.vec3i): + # position of the point + ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2])) + pos = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center + return pos + + @wp.func + def index_to_position_neon(field: Any, index: Any): + # position of the point + index_wp = neon_index_to_warp(field, index) + return index_to_position_warp(field, index_wp) + + @wp.func + def is_in_bounds(index: wp.vec3i, grid_shape: wp.vec3i, SHIFT: Any = 0): + return ( + index[0] >= SHIFT + and index[0] < grid_shape[0] - SHIFT + and index[1] >= SHIFT + and index[1] < grid_shape[1] - SHIFT + and index[2] >= SHIFT + and index[2] < grid_shape[2] - SHIFT + ) + + @wp.func + def get_pull_index_warp( + field: Any, + lattice_dir: wp.int32, + index: wp.vec3i, + level: Any, + ): + pull_index = wp.vec3i() + offset = wp.vec3i() + for d in range(self.velocity_set.d): + offset[d] = -_c[d, lattice_dir] + for _ in range(level): + offset[d] *= 2 + pull_index[d] = index[d] + offset[d] + + return pull_index, offset + + @wp.func + def get_pull_index_neon( + field: Any, + lattice_dir: wp.int32, + index: Any, + level: Any, + ): + # Convert the index to warp + index_wp = neon_index_to_warp(field, index) + pull_index_wp, _ = get_pull_index_warp(field, lattice_dir, index_wp, level) + offset = wp.neon_ngh_idx(wp.int8(-_c[0, lattice_dir]), wp.int8(-_c[1, lattice_dir]), wp.int8(-_c[2, lattice_dir])) + return pull_index_wp, offset + + @wp.func + def is_in_bc_indices_warp( + field: Any, + index: Any, + bc_indices: wp.array2d(dtype=wp.int32), + ii: wp.int32, + ): + return bc_indices[0, ii] == index[0] and bc_indices[1, ii] == index[1] and bc_indices[2, ii] == index[2] + + @wp.func + def is_in_bc_indices_neon( + field: Any, + index: Any, + bc_indices: wp.array2d(dtype=wp.int32), + ii: wp.int32, + ): + index_wp = neon_index_to_warp(field, index) + return is_in_bc_indices_warp(field, index_wp, bc_indices, ii) + + # Construct some helper warp functions + self.is_in_bounds = is_in_bounds + self.index_to_position = index_to_position_warp if self.compute_backend == ComputeBackend.WARP else index_to_position_neon + self.get_pull_index = get_pull_index_warp if self.compute_backend == ComputeBackend.WARP else get_pull_index_neon + self.is_in_bc_indices = is_in_bc_indices_warp if self.compute_backend == ComputeBackend.WARP else is_in_bc_indices_neon + + def get_grid_shape(self, field): + """ + Get the grid shape from the boundary mask. This is a CPU function that returns the shape of the grid + """ + if self.compute_backend == ComputeBackend.WARP: + return field.shape[1:] + elif self.compute_backend == ComputeBackend.NEON: + return wp.vec3i(field.get_grid().dim.x, field.get_grid().dim.y, field.get_grid().dim.z) + else: + raise ValueError(f"Unsupported compute backend: {self.compute_backend}") diff --git a/xlb/operator/boundary_masker/indices_boundary_masker.py b/xlb/operator/boundary_masker/indices_boundary_masker.py index c779972c..e888c284 100644 --- a/xlb/operator/boundary_masker/indices_boundary_masker.py +++ b/xlb/operator/boundary_masker/indices_boundary_masker.py @@ -1,12 +1,25 @@ -import numpy as np -import warp as wp +""" +Indices-based boundary masker. + +Creates boundary masks from explicit arrays of voxel indices, computing +missing-population masks via pull-index tests for each tagged voxel. +""" + +from typing import Any +import copy + import jax import jax.numpy as jnp +import numpy as np +import warp as wp + from xlb.compute_backend import ComputeBackend +from xlb.grid import grid_factory from xlb.operator.operator import Operator from xlb.operator.stream.stream import Stream -from xlb.grid import grid_factory from xlb.precision_policy import Precision +from xlb.operator.boundary_masker.helper_functions_masker import HelperFunctionsMasker +from xlb.cell_type import BC_SOLID class IndicesBoundaryMasker(Operator): @@ -19,12 +32,21 @@ def __init__( velocity_set=None, precision_policy=None, compute_backend=None, + grid=None, ): - # Make stream operator - self.stream = Stream(velocity_set, precision_policy, compute_backend) - # Call super super().__init__(velocity_set, precision_policy, compute_backend) + self.grid = grid + if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]: + # Define masker helper functions + self.helper_masker = HelperFunctionsMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) + else: + # Make stream operator + self.stream = Stream(velocity_set, precision_policy, compute_backend) def are_indices_in_interior(self, indices, shape): """ @@ -35,141 +57,262 @@ def are_indices_in_interior(self, indices, shape): :param shape: Tuple representing the shape of the domain (nx, ny) for 2D or (nx, ny, nz) for 3D. :return: Array of boolean flags where each flag indicates whether the corresponding index is inside the bounds. """ - d = self.velocity_set.d + _d = self.velocity_set.d shape_array = np.array(shape) - return np.all((indices[:d] > 0) & (indices[:d] < shape_array[:d, np.newaxis] - 1), axis=0) + return np.all((indices[:_d] > 0) & (indices[:_d] < shape_array[:_d, np.newaxis] - 1), axis=0) + + def _find_bclist_interior(self, bclist, grid_shape): + bc_interior = [] + for bc in bclist: + if any(self.are_indices_in_interior(np.array(bc.indices), grid_shape)): + bc_copy = copy.copy(bc) # shallow copy of the whole object + bc_copy.indices = copy.deepcopy(bc.pad_indices()) # deep copy only the modified part + bc_interior.append(bc_copy) + return bc_interior @Operator.register_backend(ComputeBackend.JAX) # TODO HS: figure out why uncommenting the line below fails unlike other operators! # @partial(jit, static_argnums=(0)) def jax_implementation(self, bclist, bc_mask, missing_mask, start_index=None): - # Pad the missing mask to create a grid mask to identify out of bound boundaries - # Set padded regin to True (i.e. boundary) + # Extend the missing mask by padding to identify out of bound boundaries + # Set padded region to True (i.e. boundary) dim = missing_mask.ndim - 1 + grid_shape = bc_mask[0].shape nDevices = jax.device_count() pad_x, pad_y, pad_z = nDevices, 1, 1 - # TODO MEHDI: There is sometimes a halting problem here when padding is used in a multi-GPU setting since we're not jitting this function. - # For now, we compute the bmap on GPU zero. - if dim == 2: - bmap = jnp.zeros((pad_x * 2 + bc_mask[0].shape[0], pad_y * 2 + bc_mask[0].shape[1]), dtype=jnp.uint8) - bmap = bmap.at[pad_x:-pad_x, pad_y:-pad_y].set(bc_mask[0]) - grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True) - # bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0) - if dim == 3: - bmap = jnp.zeros((pad_x * 2 + bc_mask[0].shape[0], pad_y * 2 + bc_mask[0].shape[1], pad_z * 2 + bc_mask[0].shape[2]), dtype=jnp.uint8) - bmap = bmap.at[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z].set(bc_mask[0]) - grid_mask = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True) - # bmap = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0) - # shift indices - shift_tup = (pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z) + # Shift indices due to padding + shift = np.array((pad_x, pad_y) if dim == 2 else (pad_x, pad_y, pad_z))[:, np.newaxis] if start_index is None: start_index = (0,) * dim - domain_shape = bc_mask[0].shape + # TODO MEHDI: There is sometimes a halting problem here when padding is used in a multi-GPU setting since we're not jitting this function. + # For now, we compute the bc_mask_extended on GPU zero. + if dim == 2: + bc_mask_extended = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y)), constant_values=0) + missing_mask_extended = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y)), constant_values=True) + if dim == 3: + bc_mask_extended = jnp.pad(bc_mask[0], ((pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=0) + missing_mask_extended = jnp.pad(missing_mask, ((0, 0), (pad_x, pad_x), (pad_y, pad_y), (pad_z, pad_z)), constant_values=True) + + # Iterate over boundary conditions and set the mask for bc in bclist: assert bc.indices is not None, f"Please specify indices associated with the {bc.__class__.__name__} BC!" - assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" + assert bc.mesh_vertices is None, ( + f"Please use operators based on MeshBoundaryMasker if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" + ) id_number = bc.id bc_indices = np.array(bc.indices) - local_indices = bc_indices - np.array(start_index)[:, np.newaxis] - padded_indices = local_indices + np.array(shift_tup)[:, np.newaxis] - bmap = bmap.at[tuple(padded_indices)].set(id_number) - if any(self.are_indices_in_interior(bc_indices, domain_shape)) and bc.needs_padding: - # checking if all indices associated with this BC are in the interior of the domain. - # This flag is needed e.g. if the no-slip geometry is anywhere but at the boundaries of the computational domain. + indices_origin = np.array(start_index)[:, np.newaxis] + if any(self.are_indices_in_interior(bc_indices, grid_shape)): + # If the indices are in the interior, we assume the usre specified indices are solid indices + solid_indices = bc_indices - indices_origin + solid_indices_shifted = solid_indices + shift + + # We obtain the boundary indices by padding the solid indices in all lattice directions + indices_padded = bc.pad_indices() - indices_origin + indices_shifted = indices_padded + shift + + # The missing mask is set to True meaning (exterior or solid nodes) using the original indices. + # This is because of the following streaming step which will assign missing directions for the boundary nodes. if dim == 2: - grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1]].set(True) - if dim == 3: - grid_mask = grid_mask.at[:, padded_indices[0], padded_indices[1], padded_indices[2]].set(True) + missing_mask_extended = missing_mask_extended.at[:, solid_indices_shifted[0], solid_indices_shifted[1]].set(True) + else: + missing_mask_extended = missing_mask_extended.at[ + :, solid_indices_shifted[0], solid_indices_shifted[1], solid_indices_shifted[2] + ].set(True) + else: + indices_shifted = bc_indices - indices_origin + shift - # Assign the boundary id to the push indices - push_indices = padded_indices[:, :, None] + self.velocity_set.c[:, None, :] - push_indices = push_indices.reshape(dim, -1) - bmap = bmap.at[tuple(push_indices)].set(id_number) + # Assign the boundary id to the shifted (and possibly padded) indices + bc_mask_extended = bc_mask_extended.at[tuple(indices_shifted)].set(id_number) # We are done with bc.indices. Remove them from BC objects bc.__dict__.pop("indices", None) - grid_mask = self.stream(grid_mask) + # Stream the missing mask to identify missing directions + missing_mask_extended = self.stream(missing_mask_extended) + + # Crop the extended masks to remove padding if dim == 2: - missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y] - bc_mask = bc_mask.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y]) + missing_mask = missing_mask_extended[:, pad_x:-pad_x, pad_y:-pad_y] + bc_mask = bc_mask.at[0].set(bc_mask_extended[pad_x:-pad_x, pad_y:-pad_y]) if dim == 3: - missing_mask = grid_mask[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z] - bc_mask = bc_mask.at[0].set(bmap[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z]) + missing_mask = missing_mask_extended[:, pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z] + bc_mask = bc_mask.at[0].set(bc_mask_extended[pad_x:-pad_x, pad_y:-pad_y, pad_z:-pad_z]) return bc_mask, missing_mask def _construct_warp(self): # Make constants for warp - _c = self.velocity_set.c - _q = wp.constant(self.velocity_set.q) + _q = self.velocity_set.q @wp.func - def check_index_bounds(index: wp.vec3i, shape: wp.vec3i): - is_in_bounds = index[0] >= 0 and index[0] < shape[0] and index[1] >= 0 and index[1] < shape[1] and index[2] >= 0 and index[2] < shape[2] - return is_in_bounds + def functional_domain_bounds( + index: Any, + bc_indices: Any, + id_number: Any, + is_interior: Any, + bc_mask: Any, + missing_mask: Any, + grid_shape: Any, + level: Any = 0, + ): + for ii in range(bc_indices.shape[1]): + # If the current index does not match the boundary condition index, we skip it + if not self.helper_masker.is_in_bc_indices(bc_mask, index, bc_indices, ii): + continue + + if is_interior[ii] == wp.uint8(True): + # If the index is in the interior, we set that index to be a solid node (identified by BC_SOLID) + # This information will be used in the next kernel to identify missing directions using the + # padded indices of the solid node that are associated with the boundary condition. + self.write_field(bc_mask, index, 0, wp.uint8(BC_SOLID)) + return + + # Set bc_mask for all bc indices + self.write_field(bc_mask, index, 0, wp.uint8(id_number[ii])) + + # Stream indices + for l in range(_q): + # Get the pull index which is the index of the neighboring node where information is pulled from + pull_index, _ = self.helper_masker.get_pull_index(bc_mask, l, index, level) + + # Check if pull index is out of bound + # These directions will have missing information after streaming + if not self.helper_masker.is_in_bounds(pull_index, grid_shape): + # Set the missing mask + self.write_field(missing_mask, index, l, wp.uint8(True)) + + @wp.func + def functional_interior_bc_mask( + index: Any, + bc_indices: Any, + id_number: Any, + bc_mask: Any, + ): + for ii in range(bc_indices.shape[1]): + # If the current index does not match the boundary condition index, we skip it + if not self.helper_masker.is_in_bc_indices(bc_mask, index, bc_indices, ii): + continue + # Set bc_mask for all interior bc indices + self.write_field(bc_mask, index, 0, wp.uint8(id_number[ii])) + + @wp.func + def functional_interior_missing_mask( + index: Any, + bc_indices: Any, + bc_mask: Any, + missing_mask: Any, + grid_shape: Any, + level: Any = 0, + ): + for ii in range(bc_indices.shape[1]): + # If the current index does not match the boundary condition index, we skip it + if not self.helper_masker.is_in_bc_indices(bc_mask, index, bc_indices, ii): + continue + for l in range(_q): + # Get the index of the streaming direction + pull_index, offset = self.helper_masker.get_pull_index(bc_mask, l, index, level) + + # Check if pull index is a fluid node (bc_mask is zero for fluid nodes) + bc_mask_ngh = self.read_field_neighbor(bc_mask, index, offset, 0) + if (self.helper_masker.is_in_bounds(pull_index, grid_shape)) and (bc_mask_ngh == wp.uint8(BC_SOLID)): + self.write_field(missing_mask, index, l, wp.uint8(True)) # Construct the warp 3D kernel @wp.kernel - def kernel( - indices: wp.array2d(dtype=wp.int32), + def kernel_domain_bounds( + bc_indices: wp.array2d(dtype=wp.int32), id_number: wp.array1d(dtype=wp.uint8), - is_interior: wp.array1d(dtype=wp.bool), + is_interior: wp.array1d(dtype=wp.uint8), bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), + missing_mask: wp.array4d(dtype=wp.uint8), + grid_shape: wp.vec3i, ): - # Get the index of indices - ii = wp.tid() + # get index + i, j, k = wp.tid() # Get local indices - index = wp.vec3i() - index[0] = indices[0, ii] - index[1] = indices[1, ii] - index[2] = indices[2, ii] - - # Check if index is in bounds - shape = wp.vec3i(missing_mask.shape[1], missing_mask.shape[2], missing_mask.shape[3]) - if check_index_bounds(index, shape): - # Stream indices - for l in range(_q): - # Get the index of the streaming direction - pull_index = wp.vec3i() - push_index = wp.vec3i() - for d in range(self.velocity_set.d): - pull_index[d] = index[d] - _c[d, l] - push_index[d] = index[d] + _c[d, l] + index = wp.vec3i(i, j, k) + + # Call the functional + functional_domain_bounds( + index, + bc_indices, + id_number, + is_interior, + bc_mask, + missing_mask, + grid_shape, + ) - # set bc_mask for all bc indices - bc_mask[0, index[0], index[1], index[2]] = id_number[ii] + @wp.kernel + def kernel_interior_bc_mask( + bc_indices: wp.array2d(dtype=wp.int32), + id_number: wp.array1d(dtype=wp.uint8), + bc_mask: wp.array4d(dtype=wp.uint8), + ): + # get index + i, j, k = wp.tid() - # check if pull index is out of bound - # These directions will have missing information after streaming - if not check_index_bounds(pull_index, shape): - # Set the missing mask - missing_mask[l, index[0], index[1], index[2]] = True + # Get local indices + index = wp.vec3i(i, j, k) - # handling geometries in the interior of the computational domain - elif check_index_bounds(pull_index, shape) and is_interior[ii]: - # Set the missing mask - missing_mask[l, push_index[0], push_index[1], push_index[2]] = True - bc_mask[0, push_index[0], push_index[1], push_index[2]] = id_number[ii] + # Set bc_mask for all interior bc indices + functional_interior_bc_mask( + index, + bc_indices, + id_number, + bc_mask, + ) + return + + @wp.kernel + def kernel_interior_missing_mask( + bc_indices: wp.array2d(dtype=wp.int32), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.uint8), + grid_shape: wp.vec3i, + ): + # get index + i, j, k = wp.tid() - return None, kernel + # Get local indices + index = wp.vec3i(i, j, k) + + functional_interior_missing_mask(index, bc_indices, bc_mask, missing_mask, grid_shape) + + functional_dict = { + "functional_domain_bounds": functional_domain_bounds, + "functional_interior_bc_mask": functional_interior_bc_mask, + "functional_interior_missing_mask": functional_interior_missing_mask, + } + kernel_dict = { + "kernel_domain_bounds": kernel_domain_bounds, + "kernel_interior_bc_mask": kernel_interior_bc_mask, + "kernel_interior_missing_mask": kernel_interior_missing_mask, + } + return functional_dict, kernel_dict + + def _prepare_kernel_inputs(self, bclist, grid_shape, start_index=None): + """ + Prepare the inputs for the warp kernel by pre-allocating arrays and filling them with boundary condition information. + """ - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None): # Pre-allocate arrays with maximum possible size - max_size = sum(len(bc.indices[0]) if isinstance(bc.indices, list) else bc.indices.shape[1] for bc in bclist if bc.indices is not None) + max_size = sum( + len(bc.indices[0]) if isinstance(bc.indices, (list, tuple)) else bc.indices.shape[1] for bc in bclist if bc.indices is not None + ) indices = np.zeros((3, max_size), dtype=np.int32) id_numbers = np.zeros(max_size, dtype=np.uint8) - is_interior = np.zeros(max_size, dtype=bool) + is_interior = np.zeros(max_size, dtype=np.uint8) current_index = 0 for bc in bclist: assert bc.indices is not None, f'Please specify indices associated with the {bc.__class__.__name__} BC using keyword "indices"!' - assert bc.mesh_vertices is None, f"Please use MeshBoundaryMasker operator if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" - + assert bc.mesh_vertices is None, ( + f"Please use operators based on MeshBoundaryMasker if {bc.__class__.__name__} is imposed on a mesh (e.g. STL)!" + ) bc_indices = np.asarray(bc.indices) num_indices = bc_indices.shape[1] @@ -188,10 +331,7 @@ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None): id_numbers[current_index : current_index + num_indices] = bc.id # Set is_interior flags - if bc.needs_padding: - is_interior[current_index : current_index + num_indices] = self.are_indices_in_interior(bc_indices, bc_mask[0].shape) - else: - is_interior[current_index : current_index + num_indices] = False + is_interior[current_index : current_index + num_indices] = self.are_indices_in_interior(bc_indices, grid_shape) current_index += num_indices @@ -199,26 +339,232 @@ def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None): # bc.__dict__.pop("indices", None) # Trim arrays to actual size - indices = indices[:, :current_index] - id_numbers = id_numbers[:current_index] - is_interior = is_interior[:current_index] + total_index = current_index + indices = indices[:, :total_index] + id_numbers = id_numbers[:total_index] + is_interior = is_interior[:total_index] # Convert to Warp arrays - wp_indices = wp.array(indices, dtype=wp.int32) - wp_id_numbers = wp.array(id_numbers, dtype=wp.uint8) - wp_is_interior = wp.array(is_interior, dtype=wp.bool) + def _to_wp_arrays(indices, id_numbers, is_interior, device=None): + return ( + wp.array(indices, dtype=wp.int32, device=device), + wp.array(id_numbers, dtype=wp.uint8, device=device), + wp.array(is_interior, dtype=wp.uint8, device=device), + ) + + if self.compute_backend == ComputeBackend.NEON: + grid = self.grid + ndevice = 1 if grid is None else grid.bk.get_num_devices() + + if ndevice == 1: + return _to_wp_arrays(indices, id_numbers, is_interior) + else: + # For multi-device, we need to split the indices across devices + wp_bc_indices = [] + wp_id_numbers = [] + wp_is_interior = [] + for i in range(ndevice): + device_name = grid.bk.get_device_name(i) + wp_bc_indices.append(wp.array(indices, dtype=wp.int32, device=device_name)) + wp_id_numbers.append(wp.array(id_numbers, dtype=wp.uint8, device=device_name)) + wp_is_interior.append(wp.array(is_interior, dtype=wp.uint8, device=device_name)) + return wp_bc_indices, wp_id_numbers, wp_is_interior + else: + return _to_wp_arrays(indices, id_numbers, is_interior) + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation(self, bclist, bc_mask, missing_mask, start_index=None): + # get the grid shape + grid_shape = self.helper_masker.get_grid_shape(bc_mask) + + # Find interior boundary conditions + bc_interior = self._find_bclist_interior(bclist, grid_shape) + + # Prepare the first kernel inputs for all items in boundary condition list + wp_bc_indices, wp_id_numbers, wp_is_interior = self._prepare_kernel_inputs(bclist, grid_shape, start_index) # Launch the warp kernel wp.launch( - self.warp_kernel, - dim=current_index, + self.warp_kernel["kernel_domain_bounds"], + dim=bc_mask.shape[1:], + inputs=[wp_bc_indices, wp_id_numbers, wp_is_interior, bc_mask, missing_mask, grid_shape], + ) + + # If there are no interior boundary conditions, skip the rest and retun early + if not bc_interior: + return bc_mask, missing_mask + + # Prepare the second and third kernel inputs for only a subset of boundary conditions associated with the interior + # Note 1: launching order of the following kernels are important here! + # Note 2: Due to race conditioning, the two kernels cannot be fused together. + wp_bc_indices, wp_id_numbers, _ = self._prepare_kernel_inputs(bc_interior, grid_shape) + wp.launch( + self.warp_kernel["kernel_interior_missing_mask"], + dim=bc_mask.shape[1:], + inputs=[wp_bc_indices, bc_mask, missing_mask, grid_shape], + ) + wp.launch( + self.warp_kernel["kernel_interior_bc_mask"], + dim=bc_mask.shape[1:], inputs=[ - wp_indices, + wp_bc_indices, wp_id_numbers, - wp_is_interior, bc_mask, - missing_mask, ], ) return bc_mask, missing_mask + + def _construct_neon(self): + import neon + + # Use the warp functional for the NEON backend + functional_dict, _ = self._construct_warp() + functional_domain_bounds = functional_dict.get("functional_domain_bounds") + functional_interior_bc_mask = functional_dict.get("functional_interior_bc_mask") + functional_interior_missing_mask = functional_dict.get("functional_interior_missing_mask") + + @neon.Container.factory(name="IndicesBoundaryMasker_DomainBounds") + def container_domain_bounds( + wp_bc_indices_, + wp_id_numbers_, + wp_is_interior_, + bc_mask, + missing_mask, + grid_shape, + ): + def domain_bounds_launcher(loader: neon.Loader): + loader.set_grid(bc_mask.get_grid()) + bc_mask_pn = loader.get_write_handle(bc_mask) + missing_mask_pn = loader.get_write_handle(missing_mask) + grid = bc_mask.get_grid() + bk = grid.backend + if bk.get_num_devices() == 1: + # If there is only one device, we can use the warp arrays directly + wp_bc_indices = wp_bc_indices_ + wp_id_numbers = wp_id_numbers_ + wp_is_interior = wp_is_interior_ + else: + device_id = loader.get_device_id() + wp_bc_indices = wp_bc_indices_[device_id] + wp_id_numbers = wp_id_numbers_[device_id] + wp_is_interior = wp_is_interior_[device_id] + + @wp.func + def domain_bounds_kernel(index: Any): + # apply the functional + functional_domain_bounds( + index, + wp_bc_indices, + wp_id_numbers, + wp_is_interior, + bc_mask_pn, + missing_mask_pn, + grid_shape, + ) + + loader.declare_kernel(domain_bounds_kernel) + + return domain_bounds_launcher + + @neon.Container.factory(name="IndicesBoundaryMasker_InteriorBcMask") + def container_interior_bc_mask( + wp_bc_indices, + wp_id_numbers, + bc_mask, + ): + def interior_bc_mask_launcher(loader: neon.Loader): + loader.set_grid(bc_mask.get_grid()) + bc_mask_pn = loader.get_write_handle(bc_mask) + + @wp.func + def interior_bc_mask_kernel(index: Any): + # apply the functional + functional_interior_bc_mask( + index, + wp_bc_indices, + wp_id_numbers, + bc_mask_pn, + ) + + loader.declare_kernel(interior_bc_mask_kernel) + + return interior_bc_mask_launcher + + @neon.Container.factory(name="IndicesBoundaryMasker_InteriorMissingMask") + def container_interior_missing_mask( + wp_bc_indices, + bc_mask, + missing_mask, + grid_shape, + ): + def interior_bc_mask_launcher(loader: neon.Loader): + loader.set_grid(bc_mask.get_grid()) + bc_mask_pn = loader.get_write_handle(bc_mask) + missing_mask_pn = loader.get_write_handle(missing_mask) + + @wp.func + def interior_missing_mask_kernel(index: Any): + # apply the functional + functional_interior_missing_mask( + index, + wp_bc_indices, + bc_mask_pn, + missing_mask_pn, + grid_shape, + ) + + loader.declare_kernel(interior_missing_mask_kernel) + + return interior_bc_mask_launcher + + container_dict = { + "container_domain_bounds": container_domain_bounds, + "container_interior_bc_mask": container_interior_bc_mask, + "container_interior_missing_mask": container_interior_missing_mask, + } + + return functional_dict, container_dict + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, bclist, bc_mask, missing_mask, start_index=None): + # get the grid shape + grid_shape = self.helper_masker.get_grid_shape(bc_mask) + + # Find interior boundary conditions + bc_interior = self._find_bclist_interior(bclist, grid_shape) + + # Prepare the first kernel inputs for all items in boundary condition list + wp_bc_indices, wp_id_numbers, wp_is_interior = self._prepare_kernel_inputs(bclist, grid_shape, start_index) + + # Launch the first container + container_domain_bounds = self.neon_container["container_domain_bounds"]( + wp_bc_indices, + wp_id_numbers, + wp_is_interior, + bc_mask, + missing_mask, + grid_shape, + ) + container_domain_bounds.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + + # If there are no interior boundary conditions, skip the rest and retun early + if not bc_interior: + return bc_mask, missing_mask + + # Prepare the second and third kernel inputs for only a subset of boundary conditions associated with the interior + # Note 1: launching order of the following kernels are important here! + # Note 2: Due to race conditioning, the two kernels cannot be fused together. + wp_bc_indices, wp_id_numbers, _ = self._prepare_kernel_inputs(bc_interior, grid_shape) + container_interior_missing_mask = self.neon_container["container_interior_missing_mask"](wp_bc_indices, bc_mask, missing_mask, grid_shape) + container_interior_missing_mask.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + + # Launch the third container + container_interior_bc_mask = self.neon_container["container_interior_bc_mask"]( + wp_bc_indices, + wp_id_numbers, + bc_mask, + ) + container_interior_bc_mask.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + + return bc_mask, missing_mask diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index 40dd0311..c6fb778d 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -1,63 +1,71 @@ -# Base class for all equilibriums +""" +Abstract base class for mesh-based boundary maskers. + +Provides shared input preparation logic (mesh construction, kernel arrays) +used by AABB, Ray, Winding, and AABB-Close masker subclasses. +""" import numpy as np import warp as wp -import jax +from typing import Any from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy from xlb.compute_backend import ComputeBackend from xlb.operator.operator import Operator +from xlb.operator.boundary_masker.helper_functions_masker import HelperFunctionsMasker class MeshBoundaryMasker(Operator): """ - Operator for creating a boundary missing_mask from an STL file + Operator for creating a boundary missing_mask from a mesh file """ def __init__( self, - velocity_set: VelocitySet, - precision_policy: PrecisionPolicy, - compute_backend: ComputeBackend.WARP, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, ): # Call super super().__init__(velocity_set, precision_policy, compute_backend) + assert self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON], ( + f"MeshBoundaryMasker is only implemented for {ComputeBackend.WARP} and {ComputeBackend.NEON} backends!" + ) + + assert self.velocity_set.d == 3, "MeshBoundaryMasker is only implemented for 3D velocity sets!" # Raise error if used for 2d examples: if self.velocity_set.d == 2: raise NotImplementedError("This Operator is not implemented in 2D!") - # Also using Warp kernels for JAX implementation - if self.compute_backend == ComputeBackend.JAX: - self.warp_functional, self.warp_kernel = self._construct_warp() - - @Operator.register_backend(ComputeBackend.JAX) - def jax_implementation( - self, - bc, - bc_mask, - missing_mask, - ): - raise NotImplementedError(f"Operation {self.__class__.__name} not implemented in JAX!") - # Use Warp backend even for this particular operation. - wp.init() - bc_mask = wp.from_jax(bc_mask) - missing_mask = wp.from_jax(missing_mask) - bc_mask, missing_mask = self.warp_implementation(bc, bc_mask, missing_mask) - return wp.to_jax(bc_mask), wp.to_jax(missing_mask) - - def _construct_warp(self): # Make constants for warp - _c_float = self.velocity_set.c_float - _q = wp.constant(self.velocity_set.q) - _opp_indices = self.velocity_set.opp_indices + _c = self.velocity_set.c + _q = self.velocity_set.q + + if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]: + # Define masker helper functions + self.helper_masker = HelperFunctionsMasker( + velocity_set=self.velocity_set, + precision_policy=self.precision_policy, + compute_backend=self.compute_backend, + ) @wp.func - def index_to_position(index: wp.vec3i): - # position of the point - ijk = wp.vec3(wp.float32(index[0]), wp.float32(index[1]), wp.float32(index[2])) - pos = ijk + wp.vec3(0.5, 0.5, 0.5) # cell center - return pos + def out_of_bound_pull_index( + lattice_dir: wp.int32, + index: wp.vec3i, + field: wp.array4d(dtype=wp.uint8), + grid_shape: wp.vec3i, + ): + # Get the index of the streaming direction + pull_index = wp.vec3i() + for d in range(self.velocity_set.d): + pull_index[d] = index[d] - _c[d, lattice_dir] + + # check if pull index is out of bound + # These directions will have missing information after streaming + missing = not self.helper_masker.is_in_bounds(pull_index, grid_shape) + return missing # Function to precompute useful values per triangle, assuming spacing is (1,1,1) # inputs: verts: triangle vertices, normal: triangle unit normal @@ -78,6 +86,7 @@ def pre_compute( dist_edge = wp.mat33f(0.0) for axis0 in range(0, 3): + axis1 = (axis0 + 1) % 3 axis2 = (axis0 + 2) % 3 sgn = 1.0 @@ -85,21 +94,18 @@ def pre_compute( sgn = -1.0 for i in range(0, 3): - normal_edge0[i][axis0] = -1.0 * sgn * edges[i][axis0] - normal_edge1[i][axis0] = sgn * edges[i][axis0] + normal_edge0[i, axis0] = -1.0 * sgn * edges[i, axis1] + normal_edge1[i, axis0] = sgn * edges[i, axis0] - dist_edge[i][axis0] = ( - -1.0 * (normal_edge0[i][axis0] * verts[i][axis0] + normal_edge1[i][axis0] * verts[i][axis0]) - + wp.max(0.0, normal_edge0[i][axis0]) - + wp.max(0.0, normal_edge1[i][axis0]) + dist_edge[i, axis0] = ( + -1.0 * (normal_edge0[i, axis0] * verts[i, axis0] + normal_edge1[i, axis0] * verts[i, axis1]) + + wp.max(0.0, normal_edge0[i, axis0]) + + wp.max(0.0, normal_edge1[i, axis0]) ) return dist1, dist2, normal_edge0, normal_edge1, dist_edge # Check whether this triangle intersects the unit cube at position low - # inputs: low: position of the cube, normal: triangle unit normal, dist1, dist2, normal_edge0, normal_edge1, dist_edge: precomputed values - # outputs: True if intersection, False otherwise - # reference: Fast parallel surface and solid voxelization on GPUs, M. Schwarz, H-P. Siedel, https://dl.acm.org/doi/10.1145/1882261.1866201 @wp.func def triangle_box_intersect( low: wp.vec3f, @@ -116,7 +122,7 @@ def triangle_box_intersect( for ax0 in range(0, 3): ax1 = (ax0 + 1) % 3 for i in range(0, 3): - intersect = intersect and (normal_edge0[i][ax0] * low[ax0] + normal_edge1[i][ax0] * low[ax1] + dist_edge[i][ax0] >= 0.0) + intersect = intersect and (normal_edge0[i, ax0] * low[ax0] + normal_edge1[i, ax0] * low[ax1] + dist_edge[i, ax0] >= 0.0) return intersect else: @@ -147,15 +153,11 @@ def mesh_voxel_intersect(mesh_id: wp.uint64, low: wp.vec3): return False - # Construct the warp kernel - # Do voxelization mesh query (warp.mesh_query_aabb) to find solid voxels - # - this gives an approximate 1 voxel thick surface around mesh @wp.kernel - def kernel( - mesh_id: wp.uint64, + def resolve_out_of_bound_kernel( id_number: wp.int32, bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), + missing_mask: wp.array4d(dtype=wp.uint8), ): # get index i, j, k = wp.tid() @@ -163,74 +165,81 @@ def kernel( # Get local indices index = wp.vec3i(i, j, k) - # position of the point - pos_bc_cell = index_to_position(index) - half = wp.vec3(0.5, 0.5, 0.5) + # domain shape to check for out of bounds + grid_shape = wp.vec3i(bc_mask.shape[1], bc_mask.shape[2], bc_mask.shape[3]) - if mesh_voxel_intersect(mesh_id=mesh_id, low=pos_bc_cell - half): - # Make solid voxel - bc_mask[0, index[0], index[1], index[2]] = wp.uint8(255) - else: - # Find the fractional distance to the mesh in each direction + # Find the fractional distance to the mesh in each direction + if bc_mask[0, index[0], index[1], index[2]] == wp.uint8(id_number): for l in range(1, _q): - _dir = wp.vec3f(_c_float[0, l], _c_float[1, l], _c_float[2, l]) - - # Check to see if this neighbor is solid - this is super inefficient TODO: make it way better - if mesh_voxel_intersect(mesh_id=mesh_id, low=pos_bc_cell + _dir - half): - # We know we have a solid neighbor - # Set the boundary id and missing_mask - bc_mask[0, index[0], index[1], index[2]] = wp.uint8(id_number) - missing_mask[_opp_indices[l], index[0], index[1], index[2]] = True + # Ensuring out of bound pull indices are properly considered in the missing_mask + if out_of_bound_pull_index(l, index, missing_mask, grid_shape): + missing_mask[l, index[0], index[1], index[2]] = wp.uint8(True) - return None, kernel + # Construct some helper warp functions + self.mesh_voxel_intersect = mesh_voxel_intersect + self.resolve_out_of_bound_kernel = resolve_out_of_bound_kernel - @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation( + def _prepare_kernel_inputs( self, bc, bc_mask, - missing_mask, ): assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!" assert bc.mesh_vertices.shape[1] == self.velocity_set.d, ( "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" ) - mesh_vertices = bc.mesh_vertices - id_number = bc.id - # Check mesh extents against domain dimensions - domain_shape = bc_mask.shape[1:] # (nx, ny, nz) + grid_shape = self.helper_masker.get_grid_shape(bc_mask) # (nx, ny, nz) + mesh_vertices = bc.mesh_vertices mesh_min = np.min(mesh_vertices, axis=0) mesh_max = np.max(mesh_vertices, axis=0) - if any(mesh_min < 0) or any(mesh_max >= domain_shape): + if any(mesh_min < 0) or any(mesh_max >= grid_shape): raise ValueError( - f"Mesh extents ({mesh_min}, {mesh_max}) exceed domain dimensions {domain_shape}. The mesh must be fully contained within the domain." + f"Mesh extents ({mesh_min}, {mesh_max}) exceed domain dimensions {grid_shape}. The mesh must be fully contained within the domain." ) # We are done with bc.mesh_vertices. Remove them from BC objects bc.__dict__.pop("mesh_vertices", None) - # Ensure this masker is called only for BCs that need implicit distance to the mesh - assert not bc.needs_mesh_distance, 'Please use "MeshDistanceBoundaryMasker" if this BC needs mesh distance!' - mesh_indices = np.arange(mesh_vertices.shape[0]) mesh = wp.Mesh( points=wp.array(mesh_vertices, dtype=wp.vec3), - indices=wp.array(mesh_indices, dtype=int), + indices=wp.array(mesh_indices, dtype=wp.int32), ) + mesh_id = wp.uint64(mesh.id) + bc_id = bc.id + return mesh_id, bc_id + + @Operator.register_backend(ComputeBackend.JAX) + def jax_implementation( + self, + bc, + bc_mask, + missing_mask, + ): + raise NotImplementedError(f"Operation {self.__class__.__name__} not implemented in JAX!") - # Launch the warp kernel + def warp_implementation_base( + self, + bc, + distances, + bc_mask, + missing_mask, + ): + # Prepare inputs + mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask) + + # Launch the appropriate warp kernel wp.launch( self.warp_kernel, - inputs=[ - mesh.id, - id_number, - bc_mask, - missing_mask, - ], + inputs=[mesh_id, bc_id, distances, bc_mask, missing_mask, wp.static(bc.needs_mesh_distance)], dim=bc_mask.shape[1:], ) - - return bc_mask, missing_mask + wp.launch( + self.resolve_out_of_bound_kernel, + inputs=[bc_id, bc_mask, missing_mask], + dim=bc_mask.shape[1:], + ) + return distances, bc_mask, missing_mask diff --git a/xlb/operator/boundary_masker/mesh_voxelization_method.py b/xlb/operator/boundary_masker/mesh_voxelization_method.py new file mode 100644 index 00000000..b0162de7 --- /dev/null +++ b/xlb/operator/boundary_masker/mesh_voxelization_method.py @@ -0,0 +1,55 @@ +""" +Mesh voxelization method registry. + +Defines the available voxelization strategies (AABB, Ray, AABB-Close, +Winding) and provides a factory function to create the corresponding +:class:`VoxelizationMethod` data object. +""" + +from dataclasses import dataclass + + +METHODS = { + "AABB": 1, + "RAY": 2, + "AABB_CLOSE": 3, + "WINDING": 4, +} + + +@dataclass +class VoxelizationMethod: + """Describes a mesh voxelization strategy. + + Attributes + ---------- + id : int + Numeric identifier for the method. + name : str + Human-readable name (``"AABB"``, ``"RAY"``, etc.). + options : dict + Extra options (e.g. ``close_voxels`` for AABB_CLOSE). + """ + + id: int + name: str + options: dict + + +def MeshVoxelizationMethod(name: str, **options): + """Create a :class:`VoxelizationMethod` by name. + + Parameters + ---------- + name : str + One of ``"AABB"``, ``"RAY"``, ``"AABB_CLOSE"``, ``"WINDING"``. + **options + Additional keyword arguments forwarded to + ``VoxelizationMethod.options``. + + Returns + ------- + VoxelizationMethod + """ + assert name in METHODS.keys(), f"Unsupported voxelization method: {name}" + return VoxelizationMethod(METHODS[name], name, options) diff --git a/xlb/operator/boundary_masker/multires_aabb.py b/xlb/operator/boundary_masker/multires_aabb.py new file mode 100644 index 00000000..f9cc5887 --- /dev/null +++ b/xlb/operator/boundary_masker/multires_aabb.py @@ -0,0 +1,99 @@ +""" +Multi-resolution AABB mesh-based boundary masker for the Neon backend. +""" + +import warp as wp +from typing import Any +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.boundary_masker import MeshMaskerAABB +from xlb.operator.operator import Operator + + +class MultiresMeshMaskerAABB(MeshMaskerAABB): + """ + Operator for creating boundary missing_mask from mesh using Axis-Aligned Bounding Box (AABB) voxelization in multiresolution simulations. + + This implementation uses warp.mesh_query_aabb for efficient mesh-voxel intersection testing, + providing approximate 1-voxel thick surface detection around the mesh geometry. + Suitable for scenarios where fast, approximate boundary detection is sufficient. + TODO@Hesam: + Right now, we cannot properly mask a mesh file if it lives on any level other than the finest. This issue can be easily solved by adding + gx = wp.neon_get_x(cIdx) // 2 ** level + gy = wp.neon_get_y(cIdx) // 2 ** level + gz = wp.neon_get_z(cIdx) // 2 ** level + to the "neon_index_to_warp" and subsequently add "level" to the arguments of "index_to_position_neon", "get_pull_index_neon" and + "is_in_bc_indices_neon". In order to extract "level" from the "neon_field_hdl" we can use the function wp.neon_level(neon_field_hdl). + """ + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]: + raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.") + + def _construct_neon(self): + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + + @neon.Container.factory(name="MeshMaskerAABB") + def container( + mesh_id: Any, + id_number: Any, + distances: Any, + bc_mask: Any, + missing_mask: Any, + needs_mesh_distance: Any, + level: Any, + ): + def aabb_launcher(loader: neon.Loader): + loader.set_mres_grid(bc_mask.get_grid(), level) + distances_pn = loader.get_mres_write_handle(distances) + bc_mask_pn = loader.get_mres_write_handle(bc_mask) + missing_mask_pn = loader.get_mres_write_handle(missing_mask) + + @wp.func + def aabb_kernel(index: Any): + # apply the functional + functional( + index, + mesh_id, + id_number, + distances_pn, + bc_mask_pn, + missing_mask_pn, + needs_mesh_distance, + ) + + loader.declare_kernel(aabb_kernel) + + return aabb_launcher + + return functional, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation( + self, + bc, + distances, + bc_mask, + missing_mask, + stream=0, + ): + import neon + + # Prepare inputs + mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask) + + grid = bc_mask.get_grid() + for level in range(grid.num_levels): + # Launch the neon container + c = self.neon_container(mesh_id, bc_id, distances, bc_mask, missing_mask, wp.static(bc.needs_mesh_distance), level) + c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon) + return distances, bc_mask, missing_mask diff --git a/xlb/operator/boundary_masker/multires_aabb_close.py b/xlb/operator/boundary_masker/multires_aabb_close.py new file mode 100644 index 00000000..df461706 --- /dev/null +++ b/xlb/operator/boundary_masker/multires_aabb_close.py @@ -0,0 +1,273 @@ +""" +Multi-resolution AABB-Close boundary masker with morphological closing. + +Extends the AABB-Close masker for Neon multi-resolution grids, applying +dilate-then-erode operations to fill narrow channels with solid voxels. +""" + +import warp as wp +from typing import Any +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.boundary_masker import MeshMaskerAABBClose +from xlb.operator.operator import Operator +from xlb.cell_type import BC_SOLID + + +class MultiresMeshMaskerAABBClose(MeshMaskerAABBClose): + """ + Operator for creating boundary missing_mask from mesh using Axis-Aligned Bounding Box (AABB) voxelization + in multiresolution simulations (NEON backend). It takes in a number of close_voxels to perform morphological + operations (dilate followed by erode) to ensure small channels are filled with solid voxels. + + This version provides NEON-specific functionals working on multires partitions (mPartition) and bIndex. + """ + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + close_voxels: int = None, + ): + super().__init__(velocity_set, precision_policy, compute_backend, close_voxels) + if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]: + raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.") + + # Build and store NEON dicts + self.neon_functional_dict, self.neon_container_dict = self._construct_neon() + + def _construct_neon(self): + import neon + + # Use the warp functionals from the base (for reference), but implement NEON variants here + functional_dict_warp, _ = self._construct_warp() + functional_erode_warp = functional_dict_warp.get("functional_erode") + functional_dilate_warp = functional_dict_warp.get("functional_dilate") + functional_solid = functional_dict_warp.get("functional_solid") + # We will not directly reuse functional_solid / functional_aabb from warp; we write NEON-specific ones. + + # We also need lattice info for neighbor iteration + _c = self.velocity_set.c + _q = self.velocity_set.q + _opp_indices = self.velocity_set.opp_indices + + # Set local constants + lattice_central_index = self.velocity_set.center_index + + # Main AABB close: sets bc_mask, missing_mask, distances based on solid_mask + # bc_mask: wp.uint8, missing_mask: wp.uint8, distances: dtype from precision policy (float) + @wp.func + def mres_functional_aabb( + index: Any, + mesh_id: wp.uint64, + id_number: wp.int32, + distances_pn: Any, # mPartition(dtype=distance type), cardinality=_q + bc_mask_pn: Any, # mPartition_uint8, cardinality=1 + missing_mask_pn: Any, # mPartition_uint8, cardinality=_q + solid_mask_pn: Any, # mPartition_uint8, cardinality=1 + needs_mesh_distance: bool, + ): + # Cell center from bc_mask partition + cell_center = self.helper_masker.index_to_position(bc_mask_pn, index) + + # If already solid or bc, mark solid + solid_val = wp.neon_read(solid_mask_pn, index, 0) + bc_val = wp.neon_read(bc_mask_pn, index, 0) + if solid_val == wp.uint8(BC_SOLID) or bc_val == wp.uint8(BC_SOLID): + wp.neon_write(bc_mask_pn, index, 0, wp.uint8(BC_SOLID)) + return + + # loop lattice directions + for direction_idx in range(_q): + # skip central if provided by velocity set + if direction_idx == lattice_central_index: + continue + + # If neighbor index is valid at this resolution level + ngh = wp.neon_ngh_idx(wp.int8(_c[0, direction_idx]), wp.int8(_c[1, direction_idx]), wp.int8(_c[2, direction_idx])) + is_valid = wp.bool(False) + nval = wp.neon_read_ngh(solid_mask_pn, index, ngh, 0, wp.uint8(0), is_valid) + if is_valid: + if nval == wp.uint8(BC_SOLID): + # Found solid neighbor -> boundary cell + self.write_field(bc_mask_pn, index, 0, wp.uint8(id_number)) + self.write_field(missing_mask_pn, index, _opp_indices[direction_idx], wp.uint8(True)) + + if not needs_mesh_distance: + # No distance needed; continue to next direction + continue + + # Compute mesh distance along lattice direction + dir_vec = wp.vec3f( + wp.float32(_c[0, direction_idx]), + wp.float32(_c[1, direction_idx]), + wp.float32(_c[2, direction_idx]), + ) + max_length = wp.length(dir_vec) + # Avoid division by zero for any pathological dir (shouldn't happen) + norm_dir = dir_vec / (max_length if max_length > 0.0 else 1.0) + query = wp.mesh_query_ray(mesh_id, cell_center, norm_dir, 1.5 * max_length) + if query.result: + pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) + dist = wp.length(pos_mesh - cell_center) - 0.5 * max_length + weight = dist / (max_length if max_length > 0.0 else 1.0) + # distances has cardinality _q; store into this channel + self.write_field(distances_pn, index, direction_idx, self.store_dtype(weight)) + else: + self.write_field(distances_pn, index, direction_idx, self.store_dtype(1.0)) + + # Containers + + # Erode: f_field -> f_field_out + @neon.Container.factory(name="Erode") + def container_erode(f_field: wp.array3d(dtype=Any), f_field_out: wp.array3d(dtype=Any), level: int): + def erode_launcher(loader: neon.Loader): + loader.set_mres_grid(f_field.get_grid(), level) + f_field_pn = loader.get_mres_read_handle(f_field) + f_field_out_pn = loader.get_mres_write_handle(f_field_out) + + @wp.func + def erode_kernel(index: Any): + functional_erode_warp(index, f_field_pn, f_field_out_pn) + + loader.declare_kernel(erode_kernel) + + return erode_launcher + + # Dilate: f_field -> f_field_out + @neon.Container.factory(name="Dilate") + def container_dilate(f_field: wp.array3d(dtype=Any), f_field_out: wp.array3d(dtype=Any), level: int): + def dilate_launcher(loader: neon.Loader): + loader.set_mres_grid(f_field.get_grid(), level) + f_field_pn = loader.get_mres_read_handle(f_field) + f_field_out_pn = loader.get_mres_write_handle(f_field_out) + + @wp.func + def dilate_kernel(index: Any): + functional_dilate_warp(index, f_field_pn, f_field_out_pn) + + loader.declare_kernel(dilate_kernel) + + return dilate_launcher + + # Solid mask: voxelize mesh into solid_mask + @neon.Container.factory(name="Solid") + def container_solid(mesh_id: wp.uint64, solid_mask: wp.array3d(dtype=wp.uint8), level: int): + def solid_launcher(loader: neon.Loader): + loader.set_mres_grid(solid_mask.get_grid(), level) + solid_mask_pn = loader.get_mres_write_handle(solid_mask) + + @wp.func + def solid_kernel(index: Any): + # apply the functional + functional_solid(index, mesh_id, solid_mask_pn, wp.vec3f(0.0, 0.0, 0.0)) + + loader.declare_kernel(solid_kernel) + + return solid_launcher + + # Main AABB container + @neon.Container.factory(name="MeshMaskerAABBClose") + def container( + mesh_id: Any, + id_number: Any, + distances: Any, + bc_mask: Any, + missing_mask: Any, + solid_mask: Any, + needs_mesh_distance: Any, + level: Any, + ): + def aabb_launcher(loader: neon.Loader): + loader.set_mres_grid(bc_mask.get_grid(), level) + distances_pn = loader.get_mres_write_handle(distances) + bc_mask_pn = loader.get_mres_write_handle(bc_mask) + missing_mask_pn = loader.get_mres_write_handle(missing_mask) + solid_mask_pn = loader.get_mres_write_handle(solid_mask) + + @wp.func + def aabb_kernel(index: Any): + mres_functional_aabb( + index, + mesh_id, + id_number, + distances_pn, + bc_mask_pn, + missing_mask_pn, + solid_mask_pn, + needs_mesh_distance, + ) + + loader.declare_kernel(aabb_kernel) + + return aabb_launcher + + container_dict = { + "container_erode": container_erode, + "container_dilate": container_dilate, + "container_solid": container_solid, + "container_aabb": container, + } + + # Expose NEON functionals too (in case callers want to reuse) + functional_dict = { + "mres_functional_aabb": mres_functional_aabb, + } + + return functional_dict, container_dict + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation( + self, + bc, + distances, + bc_mask, + missing_mask, + stream=0, + ): + # Prepare inputs + mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask) + + grid = bc_mask.get_grid() + # Create fields using new_field + solid_mask = grid.new_field(cardinality=1, dtype=wp.uint8, memory_type=neon.MemoryType.device()) + solid_mask_out = grid.new_field( + cardinality=1, + dtype=wp.uint8, + memory_type=neon.MemoryType.device(), + # memory_type=neon.MemoryType.host_device() + ) + + for level in range(grid.num_levels): + # Initialize to 0 + solid_mask.fill_run(level=level, value=wp.uint8(0), stream_idx=stream) + solid_mask_out.fill_run(level=level, value=wp.uint8(0), stream_idx=stream) + + # Launch the neon containers + container_solid = self.neon_container_dict["container_solid"](mesh_id, solid_mask, level) + container_solid.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + + for _ in range(self.close_voxels): + container_dilate = self.neon_container_dict["container_dilate"](solid_mask, solid_mask_out, level) + container_dilate.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + solid_mask, solid_mask_out = solid_mask_out, solid_mask + + if self.close_voxels % 2 > 0: + solid_mask, solid_mask_out = solid_mask_out, solid_mask + + for _ in range(self.close_voxels): + container_erode = self.neon_container_dict["container_erode"](solid_mask_out, solid_mask, level) + container_erode.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + solid_mask, solid_mask_out = solid_mask_out, solid_mask + + if self.close_voxels % 2 > 0: + solid_mask, solid_mask_out = solid_mask_out, solid_mask + + container_aabb = self.neon_container_dict["container_aabb"]( + mesh_id, bc_id, distances, bc_mask, missing_mask, solid_mask, wp.static(bc.needs_mesh_distance), level + ) + container_aabb.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + + return distances, bc_mask, missing_mask diff --git a/xlb/operator/boundary_masker/multires_indices_boundary_masker.py b/xlb/operator/boundary_masker/multires_indices_boundary_masker.py new file mode 100644 index 00000000..bf7fc1d7 --- /dev/null +++ b/xlb/operator/boundary_masker/multires_indices_boundary_masker.py @@ -0,0 +1,210 @@ +""" +Multi-resolution indices-based boundary masker for the Neon backend. + +Creates boundary masks from explicit voxel indices on multi-resolution +grids, computing missing-population masks for each tagged voxel. +""" + +from typing import Any +import copy +import numpy as np + +import warp as wp + +from xlb.operator.operator import Operator +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.boundary_masker import IndicesBoundaryMasker + + +class MultiresIndicesBoundaryMasker(IndicesBoundaryMasker): + """ + Operator for creating a boundary mask using indices of boundary conditions in a multi-resolution setting. + """ + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]: + raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.") + + def _construct_neon(self): + # Use the warp functional for the NEON backend + functional_dict, _ = self._construct_warp() + functional_domain_bounds = functional_dict.get("functional_domain_bounds") + functional_interior_bc_mask = functional_dict.get("functional_interior_bc_mask") + functional_interior_missing_mask = functional_dict.get("functional_interior_missing_mask") + + @neon.Container.factory(name="IndicesBoundaryMasker_DomainBounds") + def container_domain_bounds( + wp_bc_indices, + wp_id_numbers, + wp_is_interior, + bc_mask, + missing_mask, + grid_shape, + level, + ): + def domain_bounds_launcher(loader: neon.Loader): + loader.set_mres_grid(bc_mask.get_grid(), level) + bc_mask_pn = loader.get_mres_write_handle(bc_mask) + missing_mask_pn = loader.get_mres_write_handle(missing_mask) + + @wp.func + def domain_bounds_kernel(index: Any): + # apply the functional + functional_domain_bounds( + index, + wp_bc_indices, + wp_id_numbers, + wp_is_interior, + bc_mask_pn, + missing_mask_pn, + grid_shape, + level, + ) + + loader.declare_kernel(domain_bounds_kernel) + + return domain_bounds_launcher + + @neon.Container.factory(name="IndicesBoundaryMasker_InteriorBcMask") + def container_interior_bc_mask( + wp_bc_indices, + wp_id_numbers, + bc_mask, + level, + ): + def interior_bc_mask_launcher(loader: neon.Loader): + loader.set_mres_grid(bc_mask.get_grid(), level) + bc_mask_pn = loader.get_mres_write_handle(bc_mask) + + @wp.func + def interior_bc_mask_kernel(index: Any): + # apply the functional + functional_interior_bc_mask( + index, + wp_bc_indices, + wp_id_numbers, + bc_mask_pn, + ) + + loader.declare_kernel(interior_bc_mask_kernel) + + return interior_bc_mask_launcher + + @neon.Container.factory(name="IndicesBoundaryMasker_InteriorMissingMask") + def container_interior_missing_mask( + wp_bc_indices, + bc_mask, + missing_mask, + grid_shape, + level, + ): + def interior_bc_mask_launcher(loader: neon.Loader): + loader.set_mres_grid(bc_mask.get_grid(), level) + bc_mask_pn = loader.get_mres_write_handle(bc_mask) + missing_mask_pn = loader.get_mres_write_handle(missing_mask) + + @wp.func + def interior_missing_mask_kernel(index: Any): + # apply the functional + functional_interior_missing_mask( + index, + wp_bc_indices, + bc_mask_pn, + missing_mask_pn, + grid_shape, + level, + ) + + loader.declare_kernel(interior_missing_mask_kernel) + + return interior_bc_mask_launcher + + container_dict = { + "container_domain_bounds": container_domain_bounds, + "container_interior_bc_mask": container_interior_bc_mask, + "container_interior_missing_mask": container_interior_missing_mask, + } + + return functional_dict, container_dict + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, bclist, bc_mask, missing_mask, start_index=None): + import neon + + grid = bc_mask.get_grid() + num_levels = grid.num_levels + grid_shape_finest = self.helper_masker.get_grid_shape(bc_mask) + for level in range(num_levels): + # Create a copy of the boundary condition list for the current level if the indices at that level are not empty + bclist_at_level = [] + for bc in bclist: + if bc.indices is not None and bc.indices[level]: + bc_copy = copy.copy(bc) # shallow copy of the whole object + indices = copy.deepcopy(bc.indices[level]) # deep copy only the modified part + indices = np.array(indices) * 2**level # TODO: This is a hack + bc_copy.indices = tuple(indices.tolist()) # convert to tuple + bclist_at_level.append(bc_copy) + + # If the boundary condition list is empty, skip to the next level + if not bclist_at_level: + continue + + # find grid shape at current level + # TODO: this is a hack. Should be corrected in the helper function when getting neon global indices + grid_shape_at_level = tuple([shape // 2**level for shape in grid_shape_finest]) + grid_shape_finest_warp = wp.vec3i(*grid_shape_finest) + + # find interior boundary conditions + bc_interior = self._find_bclist_interior(bclist_at_level, grid_shape_at_level) + + # Prepare the first kernel inputs for all items in boundary condition list + wp_bc_indices, wp_id_numbers, wp_is_interior = self._prepare_kernel_inputs(bclist_at_level, grid_shape_at_level) + + # Launch the first container + container_domain_bounds = self.neon_container["container_domain_bounds"]( + wp_bc_indices, + wp_id_numbers, + wp_is_interior, + bc_mask, + missing_mask, + grid_shape_finest_warp, + level, + ) + container_domain_bounds.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + + # If there are no interior boundary conditions, skip the rest of the processing for this level + if not bc_interior: + continue + + # Prepare the second and third kernel inputs for only a subset of boundary conditions associated with the interior + # Note 1: launching order of the following kernels are important here! + # Note 2: Due to race conditioning, the two kernels cannot be fused together. + wp_bc_indices, wp_id_numbers, _ = self._prepare_kernel_inputs(bc_interior, grid_shape_at_level) + container_interior_missing_mask = self.neon_container["container_interior_missing_mask"]( + wp_bc_indices, + bc_mask, + missing_mask, + grid_shape_finest_warp, + level, + ) + container_interior_missing_mask.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + + # Launch the third container + container_interior_bc_mask = self.neon_container["container_interior_bc_mask"]( + wp_bc_indices, + wp_id_numbers, + bc_mask, + level, + ) + container_interior_bc_mask.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + + return bc_mask, missing_mask diff --git a/xlb/operator/boundary_masker/multires_ray.py b/xlb/operator/boundary_masker/multires_ray.py new file mode 100644 index 00000000..9a5a01d8 --- /dev/null +++ b/xlb/operator/boundary_masker/multires_ray.py @@ -0,0 +1,90 @@ +""" +Multi-resolution ray-cast mesh-based boundary masker for the Neon backend. +""" + +import warp as wp +from typing import Any +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.boundary_masker import MeshMaskerRay +from xlb.operator.operator import Operator + + +class MultiresMeshMaskerRay(MeshMaskerRay): + """ + Operator for creating a boundary missing_mask from an STL file in multiresolution simulations. + + This implementation uses warp.mesh_query_ray for efficient mesh-voxel intersection testing. + """ + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]: + raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.") + + def _construct_neon(self): + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + + @neon.Container.factory(name="MeshMaskerRay") + def container( + mesh_id: Any, + id_number: Any, + distances: Any, + bc_mask: Any, + missing_mask: Any, + needs_mesh_distance: Any, + level: Any, + ): + def ray_launcher(loader: neon.Loader): + loader.set_mres_grid(bc_mask.get_grid(), level) + distances_pn = loader.get_mres_write_handle(distances) + bc_mask_pn = loader.get_mres_write_handle(bc_mask) + missing_mask_pn = loader.get_mres_write_handle(missing_mask) + + @wp.func + def ray_kernel(index: Any): + # apply the functional + functional( + index, + mesh_id, + id_number, + distances_pn, + bc_mask_pn, + missing_mask_pn, + needs_mesh_distance, + ) + + loader.declare_kernel(ray_kernel) + + return ray_launcher + + return functional, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation( + self, + bc, + distances, + bc_mask, + missing_mask, + stream=0, + ): + import neon + + # Prepare inputs + mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask) + + grid = bc_mask.get_grid() + for level in range(grid.num_levels): + # Launch the neon container + c = self.neon_container(mesh_id, bc_id, distances, bc_mask, missing_mask, wp.static(bc.needs_mesh_distance), level) + c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon) + return distances, bc_mask, missing_mask diff --git a/xlb/operator/boundary_masker/ray.py b/xlb/operator/boundary_masker/ray.py new file mode 100644 index 00000000..f5a2d96b --- /dev/null +++ b/xlb/operator/boundary_masker/ray.py @@ -0,0 +1,175 @@ +""" +Ray-cast mesh-based boundary masker. + +Voxelizes a mesh file by casting rays along each lattice direction using +``warp.mesh_query_ray`` to detect surface crossings. +""" + +import warp as wp +from typing import Any +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker +from xlb.operator.operator import Operator + + +class MeshMaskerRay(MeshBoundaryMasker): + """ + Operator for creating a boundary missing_mask from a mesh file + """ + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.c + _q = self.velocity_set.q + _opp_indices = self.velocity_set.opp_indices + + # Set local constants + lattice_central_index = self.velocity_set.center_index + + @wp.func + def functional( + index: Any, + mesh_id: Any, + id_number: Any, + distances: Any, + bc_mask: Any, + missing_mask: Any, + needs_mesh_distance: Any, + ): + # position of the point + cell_center_pos = self.helper_masker.index_to_position(bc_mask, index) + + # Find the fractional distance to the mesh in each direction + for direction_idx in range(_q): + if direction_idx == lattice_central_index: + # Skip the central index as it is not relevant for boundary masking + continue + + direction_vec = wp.vec3f(wp.float32(_c[0, direction_idx]), wp.float32(_c[1, direction_idx]), wp.float32(_c[2, direction_idx])) + # Max length depends on ray direction (diagonals are longer) + max_length = wp.length(direction_vec) + query = wp.mesh_query_ray(mesh_id, cell_center_pos, direction_vec / max_length, max_length) + if query.result: + # Set the boundary id and missing_mask + self.write_field(bc_mask, index, 0, wp.uint8(id_number)) + self.write_field(missing_mask, index, _opp_indices[direction_idx], wp.uint8(True)) + + # If we don't need the mesh distance, we can return early + if not needs_mesh_distance: + continue + + # get position of the mesh triangle that intersects with the ray + pos_mesh = wp.mesh_eval_position(mesh_id, query.face, query.u, query.v) + dist = wp.length(pos_mesh - cell_center_pos) + weight = self.store_dtype(dist / max_length) + self.write_field(distances, index, direction_idx, self.store_dtype(weight)) + + @wp.kernel + def kernel( + mesh_id: wp.uint64, + id_number: wp.int32, + distances: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.uint8), + needs_mesh_distance: bool, + ): + # get index + i, j, k = wp.tid() + + # Get local indices + index = wp.vec3i(i, j, k) + + # apply the functional + functional( + index, + mesh_id, + id_number, + distances, + bc_mask, + missing_mask, + needs_mesh_distance, + ) + + return functional, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, + bc, + distances, + bc_mask, + missing_mask, + ): + return self.warp_implementation_base( + bc, + distances, + bc_mask, + missing_mask, + ) + + def _construct_neon(self): + import neon + + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + + @neon.Container.factory(name="MeshMaskerRay") + def container( + mesh_id: Any, + id_number: Any, + distances: Any, + bc_mask: Any, + missing_mask: Any, + needs_mesh_distance: Any, + ): + def ray_launcher(loader: neon.Loader): + loader.set_grid(bc_mask.get_grid()) + bc_mask_pn = loader.get_write_handle(bc_mask) + missing_mask_pn = loader.get_write_handle(missing_mask) + distances_pn = loader.get_write_handle(distances) + + @wp.func + def ray_kernel(index: Any): + # apply the functional + functional( + index, + mesh_id, + id_number, + distances_pn, + bc_mask_pn, + missing_mask_pn, + needs_mesh_distance, + ) + + loader.declare_kernel(ray_kernel) + + return ray_launcher + + return functional, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation( + self, + bc, + distances, + bc_mask, + missing_mask, + ): + # Prepare inputs + mesh_id, bc_id = self._prepare_kernel_inputs(bc, bc_mask) + + # Launch the appropriate neon container + c = self.neon_container(mesh_id, bc_id, distances, bc_mask, missing_mask, wp.static(bc.needs_mesh_distance)) + c.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + return distances, bc_mask, missing_mask diff --git a/xlb/operator/boundary_masker/winding.py b/xlb/operator/boundary_masker/winding.py new file mode 100644 index 00000000..1510f3a5 --- /dev/null +++ b/xlb/operator/boundary_masker/winding.py @@ -0,0 +1,115 @@ +""" +Winding-number mesh-based boundary masker. + +Uses the generalized winding-number test (``warp.mesh_query_point``) to +classify voxels as inside or outside the mesh, providing a +solid-detection method even for non-watertight geometries. +""" + +import warp as wp +from typing import Any +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.boundary_masker.mesh_boundary_masker import MeshBoundaryMasker +from xlb.operator.operator import Operator +from xlb.cell_type import BC_SOLID + + +class MeshMaskerWinding(MeshBoundaryMasker): + """ + Operator for creating a boundary missing_mask from a mesh file + """ + + def __init__( + self, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + # Call super + super().__init__(velocity_set, precision_policy, compute_backend) + assert self.compute_backend != ComputeBackend.NEON, ( + 'MeshVoxelizationMethod("WINDING") is not implemented in Neon yet! Please use a different method of mesh voxelization!' + ) + + def _construct_warp(self): + # Make constants for warp + _c = self.velocity_set.c + _q = self.velocity_set.q + _opp_indices = self.velocity_set.opp_indices + + @wp.kernel + def kernel( + mesh_id: wp.uint64, + id_number: wp.int32, + distances: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.uint8), + needs_mesh_distance: bool, + ): + # get index + i, j, k = wp.tid() + + # Get local indices + index = wp.vec3i(i, j, k) + + # position of the point + pos_cell = self.helper_masker.index_to_position(bc_mask, index) + + # Compute the maximum length + max_length = wp.sqrt( + (wp.float32(bc_mask.shape[1])) ** 2.0 + (wp.float32(bc_mask.shape[2])) ** 2.0 + (wp.float32(bc_mask.shape[3])) ** 2.0 + ) + + # evaluate if point is inside mesh + query = wp.mesh_query_point_sign_winding_number(mesh_id, pos_cell, max_length) + if query.result: + # set point to be solid + if query.sign <= 0: # TODO: fix this + # Make solid voxel + bc_mask[0, index[0], index[1], index[2]] = wp.uint8(BC_SOLID) + + # Find the fractional distance to the mesh in each direction + for direction_idx in range(1, _q): + direction_vec = wp.vec3f(wp.float32(_c[0, direction_idx]), wp.float32(_c[1, direction_idx]), wp.float32(_c[2, direction_idx])) + # Max length depends on ray direction (diagonals are longer) + max_length = wp.length(direction_vec) + query_dir = wp.mesh_query_ray(mesh_id, pos_cell, direction_vec / max_length, max_length) + if query_dir.result: + # Get the index of the streaming direction + push_index = wp.vec3i() + for d in range(self.velocity_set.d): + push_index[d] = index[d] + _c[d, direction_idx] + + # Set the boundary id and missing_mask + bc_mask[0, push_index[0], push_index[1], push_index[2]] = wp.uint8(id_number) + missing_mask[direction_idx, push_index[0], push_index[1], push_index[2]] = wp.uint8(True) + + # If we don't need the mesh distance, we can return early + if not needs_mesh_distance: + continue + + # get position of the mesh triangle that intersects with the ray + pos_mesh = wp.mesh_eval_position(mesh_id, query_dir.face, query_dir.u, query_dir.v) + cell_center_pos = self.helper_masker.index_to_position(bc_mask, push_index) + dist = wp.length(pos_mesh - cell_center_pos) + weight = self.store_dtype(dist / max_length) + distances[_opp_indices[direction_idx], push_index[0], push_index[1], push_index[2]] = weight + + return None, kernel + + @Operator.register_backend(ComputeBackend.WARP) + def warp_implementation( + self, + bc, + distances, + bc_mask, + missing_mask, + ): + return self.warp_implementation_base( + bc, + distances, + bc_mask, + missing_mask, + ) diff --git a/xlb/operator/collision/bgk.py b/xlb/operator/collision/bgk.py index ac2da2e0..29331dc7 100644 --- a/xlb/operator/collision/bgk.py +++ b/xlb/operator/collision/bgk.py @@ -1,3 +1,7 @@ +""" +Bhatnagar-Gross-Krook (BGK) single-relaxation-time collision operator. +""" + import jax.numpy as jnp from jax import jit import warp as wp @@ -10,13 +14,19 @@ class BGK(Collision): - """ - BGK collision operator for LBM. + """Single-relaxation-time BGK collision operator. + + Relaxes the distribution function toward equilibrium at a rate + controlled by the relaxation parameter *omega*:: + + f_out = f - omega * (f - f_eq) + + Supports JAX, Warp, and Neon backends. """ @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,)) - def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega): + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, omega): fneq = f - feq fout = f - self.compute_dtype(omega) * fneq return fout @@ -28,7 +38,7 @@ def _construct_warp(self): # Construct the functional @wp.func - def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any): + def functional(f: Any, feq: Any, omega: Any): fneq = f - feq fout = f - self.compute_dtype(omega) * fneq return fout @@ -39,8 +49,6 @@ def kernel( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), fout: wp.array4d(dtype=Any), - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), omega: Any, ): # Get the global index @@ -55,7 +63,7 @@ def kernel( _feq[l] = feq[l, index[0], index[1], index[2]] # Compute the collision - _fout = functional(_f, _feq, rho, u, omega) + _fout = functional(_f, _feq, omega) # Write the result for l in range(self.velocity_set.q): @@ -63,8 +71,12 @@ def kernel( return functional, kernel + def _construct_neon(self): + functional, _ = self._construct_warp() + return functional, None + @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout, rho, u, omega): + def warp_implementation(self, f, feq, fout, omega): # Launch the warp kernel wp.launch( self.warp_kernel, @@ -72,8 +84,6 @@ def warp_implementation(self, f, feq, fout, rho, u, omega): f, feq, fout, - rho, - u, omega, ], dim=f.shape[1:], diff --git a/xlb/operator/collision/forced_collision.py b/xlb/operator/collision/forced_collision.py index 80c9b0b2..55fa631c 100644 --- a/xlb/operator/collision/forced_collision.py +++ b/xlb/operator/collision/forced_collision.py @@ -1,3 +1,7 @@ +""" +Collision operator with external body-force correction. +""" + import jax.numpy as jnp from jax import jit import warp as wp @@ -11,8 +15,19 @@ class ForcedCollision(Collision): - """ - A collision operator for LBM with external force. + """Collision operator that wraps another collision with a forcing term. + + After the inner collision the forcing operator is applied to + incorporate the effect of an external body force. + + Parameters + ---------- + collision_operator : Operator + The base collision operator (e.g. :class:`BGK`). + forcing_scheme : str + Forcing scheme. Currently only ``"exact_difference"`` is supported. + force_vector : array-like + External force vector of length ``d`` (number of spatial dimensions). """ def __init__( @@ -33,9 +48,9 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,)) - def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho, u, omega): - fout = self.collision_operator(f, feq, rho, u, omega) - fout = self.forcing_operator(fout, feq, rho, u) + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, omega): + fout = self.collision_operator(f, feq, omega) + fout = self.forcing_operator(fout, feq) return fout def _construct_warp(self): @@ -45,9 +60,9 @@ def _construct_warp(self): # Construct the functional @wp.func - def functional(f: Any, feq: Any, rho: Any, u: Any, omega: Any): - fout = self.collision_operator.warp_functional(f, feq, rho, u, omega) - fout = self.forcing_operator.warp_functional(fout, feq, rho, u) + def functional(f: Any, feq: Any, omega: Any): + fout = self.collision_operator.warp_functional(f, feq, omega) + fout = self.forcing_operator.warp_functional(fout, feq) return fout # Construct the warp kernel @@ -56,8 +71,6 @@ def kernel( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), fout: wp.array4d(dtype=Any), - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), omega: Any, ): # Get the global index @@ -71,13 +84,9 @@ def kernel( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1], index[2]] _feq[l] = feq[l, index[0], index[1], index[2]] - _u = _u_vec() - for l in range(_d): - _u[l] = u[l, index[0], index[1], index[2]] - _rho = rho[0, index[0], index[1], index[2]] # Compute the collision - _fout = functional(_f, _feq, _rho, _u, omega) + _fout = functional(_f, _feq, omega) # Write the result for l in range(self.velocity_set.q): @@ -86,7 +95,7 @@ def kernel( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout, rho, u, omega): + def warp_implementation(self, f, feq, fout, omega): # Launch the warp kernel wp.launch( self.warp_kernel, @@ -94,8 +103,6 @@ def warp_implementation(self, f, feq, fout, rho, u, omega): f, feq, fout, - rho, - u, omega, ], dim=f.shape[1:], diff --git a/xlb/operator/collision/kbc.py b/xlb/operator/collision/kbc.py index 7724e07e..d814a7cf 100644 --- a/xlb/operator/collision/kbc.py +++ b/xlb/operator/collision/kbc.py @@ -43,8 +43,6 @@ def jax_implementation( self, f: jnp.ndarray, feq: jnp.ndarray, - rho: jnp.ndarray, - u: jnp.ndarray, omega, ): """ @@ -56,18 +54,14 @@ def jax_implementation( Distribution function. feq : jax.numpy.array Equilibrium distribution function. - rho : jax.numpy.array - Density. - u : jax.numpy.array - Velocity. """ fneq = f - feq if isinstance(self.velocity_set, D2Q9): shear = self.decompose_shear_d2q9_jax(fneq) - delta_s = shear * rho / 4.0 + delta_s = shear / 4.0 elif isinstance(self.velocity_set, D3Q27): shear = self.decompose_shear_d3q27_jax(fneq) - delta_s = shear * rho + delta_s = shear else: raise NotImplementedError("Velocity set not supported: {}".format(type(self.velocity_set))) @@ -269,18 +263,16 @@ def compute_entropic_scalar_products( def functional( f: Any, feq: Any, - rho: Any, - u: Any, omega: Any, ): # Compute shear and delta_s fneq = f - feq if wp.static(self.velocity_set.d == 3): shear = decompose_shear_d3q27(fneq) - delta_s = shear * rho + delta_s = shear else: shear = decompose_shear_d2q9(fneq) - delta_s = shear * rho / self.compute_dtype(4.0) + delta_s = shear / self.compute_dtype(4.0) # Compute required constants based on the input omega (omega is the inverse relaxation time) _beta = self.compute_dtype(0.5) * self.compute_dtype(omega) @@ -301,8 +293,6 @@ def kernel( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), fout: wp.array4d(dtype=Any), - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), omega: Any, ): # Get the global index @@ -316,13 +306,9 @@ def kernel( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1], index[2]] _feq[l] = feq[l, index[0], index[1], index[2]] - _u = _u_vec() - for l in range(_d): - _u[l] = u[l, index[0], index[1], index[2]] - _rho = rho[0, index[0], index[1], index[2]] # Compute the collision - _fout = functional(_f, _feq, _rho, _u, omega) + _fout = functional(_f, _feq, omega) # Write the result for l in range(self.velocity_set.q): @@ -330,8 +316,15 @@ def kernel( return functional, kernel + def _construct_neon(self): + # Redefine the momentum flux operator for the neon backend + # This is because the neon backend relies on the warp functionals for its operations. + self.momentum_flux = MomentumFlux(compute_backend=ComputeBackend.WARP) + functional, _ = self._construct_warp() + return functional, None + @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, fout, rho, u, omega): + def warp_implementation(self, f, feq, fout, omega): # Launch the warp kernel wp.launch( self.warp_kernel, @@ -339,8 +332,6 @@ def warp_implementation(self, f, feq, fout, rho, u, omega): f, feq, fout, - rho, - u, omega, ], dim=f.shape[1:], diff --git a/xlb/operator/collision/smagorinsky_les_bgk.py b/xlb/operator/collision/smagorinsky_les_bgk.py index 92ae23d2..aa552094 100644 --- a/xlb/operator/collision/smagorinsky_les_bgk.py +++ b/xlb/operator/collision/smagorinsky_les_bgk.py @@ -1,3 +1,7 @@ +""" +BGK collision operator with Smagorinsky large-eddy-simulation sub-grid model. +""" + import jax.numpy as jnp from jax import jit import warp as wp @@ -12,8 +16,19 @@ class SmagorinskyLESBGK(Collision): - """ - BGK collision operator for LBM with Smagorinsky LES model. + """BGK collision with Smagorinsky LES turbulence modelling. + + Adjusts the effective relaxation time based on the local strain rate + estimated from the non-equilibrium stress tensor, using the + Smagorinsky model constant *C_s*. + + Parameters + ---------- + velocity_set : VelocitySet, optional + precision_policy : PrecisionPolicy, optional + compute_backend : ComputeBackend, optional + smagorinsky_coef : float + Smagorinsky model constant (default 0.17). """ def __init__( @@ -28,7 +43,7 @@ def __init__( @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,)) - def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, u: jnp.ndarray, omega): + def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, omega): fneq = f - feq pi_neq = jnp.tensordot(self.velocity_set.cc, fneq, axes=(0, 0)) @@ -44,9 +59,7 @@ def jax_implementation(self, f: jnp.ndarray, feq: jnp.ndarray, rho: jnp.ndarray, tau0 = self.compute_dtype(1.0) / self.compute_dtype(omega) cs = self.compute_dtype(self.smagorinsky_coef) - tau = self.compute_dtype(0.5) * ( - tau0 + jnp.sqrt(tau0 * tau0 + self.compute_dtype(36.0) * (cs * cs) * jnp.sqrt(strain)) - ) + tau = self.compute_dtype(0.5) * (tau0 + jnp.sqrt(tau0 * tau0 + self.compute_dtype(36.0) * (cs * cs) * jnp.sqrt(strain))) omega_eff = self.compute_dtype(1.0) / tau fout = f - omega_eff[None, ...] * fneq @@ -67,38 +80,11 @@ def _construct_warp(self): def functional( f: Any, feq: Any, - rho: Any, - u: Any, omega: Any, ): # Compute the non-equilibrium distribution fneq = f - feq - # Sailfish implementation - # { - # float tmp, strain; - - # strain = 0.0f; - - # // Off-diagonal components, count twice for symmetry reasons. - # %for a in range(0, dim): - # %for b in range(a + 1, dim): - # tmp = ${cex(sym.ex_flux(grid, 'd0', a, b, config), pointers=True)} - - # ${cex(sym.ex_eq_flux(grid, a, b))}; - # strain += 2.0f * tmp * tmp; - # %endfor - # %endfor - - # // Diagonal components. - # %for a in range(0, dim): - # tmp = ${cex(sym.ex_flux(grid, 'd0', a, a, config), pointers=True)} - - # ${cex(sym.ex_eq_flux(grid, a, a))}; - # strain += tmp * tmp; - # %endfor - - # tau0 += 0.5f * (sqrtf(tau0 * tau0 + 36.0f * ${cex(smagorinsky_const**2)} * sqrtf(strain)) - tau0); - # } - # Compute strain pi_neq = _pi_vec() for a in range(_pi_dim): @@ -117,8 +103,7 @@ def functional( # Compute the Smagorinsky model _tau = self.compute_dtype(1.0) / self.compute_dtype(omega) tau = _tau + ( - self.compute_dtype(0.5) - * (wp.sqrt(_tau * _tau + self.compute_dtype(36.0) * (_smagorinsky_coef**2.0) * wp.sqrt(strain)) - _tau) + self.compute_dtype(0.5) * (wp.sqrt(_tau * _tau + self.compute_dtype(36.0) * (_smagorinsky_coef**2.0) * wp.sqrt(strain)) - _tau) ) # Compute the collision @@ -130,8 +115,6 @@ def functional( def kernel( f: wp.array4d(dtype=Any), feq: wp.array4d(dtype=Any), - rho: wp.array4d(dtype=Any), - u: wp.array4d(dtype=Any), fout: wp.array4d(dtype=Any), omega: wp.float32, ): @@ -145,13 +128,9 @@ def kernel( for l in range(self.velocity_set.q): _f[l] = f[l, index[0], index[1], index[2]] _feq[l] = feq[l, index[0], index[1], index[2]] - _u = _u_vec() - for l in range(_d): - _u[l] = u[l, index[0], index[1], index[2]] - _rho = rho[0, index[0], index[1], index[2]] # Compute the collision - _fout = functional(_f, _feq, _rho, _u, omega) + _fout = functional(_f, _feq, omega) # Write the result for l in range(self.velocity_set.q): @@ -160,15 +139,13 @@ def kernel( return functional, kernel @Operator.register_backend(ComputeBackend.WARP) - def warp_implementation(self, f, feq, rho, u, fout, omega): + def warp_implementation(self, f, feq, fout, omega): # Launch the warp kernel wp.launch( self.warp_kernel, inputs=[ f, feq, - rho, - u, fout, omega, ], diff --git a/xlb/operator/equilibrium/__init__.py b/xlb/operator/equilibrium/__init__.py index 987aa74a..beb7bb5e 100644 --- a/xlb/operator/equilibrium/__init__.py +++ b/xlb/operator/equilibrium/__init__.py @@ -1 +1,3 @@ -from xlb.operator.equilibrium.quadratic_equilibrium import Equilibrium, QuadraticEquilibrium +from xlb.operator.equilibrium.equilibrium import Equilibrium +from xlb.operator.equilibrium.quadratic_equilibrium import QuadraticEquilibrium +from xlb.operator.equilibrium.multires_quadratic_equilibrium import MultiresQuadraticEquilibrium diff --git a/xlb/operator/equilibrium/multires_quadratic_equilibrium.py b/xlb/operator/equilibrium/multires_quadratic_equilibrium.py new file mode 100644 index 00000000..0a3bb959 --- /dev/null +++ b/xlb/operator/equilibrium/multires_quadratic_equilibrium.py @@ -0,0 +1,75 @@ +""" +Multi-resolution quadratic equilibrium operator for the Neon backend. +""" + +import warp as wp +from typing import Any +from xlb.compute_backend import ComputeBackend +from xlb.operator.equilibrium import QuadraticEquilibrium +from xlb.operator import Operator + + +class MultiresQuadraticEquilibrium(QuadraticEquilibrium): + """Quadratic equilibrium operator for multi-resolution grids (Neon only). + + Computes the second-order Hermite-polynomial equilibrium distribution + from density and velocity at every active cell on each grid level. + Cells that have child refinement (halo cells) are zeroed out. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]: + raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.") + + def _construct_neon(self): + import neon + + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + + # Set local constants TODO: This is a hack and should be fixed with warp update + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + + @neon.Container.factory(name="QuadraticEquilibrium") + def container( + rho: Any, + u: Any, + f: Any, + level, + ): + def quadratic_equilibrium_ll(loader: neon.Loader): + loader.set_mres_grid(rho.get_grid(), level) + + rho_pn = loader.get_mres_read_handle(rho) + u_pn = loader.get_mres_read_handle(u) + f_pn = loader.get_mres_write_handle(f) + + @wp.func + def quadratic_equilibrium_cl(index: Any): + _u = _u_vec() + for d in range(self.velocity_set.d): + _u[d] = self.compute_dtype(wp.neon_read(u_pn, index, d)) + _rho = self.compute_dtype(wp.neon_read(rho_pn, index, 0)) + feq = functional(_rho, _u) + + if wp.neon_has_child(f_pn, index): + for l in range(self.velocity_set.q): + feq[l] = self.compute_dtype(0.0) + # Set the output + for l in range(self.velocity_set.q): + wp.neon_write(f_pn, index, l, self.store_dtype(feq[l])) + + loader.declare_kernel(quadratic_equilibrium_cl) + + return quadratic_equilibrium_ll + + return functional, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, rho, u, f, stream=0): + grid = f.get_grid() + for level in range(grid.num_levels): + c = self.neon_container(rho, u, f, level) + c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon) + return f diff --git a/xlb/operator/equilibrium/quadratic_equilibrium.py b/xlb/operator/equilibrium/quadratic_equilibrium.py index f3572d74..5ec14618 100644 --- a/xlb/operator/equilibrium/quadratic_equilibrium.py +++ b/xlb/operator/equilibrium/quadratic_equilibrium.py @@ -2,10 +2,12 @@ import jax.numpy as jnp from jax import jit import warp as wp +import os + from typing import Any from xlb.compute_backend import ComputeBackend -from xlb.operator.equilibrium.equilibrium import Equilibrium +from xlb.operator.equilibrium import Equilibrium from xlb.operator import Operator @@ -15,6 +17,9 @@ class QuadraticEquilibrium(Equilibrium): Standard equilibrium model for LBM. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) def jax_implementation(self, rho, u): @@ -96,3 +101,48 @@ def warp_implementation(self, rho, u, f): dim=rho.shape[1:], ) return f + + def _construct_neon(self): + import neon + + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + + # Set local constants TODO: This is a hack and should be fixed with warp update + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + + @neon.Container.factory(name="QuadraticEquilibrium") + def container( + rho: Any, + u: Any, + f: Any, + ): + def quadratic_equilibrium_ll(loader: neon.Loader): + loader.set_grid(rho.get_grid()) + rho_pn = loader.get_read_handle(rho) + u_pn = loader.get_read_handle(u) + f_pn = loader.get_write_handle(f) + + @wp.func + def quadratic_equilibrium_cl(index: typing.Any): + _u = _u_vec() + for d in range(self.velocity_set.d): + _u[d] = self.compute_dtype(wp.neon_read(u_pn, index, d)) + _rho = self.compute_dtype(wp.neon_read(rho_pn, index, 0)) + feq = functional(_rho, _u) + + # Set the output + for l in range(self.velocity_set.q): + wp.neon_write(f_pn, index, l, self.store_dtype(feq[l])) + + loader.declare_kernel(quadratic_equilibrium_cl) + + return quadratic_equilibrium_ll + + return functional, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, rho, u, f): + c = self.neon_container(rho, u, f) + c.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + return f diff --git a/xlb/operator/force/__init__.py b/xlb/operator/force/__init__.py index ba8a13c3..f3ceec57 100644 --- a/xlb/operator/force/__init__.py +++ b/xlb/operator/force/__init__.py @@ -1,2 +1,3 @@ from xlb.operator.force.momentum_transfer import MomentumTransfer from xlb.operator.force.exact_difference_force import ExactDifference +from xlb.operator.force.multires_momentum_transfer import MultiresMomentumTransfer diff --git a/xlb/operator/force/momentum_transfer.py b/xlb/operator/force/momentum_transfer.py index 1c6255d3..05952d7d 100644 --- a/xlb/operator/force/momentum_transfer.py +++ b/xlb/operator/force/momentum_transfer.py @@ -3,6 +3,7 @@ from jax import jit, lax import warp as wp from typing import Any +from enum import Enum, auto from xlb.velocity_set.velocity_set import VelocitySet from xlb.precision_policy import PrecisionPolicy @@ -11,6 +12,112 @@ from xlb.operator.stream import Stream +# Enum used to keep track of LBM operations +class LBMOperationSequence(Enum): + """ + Note that for dense and single resolution simulations in XLB, the order of operations in the stepper is "stream-then-collide". + For MultiRes stepper however the order of operations is always "collide-then-stream" except at the finest level when the FUSION_AT_FINEST + optimization is used. + In that case the order of operations is "stream-then-collide" ONLY at the finest level. + """ + + STREAM_THEN_COLLIDE = auto() + COLLIDE_THEN_STREAM = auto() + + +class FetchPopulations(Operator): + """ + This operator is used to get the post-collision and post-streaming populations + Note that for dense and single resolution simulations in XLB, the order of operations in the stepper is "stream-then-collide". + Therefore, f_0 represents the post-collision values and post_streaming values of the current time step need to be reconstructed + by applying the streaming and boundary conditions. These populations are readily available in XLB when using multi-resolution + grids because the mres stepper relies on "collide-then-stream". + """ + + def __init__( + self, + no_slip_bc_instance, + operation_sequence: LBMOperationSequence = LBMOperationSequence.STREAM_THEN_COLLIDE, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + self.no_slip_bc_instance = no_slip_bc_instance + self.stream = Stream(velocity_set, precision_policy, compute_backend) + self.operation_sequence = operation_sequence + + if compute_backend == ComputeBackend.WARP: + self.stream_functional = self.stream.warp_functional + self.bc_functional = self.no_slip_bc_instance.warp_functional + elif compute_backend == ComputeBackend.NEON: + self.stream_functional = self.stream.neon_functional + self.bc_functional = self.no_slip_bc_instance.neon_functional + + # Call the parent constructor + super().__init__( + velocity_set, + precision_policy, + compute_backend, + ) + + @Operator.register_backend(ComputeBackend.JAX) + @partial(jit, static_argnums=(0)) + def jax_implementation(self, f_0, f_1, bc_mask, missing_mask): + # Give the input post-collision populations, streaming once and apply the BC the find post-stream values. + f_post_collision = f_0 + f_post_stream = self.stream(f_post_collision) + f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, bc_mask, missing_mask) + return f_post_collision, f_post_stream + + def _construct_warp(self): + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + @wp.func + def functional_stream_then_collide( + index: Any, + f_0: Any, + f_1: Any, + _missing_mask: Any, + ): + # Get the distribution function + f_post_collision = _f_vec() + for l in range(self.velocity_set.q): + f_post_collision[l] = self.compute_dtype(self.read_field(f_0, index, l)) + + # Apply streaming (pull method) + timestep = 0 + f_post_stream = self.stream_functional(f_0, index) + f_post_stream = self.bc_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream) + return f_post_collision, f_post_stream + + @wp.func + def functional_collide_then_stream( + index: Any, + f_0: Any, + f_1: Any, + _missing_mask: Any, + ): + # Get the distribution function + f_post_collision = _f_vec() + f_post_stream = _f_vec() + for l in range(self.velocity_set.q): + f_post_stream[l] = self.compute_dtype(self.read_field(f_0, index, l)) + f_post_collision[l] = self.compute_dtype(self.read_field(f_1, index, l)) + return f_post_collision, f_post_stream + + if self.operation_sequence == LBMOperationSequence.STREAM_THEN_COLLIDE: + return functional_stream_then_collide, None + elif self.operation_sequence == LBMOperationSequence.COLLIDE_THEN_STREAM: + return functional_collide_then_stream, None + else: + raise ValueError(f"Unknown operation sequence: {self.operation_sequence}") + + def _construct_neon(self): + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + return functional, None + + class MomentumTransfer(Operator): """ An opertor for the momentum exchange method to compute the boundary force vector exerted on the solid geometry @@ -34,12 +141,23 @@ class MomentumTransfer(Operator): def __init__( self, no_slip_bc_instance, + operation_sequence: LBMOperationSequence = LBMOperationSequence.STREAM_THEN_COLLIDE, velocity_set: VelocitySet = None, precision_policy: PrecisionPolicy = None, compute_backend: ComputeBackend = None, ): + # Assign the no-slip boundary condition instance self.no_slip_bc_instance = no_slip_bc_instance - self.stream = Stream(velocity_set, precision_policy, compute_backend) + self.operation_sequence = operation_sequence + + # Define the needed for the momentum transfer + self.fetcher = FetchPopulations( + no_slip_bc_instance=self.no_slip_bc_instance, + operation_sequence=self.operation_sequence, + velocity_set=velocity_set, + precision_policy=precision_policy, + compute_backend=compute_backend, + ) # Call the parent constructor super().__init__( @@ -48,6 +166,11 @@ def __init__( compute_backend, ) + if self.compute_backend != ComputeBackend.JAX: + # Allocate the force vector (the total integral value will be computed) + _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) + self.force = wp.zeros((1), dtype=_u_vec) + @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0)) def jax_implementation(self, f_0, f_1, bc_mask, missing_mask): @@ -71,9 +194,7 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask): The force exerted on the solid geometry at each boundary node. """ # Give the input post-collision populations, streaming once and apply the BC the find post-stream values. - f_post_collision = f_0 - f_post_stream = self.stream(f_post_collision) - f_post_stream = self.no_slip_bc_instance(f_post_collision, f_post_stream, bc_mask, missing_mask) + f_post_collision, f_post_stream = self.fetcher(f_0, f_1, bc_mask, missing_mask) # Compute momentum transfer boundary = bc_mask == self.no_slip_bc_instance.id @@ -90,61 +211,42 @@ def jax_implementation(self, f_0, f_1, bc_mask, missing_mask): return force_net def _construct_warp(self): - # Set local constants TODO: This is a hack and should be fixed with warp update + # Set local constants _c = self.velocity_set.c _opp_indices = self.velocity_set.opp_indices - _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) # TODO fix vec bool + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) _no_slip_id = self.no_slip_bc_instance.id - # Find velocity index for 0, 0, 0 - for l in range(self.velocity_set.q): - if _c[0, l] == 0 and _c[1, l] == 0 and _c[2, l] == 0: - zero_index = l - _zero_index = wp.int32(zero_index) + # Find velocity index for (0, 0, 0) + lattice_central_index = self.velocity_set.center_index - # Construct the warp kernel - @wp.kernel - def kernel( - f_0: wp.array4d(dtype=Any), - f_1: wp.array4d(dtype=Any), - bc_mask: wp.array4d(dtype=wp.uint8), - missing_mask: wp.array4d(dtype=wp.bool), - force: wp.array(dtype=Any), + @wp.func + def functional( + index: Any, + f_0: Any, + f_1: Any, + bc_mask: Any, + missing_mask: Any, + force: Any, ): - # Get the global index - i, j, k = wp.tid() - index = wp.vec3i(i, j, k) - # Get the boundary id - _boundary_id = bc_mask[0, index[0], index[1], index[2]] + _boundary_id = self.read_field(bc_mask, index, 0) _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): - # TODO fix vec bool - if missing_mask[l, index[0], index[1], index[2]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + _missing_mask[l] = self.read_field(missing_mask, index, l) # Determin if boundary is an edge by checking if center is missing is_edge = wp.bool(False) if _boundary_id == wp.uint8(_no_slip_id): - if _missing_mask[_zero_index] == wp.uint8(0): + if _missing_mask[lattice_central_index] == wp.uint8(0): is_edge = wp.bool(True) # If the boundary is an edge then add the momentum transfer m = _u_vec() if is_edge: - # Get the distribution function - f_post_collision = _f_vec() - for l in range(self.velocity_set.q): - f_post_collision[l] = f_0[l, index[0], index[1], index[2]] - - # Apply streaming (pull method) - timestep = 0 - f_post_stream = self.stream.warp_functional(f_0, index) - f_post_stream = self.no_slip_bc_instance.warp_functional(index, timestep, _missing_mask, f_0, f_1, f_post_collision, f_post_stream) + # fetch the post-collision and post-streaming populations + f_post_collision, f_post_stream = self.fetcher_functional(index, f_0, f_1, _missing_mask) # Compute the momentum transfer for d in range(self.velocity_set.d): @@ -156,21 +258,105 @@ def kernel( m[d] += phi elif _c[d, _opp_indices[l]] == -1: m[d] -= phi - + # Atomic sum to get the total force vector wp.atomic_add(force, 0, m) - return None, kernel + # Construct the warp kernel + @wp.kernel + def kernel( + f_0: wp.array4d(dtype=Any), + f_1: wp.array4d(dtype=Any), + bc_mask: wp.array4d(dtype=wp.uint8), + missing_mask: wp.array4d(dtype=wp.uint8), + force: wp.array(dtype=Any), + ): + # Get the global index + i, j, k = wp.tid() + index = wp.vec3i(i, j, k) + + # Call the functional to compute the force + functional( + index, + f_0, + f_1, + bc_mask, + missing_mask, + force, + ) + + return functional, kernel @Operator.register_backend(ComputeBackend.WARP) def warp_implementation(self, f_0, f_1, bc_mask, missing_mask): - # Allocate the force vector (the total integral value will be computed) - _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) - force = wp.zeros((1), dtype=_u_vec) + # Ensure the force is initialized to zero + self.force *= self.compute_dtype(0.0) + + # Define the warp functionals needed for this operation + self.fetcher_functional = self.fetcher.warp_functional # Launch the warp kernel wp.launch( self.warp_kernel, - inputs=[f_0, f_1, bc_mask, missing_mask, force], + inputs=[f_0, f_1, bc_mask, missing_mask, self.force], dim=f_0.shape[1:], ) - return force.numpy()[0] + return self.force.numpy()[0] + + def _construct_neon(self): + import neon + + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + + @neon.Container.factory(name="MomentumTransfer") + def container( + f_0: Any, + f_1: Any, + bc_mask: Any, + missing_mask: Any, + force: Any, + ): + def container_launcher(loader: neon.Loader): + loader.set_grid(bc_mask.get_grid()) + bc_mask_pn = loader.get_write_handle(bc_mask) + missing_mask_pn = loader.get_write_handle(missing_mask) + f_0_pn = loader.get_write_handle(f_0) + f_1_pn = loader.get_write_handle(f_1) + + @wp.func + def container_kernel(index: Any): + # apply the functional + functional( + index, + f_0_pn, + f_1_pn, + bc_mask_pn, + missing_mask_pn, + force, + ) + + loader.declare_kernel(container_kernel) + + return container_launcher + + return functional, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation( + self, + f_0, + f_1, + bc_mask, + missing_mask, + stream=0, + ): + # Ensure the force is initialized to zero + self.force *= self.compute_dtype(0.0) + + # Define the neon functionals needed for this operation + self.fetcher_functional = self.fetcher.neon_functional + + # Launch the neon container + c = self.neon_container(f_0, f_1, bc_mask, missing_mask, self.force) + c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon) + return self.force.numpy()[0] diff --git a/xlb/operator/force/multires_momentum_transfer.py b/xlb/operator/force/multires_momentum_transfer.py new file mode 100644 index 00000000..f7e04a43 --- /dev/null +++ b/xlb/operator/force/multires_momentum_transfer.py @@ -0,0 +1,137 @@ +""" +Multi-resolution momentum-transfer force operator for the Neon backend. +""" + +from typing import Any + +import warp as wp + +from xlb.velocity_set.velocity_set import VelocitySet +from xlb.precision_policy import PrecisionPolicy +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.force import MomentumTransfer +from xlb.mres_perf_optimization_type import MresPerfOptimizationType + + +class MultiresMomentumTransfer(MomentumTransfer): + """Momentum-transfer force computation on a multi-resolution grid. + + Extends :class:`MomentumTransfer` with Neon-specific container code that + iterates over all grid levels. The LBM operation sequence (collide-then- + stream vs. stream-then-collide) is inferred from the performance + optimization type. + + Parameters + ---------- + no_slip_bc_instance : BoundaryCondition + The no-slip BC whose tagged voxels define the force integration + surface. + mres_perf_opt : MresPerfOptimizationType + Multi-resolution performance strategy. + velocity_set : VelocitySet, optional + precision_policy : PrecisionPolicy, optional + compute_backend : ComputeBackend, optional + """ + + def __init__( + self, + no_slip_bc_instance, + mres_perf_opt=MresPerfOptimizationType.NAIVE_COLLIDE_STREAM, + velocity_set: VelocitySet = None, + precision_policy: PrecisionPolicy = None, + compute_backend: ComputeBackend = None, + ): + from xlb.operator.force.momentum_transfer import LBMOperationSequence + + if compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]: + raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {compute_backend} backend.") + + # Set the sequence of operations based on the performance optimization type + if mres_perf_opt == MresPerfOptimizationType.NAIVE_COLLIDE_STREAM: + operation_sequence = LBMOperationSequence.COLLIDE_THEN_STREAM + elif mres_perf_opt in ( + MresPerfOptimizationType.FUSION_AT_FINEST, + MresPerfOptimizationType.FUSION_AT_FINEST_SFV, + MresPerfOptimizationType.FUSION_AT_FINEST_SFV_ALL, + ): + operation_sequence = LBMOperationSequence.STREAM_THEN_COLLIDE + else: + raise ValueError(f"Unknown performance optimization type: {mres_perf_opt}") + + # Check if the performance optimization type is compatible with the use of mesh distance + if operation_sequence != LBMOperationSequence.STREAM_THEN_COLLIDE: + assert not no_slip_bc_instance.needs_mesh_distance, ( + "Mesh distance is only supported in the MultiresMomentumTransfer operator when the LBM operation sequence is STREAM_THEN_COLLIDE." + ) + + # Print a warning to the user about the boundary voxels + print( + "WARNING! make sure boundary voxels are all at the same level and not among the transition regions from one level to another. " + "Otherwise, the results of force calculation are not correct!\n" + ) + + # Call super + super().__init__(no_slip_bc_instance, operation_sequence, velocity_set, precision_policy, compute_backend) + + def _construct_neon(self): + import neon + + # Use the warp functional for the NEON backend + functional, _ = self._construct_warp() + + @neon.Container.factory(name="MomentumTransfer") + def container( + f_0: Any, + f_1: Any, + bc_mask: Any, + missing_mask: Any, + force: Any, + level: Any, + ): + def container_launcher(loader: neon.Loader): + loader.set_mres_grid(bc_mask.get_grid(), level) + bc_mask_pn = loader.get_mres_write_handle(bc_mask) + missing_mask_pn = loader.get_mres_write_handle(missing_mask) + f_0_pn = loader.get_mres_write_handle(f_0) + f_1_pn = loader.get_mres_write_handle(f_1) + + @wp.func + def container_kernel(index: Any): + # apply the functional + functional( + index, + f_0_pn, + f_1_pn, + bc_mask_pn, + missing_mask_pn, + force, + ) + + loader.declare_kernel(container_kernel) + + return container_launcher + + return functional, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation( + self, + f_0, + f_1, + bc_mask, + missing_mask, + stream=0, + ): + # Ensure the force is initialized to zero + self.force *= self.compute_dtype(0.0) + + # Define the neon functionals needed for this operation + self.fetcher_functional = self.fetcher.neon_functional + + grid = bc_mask.get_grid() + for level in range(grid.num_levels): + # Launch the neon container + c = self.neon_container(f_0, f_1, bc_mask, missing_mask, self.force, level) + c.run(stream, container_runtime=neon.Container.ContainerRuntime.neon) + return self.force.numpy()[0] diff --git a/xlb/operator/macroscopic/__init__.py b/xlb/operator/macroscopic/__init__.py index 75dec9ea..75eacee6 100644 --- a/xlb/operator/macroscopic/__init__.py +++ b/xlb/operator/macroscopic/__init__.py @@ -2,3 +2,4 @@ from xlb.operator.macroscopic.second_moment import SecondMoment from xlb.operator.macroscopic.zero_moment import ZeroMoment from xlb.operator.macroscopic.first_moment import FirstMoment +from xlb.operator.macroscopic.multires_macroscopic import MultiresMacroscopic diff --git a/xlb/operator/macroscopic/first_moment.py b/xlb/operator/macroscopic/first_moment.py index cb99a9ff..626767fd 100644 --- a/xlb/operator/macroscopic/first_moment.py +++ b/xlb/operator/macroscopic/first_moment.py @@ -23,17 +23,31 @@ def _construct_warp(self): _u_vec = wp.vec(self.velocity_set.d, dtype=self.compute_dtype) @wp.func - def functional( - f: _f_vec, - rho: Any, - ): - u = _u_vec() + def neumaier_sum_component(d: int, f: _f_vec): + total = self.compute_dtype(0.0) + compensation = self.compute_dtype(0.0) for l in range(self.velocity_set.q): - for d in range(self.velocity_set.d): - if _c[d, l] == 1: - u[d] += f[l] - elif _c[d, l] == -1: - u[d] -= f[l] + # Get contribution based on the sign of _c[d, l] + if _c[d, l] == 1: + val = f[l] + elif _c[d, l] == -1: + val = -f[l] + else: + val = self.compute_dtype(0.0) + t = total + val + if wp.abs(total) >= wp.abs(val): + compensation = compensation + ((total - t) + val) + else: + compensation = compensation + ((val - t) + total) + total = t + return total + compensation + + @wp.func + def functional(f: _f_vec, rho: Any): + u = _u_vec() + # Use Neumaier summation for each spatial component + for d in range(self.velocity_set.d): + u[d] = neumaier_sum_component(d, f) u /= rho return u @@ -65,3 +79,12 @@ def warp_implementation(self, f, rho, u): dim=u.shape[1:], ) return u + + def _construct_neon(self): + functional, _ = self._construct_warp() + return functional, None + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f, rho): + # raise exception as this feature is not implemented yet + raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.") diff --git a/xlb/operator/macroscopic/macroscopic.py b/xlb/operator/macroscopic/macroscopic.py index ab1193b0..971081c9 100644 --- a/xlb/operator/macroscopic/macroscopic.py +++ b/xlb/operator/macroscopic/macroscopic.py @@ -20,20 +20,18 @@ def __init__(self, *args, **kwargs): @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0), inline=True) - def jax_implementation(self, f): + def jax_implementation(self, f, rho=None, u=None): rho = self.zero_moment(f) u = self.first_moment(f, rho) return rho, u def _construct_warp(self): - zero_moment_func = self.zero_moment.warp_functional - first_moment_func = self.first_moment.warp_functional _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) @wp.func def functional(f: _f_vec): - rho = zero_moment_func(f) - u = first_moment_func(f, rho) + rho = self.zero_moment.warp_functional(f) + u = self.first_moment.warp_functional(f, rho) return rho, u @wp.kernel @@ -64,3 +62,52 @@ def warp_implementation(self, f, rho, u): dim=rho.shape[1:], ) return rho, u + + def _construct_neon(self): + import neon + + # Redefine the zero and first moment operators for the neon backend + # This is because the neon backend relies on the warp functionals for its operations. + self.zero_moment = ZeroMoment(compute_backend=ComputeBackend.WARP) + self.first_moment = FirstMoment(compute_backend=ComputeBackend.WARP) + functional, _ = self._construct_warp() + + # Set local vectors + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + @neon.Container.factory("macroscopic") + def container( + f_field: Any, + rho_field: Any, + u_fild: Any, + ): + _d = self.velocity_set.d + + def macroscopic_ll(loader: neon.Loader): + loader.set_grid(f_field.get_grid()) + + rho = loader.get_read_handle(rho_field) + u = loader.get_read_handle(u_fild) + f = loader.get_read_handle(f_field) + + @wp.func + def macroscopic_cl(gIdx: typing.Any): + _f = _f_vec() + for l in range(self.velocity_set.q): + _f[l] = self.compute_dtype(wp.neon_read(f, gIdx, l)) + _rho, _u = functional(_f) + wp.neon_write(rho, gIdx, 0, self.store_dtype(_rho)) + for d in range(_d): + wp.neon_write(u, gIdx, d, self.store_dtype(_u[d])) + + loader.declare_kernel(macroscopic_cl) + + return macroscopic_ll + + return functional, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f, rho, u): + c = self.neon_container(f, rho, u) + c.run(0) + return rho, u diff --git a/xlb/operator/macroscopic/multires_macroscopic.py b/xlb/operator/macroscopic/multires_macroscopic.py new file mode 100644 index 00000000..b7df10c7 --- /dev/null +++ b/xlb/operator/macroscopic/multires_macroscopic.py @@ -0,0 +1,89 @@ +""" +Multi-resolution macroscopic moment computation for the Neon backend. +""" + +from functools import partial +import jax.numpy as jnp +from jax import jit +import warp as wp +from typing import Any + +from xlb.compute_backend import ComputeBackend +from xlb.operator.operator import Operator +from xlb.operator.macroscopic import Macroscopic, ZeroMoment, FirstMoment +from xlb.cell_type import BC_SOLID + + +class MultiresMacroscopic(Macroscopic): + """Compute density and velocity on a multi-resolution grid (Neon only). + + Iterates over all grid levels, computing zero-th and first moments of + the distribution function. Solid voxels and voxels that have child + refinement (halo cells) are set to zero. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.compute_backend in [ComputeBackend.JAX, ComputeBackend.WARP]: + raise NotImplementedError(f"Operator {self.__class__.__name__} not supported in {self.compute_backend} backend.") + + def _construct_neon(self): + import neon + + # Redefine the zero and first moment operators for the neon backend + # This is because the neon backend relies on the warp functionals for its operations. + self.zero_moment = ZeroMoment(compute_backend=ComputeBackend.WARP) + self.first_moment = FirstMoment(compute_backend=ComputeBackend.WARP) + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + functional, _ = self._construct_warp() + + @neon.Container.factory("macroscopic") + def container( + level: int, + f_field: Any, + bc_mask: Any, + rho_field: Any, + u_fild: Any, + ): + _d = self.velocity_set.d + + def macroscopic_ll(loader: neon.Loader): + loader.set_mres_grid(f_field.get_grid(), level) + + rho = loader.get_mres_write_handle(rho_field) + u = loader.get_mres_write_handle(u_fild) + f = loader.get_mres_read_handle(f_field) + bc_mask_pn = loader.get_mres_read_handle(bc_mask) + + @wp.func + def macroscopic_cl(gIdx: typing.Any): + _f = _f_vec() + _boundary_id = wp.neon_read(bc_mask_pn, gIdx, 0) + + for l in range(self.velocity_set.q): + _f[l] = self.compute_dtype(wp.neon_read(f, gIdx, l)) + + _rho, _u = functional(_f) + + if _boundary_id == wp.uint8(BC_SOLID) or wp.neon_has_child(f, gIdx): + _rho = self.compute_dtype(0.0) + for d in range(_d): + _u[d] = self.compute_dtype(0.0) + + wp.neon_write(rho, gIdx, 0, self.store_dtype(_rho)) + for d in range(_d): + wp.neon_write(u, gIdx, d, self.store_dtype(_u[d])) + + loader.declare_kernel(macroscopic_cl) + + return macroscopic_ll + + return functional, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f, bc_mask, rho, u, streamId=0): + grid = f.get_grid() + for level in range(grid.num_levels): + c = self.neon_container(level, f, bc_mask, rho, u) + c.run(streamId) + return rho, u diff --git a/xlb/operator/macroscopic/second_moment.py b/xlb/operator/macroscopic/second_moment.py index 6c7e70ea..1a0a0f07 100644 --- a/xlb/operator/macroscopic/second_moment.py +++ b/xlb/operator/macroscopic/second_moment.py @@ -104,3 +104,12 @@ def warp_implementation(self, f, pi): # Launch the warp kernel wp.launch(self.warp_kernel, inputs=[f, pi], dim=pi.shape[1:]) return pi + + def _construct_neon(self): + functional, _ = self._construct_warp() + return functional, None + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f, rho): + # raise exception as this feature is not implemented yet + raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.") diff --git a/xlb/operator/macroscopic/zero_moment.py b/xlb/operator/macroscopic/zero_moment.py index 8abb4de7..f536f8d7 100644 --- a/xlb/operator/macroscopic/zero_moment.py +++ b/xlb/operator/macroscopic/zero_moment.py @@ -20,11 +20,23 @@ def _construct_warp(self): _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) @wp.func - def functional(f: _f_vec): - rho = self.compute_dtype(0.0) + def neumaier_sum(f: _f_vec): + total = self.compute_dtype(0.0) + compensation = self.compute_dtype(0.0) for l in range(self.velocity_set.q): - rho += f[l] - return rho + x = f[l] + t = total + x + # Using wp.abs to compute absolute value + if wp.abs(total) >= wp.abs(x): + compensation = compensation + ((total - t) + x) + else: + compensation = compensation + ((x - t) + total) + total = t + return total + compensation + + @wp.func + def functional(f: _f_vec): + return neumaier_sum(f) @wp.kernel def kernel( @@ -47,3 +59,12 @@ def kernel( def warp_implementation(self, f, rho): wp.launch(self.warp_kernel, inputs=[f, rho], dim=rho.shape[1:]) return rho + + def _construct_neon(self): + functional, _ = self._construct_warp() + return functional, None + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f, rho): + # raise exception as this feature is not implemented yet + raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.") diff --git a/xlb/operator/operator.py b/xlb/operator/operator.py index fcbd07b0..4405708c 100644 --- a/xlb/operator/operator.py +++ b/xlb/operator/operator.py @@ -1,6 +1,22 @@ +""" +Base operator module for XLB. + +Every LBM operator (collision, streaming, equilibrium, boundary condition, +masker, stepper, etc.) inherits from :class:`Operator`. The class provides: + +* **Backend dispatch** — ``__call__`` automatically selects the registered + implementation for the active compute backend. +* **Precision management** — ``compute_dtype`` and ``store_dtype`` properties + return the correct type for the active backend and precision policy. +* **Kernel construction hooks** — ``_construct_warp()`` / ``_construct_neon()`` + are called at init time to compile backend-specific kernels and functionals. +""" + import inspect import traceback import jax +import warp as wp +from typing import Any from xlb.compute_backend import ComputeBackend from xlb import DefaultConfig @@ -17,6 +33,17 @@ class Operator: _backends = {} def __init__(self, velocity_set=None, precision_policy=None, compute_backend=None): + """Initialize the operator. + + Parameters + ---------- + velocity_set : VelocitySet, optional + Lattice velocity set. Defaults to ``DefaultConfig.velocity_set``. + precision_policy : PrecisionPolicy, optional + Precision policy. Defaults to ``DefaultConfig.default_precision_policy``. + compute_backend : ComputeBackend, optional + Compute backend. Defaults to ``DefaultConfig.default_backend``. + """ # Set the default values from the global config self.velocity_set = velocity_set or DefaultConfig.velocity_set self.precision_policy = precision_policy or DefaultConfig.default_precision_policy @@ -26,10 +53,18 @@ def __init__(self, velocity_set=None, precision_policy=None, compute_backend=Non if self.compute_backend not in ComputeBackend: raise ValueError(f"Compute_backend {compute_backend} is not supported") + # Construct read/write functions for the compute backend + if self.compute_backend in [ComputeBackend.WARP, ComputeBackend.NEON]: + self.read_field, self.write_field = self._construct_read_write_functions() + self.read_field_neighbor = self._construct_read_field_neighbor() + # Construct the kernel based compute_backend functions TODO: Maybe move this to the register or something if self.compute_backend == ComputeBackend.WARP: self.warp_functional, self.warp_kernel = self._construct_warp() + if self.compute_backend == ComputeBackend.NEON: + self.neon_functional, self.neon_container = self._construct_neon() + # Updating JAX config in case fp64 is requested if self.compute_backend == ComputeBackend.JAX and ( precision_policy == PrecisionPolicy.FP64FP64 or precision_policy == PrecisionPolicy.FP64FP32 @@ -52,6 +87,20 @@ def decorator(func): return decorator def __call__(self, *args, callback=None, **kwargs): + """Dispatch to the registered backend implementation. + + Iterates over all registered implementations for this operator class + and the active backend, attempts to bind the provided arguments, and + executes the first matching signature. An optional *callback* is + invoked with the result after successful execution. + + Raises + ------ + NotImplementedError + If no implementation is registered for the active backend. + Exception + If all candidate implementations raise errors. + """ method_candidates = [ (key, method) for key, method in self._backends.items() if key[0] == self.__class__.__name__ and key[1] == self.compute_backend ] @@ -80,7 +129,7 @@ def __call__(self, *args, callback=None, **kwargs): error = e traceback_str = traceback.format_exc() continue # This skips to the next candidate if binding fails - + method_candidates = [(key, method) for key, method in self._backends.items() if key[1] == self.compute_backend] raise Exception(f"Error captured for backend with key {key} for operator {self.__class__.__name__}: {error}\n {traceback_str}") @property @@ -123,6 +172,8 @@ def compute_dtype(self): return self.precision_policy.compute_precision.jax_dtype elif self.compute_backend == ComputeBackend.WARP: return self.precision_policy.compute_precision.wp_dtype + elif self.compute_backend == ComputeBackend.NEON: + return self.precision_policy.compute_precision.wp_dtype @property def store_dtype(self): @@ -133,6 +184,20 @@ def store_dtype(self): return self.precision_policy.store_precision.jax_dtype elif self.compute_backend == ComputeBackend.WARP: return self.precision_policy.store_precision.wp_dtype + elif self.compute_backend == ComputeBackend.NEON: + return self.precision_policy.store_precision.wp_dtype + + def get_precision_policy(self): + """ + Returns the precision policy + """ + return self.precision_policy + + def get_grid(self): + """ + Returns the grid object + """ + return self.grid def _construct_warp(self): """ @@ -142,3 +207,110 @@ def _construct_warp(self): Leave it for now, as it is not clear how the warp compute backend will evolve """ return None, None + + def _construct_neon(self): + """ + Construct the Neon functional and Neon container of the operator + TODO: Maybe a better way to do this? + Maybe add this to the backend decorator? + Leave it for now, as it is not clear how the neon backend will evolve + """ + return None, None + + def _construct_read_write_functions(self): + """Build backend-specific ``read_field`` / ``write_field`` helpers. + + For the Warp backend these are direct 4-D array accesses. For the + Neon backend they wrap ``wp.neon_read`` / ``wp.neon_write``. + + Returns + ------- + tuple of wp.func + ``(read_field, write_field)`` + """ + if self.compute_backend == ComputeBackend.WARP: + + @wp.func + def read_field( + field: Any, + index: Any, + direction: Any, + ): + # This function reads a field value at a given index and direction. + return field[direction, index[0], index[1], index[2]] + + @wp.func + def write_field( + field: Any, + index: Any, + direction: Any, + value: Any, + ): + # This function writes a value to a field at a given index and direction. + field[direction, index[0], index[1], index[2]] = value + + elif self.compute_backend == ComputeBackend.NEON: + import neon + + @wp.func + def read_field( + field: Any, + index: Any, + direction: Any, + ): + # This function reads a field value at a given index and direction. + return wp.neon_read(field, index, direction) + + @wp.func + def write_field( + field: Any, + index: Any, + direction: Any, + value: Any, + ): + # This function writes a value to a field at a given index and direction. + wp.neon_write(field, index, direction, value) + + else: + raise ValueError(f"Unsupported compute backend: {self.compute_backend}") + + return read_field, write_field + + def _construct_read_field_neighbor(self): + """ + Construct a function to read a field value at a neighboring index along a given direction. + """ + + if self.compute_backend == ComputeBackend.WARP: + + @wp.func + def read_field_neighbor( + field: Any, + index: Any, + offset: Any, + direction: Any, + ): + # This function reads a field value at a given neighboring index and direction. + neighbor = index + offset + return field[direction, neighbor[0], neighbor[1], neighbor[2]] + + elif self.compute_backend == ComputeBackend.NEON: + import neon + # from neon.multires.mPartition import neon_get_type + + @wp.func + def read_field_neighbor( + field: Any, + index: Any, + offset: Any, + direction: Any, + ): + # This function reads a field value at a given neighboring index and direction. + unused_is_valid = wp.bool(False) + # dtype = neon_get_type(field) # This is a placeholder to ensure the dtype is set correctly + return wp.neon_read_ngh(field, index, offset, direction, wp.uint8(0.0), unused_is_valid) + + else: + raise ValueError(f"Unsupported compute backend: {self.compute_backend}") + + return read_field_neighbor diff --git a/xlb/operator/stepper/__init__.py b/xlb/operator/stepper/__init__.py index 1eab1668..87c1274f 100644 --- a/xlb/operator/stepper/__init__.py +++ b/xlb/operator/stepper/__init__.py @@ -1,3 +1,4 @@ from xlb.operator.stepper.stepper import Stepper from xlb.operator.stepper.nse_stepper import IncompressibleNavierStokesStepper +from xlb.operator.stepper.nse_multires_stepper import MultiresIncompressibleNavierStokesStepper from xlb.operator.stepper.ibm_stepper import IBMStepper diff --git a/xlb/operator/stepper/nse_multires_stepper.py b/xlb/operator/stepper/nse_multires_stepper.py new file mode 100644 index 00000000..629f1b07 --- /dev/null +++ b/xlb/operator/stepper/nse_multires_stepper.py @@ -0,0 +1,1196 @@ +""" +Multi-Resolution Navier-Stokes Stepper for the NEON Backend + +This module implements the multi-resolution LBM stepper using Warp kernels on the +Neon multi-GPU runtime. It uses several programming patterns specific to Warp's +compile-time code generation model. + +Compile-Time Specialization Pattern +----------------------------------- +Warp's @wp.func decorator traces Python code at kernel compilation time, not runtime. +This means runtime boolean parameters cause Warp to emit branching code for both paths, +increasing register pressure even when only one path is ever taken. + +To generate optimized, branch-free kernels, we use a **factory pattern** that captures +boolean configuration at function-definition time: + + def make_specialized_func(do_feature: bool): + @wp.func + def impl(...): + if wp.static(do_feature): # Evaluated at compile time + # This code is only emitted when do_feature=True + ... + else: + # This code is only emitted when do_feature=False + ... + return impl + + # Generate specialized variants + func_with_feature = make_specialized_func(do_feature=True) + func_without_feature = make_specialized_func(do_feature=False) + +The `wp.static()` call evaluates its argument during Warp's tracing phase. Since +`do_feature` is a Python bool captured in the closure, Warp sees a constant and +eliminates the dead branch entirely. + +This pattern is used for: +- `apply_bc_post_streaming` / `apply_bc_post_collision`: Specialized BC application + for streaming vs collision implementation steps +- `collide_bc_accum` / `collide_simple`: Collision pipeline variants with/without + BC application and multi-resolution accumulation + +Closure Capture for Self Attributes +----------------------------------- +Warp cannot resolve `self.X` in plain assignments inside @wp.func bodies (e.g., +`_c = self.velocity_set.c` fails with "Invalid external reference type"). However, +it can resolve `self.X` in: +- Function call contexts: `self.stream.neon_functional(...)` +- Range arguments: `range(self.velocity_set.q)` +- Type casts: `self.compute_dtype(0)` + +For other uses, we pre-capture attributes at the Python level before defining the +@wp.func, making them available as simple closure variables: + + _c = self.velocity_set.c # Captured in Python scope + + @wp.func + def my_kernel(...): + # Use _c directly — Warp sees it as a closure variable + direction = wp.neon_ngh_idx(wp.int8(_c[0, l]), ...) + +Cell Type Constants +------------------- +Cell types are defined in `xlb.cell_type`: +- BC_SFV (254): Simple Fluid Voxel — no BC, no explosion/coalescence +- BC_SOLID (255): Solid obstacle voxel +- BC_NONE (0): Regular fluid voxel with potential BCs or multi-res interactions +""" + +import nvtx +import warp as wp +from typing import Any + +from xlb import DefaultConfig +from xlb.compute_backend import ComputeBackend +from xlb.precision_policy import Precision +from xlb.operator import Operator +from xlb.operator.stream import Stream +from xlb.operator.collision import BGK, KBC +from xlb.operator.equilibrium import MultiresQuadraticEquilibrium +from xlb.operator.macroscopic import MultiresMacroscopic +from xlb.operator.stepper import Stepper +from xlb.operator.boundary_condition.boundary_condition import ImplementationStep +from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry +from xlb.operator.collision import ForcedCollision +from xlb.helper import check_bc_overlaps +from xlb.operator.boundary_masker import ( + MeshVoxelizationMethod, + MultiresMeshMaskerAABB, + MultiresMeshMaskerAABBClose, + MultiresIndicesBoundaryMasker, + MultiresMeshMaskerRay, +) +from xlb.operator.boundary_condition.helper_functions_bc import MultiresEncodeAuxiliaryData +from xlb.cell_type import BC_SFV, BC_SOLID + +""" +SFV = Simple Fluid Voxel: a fluid voxel that is not a BC nor is involved in explosion or coalescence +CFV = Complex Fluid Voxel: a fluid voxel that is not a SFV +""" + + +class MultiresIncompressibleNavierStokesStepper(Stepper): + """Multi-resolution incompressible Navier-Stokes stepper for the Neon backend. + + Implements the full LBM step (stream, collide, boundary conditions) across + a hierarchy of grid levels using Neon containers. Each container is a + compile-time specialized Warp kernel wrapped in a Neon execution-graph + node. + + The stepper supports several performance optimization strategies (see + :class:`MresPerfOptimizationType`): + + * **NAIVE_COLLIDE_STREAM** — separate collide and stream containers at + every level. + * **FUSION_AT_FINEST** — fused stream+collide at the finest level. + * **FUSION_AT_FINEST_SFV** — additionally splits SFV / CFV voxels at + the finest level for reduced branching. + * **FUSION_AT_FINEST_SFV_ALL** — SFV / CFV splitting at all levels. + + Parameters + ---------- + grid : NeonMultiresGrid + The multi-resolution grid. + boundary_conditions : list of BoundaryCondition + Boundary conditions to apply. + collision_type : str + Collision operator type: ``"BGK"`` or ``"KBC"``. + forcing_scheme : str + Forcing scheme name (only used when *force_vector* is given). + force_vector : array-like, optional + External body force vector. + """ + + def __init__( + self, + grid, + boundary_conditions=[], + collision_type="BGK", + forcing_scheme="exact_difference", + force_vector=None, + ): + super().__init__(grid, boundary_conditions) + + # Construct the collision operator + if collision_type == "BGK": + self.collision = BGK(self.velocity_set, self.precision_policy, self.compute_backend) + elif collision_type == "KBC": + self.collision = KBC(self.velocity_set, self.precision_policy, self.compute_backend) + + if force_vector is not None: + self.collision = ForcedCollision(collision_operator=self.collision, forcing_scheme=forcing_scheme, force_vector=force_vector) + + # Construct the operators + self.stream = Stream(self.velocity_set, self.precision_policy, self.compute_backend) + self.equilibrium = MultiresQuadraticEquilibrium(self.velocity_set, self.precision_policy, self.compute_backend) + self.macroscopic = MultiresMacroscopic(self.velocity_set, self.precision_policy, self.compute_backend) + + def prepare_fields(self, rho, u, initializer=None): + import neon + + """Prepare the fields required for the stepper. + + Args: + initializer: Optional operator to initialize the distribution functions. + If provided, it should be a callable that takes (grid, velocity_set, + precision_policy, compute_backend) as arguments and returns initialized f_0. + If None, default equilibrium initialization is used with rho=1 and u=0. + + Returns: + Tuple of (f_0, f_1, bc_mask, missing_mask): + - f_0: Initial distribution functions + - f_1: Copy of f_0 for double-buffering + - bc_mask: Boundary condition mask indicating which BC applies to each node + - missing_mask: Mask indicating which populations are missing at boundary nodes + """ + + f_0 = self.grid.create_field( + cardinality=self.velocity_set.q, dtype=self.precision_policy.store_precision, neon_memory_type=neon.MemoryType.device() + ) + + f_1 = self.grid.create_field( + cardinality=self.velocity_set.q, dtype=self.precision_policy.store_precision, neon_memory_type=neon.MemoryType.device() + ) + + missing_mask = self.grid.create_field(cardinality=self.velocity_set.q, dtype=Precision.UINT8) + bc_mask = self.grid.create_field(cardinality=1, dtype=Precision.UINT8) + + for level in range(self.grid.count_levels): + f_1.copy_from_run(level, f_0, 0) + + # Process boundary conditions and update masks + f_1, bc_mask, missing_mask = self._process_boundary_conditions(self.boundary_conditions, f_1, bc_mask, missing_mask) + # Initialize auxiliary data if needed + f_1 = self._initialize_auxiliary_data(self.boundary_conditions, f_1, bc_mask, missing_mask) + + # Initialize distribution functions if initializer is provided + if initializer is not None: + # Refer to xlb.helper.initializers for available initializers + f_0 = initializer(bc_mask, f_0) + else: + from xlb.helper.initializers import initialize_multires_eq + + f_0 = initialize_multires_eq(f_0, self.grid, self.velocity_set, self.precision_policy, self.compute_backend, rho=rho, u=u) + + return f_0, f_1, bc_mask, missing_mask + + def prepare_coalescence_count(self, coalescence_factor, bc_mask): + """Precompute coalescence weighting factors for multi-resolution streaming. + + For each non-halo voxel at every level, this method accumulates + the number of finer neighbours that contribute populations via + coalescence (child-to-parent transfer), then inverts the count + so that the streaming kernel can apply the correct averaging weight. + + Parameters + ---------- + coalescence_factor : field + Multi-resolution field to store the per-direction coalescence + weights (modified in-place). + bc_mask : field + Boundary-condition mask used to skip solid voxels. + """ + lattice_central_index = self.velocity_set.center_index + num_levels = coalescence_factor.get_grid().num_levels + + @neon.Container.factory(name="sum_kernel_by_level") + def sum_kernel_by_level(level): + def ll_coalescence_count(loader: neon.Loader): + loader.set_mres_grid(coalescence_factor.get_grid(), level) + + coalescence_factor_pn = loader.get_mres_read_handle(coalescence_factor) + bc_mask_pn = loader.get_mres_read_handle(bc_mask) + + _c = self.velocity_set.c + _w = self.velocity_set.w + + @wp.func + def cl_collide_coarse(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id == wp.uint8(BC_SOLID): + return + if not wp.neon_has_child(coalescence_factor_pn, index): + for l in range(self.velocity_set.q): + if level < num_levels - 1: + push_direction = wp.neon_ngh_idx(wp.int8(_c[0, l]), wp.int8(_c[1, l]), wp.int8(_c[2, l])) + val = self.store_dtype(1) + wp.neon_mres_lbm_store_op(coalescence_factor_pn, index, l, push_direction, val) + + loader.declare_kernel(cl_collide_coarse) + + return ll_coalescence_count + + for level in range(num_levels): + sum_kernel = sum_kernel_by_level(level) + sum_kernel.run(0) + + @neon.Container.factory(name="sum_kernel_by_level") + def invert_count(level): + def loading(loader: neon.Loader): + loader.set_mres_grid(coalescence_factor.get_grid(), level) + + coalescence_factor_pn = loader.get_mres_read_handle(coalescence_factor) + bc_mask_pn = loader.get_mres_read_handle(bc_mask) + + _c = self.velocity_set.c + _w = self.velocity_set.w + + @wp.func + def compute(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id == wp.uint8(BC_SOLID): + return + + if wp.neon_has_child(coalescence_factor_pn, index): + # we are a halo cell so we just exit + return + + for l in range(self.velocity_set.q): + if l == lattice_central_index: + continue + + pull_direction = wp.neon_ngh_idx(wp.int8(-_c[0, l]), wp.int8(-_c[1, l]), wp.int8(-_c[2, l])) + + has_ngh_at_same_level = wp.bool(False) + coalescence_factor = self.compute_dtype( + wp.neon_read_ngh(coalescence_factor_pn, index, pull_direction, l, self.store_dtype(0), has_ngh_at_same_level) + ) + + if not wp.neon_has_finer_ngh(coalescence_factor_pn, index, pull_direction): + pass + else: + # Finer neighbour exists in the pull direction (opposite of l). + # Read from the halo sitting on top of that finer neighbour. + if has_ngh_at_same_level: + # Finer ngh in pull direction: YES + # Same-level ngh: YES + # Compute coalescence factor + if coalescence_factor > self.compute_dtype(0): + coalescence_factor = self.compute_dtype(1) / (self.compute_dtype(2) * coalescence_factor) + wp.neon_write(coalescence_factor_pn, index, l, self.store_dtype(coalescence_factor)) + + loader.declare_kernel(compute) + + return loading + + for level in range(num_levels): + sum_kernel = invert_count(level) + sum_kernel.run(0) + return + + @classmethod + def _process_boundary_conditions(cls, boundary_conditions, f_1, bc_mask, missing_mask): + """Process boundary conditions and update boundary masks.""" + + # Check for boundary condition overlaps + # TODO! check_bc_overlaps(boundary_conditions, DefaultConfig.velocity_set.d, DefaultConfig.default_backend) + + # Create boundary maskers + indices_masker = MultiresIndicesBoundaryMasker( + velocity_set=DefaultConfig.velocity_set, + precision_policy=DefaultConfig.default_precision_policy, + compute_backend=DefaultConfig.default_backend, + ) + + # Split boundary conditions by type + bc_with_vertices = [bc for bc in boundary_conditions if bc.mesh_vertices is not None] + bc_with_indices = [bc for bc in boundary_conditions if bc.indices is not None] + + # Process indices-based boundary conditions + if bc_with_indices: + bc_mask, missing_mask = indices_masker(bc_with_indices, bc_mask, missing_mask) + + # Process mesh-based boundary conditions for 3D + if DefaultConfig.velocity_set.d == 3 and bc_with_vertices: + for bc in bc_with_vertices: + if bc.voxelization_method.id is MeshVoxelizationMethod("AABB").id: + mesh_masker = MultiresMeshMaskerAABB( + velocity_set=DefaultConfig.velocity_set, + precision_policy=DefaultConfig.default_precision_policy, + compute_backend=DefaultConfig.default_backend, + ) + elif bc.voxelization_method.id is MeshVoxelizationMethod("RAY").id: + mesh_masker = MultiresMeshMaskerRay( + velocity_set=DefaultConfig.velocity_set, + precision_policy=DefaultConfig.default_precision_policy, + compute_backend=DefaultConfig.default_backend, + ) + elif bc.voxelization_method.id is MeshVoxelizationMethod("AABB_CLOSE").id: + mesh_masker = MultiresMeshMaskerAABBClose( + velocity_set=DefaultConfig.velocity_set, + precision_policy=DefaultConfig.default_precision_policy, + compute_backend=DefaultConfig.default_backend, + close_voxels=bc.voxelization_method.options.get("close_voxels"), + ) + else: + raise ValueError(f"Unsupported voxelization method for multi-res: {bc.voxelization_method}") + # Apply the mesh masker to the boundary condition + f_1, bc_mask, missing_mask = mesh_masker(bc, f_1, bc_mask, missing_mask) + + return f_1, bc_mask, missing_mask + + @staticmethod + def _initialize_auxiliary_data(boundary_conditions, f_1, bc_mask, missing_mask): + """Initialize auxiliary data for boundary conditions that require it.""" + for bc in boundary_conditions: + if bc.needs_aux_init and not bc.is_initialized_with_aux_data: + # Create the encoder operator for storing the auxiliary data + encode_auxiliary_data = MultiresEncodeAuxiliaryData( + bc.id, + bc.num_of_aux_data, + bc.profile, + velocity_set=bc.velocity_set, + precision_policy=bc.precision_policy, + compute_backend=bc.compute_backend, + ) + + # Encode the auxiliary data in f_1 + f_1 = encode_auxiliary_data(f_1, bc_mask, missing_mask, stream=0) + bc.is_initialized_with_aux_data = True + return f_1 + + def _construct_neon(self): + # Pre-capture self attributes that Warp cannot resolve inside @wp.func bodies. + # Warp rejects `self` as an "Invalid external reference type" when it appears + # in a plain assignment (e.g. `_c = self.velocity_set.c`). Capturing here + # makes these values available as simple closure variables. + lattice_central_index = self.velocity_set.center_index + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) + _opp_indices = self.velocity_set.opp_indices + _c = self.velocity_set.c + + # Read the list of bc_to_id created upon instantiation + bc_to_id = boundary_condition_registry.bc_to_id + + # Gather IDs of ExtrapolationOutflowBC boundary conditions + extrapolation_outflow_bc_ids = [] + for bc_name, bc_id in bc_to_id.items(): + if bc_name.startswith("ExtrapolationOutflowBC"): + extrapolation_outflow_bc_ids.append(bc_id) + + # Factory for apply_bc: generates compile-time specialized variants + def make_apply_bc(is_post_streaming: bool): + @wp.func + def apply_bc_impl( + index: Any, + timestep: Any, + _boundary_id: Any, + _missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, + ): + f_result = f_post + + for i in range(wp.static(len(self.boundary_conditions))): + if wp.static(is_post_streaming): + if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.STREAMING): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].neon_functional)( + index, timestep, _missing_mask, f_0, f_1, f_pre, f_post + ) + else: + if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.COLLISION): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].neon_functional)( + index, timestep, _missing_mask, f_0, f_1, f_pre, f_post + ) + if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].assemble_auxiliary_data)( + index, timestep, _missing_mask, f_0, f_1, f_pre, f_post + ) + return f_result + + return apply_bc_impl + + # Compile-time specialized BC application variants + apply_bc_post_streaming = make_apply_bc(is_post_streaming=True) + apply_bc_post_collision = make_apply_bc(is_post_streaming=False) + + @wp.func + def neon_get_thread_data( + f0_pn: Any, + missing_mask_pn: Any, + index: Any, + ): + # Read thread data for populations + _f0_thread = _f_vec() + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # q-sized vector of pre-streaming populations + _f0_thread[l] = self.compute_dtype(wp.neon_read(f0_pn, index, l)) + _missing_mask[l] = wp.neon_read(missing_mask_pn, index, l) + + return _f0_thread, _missing_mask + + @wp.func + def neon_apply_aux_recovery_bc( + index: Any, + _boundary_id: Any, + _missing_mask: Any, + f_0_pn: Any, + f_1_pn: Any, + ): + # Note: + # In XLB, the BC auxiliary data (e.g. prescribed values of pressure or normal velocity) are stored in (i) central index of f_1 and/or + # (ii) missing directions of f_1. Some BCs may or may not need all these available storage space. This function checks whether + # the BC needs recovery of auxiliary data and then recovers the information for the next iteration (due to buffer swapping) by + # writting the values of f_1 into f_0. + + # Unroll the loop over boundary conditions + for i in range(wp.static(len(self.boundary_conditions))): + if wp.static(self.boundary_conditions[i].needs_aux_recovery): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + for l in range(self.velocity_set.q): + # Perform the swapping of data + if l == lattice_central_index: + # (i) Recover the values stored in the central index of f_1 + _f1_thread = wp.neon_read(f_1_pn, index, l) + wp.neon_write(f_0_pn, index, l, self.store_dtype(_f1_thread)) + elif _missing_mask[l] == wp.uint8(1): + # (ii) Recover the values stored in the missing directions of f_1 + _f1_thread = wp.neon_read(f_1_pn, index, _opp_indices[l]) + wp.neon_write(f_0_pn, index, _opp_indices[l], self.store_dtype(_f1_thread)) + + # Factory for neon_collide_pipeline: generates compile-time specialized variants + def make_collide_pipeline(do_bc: bool, do_accumulation: bool): + @wp.func + def collide_pipeline_impl( + index: Any, + timestep: Any, + _boundary_id: Any, + _missing_mask: Any, + f_0_pn: Any, + f_1_pn: Any, + _f_post_stream: Any, + omega: Any, + num_levels: int, + level: int, + accumulation_pn: Any, + ): + _rho, _u = self.macroscopic.neon_functional(_f_post_stream) + _feq = self.equilibrium.neon_functional(_rho, _u) + _f_post_collision = self.collision.neon_functional(_f_post_stream, _feq, omega) + + if wp.static(do_bc): + _f_post_collision = apply_bc_post_collision( + index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_stream, _f_post_collision + ) + neon_apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0_pn, f_1_pn) + + if wp.static(do_accumulation): + for l in range(self.velocity_set.q): + push_direction = wp.neon_ngh_idx(wp.int8(_c[0, l]), wp.int8(_c[1, l]), wp.int8(_c[2, l])) + if level < num_levels - 1: + wp.neon_mres_lbm_store_op(accumulation_pn, index, l, push_direction, self.store_dtype(_f_post_collision[l])) + wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_collision[l])) + else: + for l in range(self.velocity_set.q): + wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_collision[l])) + + return _f_post_collision + + return collide_pipeline_impl + + # Compile-time specialized collision pipeline variants + collide_bc_accum = make_collide_pipeline(do_bc=True, do_accumulation=True) + collide_bc_only = make_collide_pipeline(do_bc=True, do_accumulation=False) + collide_simple = make_collide_pipeline(do_bc=False, do_accumulation=False) + + @wp.func + def neon_stream_explode_coalesce( + index: Any, + f_0_pn: Any, + coalescence_factor_pn: Any, + ): + _f_post_stream = self.stream.neon_functional(f_0_pn, index) + + for l in range(self.velocity_set.q): + if l == lattice_central_index: + continue + + pull_direction = wp.neon_ngh_idx(wp.int8(-_c[0, l]), wp.int8(-_c[1, l]), wp.int8(-_c[2, l])) + + has_ngh_at_same_level = wp.bool(False) + accumulated = wp.neon_read_ngh(f_0_pn, index, pull_direction, l, self.store_dtype(0), has_ngh_at_same_level) + + if not wp.neon_has_finer_ngh(f_0_pn, index, pull_direction): + # No finer ngh in the pull direction (opposite of l) + if not has_ngh_at_same_level: + # No same-level ngh — could we have a coarser-level ngh? + if wp.neon_has_parent(f_0_pn, index): + # Halo cell on top of us (parent exists) + has_a_coarser_ngh = wp.bool(False) + exploded_pop = wp.neon_lbm_read_coarser_ngh(f_0_pn, index, pull_direction, l, self.store_dtype(0), has_a_coarser_ngh) + if has_a_coarser_ngh: + # No finer ngh in pull direction, no same-level ngh, + # but a parent (ghost cell) exists with a coarser ngh + # -> Explosion: read the exploded population from the + # coarser level's halo. + _f_post_stream[l] = self.compute_dtype(exploded_pop) + else: + # Finer ngh exists in the pull direction (opposite of l). + # Read from the halo on top of that finer ngh. + if has_ngh_at_same_level: + # Finer ngh in pull direction: YES + # Same-level ngh: YES + # -> Coalescence + coalescence_factor = wp.neon_read(coalescence_factor_pn, index, l) + accumulated = accumulated * coalescence_factor + _f_post_stream[l] = self.compute_dtype(accumulated) + + return _f_post_stream + + @neon.Container.factory(name="collide_coarse") + def collide_coarse(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any, timestep: int): + num_levels = f_0_fd.get_grid().num_levels + + def ll(loader: neon.Loader): + loader.set_mres_grid(bc_mask_fd.get_grid(), level) + if level + 1 < f_0_fd.get_grid().num_levels: + f_0_pn = loader.get_mres_write_handle(f_0_fd, neon.Loader.Operation.stencil_up) + f_1_pn = loader.get_mres_write_handle(f_1_fd, neon.Loader.Operation.stencil_up) + else: + f_0_pn = loader.get_mres_read_handle(f_0_fd) + f_1_pn = loader.get_mres_write_handle(f_1_fd) + bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd) + + @wp.func + def device(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id == wp.uint8(BC_SOLID): + return + if not wp.neon_has_child(f_0_pn, index): + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + collide_bc_accum( + index, + timestep, + _boundary_id, + _missing_mask, + f_0_pn, + f_1_pn, + _f0_thread, + omega, + num_levels, + level, + f_1_pn, + ) + else: + for l in range(self.velocity_set.q): + wp.neon_write(f_1_pn, index, l, self.store_dtype(0)) + + loader.declare_kernel(device) + + return ll + + @neon.Container.factory(name="SFV_collide_coarse") + def SFV_collide_coarse(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any, timestep: int): + """Collision on SFV voxels only — no BCs, no multi-resolution accumulation.""" + + def ll(loader: neon.Loader): + loader.set_mres_grid(bc_mask_fd.get_grid(), level) + f_0_pn = loader.get_mres_read_handle(f_0_fd) + f_1_pn = loader.get_mres_write_handle(f_1_fd) + bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd) + + @wp.func + def device(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id != wp.uint8(BC_SFV): + return + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + collide_simple( + index, + 0, + _boundary_id, + _missing_mask, + f_0_pn, + f_1_pn, + _f0_thread, + omega, + 0, + level, + f_1_pn, + ) + + loader.declare_kernel(device) + + return ll + + @neon.Container.factory(name="CFV_collide_coarse") + def CFV_collide_coarse(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any, timestep: int): + """Collision on CFV voxels only — skips both solid and SFV.""" + num_levels = f_0_fd.get_grid().num_levels + + def ll(loader: neon.Loader): + loader.set_mres_grid(bc_mask_fd.get_grid(), level) + if level + 1 < f_0_fd.get_grid().num_levels: + f_0_pn = loader.get_mres_write_handle(f_0_fd, neon.Loader.Operation.stencil_up) + f_1_pn = loader.get_mres_write_handle(f_1_fd, neon.Loader.Operation.stencil_up) + else: + f_0_pn = loader.get_mres_read_handle(f_0_fd) + f_1_pn = loader.get_mres_write_handle(f_1_fd) + bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd) + + @wp.func + def device(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id == wp.uint8(BC_SOLID): + return + if _boundary_id == wp.uint8(BC_SFV): + return + if not wp.neon_has_child(f_0_pn, index): + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + collide_bc_accum( + index, + timestep, + _boundary_id, + _missing_mask, + f_0_pn, + f_1_pn, + _f0_thread, + omega, + num_levels, + level, + f_1_pn, + ) + else: + for l in range(self.velocity_set.q): + wp.neon_write(f_1_pn, index, l, self.store_dtype(0)) + + loader.declare_kernel(device) + + return ll + + @neon.Container.factory(name="stream_coarse_step_ABC") + def stream_coarse_step_ABC(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any, timestep: int): + def ll(loader: neon.Loader): + loader.set_mres_grid(bc_mask_fd.get_grid(), level) + f_0_pn = loader.get_mres_read_handle(f_0_fd) + f_1_pn = loader.get_mres_write_handle(f_1_fd) + bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd) + coalescence_factor_pn = loader.get_mres_read_handle(omega) + + @wp.func + def device(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id == wp.uint8(BC_SOLID): + return + if wp.neon_has_child(f_0_pn, index): + return + + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + _f_post_collision = _f0_thread + _f_post_stream = neon_stream_explode_coalesce(index, f_0_pn, coalescence_factor_pn) + + _f_post_stream = apply_bc_post_streaming( + index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_collision, _f_post_stream + ) + neon_apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0_pn, f_1_pn) + + for l in range(self.velocity_set.q): + wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_stream[l])) + + loader.declare_kernel(device) + + return ll + + @neon.Container.factory(name="SFV_stream_coarse_step_ABC") + def SFV_stream_coarse_step_ABC(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any, timestep: int): + """Stream on CFV voxels only — skips SFV and solid.""" + + def ll(loader: neon.Loader): + loader.set_mres_grid(bc_mask_fd.get_grid(), level) + f_0_pn = loader.get_mres_read_handle(f_0_fd) + f_1_pn = loader.get_mres_write_handle(f_1_fd) + bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd) + coalescence_factor_pn = loader.get_mres_read_handle(omega) + + @wp.func + def device(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id == wp.uint8(BC_SFV): + return + if _boundary_id == wp.uint8(BC_SOLID): + return + if wp.neon_has_child(f_0_pn, index): + return + + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + _f_post_collision = _f0_thread + _f_post_stream = neon_stream_explode_coalesce(index, f_0_pn, coalescence_factor_pn) + + _f_post_stream = apply_bc_post_streaming( + index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_collision, _f_post_stream + ) + neon_apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0_pn, f_1_pn) + + for l in range(self.velocity_set.q): + wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_stream[l])) + + loader.declare_kernel(device) + + return ll + + @neon.Container.factory(name="SFV_reset_bc_mask") + def SFV_reset_bc_mask( + level: int, + f_0_fd: Any, + f_1_fd: Any, + bc_mask_fd: Any, + missing_mask_fd: Any, + ): + """ + Setting the BC type to BC_SFV + """ + + def ll_stream_coarse(loader: neon.Loader): + loader.set_mres_grid(bc_mask_fd.get_grid(), level) + + f_0_pn = loader.get_mres_read_handle(f_0_fd) + + bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd) + + _c = self.velocity_set.c + + @wp.func + def cl_stream_coarse(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id == wp.uint8(BC_SOLID): + return + if _boundary_id != 0: + return + + if wp.neon_has_child(f_0_pn, index): + # we are a halo cell so we just exit + return + + # do stream normally + _missing_mask = _missing_mask_vec() + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + _f_post_collision = _f0_thread + _f_post_stream = self.stream.neon_functional(f_0_pn, index) + + for l in range(self.velocity_set.q): + if l == lattice_central_index: + continue + + pull_direction = wp.neon_ngh_idx(wp.int8(-_c[0, l]), wp.int8(-_c[1, l]), wp.int8(-_c[2, l])) + + has_ngh_at_same_level = wp.bool(False) + wp.neon_read_ngh(f_0_pn, index, pull_direction, l, self.store_dtype(0), has_ngh_at_same_level) + + if not wp.neon_has_finer_ngh(f_0_pn, index, pull_direction): + if not has_ngh_at_same_level: + if wp.neon_has_parent(f_0_pn, index): + has_a_coarser_ngh = wp.bool(False) + wp.neon_lbm_read_coarser_ngh(f_0_pn, index, pull_direction, l, self.store_dtype(0), has_a_coarser_ngh) + if has_a_coarser_ngh: + # Explosion: not an SFV + return + else: + if has_ngh_at_same_level: + # Coalescence: not an SFV + return + + # Voxel is a pure fluid cell with no multi-resolution interactions — mark as SFV + wp.neon_write(bc_mask_pn, index, 0, wp.uint8(BC_SFV)) + + loader.declare_kernel(cl_stream_coarse) + + return ll_stream_coarse + + @neon.Container.factory(name="SFV_stream_coarse_step") + def SFV_stream_coarse_step(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any): + def ll_stream_coarse(loader: neon.Loader): + loader.set_mres_grid(bc_mask_fd.get_grid(), level) + + f_0_pn = loader.get_mres_read_handle(f_0_fd) + f_1_pn = loader.get_mres_write_handle(f_1_fd) + + bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd) + + _c = self.velocity_set.c + + @wp.func + def cl_stream_coarse(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id != wp.uint8(BC_SFV): + return + # BC_SFV voxel type: + # - They are not BC voxels + # - They are not on a resolution jump -> they do not do coalescence or explosion + # - They are not mr halo cells + + _missing_mask = _missing_mask_vec() + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + _f_post_collision = _f0_thread + _f_post_stream = self.stream.neon_functional(f_0_pn, index) + + for l in range(self.velocity_set.q): + wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_stream[l])) + + loader.declare_kernel(cl_stream_coarse) + + return ll_stream_coarse + + @wp.func + def neon_stream_finest_with_explosion( + index: Any, + f_0_pn: Any, + explosion_src_pn: Any, + ): + _f_post_stream = self.stream.neon_functional(f_0_pn, index) + + for l in range(self.velocity_set.q): + if l == lattice_central_index: + continue + + pull_direction = wp.neon_ngh_idx(wp.int8(-_c[0, l]), wp.int8(-_c[1, l]), wp.int8(-_c[2, l])) + + has_ngh_at_same_level = wp.bool(False) + wp.neon_read_ngh(f_0_pn, index, pull_direction, l, self.store_dtype(0), has_ngh_at_same_level) + + if not has_ngh_at_same_level: + # No same-level ngh — could we have a coarser-level ngh? + if wp.neon_has_parent(f_0_pn, index): + # Parent exists — try to read the exploded population from the coarser level + has_a_coarser_ngh = wp.bool(False) + exploded_pop = wp.neon_lbm_read_coarser_ngh( + explosion_src_pn, index, pull_direction, l, self.store_dtype(0), has_a_coarser_ngh + ) + if has_a_coarser_ngh: + # No finer ngh in pull direction, no same-level ngh, + # but a parent (ghost cell) exists with a coarser ngh + # -> Explosion: read the exploded population from the + # coarser level's halo. + _f_post_stream[l] = self.compute_dtype(exploded_pop) + + return _f_post_stream + + @neon.Container.factory(name="finest_fused_pull") + def finest_fused_pull( + level: int, + f_0_fd: Any, + f_1_fd: Any, + bc_mask_fd: Any, + missing_mask_fd: Any, + omega: Any, + timestep: Any, + is_f1_the_explosion_src_field: bool, + ): + if level != 0: + raise Exception("Only the finest level is supported for now") + num_levels = f_0_fd.get_grid().num_levels + + def ll(loader: neon.Loader): + loader.set_mres_grid(bc_mask_fd.get_grid(), level) + if level + 1 < f_0_fd.get_grid().num_levels: + f_0_pn = loader.get_mres_write_handle(f_0_fd, neon.Loader.Operation.stencil_up) + f_1_pn = loader.get_mres_write_handle(f_1_fd, neon.Loader.Operation.stencil_up) + else: + f_0_pn = loader.get_mres_read_handle(f_0_fd) + f_1_pn = loader.get_mres_write_handle(f_1_fd) + bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd) + explosion_src_pn = f_1_pn if is_f1_the_explosion_src_field else f_0_pn + accumulation_pn = f_1_pn if is_f1_the_explosion_src_field else f_0_pn + + @wp.func + def device(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id == wp.uint8(BC_SOLID): + return + if wp.neon_has_child(f_0_pn, index): + return + + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + _f_post_collision = _f0_thread + _f_post_stream = neon_stream_finest_with_explosion(index, f_0_pn, explosion_src_pn) + + _f_post_stream = apply_bc_post_streaming( + index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_collision, _f_post_stream + ) + + collide_bc_accum( + index, + timestep, + _boundary_id, + _missing_mask, + f_0_pn, + f_1_pn, + _f_post_stream, + omega, + num_levels, + level, + accumulation_pn, + ) + + loader.declare_kernel(device) + + return ll + + @neon.Container.factory(name="CFV_finest_fused_pull") + def CFV_finest_fused_pull( + level: int, + f_0_fd: Any, + f_1_fd: Any, + bc_mask_fd: Any, + missing_mask_fd: Any, + omega: Any, + timestep: Any, + is_f1_the_explosion_src_field: bool, + ): + """Fused stream+collide on CFV voxels at the finest level — skips SFV and solid.""" + if level != 0: + raise Exception("Only the finest level is supported for now") + num_levels = f_0_fd.get_grid().num_levels + + def ll(loader: neon.Loader): + loader.set_mres_grid(bc_mask_fd.get_grid(), level) + if level + 1 < f_0_fd.get_grid().num_levels: + f_0_pn = loader.get_mres_write_handle(f_0_fd, neon.Loader.Operation.stencil_up) + f_1_pn = loader.get_mres_write_handle(f_1_fd, neon.Loader.Operation.stencil_up) + else: + f_0_pn = loader.get_mres_read_handle(f_0_fd) + f_1_pn = loader.get_mres_write_handle(f_1_fd) + bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd) + explosion_src_pn = f_1_pn if is_f1_the_explosion_src_field else f_0_pn + accumulation_pn = f_1_pn if is_f1_the_explosion_src_field else f_0_pn + + @wp.func + def device(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id == wp.uint8(BC_SOLID): + return + if _boundary_id == wp.uint8(BC_SFV): + return + if wp.neon_has_child(f_0_pn, index): + return + + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + _f_post_collision = _f0_thread + _f_post_stream = neon_stream_finest_with_explosion(index, f_0_pn, explosion_src_pn) + + _f_post_stream = apply_bc_post_streaming( + index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_collision, _f_post_stream + ) + + collide_bc_accum( + index, + timestep, + _boundary_id, + _missing_mask, + f_0_pn, + f_1_pn, + _f_post_stream, + omega, + num_levels, + level, + accumulation_pn, + ) + + loader.declare_kernel(device) + + return ll + + @neon.Container.factory(name="SFV_finest_fused_pull") + def SFV_finest_fused_pull(level: int, f_0_fd: Any, f_1_fd: Any, bc_mask_fd: Any, missing_mask_fd: Any, omega: Any): + """Fused stream+collide on SFV voxels at the finest level — no BCs, no explosion.""" + if level != 0: + raise Exception("Only the finest level is supported for now") + + def ll(loader: neon.Loader): + loader.set_mres_grid(bc_mask_fd.get_grid(), level) + f_0_pn = loader.get_mres_read_handle(f_0_fd) + f_1_pn = loader.get_mres_write_handle(f_1_fd) + bc_mask_pn = loader.get_mres_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_mres_read_handle(missing_mask_fd) + + @wp.func + def device(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id != wp.uint8(BC_SFV): + return + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + _f_post_stream = self.stream.neon_functional(f_0_pn, index) + collide_simple( + index, + 0, + _boundary_id, + _missing_mask, + f_0_pn, + f_1_pn, + _f_post_stream, + omega, + 0, + 0, + f_1_pn, + ) + + loader.declare_kernel(device) + + return ll + + return None, { + "collide_coarse": collide_coarse, + "stream_coarse_step_ABC": stream_coarse_step_ABC, + "finest_fused_pull": finest_fused_pull, + "CFV_finest_fused_pull": CFV_finest_fused_pull, + "SFV_finest_fused_pull": SFV_finest_fused_pull, + "SFV_reset_bc_mask": SFV_reset_bc_mask, + "CFV_collide_coarse": CFV_collide_coarse, + "SFV_collide_coarse": SFV_collide_coarse, + "SFV_stream_coarse_step_ABC": SFV_stream_coarse_step_ABC, + "SFV_stream_coarse_step": SFV_stream_coarse_step, + } + + def launch_container(self, streamId, op_name, mres_level, f_0, f_1, bc_mask, missing_mask, omega, timestep): + """Immediately launch a single Neon container by name. + + Parameters + ---------- + streamId : int + CUDA stream index. + op_name : str + Key into the container dictionary returned by ``_construct_neon``. + mres_level : int + Grid level to execute on. + f_0, f_1 : field + Double-buffered distribution-function fields. + bc_mask, missing_mask : field + Boundary condition and missing-population masks. + omega : float + Relaxation parameter at this level. + timestep : int + Current simulation timestep. + """ + self.neon_container[op_name](mres_level, f_0, f_1, bc_mask, missing_mask, omega, timestep).run(0) + + def add_to_app(self, **kwargs): + """Append a container invocation to the Neon skeleton application list. + + Required keyword arguments are ``op_name`` (str) and ``app`` (list). + All remaining keyword arguments are forwarded to the container + factory for the given ``op_name``. Argument validation is performed + before the call, and a ``ValueError`` is raised on mismatch. + """ + import inspect + + def validate_kwargs_forward(func, kwargs): + """ + Check whether `func(**kwargs)` would be valid, + and return *all* the issues instead of raising on the first one. + + Returns a dict; empty dict means "everything is OK". + """ + sig = inspect.signature(func) + params = sig.parameters + + errors = {} + + # --- 1. Positional-only required params (cannot be given via kwargs) --- + pos_only_required = [name for name, p in params.items() if p.kind == inspect.Parameter.POSITIONAL_ONLY and p.default is inspect._empty] + if pos_only_required: + errors["positional_only_required"] = pos_only_required + + # --- 2. Unexpected kwargs (if no **kwargs in target) --- + has_var_kw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + if not has_var_kw: + allowed_kw = { + name + for name, p in params.items() + if p.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + } + unexpected = sorted(set(kwargs) - allowed_kw) + if unexpected: + errors["unexpected_kwargs"] = unexpected + + # --- 3. Missing required keyword-passable params --- + missing_required = [ + name + for name, p in params.items() + if p.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + and p.default is inspect._empty # no default + and name not in kwargs # not provided + ] + if missing_required: + errors["missing_required"] = missing_required + + return errors + + container_generator = None + try: + op_name = kwargs.pop("op_name") + app = kwargs.pop("app") + except: + raise ValueError("op_name and app must be provided as keyword arguments") + + try: + container_generator = self.neon_container[op_name] + except KeyError: + raise ValueError(f"Operator {op_name} not found in neon container. Available operators: {list(self.neon_container.keys())}") + + errors = validate_kwargs_forward(container_generator, kwargs) + if errors: + raise ValueError(f"Cannot forward kwargs to target: {errors}") + + nvtx.push_range(f"New Container {op_name}", color="yellow") + app.append(container_generator(**kwargs)) + nvtx.pop_range() + + @Operator.register_backend(ComputeBackend.NEON) + def neon_launch(self, f_0, f_1, bc_mask, missing_mask, omega, timestep): + """Execute a single LBM step through the Neon backend (direct launch).""" + c = self.neon_container(f_0, f_1, bc_mask, missing_mask, omega, timestep) + c.run(0) + return f_0, f_1 diff --git a/xlb/operator/stepper/nse_stepper.py b/xlb/operator/stepper/nse_stepper.py index ef0ff412..475958ac 100644 --- a/xlb/operator/stepper/nse_stepper.py +++ b/xlb/operator/stepper/nse_stepper.py @@ -1,6 +1,13 @@ -# Base class for all stepper operators +""" +Single-resolution incompressible Navier-Stokes stepper. + +Implements the full LBM step (stream, collide, apply BCs) for a single- +resolution grid. Supports pull and push streaming schemes on JAX, a +pull-only fused kernel on Warp, and a pull-only Neon container. +""" from functools import partial + from jax import jit import warp as wp from typing import Any @@ -17,12 +24,44 @@ from xlb.operator.boundary_condition.boundary_condition import ImplementationStep from xlb.operator.boundary_condition.boundary_condition_registry import boundary_condition_registry from xlb.operator.collision import ForcedCollision -from xlb.operator.boundary_masker import IndicesBoundaryMasker, MeshBoundaryMasker +from xlb.operator.boundary_masker import ( + IndicesBoundaryMasker, + MeshVoxelizationMethod, + MeshMaskerAABB, + MeshMaskerRay, + MeshMaskerWinding, + MeshMaskerAABBClose, +) from xlb.helper import check_bc_overlaps -from xlb.helper.nse_solver import create_nse_fields +from xlb.helper.nse_fields import create_nse_fields +from xlb.operator.boundary_condition.helper_functions_bc import EncodeAuxiliaryData +from xlb.cell_type import BC_SOLID class IncompressibleNavierStokesStepper(Stepper): + """Single-resolution incompressible Navier-Stokes LBM stepper. + + Composes streaming, collision, equilibrium, macroscopic, and boundary- + condition operators into a complete timestep. + + Parameters + ---------- + grid : Grid + Computational grid. + boundary_conditions : list of BoundaryCondition + Boundary conditions to apply each step. + collision_type : str + ``"BGK"``, ``"KBC"``, or ``"SmagorinskyLESBGK"``. + streaming_scheme : str + ``"pull"`` (default) or ``"push"`` (JAX only). + forcing_scheme : str + Forcing scheme name (used when *force_vector* is given). + force_vector : array-like, optional + External body force vector. + backend_config : dict + Backend-specific options (e.g. Neon OCC configuration). + """ + def __init__( self, grid, @@ -31,8 +70,10 @@ def __init__( streaming_scheme="pull", forcing_scheme="exact_difference", force_vector=None, + backend_config={}, ): super().__init__(grid, boundary_conditions) + self.backend_config = backend_config # Construct the collision operator if collision_type == "BGK": @@ -76,63 +117,112 @@ def prepare_fields(self, initializer=None): grid=self.grid, velocity_set=self.velocity_set, compute_backend=self.compute_backend, precision_policy=self.precision_policy ) - # Initialize distribution functions if initializer is provided - if initializer is not None: - f_0 = initializer(self.grid, self.velocity_set, self.precision_policy, self.compute_backend) - else: - from xlb.helper.initializers import initialize_eq - - f_0 = initialize_eq(f_0, self.grid, self.velocity_set, self.precision_policy, self.compute_backend) - # Copy f_0 using backend-specific copy to f_1 if self.compute_backend == ComputeBackend.JAX: f_1 = f_0.copy() - else: + if self.compute_backend == ComputeBackend.WARP: wp.copy(f_1, f_0) + if self.compute_backend == ComputeBackend.NEON: + f_1.copy_from_run(f_0, 0) + # Important note: XLB uses f_1 buffer (center index and missing directions) to store auxiliary data for boundary conditions. # Process boundary conditions and update masks - bc_mask, missing_mask = self._process_boundary_conditions(self.boundary_conditions, bc_mask, missing_mask) + f_1, bc_mask, missing_mask = self._process_boundary_conditions(self.boundary_conditions, f_1, bc_mask, missing_mask) + # Initialize auxiliary data if needed - f_0, f_1 = self._initialize_auxiliary_data(self.boundary_conditions, f_0, f_1, bc_mask, missing_mask) + f_1 = self._initialize_auxiliary_data(self.boundary_conditions, f_1, bc_mask, missing_mask) + # bc_mask.update_host(0) + # missing_mask.update_host(0) + wp.synchronize() + # bc_mask.export_vti("bc_mask.vti", 'bc_mask') + # missing_mask.export_vti("missing_mask.vti", 'missing_mask') + + # Initialize distribution functions if initializer is provided + if initializer is not None: + f_0 = initializer(bc_mask, f_0) + else: + from xlb.helper.initializers import initialize_eq + + f_0 = initialize_eq(f_0, self.grid, self.velocity_set, self.precision_policy, self.compute_backend) return f_0, f_1, bc_mask, missing_mask - @classmethod - def _process_boundary_conditions(cls, boundary_conditions, bc_mask, missing_mask): + def _process_boundary_conditions(self, boundary_conditions, f_1, bc_mask, missing_mask): """Process boundary conditions and update boundary masks.""" + # Check for boundary condition overlaps check_bc_overlaps(boundary_conditions, DefaultConfig.velocity_set.d, DefaultConfig.default_backend) + # Create boundary maskers indices_masker = IndicesBoundaryMasker( velocity_set=DefaultConfig.velocity_set, precision_policy=DefaultConfig.default_precision_policy, compute_backend=DefaultConfig.default_backend, + grid=self.grid, ) + # Split boundary conditions by type bc_with_vertices = [bc for bc in boundary_conditions if bc.mesh_vertices is not None] bc_with_indices = [bc for bc in boundary_conditions if bc.indices is not None] + # Process indices-based boundary conditions if bc_with_indices: bc_mask, missing_mask = indices_masker(bc_with_indices, bc_mask, missing_mask) + # Process mesh-based boundary conditions for 3D if DefaultConfig.velocity_set.d == 3 and bc_with_vertices: - mesh_masker = MeshBoundaryMasker( - velocity_set=DefaultConfig.velocity_set, - precision_policy=DefaultConfig.default_precision_policy, - compute_backend=DefaultConfig.default_backend, - ) for bc in bc_with_vertices: - bc_mask, missing_mask = mesh_masker(bc, bc_mask, missing_mask) + if bc.voxelization_method.id is MeshVoxelizationMethod("AABB").id: + mesh_masker = MeshMaskerAABB( + velocity_set=DefaultConfig.velocity_set, + precision_policy=DefaultConfig.default_precision_policy, + compute_backend=DefaultConfig.default_backend, + ) + elif bc.voxelization_method.id is MeshVoxelizationMethod("RAY").id: + mesh_masker = MeshMaskerRay( + velocity_set=DefaultConfig.velocity_set, + precision_policy=DefaultConfig.default_precision_policy, + compute_backend=DefaultConfig.default_backend, + ) + elif bc.voxelization_method.id is MeshVoxelizationMethod("WINDING").id: + mesh_masker = MeshMaskerWinding( + velocity_set=DefaultConfig.velocity_set, + precision_policy=DefaultConfig.default_precision_policy, + compute_backend=DefaultConfig.default_backend, + ) + elif bc.voxelization_method.id is MeshVoxelizationMethod("AABB_CLOSE").id: + mesh_masker = MeshMaskerAABBClose( + velocity_set=DefaultConfig.velocity_set, + precision_policy=DefaultConfig.default_precision_policy, + compute_backend=DefaultConfig.default_backend, + close_voxels=bc.voxelization_method.options.get("close_voxels"), + ) + else: + raise ValueError(f"Unsupported voxelization method: {bc.voxelization_method}") + # Apply the mesh masker to the boundary condition + f_1, bc_mask, missing_mask = mesh_masker(bc, f_1, bc_mask, missing_mask) - return bc_mask, missing_mask + return f_1, bc_mask, missing_mask @staticmethod - def _initialize_auxiliary_data(boundary_conditions, f_0, f_1, bc_mask, missing_mask): + def _initialize_auxiliary_data(boundary_conditions, f_1, bc_mask, missing_mask): """Initialize auxiliary data for boundary conditions that require it.""" for bc in boundary_conditions: if bc.needs_aux_init and not bc.is_initialized_with_aux_data: - f_0, f_1 = bc.aux_data_init(f_0, f_1, bc_mask, missing_mask) - return f_0, f_1 + # Create the encoder operator for storing the auxiliary data + encode_auxiliary_data = EncodeAuxiliaryData( + bc.id, + bc.num_of_aux_data, + bc.profile, + velocity_set=bc.velocity_set, + precision_policy=bc.precision_policy, + compute_backend=bc.compute_backend, + ) + + # Encode the auxiliary data in f_1 + f_1 = encode_auxiliary_data(f_1, bc_mask, missing_mask) + bc.is_initialized_with_aux_data = True + return f_1 @Operator.register_backend(ComputeBackend.JAX) @partial(jit, static_argnums=(0,)) @@ -173,11 +263,11 @@ def jax_implementation_pull(self, f_0, f_1, bc_mask, missing_mask, omega, timest feq = self.equilibrium(rho, u) # Apply collision - f_post_collision = self.collision(f_post_stream, feq, rho, u, omega) + f_post_collision = self.collision(f_post_stream, feq, omega) # Apply collision type boundary conditions for bc in self.boundary_conditions: - f_post_collision = bc.update_bc_auxilary_data(f_post_stream, f_post_collision, bc_mask, missing_mask) + f_post_collision = bc.assemble_auxiliary_data(f_post_stream, f_post_collision, bc_mask, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( f_post_stream, @@ -210,11 +300,11 @@ def jax_implementation_push(self, f_0, f_1, bc_mask, missing_mask, omega, timest feq = self.equilibrium(rho, u) # Apply collision - f_post_collision = self.collision(f_post_stream, feq, rho, u, omega) + f_post_collision = self.collision(f_post_stream, feq, omega) # Apply collision type boundary conditions for bc in self.boundary_conditions: - f_post_collision = bc.update_bc_auxilary_data(f_post_stream, f_post_collision, bc_mask, missing_mask) + f_post_collision = bc.update_bc_auxiliary_data(f_post_stream, f_post_collision, bc_mask, missing_mask) if bc.implementation_step == ImplementationStep.COLLISION: f_post_collision = bc( f_post_stream, @@ -247,27 +337,23 @@ def _construct_warp(self): _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) _opp_indices = self.velocity_set.opp_indices + lattice_central_index = self.velocity_set.center_index # Read the list of bc_to_id created upon instantiation bc_to_id = boundary_condition_registry.bc_to_id - id_to_bc = boundary_condition_registry.id_to_bc # Gather IDs of ExtrapolationOutflowBC boundary conditions extrapolation_outflow_bc_ids = [] for bc_name, bc_id in bc_to_id.items(): if bc_name.startswith("ExtrapolationOutflowBC"): extrapolation_outflow_bc_ids.append(bc_id) - # Group active boundary conditions - active_bcs = set(boundary_condition_registry.id_to_bc[bc.id] for bc in self.boundary_conditions) - - _opp_indices = self.velocity_set.opp_indices @wp.func def apply_bc( index: Any, timestep: Any, _boundary_id: Any, - missing_mask: Any, + _missing_mask: Any, f_0: Any, f_1: Any, f_pre: Any, @@ -281,39 +367,33 @@ def apply_bc( if is_post_streaming: if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.STREAMING): if _boundary_id == wp.static(self.boundary_conditions[i].id): - f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, _missing_mask, f_0, f_1, f_pre, f_post) else: if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.COLLISION): if _boundary_id == wp.static(self.boundary_conditions[i].id): - f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, missing_mask, f_0, f_1, f_pre, f_post) + f_result = wp.static(self.boundary_conditions[i].warp_functional)(index, timestep, _missing_mask, f_0, f_1, f_pre, f_post) if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids): if _boundary_id == wp.static(self.boundary_conditions[i].id): - f_result = wp.static(self.boundary_conditions[i].update_bc_auxilary_data)( - index, timestep, missing_mask, f_0, f_1, f_pre, f_post + f_result = wp.static(self.boundary_conditions[i].assemble_auxiliary_data)( + index, timestep, _missing_mask, f_0, f_1, f_pre, f_post ) return f_result @wp.func def get_thread_data( f0_buffer: wp.array4d(dtype=Any), - f1_buffer: wp.array4d(dtype=Any), missing_mask: wp.array4d(dtype=Any), index: Any, ): # Read thread data for populations _f0_thread = _f_vec() - _f1_thread = _f_vec() _missing_mask = _missing_mask_vec() for l in range(self.velocity_set.q): # q-sized vector of pre-streaming populations _f0_thread[l] = self.compute_dtype(f0_buffer[l, index[0], index[1], index[2]]) - _f1_thread[l] = self.compute_dtype(f1_buffer[l, index[0], index[1], index[2]]) - if missing_mask[l, index[0], index[1], index[2]]: - _missing_mask[l] = wp.uint8(1) - else: - _missing_mask[l] = wp.uint8(0) + _missing_mask[l] = missing_mask[l, index[0], index[1], index[2]] - return _f0_thread, _f1_thread, _missing_mask + return _f0_thread, _missing_mask @wp.func def apply_aux_recovery_bc( @@ -321,25 +401,28 @@ def apply_aux_recovery_bc( _boundary_id: Any, _missing_mask: Any, f_0: Any, - _f1_thread: Any, + f_1: Any, ): # Note: # In XLB, the BC auxiliary data (e.g. prescribed values of pressure or normal velocity) are stored in (i) central index of f_1 and/or # (ii) missing directions of f_1. Some BCs may or may not need all these available storage space. This function checks whether # the BC needs recovery of auxiliary data and then recovers the information for the next iteration (due to buffer swapping) by - # writting the thread values of f_1 (i.e._f1_thread) into f_0. + # writting the values of f_1 into f_0. # Unroll the loop over boundary conditions for i in range(wp.static(len(self.boundary_conditions))): if wp.static(self.boundary_conditions[i].needs_aux_recovery): if _boundary_id == wp.static(self.boundary_conditions[i].id): - # Perform the swapping of data - # (i) Recover the values stored in the central index of f_1 - f_0[0, index[0], index[1], index[2]] = self.store_dtype(_f1_thread[0]) - # (ii) Recover the values stored in the missing directions of f_1 - for l in range(1, self.velocity_set.q): - if _missing_mask[l] == wp.uint8(1): - f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype(_f1_thread[_opp_indices[l]]) + for l in range(self.velocity_set.q): + # Perform the swapping of data + if l == lattice_central_index: + # (i) Recover the values stored in the central index of f_1 + f_0[l, index[0], index[1], index[2]] = self.store_dtype(f_1[l, index[0], index[1], index[2]]) + elif _missing_mask[l] == wp.uint8(1): + # (ii) Recover the values stored in the missing directions of f_1 + f_0[_opp_indices[l], index[0], index[1], index[2]] = self.store_dtype( + f_1[_opp_indices[l], index[0], index[1], index[2]] + ) @wp.kernel def kernel( @@ -354,13 +437,13 @@ def kernel( index = wp.vec3i(i, j, k) _boundary_id = bc_mask[0, index[0], index[1], index[2]] - if _boundary_id == wp.uint8(255): + if _boundary_id == wp.uint8(BC_SOLID): return # Apply streaming _f_post_stream = self.stream.warp_functional(f_0, index) - _f0_thread, _f1_thread, _missing_mask = get_thread_data(f_0, f_1, missing_mask, index) + _f0_thread, _missing_mask = get_thread_data(f_0, missing_mask, index) _f_post_collision = _f0_thread # Apply post-streaming boundary conditions @@ -368,13 +451,13 @@ def kernel( _rho, _u = self.macroscopic.warp_functional(_f_post_stream) _feq = self.equilibrium.warp_functional(_rho, _u) - _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, _rho, _u, omega) + _f_post_collision = self.collision.warp_functional(_f_post_stream, _feq, omega) # Apply post-collision boundary conditions _f_post_collision = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0, f_1, _f_post_stream, _f_post_collision, False) # Apply auxiliary recovery for boundary conditions (swapping) - apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0, _f1_thread) + apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0, f_1) # Store the result in f_1 for l in range(self.velocity_set.q): @@ -390,3 +473,188 @@ def warp_implementation(self, f_0, f_1, bc_mask, missing_mask, omega, timestep): dim=f_0.shape[1:], ) return f_0, f_1 + + def _construct_neon(self): + import neon + + # Set local constants + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + _missing_mask_vec = wp.vec(self.velocity_set.q, dtype=wp.uint8) + _opp_indices = self.velocity_set.opp_indices + lattice_central_index = self.velocity_set.center_index + + # Read the list of bc_to_id created upon instantiation + bc_to_id = boundary_condition_registry.bc_to_id + + # Gather IDs of ExtrapolationOutflowBC boundary conditions + extrapolation_outflow_bc_ids = [] + for bc_name, bc_id in bc_to_id.items(): + if bc_name.startswith("ExtrapolationOutflowBC"): + extrapolation_outflow_bc_ids.append(bc_id) + + @wp.func + def apply_bc( + index: Any, + timestep: Any, + _boundary_id: Any, + _missing_mask: Any, + f_0: Any, + f_1: Any, + f_pre: Any, + f_post: Any, + is_post_streaming: bool, + ): + f_result = f_post + + # Unroll the loop over boundary conditions + for i in range(wp.static(len(self.boundary_conditions))): + if is_post_streaming: + if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.STREAMING): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].neon_functional)(index, timestep, _missing_mask, f_0, f_1, f_pre, f_post) + else: + if wp.static(self.boundary_conditions[i].implementation_step == ImplementationStep.COLLISION): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].neon_functional)(index, timestep, _missing_mask, f_0, f_1, f_pre, f_post) + if wp.static(self.boundary_conditions[i].id in extrapolation_outflow_bc_ids): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + f_result = wp.static(self.boundary_conditions[i].assemble_auxiliary_data)( + index, timestep, _missing_mask, f_0, f_1, f_pre, f_post + ) + return f_result + + @wp.func + def neon_get_thread_data( + f0_pn: Any, + missing_mask_pn: Any, + index: Any, + ): + # Read thread data for populations + _f0_thread = _f_vec() + _missing_mask = _missing_mask_vec() + for l in range(self.velocity_set.q): + # q-sized vector of pre-streaming populations + _f0_thread[l] = self.compute_dtype(wp.neon_read(f0_pn, index, l)) + _missing_mask[l] = wp.neon_read(missing_mask_pn, index, l) + + return _f0_thread, _missing_mask + + @wp.func + def neon_apply_aux_recovery_bc( + index: Any, + _boundary_id: Any, + _missing_mask: Any, + f_0_pn: Any, + f_1_pn: Any, + ): + # Note: + # In XLB, the BC auxiliary data (e.g. prescribed values of pressure or normal velocity) are stored in (i) central index of f_1 and/or + # (ii) missing directions of f_1. Some BCs may or may not need all these available storage space. This function checks whether + # the BC needs recovery of auxiliary data and then recovers the information for the next iteration (due to buffer swapping) by + # writting the values of f_1 into f_0. + + # Unroll the loop over boundary conditions + for i in range(wp.static(len(self.boundary_conditions))): + if wp.static(self.boundary_conditions[i].needs_aux_recovery): + if _boundary_id == wp.static(self.boundary_conditions[i].id): + for l in range(self.velocity_set.q): + # Perform the swapping of data + if l == lattice_central_index: + # (i) Recover the values stored in the central index of f_1 + _f1_thread = wp.neon_read(f_1_pn, index, l) + wp.neon_write(f_0_pn, index, l, self.store_dtype(_f1_thread)) + elif _missing_mask[l] == wp.uint8(1): + # (ii) Recover the values stored in the missing directions of f_1 + _f1_thread = wp.neon_read(f_1_pn, index, _opp_indices[l]) + wp.neon_write(f_0_pn, index, _opp_indices[l], self.store_dtype(_f1_thread)) + + @neon.Container.factory(name="nse_stepper") + def container( + f_0_fd: Any, + f_1_fd: Any, + bc_mask_fd: Any, + missing_mask_fd: Any, + omega: Any, + timestep: int, + ): + def nse_stepper_ll(loader: neon.Loader): + loader.set_grid(bc_mask_fd.get_grid()) + + f_0_pn = loader.get_read_handle( + f_0_fd, + operation=neon.Loader.Operation.stencil, + discretization=neon.Loader.Discretization.lattice, + ) + bc_mask_pn = loader.get_read_handle(bc_mask_fd) + missing_mask_pn = loader.get_read_handle(missing_mask_fd) + + f_1_pn = loader.get_write_handle(f_1_fd) + + @wp.func + def nse_stepper_cl(index: Any): + _boundary_id = wp.neon_read(bc_mask_pn, index, 0) + if _boundary_id == wp.uint8(BC_SOLID): + return + # Apply streaming + _f_post_stream = self.stream.neon_functional(f_0_pn, index) + + _f0_thread, _missing_mask = neon_get_thread_data(f_0_pn, missing_mask_pn, index) + _f_post_collision = _f0_thread + + # Apply post-streaming boundary conditions + _f_post_stream = apply_bc(index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_collision, _f_post_stream, True) + + _rho, _u = self.macroscopic.neon_functional(_f_post_stream) + _feq = self.equilibrium.neon_functional(_rho, _u) + _f_post_collision = self.collision.neon_functional(_f_post_stream, _feq, omega) + + # Apply post-collision boundary conditions + _f_post_collision = apply_bc( + index, timestep, _boundary_id, _missing_mask, f_0_pn, f_1_pn, _f_post_stream, _f_post_collision, False + ) + + # Apply auxiliary recovery for boundary conditions (swapping) + neon_apply_aux_recovery_bc(index, _boundary_id, _missing_mask, f_0_pn, f_1_pn) + + # Store the result in f_1 + for l in range(self.velocity_set.q): + wp.neon_write(f_1_pn, index, l, self.store_dtype(_f_post_collision[l])) + + loader.declare_kernel(nse_stepper_cl) + + return nse_stepper_ll + + return None, container + + @Operator.register_backend(ComputeBackend.NEON) + def neon_launch(self, f_0, f_1, bc_mask, missing_mask, omega, timestep): + if timestep == 0: + self.prepare_skeleton(f_0, f_1, bc_mask, missing_mask, omega) + self.sk[self.sk_iter].run() + self.sk_iter = (self.sk_iter + 1) % 2 + return f_0, f_1 + + def prepare_skeleton(self, f_0, f_1, bc_mask, missing_mask, omega): + """Build the Neon odd/even skeletons for double-buffered time stepping.""" + grid = f_0.get_grid() + bk = grid.backend + self.neon_skeleton = {"odd": {}, "even": {}} + self.neon_skeleton["odd"]["container"] = self.neon_container(f_0, f_1, bc_mask, missing_mask, omega, 0) + self.neon_skeleton["even"]["container"] = self.neon_container(f_1, f_0, bc_mask, missing_mask, omega, 1) + # check if 'occ' is a valid key + if "occ" not in self.backend_config: + occ = neon.SkeletonConfig.OCC.none() + else: + occ = self.backend_config["occ"] + # check that occ is of type neon.SkeletonConfig.OCC + if not isinstance(occ, neon.SkeletonConfig.OCC): + print(type(occ)) + raise ValueError("occ must be of type neon.SkeletonConfig.OCC") + + for key in self.neon_skeleton: + self.neon_skeleton[key]["app"] = [self.neon_skeleton[key]["container"]] + self.neon_skeleton[key]["skeleton"] = neon.Skeleton(backend=bk) + self.neon_skeleton[key]["skeleton"].sequence(name="mres_nse_stepper", containers=self.neon_skeleton[key]["app"], occ=occ) + + self.sk = [self.neon_skeleton["odd"]["skeleton"], self.neon_skeleton["even"]["skeleton"]] + self.sk_iter = 0 diff --git a/xlb/operator/stream/stream.py b/xlb/operator/stream/stream.py index 247fa5a9..7cfe2d19 100644 --- a/xlb/operator/stream/stream.py +++ b/xlb/operator/stream/stream.py @@ -1,4 +1,9 @@ -# Base class for all streaming operators +""" +Streaming operator for the Lattice Boltzmann Method. + +Implements the pull-scheme propagation step: each voxel reads populations +from its lattice neighbours according to the velocity-set directions. +""" from functools import partial import jax.numpy as jnp @@ -11,8 +16,14 @@ class Stream(Operator): - """ - Base class for all streaming operators. This is used for pulling the distribution + """Pull-scheme streaming operator. + + Propagates distribution functions by reading each population from the + upstream neighbour along the corresponding lattice direction. Periodic + boundaries are applied automatically when a pull index falls outside + the domain (Warp backend only; JAX uses ``jnp.roll``). + + Supports JAX, Warp, and Neon backends. """ @Operator.register_backend(ComputeBackend.JAX) @@ -112,3 +123,32 @@ def warp_implementation(self, f_0, f_1): dim=f_0.shape[1:], ) return f_1 + + def _construct_neon(self): + # Set local constants TODO: This is a hack and should be fixed with warp update + _c = self.velocity_set.c + _f_vec = wp.vec(self.velocity_set.q, dtype=self.compute_dtype) + + # Construct the funcional to get streamed indices + @wp.func + def functional( + f: Any, + index: Any, + ): + # Pull the distribution function + _f = _f_vec() + for l in range(self.velocity_set.q): + # Get pull offset + ngh = wp.neon_ngh_idx(wp.int8(-_c[0, l]), wp.int8(-_c[1, l]), wp.int8(-_c[2, l])) + unused_is_valid = wp.bool(False) + + # Read the distribution function from the neighboring cell in the pull direction + _f[l] = self.compute_dtype(wp.neon_read_ngh(f, index, ngh, l, self.store_dtype(0), unused_is_valid)) + return _f + + return functional, None + + @Operator.register_backend(ComputeBackend.NEON) + def neon_implementation(self, f_0, f_1): + # raise exception as this feature is not implemented yet + raise NotImplementedError("This feature is not implemented in XLB with the NEON backend yet.") diff --git a/xlb/precision_policy.py b/xlb/precision_policy.py index 7d31c8a3..32a6d567 100644 --- a/xlb/precision_policy.py +++ b/xlb/precision_policy.py @@ -1,11 +1,18 @@ -# Enum for precision policy +""" +Precision and precision-policy enumerations for XLB. + +:class:`Precision` maps symbolic precisions to Warp and JAX dtypes. +:class:`PrecisionPolicy` pairs a *compute* precision (used during +arithmetic) with a *store* precision (used in memory), enabling +mixed-precision simulations. +""" from enum import Enum, auto -import jax.numpy as jnp -import warp as wp class Precision(Enum): + """Scalar precision levels with Warp and JAX dtype accessors.""" + FP64 = auto() FP32 = auto() FP16 = auto() @@ -14,6 +21,8 @@ class Precision(Enum): @property def wp_dtype(self): + import warp as wp + if self == Precision.FP64: return wp.float64 elif self == Precision.FP32: @@ -29,6 +38,8 @@ def wp_dtype(self): @property def jax_dtype(self): + import jax.numpy as jnp + if self == Precision.FP64: return jnp.float64 elif self == Precision.FP32: @@ -44,6 +55,12 @@ def jax_dtype(self): class PrecisionPolicy(Enum): + """Mixed-precision policy pairing compute and store precisions. + + The naming convention is ````, e.g. ``FP32FP16`` + computes in FP32 and stores results in FP16. + """ + FP64FP64 = auto() FP64FP32 = auto() FP64FP16 = auto() @@ -81,9 +98,13 @@ def store_precision(self): raise ValueError("Invalid precision policy") def cast_to_compute_jax(self, array): + import jax.numpy as jnp + compute_precision = self.compute_precision return jnp.array(array, dtype=compute_precision.jax_dtype) def cast_to_store_jax(self, array): + import jax.numpy as jnp + store_precision = self.store_precision return jnp.array(array, dtype=store_precision.jax_dtype) diff --git a/xlb/utils/__init__.py b/xlb/utils/__init__.py index 735cfad8..b7bdbf9a 100644 --- a/xlb/utils/__init__.py +++ b/xlb/utils/__init__.py @@ -6,9 +6,12 @@ rotate_geometry, voxelize_stl, axangle2mat, + ToJAX, + UnitConvertor, save_usd_vorticity, save_usd_q_criterion, update_usd_lagrangian_parts, plot_object_placement, colorize_scalars, ) +from .mesher import make_cuboid_mesh, MultiresIO diff --git a/xlb/utils/mesher.py b/xlb/utils/mesher.py new file mode 100644 index 00000000..8b849b87 --- /dev/null +++ b/xlb/utils/mesher.py @@ -0,0 +1,941 @@ +""" +Multi-resolution mesh utilities. + +Provides geometry preparation and I/O for multi-resolution LBM simulations: + +* :func:`make_cuboid_mesh` — builds a strongly-balanced cuboid mesh hierarchy + from an STL file and a sequence of domain multipliers. +* :func:`prepare_sparsity_pattern` — converts level data into the sparsity + arrays required by :func:`multires_grid_factory`. +* :class:`MultiresIO` — exports multi-resolution Neon field data to HDF5 / + XDMF, 2-D slice images, and 1-D line profiles. +""" + +import numpy as np +import trimesh +from typing import Any, Optional + +import warp as wp +from xlb.utils.utils import UnitConvertor + + +def adjust_bbox(cuboid_max, cuboid_min, voxel_size_up): + """ + Adjust the bounding box to the nearest points of one level finer grid that encloses the desired region. + + Args: + cuboid_min (np.ndarray): Desired minimum coordinates of the bounding box. + cuboid_max (np.ndarray): Desired maximum coordinates of the bounding box. + voxel_size_up (float): Voxel size of one level higher (finer) grid. + + Returns: + tuple: (adjusted_min, adjusted_max) snapped to grid points of one level higher. + """ + adjusted_min = np.round(cuboid_min / voxel_size_up) * voxel_size_up + adjusted_max = np.round(cuboid_max / voxel_size_up) * voxel_size_up + return adjusted_min, adjusted_max + + +def prepare_sparsity_pattern(level_data): + """ + Prepare the sparsity pattern for the multiresolution grid based on the level data. "level_data" is expected to be formatted as in + the output of "make_cuboid_mesh". + """ + num_levels = len(level_data) + level_origins = [] + sparsity_pattern = [] + for lvl in range(num_levels): + # Get the level mask from the level data + level_mask = level_data[lvl][0] + + # Ensure level_0 is contiguous int32 + level_mask = np.ascontiguousarray(level_mask, dtype=np.int32) + + # Append the padded level mask to the sparsity pattern + sparsity_pattern.append(level_mask) + + # Get the origin for this level + level_origins.append(level_data[lvl][2]) + + return sparsity_pattern, level_origins + + +def make_cuboid_mesh(voxel_size, cuboids, stl_filename): + """ + Create a strongly-balanced multi-level cuboid mesh with a sequence of bounding boxes. + Outputs mask arrays that are set to True only in regions not covered by finer levels. + + Args: + voxel_size (float): Voxel size of the finest grid . + cuboids (list): List of multipliers defining each level's domain. + stl_name (str): Path to the STL file. + + Returns: + list: Level data with mask arrays, voxel sizes, origins, and levels. + """ + # Load the mesh and get its bounding box + mesh = trimesh.load_mesh(stl_filename, process=False) + assert not mesh.is_empty, "Loaded mesh is empty or invalid." + + mesh_vertices = mesh.vertices + min_bound = mesh_vertices.min(axis=0) + max_bound = mesh_vertices.max(axis=0) + partSize = max_bound - min_bound + + level_data = [] + adjusted_bboxes = [] + max_voxel_size = voxel_size * pow(2, (len(cuboids) - 1)) + # Step 1: Generate all levels and store their data + for level in range(len(cuboids)): + # Compute desired bounding box for this level + cuboid_min = np.array( + [ + min_bound[0] - cuboids[level][0] * partSize[0], + min_bound[1] - cuboids[level][2] * partSize[1], + min_bound[2] - cuboids[level][4] * partSize[2], + ], + dtype=float, + ) + + cuboid_max = np.array( + [ + max_bound[0] + cuboids[level][1] * partSize[0], + max_bound[1] + cuboids[level][3] * partSize[1], + max_bound[2] + cuboids[level][5] * partSize[2], + ], + dtype=float, + ) + + # Set voxel size for this level + voxel_size_level = max_voxel_size / pow(2, level) + + # Adjust bounding box to align with one level up (finer grid) + if level > 0: + voxel_level_up = max_voxel_size / pow(2, level - 1) + else: + voxel_level_up = voxel_size_level + adjusted_min, adjusted_max = adjust_bbox(cuboid_max, cuboid_min, voxel_level_up) + + xmin, ymin, zmin = adjusted_min + xmax, ymax, zmax = adjusted_max + + # Compute number of voxels based on level-specific voxel size + nx = int(np.round((xmax - xmin) / voxel_size_level)) + ny = int(np.round((ymax - ymin) / voxel_size_level)) + nz = int(np.round((zmax - zmin) / voxel_size_level)) + print(f"Domain {nx}, {ny}, {nz} Origin {adjusted_min} Voxel Size {voxel_size_level} Voxel Level Up {voxel_level_up}") + + voxel_matrix = np.ones((nx, ny, nz), dtype=bool) + + origin = adjusted_min + level_data.append((voxel_matrix, voxel_size_level, origin, level)) + adjusted_bboxes.append((adjusted_min, adjusted_max)) + + # Step 2: Adjust coarser levels to exclude regions covered by finer levels + for k in range(len(level_data) - 1): # Exclude the finest level + # Current level's data + voxel_matrix_k = level_data[k][0] + origin_k = level_data[k][2] + voxel_size_k = level_data[k][1] + nx, ny, nz = voxel_matrix_k.shape + + # Next finer level's bounding box + adjusted_min_k1, adjusted_max_k1 = adjusted_bboxes[k + 1] + + # Compute index ranges in level k that overlap with level k+1's bounding box + # Use epsilon (1e-10) to handle floating-point precision + i_start = max(0, int(np.ceil((adjusted_min_k1[0] - origin_k[0] - 1e-10) / voxel_size_k))) + i_end = min(nx, int(np.floor((adjusted_max_k1[0] - origin_k[0] + 1e-10) / voxel_size_k))) + j_start = max(0, int(np.ceil((adjusted_min_k1[1] - origin_k[1] - 1e-10) / voxel_size_k))) + j_end = min(ny, int(np.floor((adjusted_max_k1[1] - origin_k[1] + 1e-10) / voxel_size_k))) + k_start = max(0, int(np.ceil((adjusted_min_k1[2] - origin_k[2] - 1e-10) / voxel_size_k))) + k_end = min(nz, int(np.floor((adjusted_max_k1[2] - origin_k[2] + 1e-10) / voxel_size_k))) + + # Set overlapping region to zero + voxel_matrix_k[i_start:i_end, j_start:j_end, k_start:k_end] = 0 + + # Step 3 Convert to Indices from STL units + num_levels = len(level_data) + level_data = [(dr, int(v / voxel_size), np.round(dOrigin / v).astype(int), num_levels - 1 - l) for dr, v, dOrigin, l in level_data] + + return list(reversed(level_data)) + + +class MultiresIO(object): + """I/O helper for multi-resolution Neon field data. + + Converts hierarchical Neon ``mGrid`` fields into merged unstructured + hexahedral meshes and exports them as HDF5 + XDMF (for ParaView), + 2-D slice PNG images, or 1-D line CSV profiles. + + The constructor precomputes the merged geometry (coordinates, + connectivity, centroids) and allocates intermediate Warp fields so + that repeated exports only need to transfer data from the Neon fields. + """ + + def __init__( + self, + field_name_cardinality_dict, + levels_data, + unit_convertor: UnitConvertor = None, + offset: Optional[tuple] = (0.0, 0.0, 0.0), + store_precision=None, + ): + """ + Initialize the MultiresIO object. + + Parameters + ---------- + field_name_cardinality_dict : dict + A dictionary mapping field names to their cardinalities. + Example: {'velocity_x': 1, 'velocity_y': 1, 'velocity': 3, 'density': 1} + levels_data : list of tuples + Each tuple contains (data, voxel_size, origin, level). + unit_convertor : UnitConvertor + An instance of the UnitConvertor class for unit conversions. + offset : tuple, optional + Offset to be applied to the coordinates. + store_precision : str, optional + The precision policy for storing data. + """ + # Set the unit convertor object + self.unit_convertor = unit_convertor + + # Process the multires geometry and extract coordinates and connectivity in the coordinate system of the finest level + coordinates, connectivity, level_id_field, total_cells = self.process_geometry(levels_data) + + # Ensure that coordinates and connectivity are not empty + assert coordinates.size != 0, "Error: No valid data to process. Check the input levels_data." + + # Merge duplicate points + coordinates, connectivity = self._merge_duplicates(coordinates, connectivity, levels_data) + + # Transform coordinates to physical units and apply offset if provided + coordinates = self._transform_coordinates(coordinates, offset) + + # Assign to self + self.field_name_cardinality_dict = field_name_cardinality_dict + self.levels_data = levels_data + self.coordinates = coordinates + self.connectivity = connectivity + self.level_id_field = level_id_field + self.total_cells = total_cells + self.centroids = np.mean(coordinates[connectivity], axis=1) + + # Set the default precision policy if not provided + from xlb import DefaultConfig + + if store_precision is None: + self.store_precision = DefaultConfig.default_precision_policy.store_precision + self.store_dtype = DefaultConfig.default_precision_policy.store_precision.wp_dtype + + # Prepare and allocate the inputs for the NEON container + self.field_warp_dict, self.origin_list = self._prepare_container_inputs() + + # Construct the NEON container for exporting multi-resolution data + self.container = self._construct_neon_container() + + def process_geometry(self, levels_data): + """Build merged coordinates and connectivity from all levels. + + Returns + ------- + coordinates : np.ndarray, shape (N, 3) + Vertex positions (8 per active voxel, before deduplication). + connectivity : np.ndarray, shape (M, 8) + Hexahedral connectivity (one row per active voxel). + level_id_field : np.ndarray, shape (M,) + Grid level index for each cell. + total_cells : int + Total number of active voxels across all levels. + """ + num_voxels_per_level = [np.sum(data) for data, _, _, _ in levels_data] + num_points_per_level = [8 * nv for nv in num_voxels_per_level] + point_id_offsets = np.cumsum([0] + num_points_per_level[:-1]) + + all_corners = [] + all_connectivity = [] + level_id_field = [] + total_cells = 0 + + for level_idx, (data, voxel_size, origin, level) in enumerate(levels_data): + origin = origin * voxel_size + corners_list, conn_list = self._process_level(data, voxel_size, origin, point_id_offsets[level_idx]) + + if corners_list: + print(f"\tProcessing level {level}: Voxel size {voxel_size}, Origin {origin}, Shape {data.shape}") + all_corners.extend(corners_list) + all_connectivity.extend(conn_list) + num_cells = sum(c.shape[0] for c in conn_list) + level_id_field.extend([level] * num_cells) + total_cells += num_cells + else: + print(f"\tSkipping level {level} (no unique data)") + + # Stacking coordinates and connectivity + coordinates = np.concatenate(all_corners, axis=0).astype(np.float32) + connectivity = np.concatenate(all_connectivity, axis=0).astype(np.int32) + level_id_field = np.array(level_id_field, dtype=np.uint8) + + return coordinates, connectivity, level_id_field, total_cells + + def _process_level(self, data, voxel_size, origin, point_id_offset): + """ + Given a voxel grid, returns all corners and connectivity in NumPy for this resolution level. + """ + true_indices = np.argwhere(data) + if true_indices.size == 0: + return [], [] + + max_voxels_per_chunk = 268_435_450 + chunks = np.array_split(true_indices, max(1, (len(true_indices) + max_voxels_per_chunk - 1) // max_voxels_per_chunk)) + + all_corners = [] + all_connectivity = [] + pid_offset = point_id_offset + + for chunk in chunks: + if chunk.size == 0: + continue + corners, connectivity = self._process_voxel_chunk(chunk, np.asarray(origin, dtype=np.float32), voxel_size, pid_offset) + all_corners.append(corners) + all_connectivity.append(connectivity) + pid_offset += len(chunk) * 8 + + return all_corners, all_connectivity + + def _process_voxel_chunk(self, true_indices, origin, voxel_size, point_id_offset): + """ + Given a set of voxel indices, returns 8 corners and connectivity for each cube using NumPy. + """ + true_indices = np.asarray(true_indices, dtype=np.float32) + mins = origin + true_indices * voxel_size + offsets = np.array( + [ + [0, 0, 0], + [1, 0, 0], + [1, 1, 0], + [0, 1, 0], + [0, 0, 1], + [1, 0, 1], + [1, 1, 1], + [0, 1, 1], + ], + dtype=np.float32, + ) + + corners = (mins[:, None, :] + offsets[None, :, :] * voxel_size).reshape(-1, 3).astype(np.float32) + base_ids = point_id_offset + np.arange(len(true_indices), dtype=np.int32) * 8 + connectivity = (base_ids[:, None] + np.arange(8, dtype=np.int32)).astype(np.int32) + + return corners, connectivity + + def save_xdmf(self, h5_filename, xmf_filename, total_cells, num_points, fields={}): + """Write an XDMF descriptor that references the companion HDF5 file.""" + # Generate an XDMF file to accompany the HDF5 file + print(f"\tGenerating XDMF file: {xmf_filename}") + hdf5_rel_path = h5_filename.split("/")[-1] + with open(xmf_filename, "w") as xmf: + xmf.write(f''' + + + + + + + {hdf5_rel_path}:/Mesh/Connectivity + + + + + {hdf5_rel_path}:/Mesh/Points + + + + + {hdf5_rel_path}:/Mesh/Level + + + ''') + for field_name in fields.keys(): + xmf.write(f''' + + + {hdf5_rel_path}:/Fields/{field_name} + + + ''') + xmf.write(""" + + + + """) + print("\tXDMF file written successfully") + return + + def save_hdf5_file(self, filename, coordinates, connectivity, level_id_field, fields_data, compression="gzip", compression_opts=0): + """Write the processed mesh data to an HDF5 file. + Parameters + ---------- + filename : str + The name of the output HDF5 file. + coordinates : numpy.ndarray + An array of all coordinates. + connectivity : numpy.ndarray + An array of all connectivity data. + level_id_field : numpy.ndarray + An array of all level data. + fields_data : dict + A dictionary of all field data. + compression : str, optional + The compression method to use for the HDF5 file. + compression_opts : int, optional + The compression options to use for the HDF5 file. + """ + import h5py + + with h5py.File(filename + ".h5", "w") as f: + f.create_dataset("/Mesh/Points", data=coordinates, compression=compression, compression_opts=compression_opts, chunks=True) + f.create_dataset( + "/Mesh/Connectivity", + data=connectivity, + compression=compression, + compression_opts=compression_opts, + chunks=True, + ) + f.create_dataset("/Mesh/Level", data=level_id_field, compression=compression, compression_opts=compression_opts) + fg = f.create_group("/Fields") + for fname, fdata in fields_data.items(): + fg.create_dataset(fname, data=fdata.astype(np.float32), compression=compression, compression_opts=compression_opts, chunks=True) + + def _merge_duplicates(self, coordinates, connectivity, levels_data): + """Deduplicate vertices shared between adjacent voxels. + + Uses spatial hashing (grid-snapped coordinates) processed in + chunks to keep memory bounded. + """ + # Merging duplicate points + tolerance = 0.01 + chunk_size = 10_000_000 # Adjust based on GPU memory + num_points = coordinates.shape[0] + unique_points = [] + mapping = np.zeros(num_points, dtype=np.int32) + unique_idx = 0 + + # Get the grid shape of computational box at the finest level from the levels_data + num_levels = len(levels_data) + grid_shape_finest = np.array(levels_data[-1][0].shape) * 2 ** (num_levels - 1) + + for start in range(0, num_points, chunk_size): + end = min(start + chunk_size, num_points) + coords_chunk = coordinates[start:end] + + # Simple hashing: grid coordinates as tuple keys + grid_coords = np.round(coords_chunk / tolerance).astype(np.int64) + hash_keys = grid_coords[:, 0] + grid_coords[:, 1] * grid_shape_finest[0] + grid_coords[:, 2] * grid_shape_finest[0] * grid_shape_finest[1] + unique_hash, inverse = np.unique(hash_keys, return_inverse=True) + unique_hash, unique_indices, inverse = np.unique(hash_keys, return_index=True, return_inverse=True) + unique_chunk = coords_chunk[unique_indices] + + unique_points.append(unique_chunk) + mapping[start:end] = inverse + unique_idx + unique_idx += len(unique_hash) + + coordinates = np.concatenate(unique_points) + connectivity = mapping[connectivity] + return coordinates, connectivity + + def _transform_coordinates(self, coordinates, offset): + """Convert lattice coordinates to physical units and apply offset.""" + offset = np.array(offset, dtype=np.float32) + if self.unit_convertor is not None: + coordinates = self.unit_convertor.length_to_physical(coordinates) + return coordinates + offset + + def _prepare_container_inputs(self): + """Allocate dense Warp fields used as staging buffers for Neon-to-NumPy transfer.""" + # load necessary modules + from xlb.compute_backend import ComputeBackend + from xlb.grid import grid_factory + + # Get the number of levels from the levels_data + num_levels = len(self.levels_data) + + # Prepare lists to hold warp fields and origins allocated for each level + field_warp_dict = {} + origin_list = [] + for field_name, cardinality in self.field_name_cardinality_dict.items(): + field_warp_dict[field_name] = [] + for level in range(num_levels): + # get the shape of the grid at this level + box_shape = self.levels_data[level][0].shape + + # Use the warp backend to create dense fields to be written in multi-res NEON fields + grid_dense = grid_factory(box_shape, compute_backend=ComputeBackend.WARP) + field_warp_dict[field_name].append(grid_dense.create_field(cardinality=cardinality, dtype=self.store_precision)) + origin_list.append(wp.vec3i(*([int(x) for x in self.levels_data[level][2]]))) + + return field_warp_dict, origin_list + + def _construct_neon_container(self): + """ + Constructs a NEON container for exporting multi-resolution data to HDF5. + This container will be used to transfer multi-resolution NEON fields into stacked warp fields. + """ + import neon + + @neon.Container.factory(name="HDF5MultiresExporter") + def container( + field_neon: Any, + field_warp: Any, + origin: Any, + level: Any, + ): + def launcher(loader: neon.Loader): + loader.set_mres_grid(field_neon.get_grid(), level) + field_neon_hdl = loader.get_mres_read_handle(field_neon) + refinement = 2**level + + @wp.func + def kernel(index: Any): + cIdx = wp.neon_global_idx(field_neon_hdl, index) + # Get local indices by dividing the global indices (associated with the finest level) by 2^level + # Subtract the origin to get the local indices in the warp field + lx = wp.neon_get_x(cIdx) // refinement - origin[0] + ly = wp.neon_get_y(cIdx) // refinement - origin[1] + lz = wp.neon_get_z(cIdx) // refinement - origin[2] + + # write the values to the warp field + cardinality = field_warp.shape[0] + for card in range(cardinality): + field_warp[card, lx, ly, lz] = self.store_dtype(wp.neon_read(field_neon_hdl, index, card)) + + loader.declare_kernel(kernel) + + return launcher + + return container + + def get_fields_data(self, field_neon_dict): + """ + Extracts and prepares the fields data from the NEON fields for export. + """ + # Check if the field_neon_dict is empty + if not field_neon_dict: + return {} + + # Ensure that this operator is called on multires grids + grid_mres = next(iter(field_neon_dict.values())).get_grid() + assert grid_mres.name == "mGrid", f"Operation {self.__class__.__name__} is only applicable to multi-resolution cases!" + + for field_name in field_neon_dict.keys(): + assert field_name in self.field_name_cardinality_dict.keys(), ( + f"Field {field_name} is not provided in the instantiation of the MultiresIO class!" + ) + + # number of levels + num_levels = grid_mres.num_levels + assert num_levels == len(self.levels_data), "Error: Inconsistent number of levels!" + + # Prepare the fields dictionary to be written by transfering multi-res NEON fields into stacked warp fields and then numpy arrays + fields_data = {} + for field_name, cardinality in self.field_name_cardinality_dict.items(): + if field_name not in field_neon_dict: + continue + for card in range(cardinality): + fields_data[f"{field_name}_{card}"] = [] + + # Iterate over each field and level to fill the dictionary with numpy fields + for field_name, cardinality in self.field_name_cardinality_dict.items(): + if field_name not in field_neon_dict: + continue + for level in range(num_levels): + # Create the container and run it to fill the warp fields + c = self.container(field_neon_dict[field_name], self.field_warp_dict[field_name][level], self.origin_list[level], level) + c.run(0, container_runtime=neon.Container.ContainerRuntime.neon) + + # Ensure all operations are complete before converting to JAX and Numpy arrays + wp.synchronize() + + # Convert the warp fields to numpy arrays and use level's mask to filter the data + mask = self.levels_data[level][0] + field_np = self.field_warp_dict[field_name][level].numpy() + for card in range(cardinality): + field_np_card = field_np[card][mask] + fields_data[f"{field_name}_{card}"].append(field_np_card) + + # Concatenate all field data + for field_name in fields_data.keys(): + fields_data[field_name] = np.concatenate(fields_data[field_name]) + assert fields_data[field_name].size == self.total_cells, f"Error: Field {field_name} size mismatch!" + + # Unit conversion if applicable + if self.unit_convertor is not None: + if "velocity" in field_name.lower(): + fields_data[field_name] = self.unit_convertor.velocity_to_physical(fields_data[field_name]) + elif "density" in field_name.lower(): + fields_data[field_name] = self.unit_convertor.density_to_physical(fields_data[field_name]) + elif "pressure" in field_name.lower(): + fields_data[field_name] = self.unit_convertor.pressure_to_physical(fields_data[field_name]) + # Add more physical quantities as needed + + return fields_data + + def to_hdf5(self, output_filename, field_neon_dict, compression="gzip", compression_opts=0): + """ + Export the multi-resolution mesh data to an HDF5 file. + Parameters + ---------- + output_filename : str + The name of the output HDF5 file (without extension). + field_neon_dict : a dictionary of neon mGrid Fields + Eg. The NEON fields containing velocity and density data as { "velocity": velocity_neon, "density": density_neon} + compression : str, optional + The compression method to use for the HDF5 file. + compression_opts : int, optional + The compression options to use for the HDF5 file. + """ + import time + + # Get the fields data from the NEON fields + fields_data = self.get_fields_data(field_neon_dict) + + # Save XDMF file + self.save_xdmf(output_filename + ".h5", output_filename + ".xmf", self.total_cells, len(self.coordinates), fields_data) + + # Writing HDF5 file + print("\tWriting HDF5 file") + tic_write = time.perf_counter() + self.save_hdf5_file(output_filename, self.coordinates, self.connectivity, self.level_id_field, fields_data, compression, compression_opts) + toc_write = time.perf_counter() + print(f"\tHDF5 file written in {toc_write - tic_write:0.1f} seconds") + + def to_slice_image( + self, + output_filename, + field_neon_dict, + plane_point, + plane_normal, + slice_thickness=1.0, + bounds=[0, 1, 0, 1], + grid_res=512, + cmap=None, + component=None, + show_axes=False, + show_colorbar=False, + **kwargs, + ): + """ + Export an arbitrary-plane slice from unstructured point data to PNG. + + Parameters + ---------- + output_filename : str + Output PNG filename (without extension). + field_neon_dict : dict + A dictionary of NEON fields containing the data to be plotted. + Example: {"velocity": velocity_neon, "density": density_neon} + plane_point : array_like + A point [x, y, z] on the plane. + plane_normal : array_like + Plane normal vector [nx, ny, nz]. + slice_thickness : float + How thick (in units of the coordinate system) the slice should be. + grid_resolution : tuple + Resolution of output image (pixels in plane u, v directions). + grid_size : tuple + Physical size of slice grid (width, height). + cmap : str + Matplotlib colormap. + """ + # Get the fields data from the NEON fields + assert len(field_neon_dict.keys()) == 1, "Error: This function is designed to plot a single field at a time." + fields_data = self.get_fields_data(field_neon_dict) + + # Check if the component is within the valid range + if component is None: + print("\tCreating slice image of the field magnitude!") + cell_data = list(fields_data.values()) + squared = [comp**2 for comp in cell_data] + cell_data = np.sqrt(sum(squared)) + field_name = list(fields_data.keys())[0].split("_")[0] + "_magnitude" + else: + assert component < max(self.field_name_cardinality_dict.values()), ( + f"Error: Component {component} is out of range for the provided fields." + ) + print(f"\tCreating slice image for component {component} of the input field!") + field_name = list(fields_data.keys())[component] + cell_data = fields_data[field_name] + + # Plot each field in the dictionary + self._to_slice_image_single_field( + f"{output_filename}_{field_name}", + cell_data, + plane_point, + plane_normal, + slice_thickness=slice_thickness, + bounds=bounds, + grid_res=grid_res, + cmap=cmap, + show_axes=show_axes, + show_colorbar=show_colorbar, + **kwargs, + ) + print(f"\tSlice image for field {field_name} saved as {output_filename}.png") + + def _to_slice_image_single_field( + self, + output_filename, + field_data, + plane_point, + plane_normal, + slice_thickness, + bounds, + grid_res, + cmap, + show_axes, + show_colorbar, + **kwargs, + ): + """ + Helper function to create a slice image for a single field. + """ + from matplotlib import cm + import numpy as np + import matplotlib.pyplot as plt + from scipy.spatial import cKDTree + + # field data are associated with the cells centers + cell_values = field_data + + # get the normalized plane normal + plane_normal = np.asarray(np.abs(plane_normal)) + n = plane_normal / np.linalg.norm(plane_normal) + + # Compute signed distances of each cell center to the plane + plane_point *= plane_normal + sdf = np.dot(self.centroids - plane_point, n) + + # Filter: cells with centroid near plane + mask = np.abs(sdf) <= slice_thickness / 2 + if not np.any(mask): + raise ValueError("No cells intersect the plane within thickness.") + + # Project centroids to plane + centroids_slice = self.centroids[mask] + sdf_slice = sdf[mask] + proj = centroids_slice - np.outer(sdf_slice, n) + + values = cell_values[mask] + + # Build in-plane basis + if np.allclose(n, [1, 0, 0]): + u1 = np.array([0, 1, 0]) + else: + u1 = np.array([1, 0, 0]) + u2 = np.abs(np.cross(n, u1)) + + local_x = np.dot(proj - plane_point, u1) + local_y = np.dot(proj - plane_point, u2) + + # Define extent of the plot + xmin, xmax, ymin, ymax = local_x.min(), local_x.max(), local_y.min(), local_y.max() + Lx = xmax - xmin + Ly = ymax - ymin + extent = np.array([xmin + bounds[0] * Lx, xmin + bounds[1] * Lx, ymin + bounds[2] * Ly, ymin + bounds[3] * Ly]) + mask_bounds = (extent[0] <= local_x) & (local_x <= extent[1]) & (extent[2] <= local_y) & (local_y <= extent[3]) + + if cmap is None: + cmap = cm.nipy_spectral + + # Adjust vertical resolution based on bounds + bounded_x_min = local_x[mask_bounds].min() + bounded_x_max = local_x[mask_bounds].max() + bounded_y_min = local_y[mask_bounds].min() + bounded_y_max = local_y[mask_bounds].max() + width_x = bounded_x_max - bounded_x_min + height_y = bounded_y_max - bounded_y_min + aspect_ratio = height_y / width_x + grid_resY = max(1, int(np.round(grid_res * aspect_ratio))) + + # Create grid + grid_x = np.linspace(bounded_x_min, bounded_x_max, grid_res) + grid_y = np.linspace(bounded_y_min, bounded_y_max, grid_resY) + xv, yv = np.meshgrid(grid_x, grid_y, indexing="xy") + + # Fast KDTree-based interpolation + points = np.column_stack((local_x[mask_bounds], local_y[mask_bounds])) + tree = cKDTree(points) + + # Query points + query_points = np.column_stack((xv.ravel(), yv.ravel())) + + # Find k nearest neighbors for smoother interpolation + k = min(4, len(points)) # Use 4 neighbors or less if not enough points + distances, indices = tree.query(query_points, k=k, workers=-1) # -1 uses all cores + + # Inverse distance weighting + epsilon = 1e-10 + weights = 1.0 / (distances + epsilon) + weights /= weights.sum(axis=1, keepdims=True) + + # Interpolate values + neighbor_values = values[mask_bounds][indices] + grid_field = (neighbor_values * weights).sum(axis=1).reshape(grid_resY, grid_res) + + # Plot + if show_colorbar or show_axes: + dpi = 300 + plt.imshow( + grid_field, + extent=[bounded_x_min, bounded_x_max, bounded_y_min, bounded_y_max], + cmap=cmap, + origin="lower", + aspect="equal", + **kwargs, + ) + if show_colorbar: + plt.colorbar() + if not show_axes: + plt.axis("off") + plt.savefig(output_filename + ".png", dpi=dpi, bbox_inches="tight", pad_inches=0) + plt.close() + else: + plt.imsave(output_filename + ".png", grid_field, cmap=cmap, origin="lower") + + def to_line( + self, + output_filename, + field_neon_dict, + start_point, + end_point, + resolution, + component=None, + radius=1.0, + **kwargs, + ): + """ + Extract field data along a line between start_point and end_point and save to a CSV file. + + This function performs two main steps: + 1. Extracts field data from field_neon_dict, handling components or computing magnitude. + 2. Interpolates the field values along a line defined by start_point and end_point, + then saves the results (coordinates and field values) to a CSV file. + + Parameters + ---------- + output_filename : str + The name of the output CSV file (without extension). Example: "velocity_profile". + field_neon_dict : dict + A dictionary containing the field data to extract, with a single key-value pair. + The key is the field name (e.g., "velocity"), and the value is the NEON data object + containing the field values. Example: {"velocity": velocity_neon}. + start_point : array_like + The starting point of the line in 3D space (e.g., [x0, y0, z0]). + Units must match the coordinate system used in the class (voxel units if untransformed, + or model units if scale/offset are applied). + end_point : array_like + The ending point of the line in 3D space (e.g., [x1, y1, z1]). + Units must match the coordinate system used in the class. + resolution : int + The number of points along the line where the field will be interpolated. + Example: 100 for 100 evenly spaced points. + component : int, optional + The specific component of the field to extract (e.g., 0 for x-component, 1 for y-component). + If None, the magnitude of the field is computed. Default is None. + radius : int + The specified distance (in units of the coordinate system) to prefilter and query for line plot + + Returns + ------- + None + The function writes the output to a CSV file and prints a confirmation message. + + Notes + ----- + - The output CSV file will contain columns: 'x', 'y', 'z', and the value of the field name (e.g., 'velocity_x' or 'velocity_magnitude'). + """ + + # Get the fields data from the NEON fields + assert len(field_neon_dict.keys()) == 1, "Error: This function is designed to plot a single field at a time." + fields_data = self.get_fields_data(field_neon_dict) + + # Check if the component is within the valid range + if component is None: + print("\tCreating csv plot of the field magnitude!") + cell_data = list(fields_data.values()) + squared = [comp**2 for comp in cell_data] + cell_data = np.sqrt(sum(squared)) + field_name = list(fields_data.keys())[0].split("_")[0] + "_magnitude" + + else: + assert component < max(self.field_name_cardinality_dict.values()), ( + f"Error: Component {component} is out of range for the provided fields." + ) + print(f"\tCreating csv plot for component {component} of the input field!") + field_name = list(fields_data.keys())[component] + cell_data = fields_data[field_name] + + # Plot each field in the dictionary + self._to_line_field( + f"{output_filename}_{field_name}", + cell_data, + start_point, + end_point, + resolution, + radius=radius, + **kwargs, + ) + print(f"\tLine Plot for field {field_name} saved as {output_filename}.csv") + + def _to_line_field( + self, + output_filename, + cell_data, + start_point, + end_point, + resolution, + radius, + **kwargs, + ): + """ + Helper function to create a line plot for a single field. + """ + import numpy as np + + # cell_points = self.coordinates[self.connectivity] # Shape: (M, K, 3), where M is num cells, K is nodes per cell + # centroids = np.mean(cell_points, axis=1) # Shape: (M, 3) + centroids = self.centroids + p0 = np.array(start_point, dtype=np.float32) + p1 = np.array(end_point, dtype=np.float32) + + # direction and parameter t for each centroid + d = p1 - p0 + L = np.linalg.norm(d) + d_unit = d / L + v = centroids - p0 + t = v.dot(d_unit) + closest = p0 + np.outer(t, d_unit) + perp_dist = np.linalg.norm(centroids - closest, axis=1) + + # optionally mask to [0,L] or a small perp-radius + mask = (t >= 0) & (t <= L) & (perp_dist <= radius) + t, data = t[mask], cell_data[mask] + + # sort by t + idx = np.argsort(t) + t_sorted = t[idx] + data_sorted = data[idx] + + # target samples + t_line = np.linspace(0, L, resolution) + + # 1D linear interpolation + vals_line = np.interp(t_line, t_sorted, data_sorted, left=np.nan, right=np.nan) + + # reconstruct (x,y,z) + line_xyz = p0[None, :] + t_line[:, None] * d_unit[None, :] + + # vectorized CSV dump + out = np.hstack([line_xyz, vals_line[:, None]]) + np.savetxt(output_filename + ".csv", out, delimiter=",", header="x,y,z,value", comments="") diff --git a/xlb/utils/utils.py b/xlb/utils/utils.py index a42dea98..d707fb5c 100644 --- a/xlb/utils/utils.py +++ b/xlb/utils/utils.py @@ -1,3 +1,11 @@ +""" +General-purpose utilities for XLB. + +Includes helpers for field downsampling, VTK/image/USD I/O, geometry +rotation, STL voxelization, Neon-to-JAX field transfer, and +physical-to-lattice unit conversion. +""" + import numpy as np import matplotlib.pylab as plt from matplotlib import cm @@ -83,7 +91,7 @@ def save_image(fld, timestep=None, prefix=None, **kwargs): if len(fld.shape) > 3: raise ValueError("The input field should be 2D!") if len(fld.shape) == 3: - fld = np.sqrt(fld[0, ...] ** 2 + fld[0, ...] ** 2) + fld = np.sqrt(fld[0, ...] ** 2 + fld[1, ...] ** 2 + fld[2, ...] ** 2) plt.clf() kwargs.pop("cmap", None) @@ -237,7 +245,7 @@ def rotate_geometry(indices, origin, axis, angle): return tuple(jnp.rint(indices_rotated).astype("int32").T) -def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None, pitch=None): +def voxelize_stl(stl_filename, length_lbm_unit=None, transformation_matrix=None, pitch=None): """ Converts an STL file to a voxelized mesh. @@ -247,7 +255,7 @@ def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None, The name of the STL file to be voxelized. length_lbm_unit : float, optional The unit length in LBM. Either this or 'pitch' must be provided. - tranformation_matrix : array-like, optional + transformation_matrix : array-like, optional A transformation matrix to be applied to the mesh before voxelization. pitch : float, optional The pitch of the voxel grid. Either this or 'length_lbm_unit' must be provided. @@ -267,8 +275,8 @@ def voxelize_stl(stl_filename, length_lbm_unit=None, tranformation_matrix=None, raise ValueError("Either 'length_lbm_unit' or 'pitch' must be provided!") mesh = trimesh.load_mesh(stl_filename, process=False) length_phys_unit = mesh.extents.max() - if tranformation_matrix is not None: - mesh.apply_transform(tranformation_matrix) + if transformation_matrix is not None: + mesh.apply_transform(transformation_matrix) if pitch is None: pitch = length_phys_unit / length_lbm_unit mesh_voxelized = mesh.voxelized(pitch=pitch) @@ -319,6 +327,190 @@ def axangle2mat(axis, angle, is_normalized=False): ]) +class ToJAX(object): + """Convert a Neon field to a JAX array via an intermediate Warp grid.""" + + def __init__(self, field_name, field_cardinality, grid_shape, store_precision=None): + """Initialise the Neon-to-JAX converter. + + Parameters + ---------- + field_name : str + The name of the field to be converted. + field_cardinality : int + The cardinality of the field to be converted. + grid_shape : tuple + The shape of the grid on which the field is defined. + store_precision : Precision, optional + Storage precision. Defaults to the global config value. + """ + from xlb.compute_backend import ComputeBackend + from xlb.grid import grid_factory + from xlb import DefaultConfig + + # Assign to self + self.field_name = field_name + self.field_cardinality = field_cardinality + self.grid_shape = grid_shape + self.compute_backend = DefaultConfig.default_backend + self.velocity_set = DefaultConfig.velocity_set + if store_precision is None: + self.store_precision = DefaultConfig.default_precision_policy.store_precision + self.store_dtype = DefaultConfig.default_precision_policy.store_precision.wp_dtype + + if self.compute_backend == ComputeBackend.NEON: + # Allocate warp fields for copying neon fields + # Use the warp backend to create dense fields for copying NEON dGrid fields + grid_dense = grid_factory(grid_shape, compute_backend=ComputeBackend.WARP) + self.warp_field = grid_dense.create_field(cardinality=self.field_cardinality, dtype=self.store_precision) + + def copy_neon_to_warp(self, neon_field): + """Convert a dense neon field to a warp field by copying.""" + import warp as wp + import neon + from typing import Any + + assert neon_field.get_grid().name == "dGrid", "to_warp only supports dense grids" + _d = self.velocity_set.d + + @neon.Container.factory("to_warp") + def container(src_field: Any, dst_field: Any, cardinality: wp.int32): + def loading_step(loader: neon.Loader): + loader.set_grid(src_field.get_grid()) + src_pn = loader.get_read_handle(src_field) + + @wp.func + def cloning(gridIdx: Any): + cIdx = wp.neon_global_idx(src_pn, gridIdx) + gx = wp.neon_get_x(cIdx) + gy = wp.neon_get_y(cIdx) + gz = wp.neon_get_z(cIdx) + + # XLB is flattening the z dimension in 3D, while neon uses the y dimension + if _d == 2: + gy, gz = gz, gy + + for card in range(cardinality): + value = wp.neon_read(src_pn, gridIdx, card) + dst_field[card, gx, gy, gz] = value + + loader.declare_kernel(cloning) + + return loading_step + + cardinality = neon_field.cardinality + c = container(neon_field, self.warp_field, cardinality) + c.run(0) + wp.synchronize() + return self.warp_field + + def __call__(self, field): + from xlb.compute_backend import ComputeBackend + import warp as wp + + if self.compute_backend == ComputeBackend.JAX: + return field + elif self.compute_backend == ComputeBackend.WARP: + return wp.to_jax(field) + elif self.compute_backend == ComputeBackend.NEON: + assert field.cardinality == self.field_cardinality, ( + f"Field cardinality mismatch! Expected {self.field_cardinality}, got {field.cardinality}!" + ) + return wp.to_jax(self.copy_neon_to_warp(field)) + + else: + raise ValueError("Unsupported compute backend!") + + +class UnitConvertor(object): + def __init__( + self, + velocity_lbm_unit: float, + velocity_physical_unit: float, + voxel_size_physical_unit: float, + density_physical_unit: float = 1.2041, + pressure_physical_unit: float = 1.101325e5, + ): + """ + Initialize the UnitConvertor object. + + Parameters + ---------- + velocity_lbm_unit : float + The reference velocity in lattice Boltzmann units. + velocity_physical_unit : float + The reference velocity in physical units (e.g., m/s). + voxel_size_physical_unit : float + The size of a voxel in physical units (e.g., meters). + density_physical_unit : float, optional + The reference density in physical units (e.g., kg/m^3). Default is 1.2041 (density of air at room temperature). + pressure_physical_unit : float, optional + The reference pressure in physical units (e.g., Pascals). Default is 1.101325e5 (atmospheric pressure at sea level). + """ + + self.voxel_size = voxel_size_physical_unit + self.velocity_lbm_unit = velocity_lbm_unit + self.velocity_phys_unit = velocity_physical_unit + + # Reference density and pressure in physical units + self.reference_density = density_physical_unit + self.referece_pressure = pressure_physical_unit + + @property + def time_step_physical(self): + return self.voxel_size * self.velocity_lbm_unit / self.velocity_phys_unit + + @property + def reference_length(self): + return self.voxel_size + + @property + def reference_time(self): + return self.time_step_physical + + @property + def reference_velocity(self): + return self.reference_length / self.reference_time + + def length_to_lbm(self, length_phys): + return length_phys / self.reference_length + + def length_to_physical(self, length_lbm): + return length_lbm * self.reference_length + + def time_to_lbm(self, time_phys): + return time_phys / self.reference_time + + def time_to_physical(self, time_lbm): + return time_lbm * self.reference_time + + def density_to_lbm(self, rho_phys): + return rho_phys / self.reference_density + + def density_to_physical(self, rho_lbm): + return rho_lbm * self.reference_density + + def velocity_to_lbm(self, velocity_phys): + return velocity_phys / self.reference_velocity + + def velocity_to_physical(self, velocity_lbm): + return velocity_lbm * self.reference_velocity + + def viscosity_to_lbm(self, viscosity_phys): + return viscosity_phys * (self.reference_time / (self.reference_length**2)) + + def viscosity_to_physical(self, viscosity_lbm): + return viscosity_lbm * (self.reference_length**2 / self.reference_time) + + def pressure_to_lbm(self, pressure_phys): + pressure_perturbation = pressure_phys - self.reference_pressure + return pressure_perturbation / self.reference_density / self.reference_velocity**2 + + def pressure_to_physical(self, pressure_lbm): + pressure_perturbation = pressure_lbm - 1.0 / 3.0 + return self.referece_pressure + pressure_perturbation * self.reference_density * (self.reference_velocity**2) + + @wp.kernel def get_color( low: float, diff --git a/xlb/velocity_set/velocity_set.py b/xlb/velocity_set/velocity_set.py index 8b8d3213..6f9634b2 100644 --- a/xlb/velocity_set/velocity_set.py +++ b/xlb/velocity_set/velocity_set.py @@ -1,4 +1,10 @@ -# Base Velocity Set class +""" +Base velocity-set class for the Lattice Boltzmann Method. + +Defines lattice directions, weights, and derived properties (opposite +indices, moments, etc.) for any DdQq stencil. Backend-specific constants +(Warp vectors, JAX arrays, Neon lattice objects) are initialised lazily. +""" import math import numpy as np @@ -44,6 +50,8 @@ def __init__(self, d, q, c, w, precision_policy, compute_backend): # Convert properties to backend-specific format if self.compute_backend == ComputeBackend.WARP: self._init_warp_properties() + elif self.compute_backend == ComputeBackend.NEON: + self._init_neon_properties() elif self.compute_backend == ComputeBackend.JAX: self._init_jax_properties() else: @@ -72,6 +80,7 @@ def _init_numpy_properties(self, c, w): self.main_indices = self._construct_main_indices() self.right_indices = self._construct_right_indices() self.left_indices = self._construct_left_indices() + self.center_index = self._get_center_index() def _init_warp_properties(self): """ @@ -85,6 +94,12 @@ def _init_warp_properties(self): self.c_float = wp.constant(wp.mat((self.d, self.q), dtype=dtype)(self._c_float)) self.qi = wp.constant(wp.mat((self.q, self.d * (self.d + 1) // 2), dtype=dtype)(self._qi)) + def _init_neon_properties(self): + """ + Convert NumPy properties to Neon-specific properties which are identical to Warp. + """ + self._init_warp_properties() + def _init_jax_properties(self): """ Convert NumPy properties to JAX-specific properties. @@ -220,6 +235,23 @@ def _construct_left_indices(self): """ return np.nonzero(self._c.T[:, 0] == -1)[0] + def _get_center_index(self): + """ + This function returns the index of the center point in the lattice associated with (0,0,0) + + Returns + ------- + numpy.ndarray + The index of the zero lattice velocity. + """ + arr = self._c.T + if self.d == 2: + target = np.array([0, 0]) + else: + target = np.array([0, 0, 0]) + match = np.all(arr == target, axis=1) + return int(np.nonzero(match)[0][0]) + def __str__(self): """ This function returns the name of the lattice in the format of DxQy.