Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/examples/circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def plot_head_ugrid(head, cbc, workspace):
)

# Build the xugrid Ugrid2d mesh from the Disv package.
# `disv.to_grid()` returns a `VertexGrid`; `.ugrid` converts it to an
# `xu.Ugrid2d` object suitable for xugrid operations.
grid = disv.to_grid().ugrid
# `disv.to_grid()` returns a `VertexGrid`; `.to_xarray()` returns an
# `xu.UgridDataset`; `.grids[0]` extracts the `xu.Ugrid2d` mesh object.
grid = disv.to_grid().to_xarray().grids[0]

# `dims` captures array shapes needed by packages that pre-allocate xarray storage.
dims = {"nper": nper, "nlay": nlay, "ncpl": ncpl, "nvert": len(vertices), "nodes": nlay * ncpl}
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/frenchman-flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def plot_head_ugrid(head, cbc, grid, workspace):
import xarray as xr
import xugrid as xu

ugrid = grid.ugrid
ugrid = grid.to_xarray(netcdf_format=NetCDFFormat.LAYERED_MESH).grids[0]
facedim = ugrid.face_dimension

# Select first timestep and first layer; flatten (y, x) -> face dimension
Expand Down Expand Up @@ -113,7 +113,7 @@ def plot_head_ugrid(head, cbc, grid, workspace):
ds = ds.ugrid.assign_face_coords()

fig, ax = plt.subplots(figsize=(10, 8))
ds.plot.quiver(x="mesh2d_face_x", y="mesh2d_face_y", u="u", v="v", color="black", scale=100)
ds.plot.quiver(x="mesh_face_x", y="mesh_face_y", u="u", v="v", color="black", scale=100)
xu.plot.line(ugrid, ax=ax, color="black", linewidth=0.2)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.set_title("Frenchman Flat flow vectors overlaid on Mesh")
Expand Down
45 changes: 38 additions & 7 deletions flopy4/mf6/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import xarray as xr
import xugrid as xu
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -102,7 +103,7 @@ def from_dict(cls, meta: dict, context: dict | None):
pass

@abc.abstractmethod
def to_xarray(self) -> xr.Dataset:
def to_xarray(self) -> "xr.Dataset | xu.UgridDataset":
"""create xarray dataset."""
pass

Expand Down Expand Up @@ -228,11 +229,11 @@ def from_model(
nc_model.time = time
return nc_model

def to_xarray(self) -> xr.Dataset:
def to_xarray(self) -> "xr.Dataset | xu.UgridDataset":
import datetime

dss = []
meta = self.model_dump(by_alias=True)
grid_result = None

if self._grid is not None and self._time is not None: # type: ignore
conventions = "CF-1.11" # type: ignore
Expand All @@ -243,14 +244,20 @@ def to_xarray(self) -> xr.Dataset:
if meta["attrs"]["mesh"] is not None
else NetCDFFormat.STRUCTURED
)
dss.append(self._grid.to_xarray(netcdf_format=_fmt, modeltime=self._time))
grid_result = self._grid.to_xarray(netcdf_format=_fmt, modeltime=self._time)
meta["attrs"]["Conventions"] = conventions

pkg_dss = []
for p in self.packages:
p._context["grid"] = self.grid
dss.append(p.to_xarray())
pkg_dss.append(p.to_xarray())

ds = xr.merge(dss)
if isinstance(grid_result, xu.UgridDataset):
merged = xr.merge([grid_result.obj, *pkg_dss])
ds = xu.UgridDataset(merged, grids=grid_result.grids)
else:
all_dss = ([grid_result] if grid_result is not None else []) + pkg_dss
ds = xr.merge(all_dss) if all_dss else xr.Dataset()

dt = datetime.datetime.now()
timestamp = dt.strftime("%m/%d/%Y %H:%M:%S")
Expand All @@ -263,7 +270,31 @@ def to_xarray(self) -> xr.Dataset:
return ds

def to_netcdf(self, path: str | PathLike) -> None:
self.to_xarray().to_netcdf(path)
result = self.to_xarray()
if isinstance(result, xu.UgridDataset):
grid = result.grids[0]
topo = grid.assign_face_coords(grid.to_dataset())
# Rename xugrid's generated dimension names to MF6's UGRID convention
# so input files are consistent with MODFLOW 6 output files.
_dim_rename = {
k: v
for k, v in {
"mesh_nFaces": "nmesh_face",
"mesh_nNodes": "nmesh_node",
"mesh_nMax_face_nodes": "max_nmesh_face_nodes",
}.items()
if k in topo.dims
}
if _dim_rename:
topo = topo.rename(_dim_rename)
for attr in ("face_dimension", "node_dimension", "max_face_nodes_dimension"):
if topo["mesh"].attrs.get(attr) in _dim_rename:
topo["mesh"].attrs[attr] = _dim_rename[topo["mesh"].attrs[attr]]
merged = topo.merge(result.obj)
merged.attrs = result.obj.attrs
merged.to_netcdf(path)
else:
result.to_netcdf(path)

@property
def meta(self):
Expand Down
3 changes: 1 addition & 2 deletions flopy4/mf6/utils/codegen/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,7 @@ def _expand_record_field(
spec_call="",
generatable=False,
skip_reason=(
f"positional sub-fields not yet supported: "
f"{', '.join(unexpandable_optional)}"
f"positional sub-fields not yet supported: {', '.join(unexpandable_optional)}"
),
)
)
Expand Down
Loading
Loading