diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 74d36190..c5da8f8d 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -213,8 +213,12 @@ def open_geotiff(source: str, *, window=None, # Adjust coordinates for windowed read r0, c0, r1, c1 = window t = geo_info.transform - full_x = np.arange(c0, c1, dtype=np.float64) * t.pixel_width + t.origin_x + t.pixel_width * 0.5 - full_y = np.arange(r0, r1, dtype=np.float64) * t.pixel_height + t.origin_y + t.pixel_height * 0.5 + if geo_info.raster_type == RASTER_PIXEL_IS_POINT: + full_x = np.arange(c0, c1, dtype=np.float64) * t.pixel_width + t.origin_x + full_y = np.arange(r0, r1, dtype=np.float64) * t.pixel_height + t.origin_y + else: + full_x = np.arange(c0, c1, dtype=np.float64) * t.pixel_width + t.origin_x + t.pixel_width * 0.5 + full_y = np.arange(r0, r1, dtype=np.float64) * t.pixel_height + t.origin_y + t.pixel_height * 0.5 coords = {'y': full_y, 'x': full_x} if name is None: @@ -402,6 +406,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, geo_transform = None epsg = None + wkt_fallback = None # WKT string when EPSG is not available raster_type = RASTER_PIXEL_IS_AREA x_res = None y_res = None @@ -414,6 +419,8 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, epsg = crs elif isinstance(crs, str): epsg = _wkt_to_epsg(crs) # try to extract EPSG from WKT/PROJ + if epsg is None: + wkt_fallback = crs if isinstance(data, xr.DataArray): # Handle CuPy-backed DataArrays: convert to numpy for CPU write @@ -436,12 +443,16 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, if isinstance(crs_attr, str): # WKT string from reproject() or other source epsg = _wkt_to_epsg(crs_attr) + if epsg is None and wkt_fallback is None: + wkt_fallback = crs_attr elif crs_attr is not None: epsg = int(crs_attr) if epsg is None: wkt = data.attrs.get('crs_wkt') if isinstance(wkt, str): epsg = _wkt_to_epsg(wkt) + if epsg is None and wkt_fallback is None: + wkt_fallback = wkt if nodata is None: nodata = data.attrs.get('nodata') if data.attrs.get('raster_type') == 'point': @@ -477,10 +488,19 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path: str, *, elif arr.dtype == np.bool_: arr = arr.astype(np.uint8) + # Restore NaN pixels to the nodata sentinel value so the written file + # has sentinel values matching the GDAL_NODATA tag. + if nodata is not None and arr.dtype.kind == 'f' and not np.isnan(nodata): + nan_mask = np.isnan(arr) + if nan_mask.any(): + arr = arr.copy() + arr[nan_mask] = arr.dtype.type(nodata) + write( arr, path, geo_transform=geo_transform, crs_epsg=epsg, + crs_wkt=wkt_fallback if epsg is None else None, nodata=nodata, compression=compression, tiled=tiled, diff --git a/xrspatial/geotiff/_geotags.py b/xrspatial/geotiff/_geotags.py index d3352819..e7f95c82 100644 --- a/xrspatial/geotiff/_geotags.py +++ b/xrspatial/geotiff/_geotags.py @@ -522,9 +522,18 @@ def extract_geo_info(ifd: IFD, data: bytes | memoryview, ) +def _model_type_from_wkt(wkt: str) -> int: + """Guess ModelType from a WKT string prefix.""" + upper = wkt.strip().upper() + if upper.startswith(('GEOGCS', 'GEOGCRS')): + return MODEL_TYPE_GEOGRAPHIC + return MODEL_TYPE_PROJECTED + + def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None, nodata=None, - raster_type: int = RASTER_PIXEL_IS_AREA) -> dict[int, tuple]: + raster_type: int = RASTER_PIXEL_IS_AREA, + crs_wkt: str | None = None) -> dict[int, tuple]: """Build GeoTIFF IFD tag entries for writing. Parameters @@ -537,6 +546,11 @@ def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None, NoData value. raster_type : int RASTER_PIXEL_IS_AREA (1) or RASTER_PIXEL_IS_POINT (2). + crs_wkt : str or None + WKT or PROJ string for the CRS. Used only when *crs_epsg* is + None so that custom (non-EPSG) coordinate systems survive + round-trips. Stored in the GeoAsciiParamsTag and referenced + from GTCitationGeoKey. Returns ------- @@ -562,6 +576,10 @@ def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None, num_keys = 1 # at least RasterType key_entries = [] + # Collect ASCII params strings (pipe-delimited in GeoAsciiParamsTag) + ascii_parts = [] + ascii_offset = 0 + # ModelType if crs_epsg is not None: # Guess model type from EPSG (simple heuristic) @@ -571,6 +589,10 @@ def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None, model_type = MODEL_TYPE_PROJECTED key_entries.append((GEOKEY_MODEL_TYPE, 0, 1, model_type)) num_keys += 1 + elif crs_wkt is not None: + model_type = _model_type_from_wkt(crs_wkt) + key_entries.append((GEOKEY_MODEL_TYPE, 0, 1, model_type)) + num_keys += 1 # RasterType key_entries.append((GEOKEY_RASTER_TYPE, 0, 1, raster_type)) @@ -582,6 +604,22 @@ def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None, else: key_entries.append((GEOKEY_PROJECTED_CS_TYPE, 0, 1, crs_epsg)) num_keys += 1 + elif crs_wkt is not None: + # User-defined CRS: store 32767 and write WKT to GeoAsciiParams + if model_type == MODEL_TYPE_GEOGRAPHIC: + key_entries.append((GEOKEY_GEOGRAPHIC_TYPE, 0, 1, 32767)) + else: + key_entries.append((GEOKEY_PROJECTED_CS_TYPE, 0, 1, 32767)) + num_keys += 1 + # GTCitationGeoKey -> GeoAsciiParams + wkt_with_pipe = crs_wkt + '|' + key_entries.append(( + GEOKEY_CITATION, TAG_GEO_ASCII_PARAMS, + len(wkt_with_pipe), ascii_offset, + )) + ascii_parts.append(wkt_with_pipe) + ascii_offset += len(wkt_with_pipe) + num_keys += 1 num_keys = len(key_entries) header = [1, 1, 0, num_keys] @@ -591,6 +629,10 @@ def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None, tags[TAG_GEO_KEY_DIRECTORY] = tuple(flat) + # GeoAsciiParamsTag (34737) + if ascii_parts: + tags[TAG_GEO_ASCII_PARAMS] = ''.join(ascii_parts) + # GDAL_NODATA if nodata is not None: tags[TAG_GDAL_NODATA] = str(nodata) diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py index 6768c6e6..eabe8695 100644 --- a/xrspatial/geotiff/_writer.py +++ b/xrspatial/geotiff/_writer.py @@ -32,6 +32,7 @@ from ._geotags import ( GeoTransform, build_geo_tags, + TAG_GEO_ASCII_PARAMS, TAG_GEO_KEY_DIRECTORY, TAG_GDAL_NODATA, TAG_MODEL_PIXEL_SCALE, @@ -525,6 +526,7 @@ def _assemble_tiff(width: int, height: int, dtype: np.dtype, nodata, is_cog: bool = False, raster_type: int = 1, + crs_wkt: str | None = None, gdal_metadata_xml: str | None = None, extra_tags: list | None = None, x_resolution: float | None = None, @@ -557,12 +559,14 @@ def _assemble_tiff(width: int, height: int, dtype: np.dtype, geo_tags_dict = {} if geo_transform is not None: geo_tags_dict = build_geo_tags( - geo_transform, crs_epsg, nodata, raster_type=raster_type) + geo_transform, crs_epsg, nodata, raster_type=raster_type, + crs_wkt=crs_wkt) else: # No spatial reference -- still write CRS and nodata if provided - if crs_epsg is not None or nodata is not None: + if crs_epsg is not None or crs_wkt is not None or nodata is not None: geo_tags_dict = build_geo_tags( GeoTransform(), crs_epsg, nodata, raster_type=raster_type, + crs_wkt=crs_wkt, ) # Remove the default pixel scale / tiepoint tags since we # have no real transform -- keep only GeoKeys and NODATA. @@ -641,6 +645,8 @@ def _assemble_tiff(width: int, height: int, dtype: np.dtype, tags.append((gtag, DOUBLE, 6, list(gval))) elif gtag == TAG_GEO_KEY_DIRECTORY: tags.append((gtag, SHORT, len(gval), list(gval))) + elif gtag == TAG_GEO_ASCII_PARAMS: + tags.append((gtag, ASCII, len(str(gval)) + 1, str(gval))) elif gtag == TAG_GDAL_NODATA: tags.append((gtag, ASCII, len(str(gval)) + 1, str(gval))) @@ -846,6 +852,7 @@ def _assemble_cog_layout(header_size: int, def write(data: np.ndarray, path: str, *, geo_transform: GeoTransform | None = None, crs_epsg: int | None = None, + crs_wkt: str | None = None, nodata=None, compression: str = 'zstd', tiled: bool = True, @@ -939,7 +946,7 @@ def write(data: np.ndarray, path: str, *, file_bytes = _assemble_tiff( w, h, data.dtype, comp_tag, predictor, tiled, tile_size, parts, geo_transform, crs_epsg, nodata, is_cog=cog, - raster_type=raster_type, + raster_type=raster_type, crs_wkt=crs_wkt, gdal_metadata_xml=gdal_metadata_xml, extra_tags=extra_tags, x_resolution=x_resolution, y_resolution=y_resolution, diff --git a/xrspatial/geotiff/tests/test_accuracy_1081.py b/xrspatial/geotiff/tests/test_accuracy_1081.py new file mode 100644 index 00000000..626119eb --- /dev/null +++ b/xrspatial/geotiff/tests/test_accuracy_1081.py @@ -0,0 +1,292 @@ +"""Tests for accuracy bugs fixed in #1081. + +Bug 1: Windowed read ignores PixelIsPoint raster type +Bug 2: CRS WKT silently lost on write for non-EPSG CRS +Bug 3: NaN not restored to nodata sentinel on write +""" +from __future__ import annotations + +import struct + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import open_geotiff, to_geotiff +from xrspatial.geotiff._geotags import ( + RASTER_PIXEL_IS_POINT, + TAG_GEO_ASCII_PARAMS, + TAG_GEO_KEY_DIRECTORY, + extract_geo_info, +) +from xrspatial.geotiff._header import parse_header, parse_all_ifds +from xrspatial.geotiff._reader import read_to_array +from xrspatial.geotiff._writer import write + + +def _make_pixel_is_point_tiff(tmp_path, width=8, height=8): + """Create a GeoTIFF with PixelIsPoint raster type via the writer.""" + from xrspatial.geotiff._geotags import GeoTransform + + arr = np.arange(width * height, dtype=np.float32).reshape(height, width) + path = str(tmp_path / 'point_1081.tif') + write( + arr, path, + geo_transform=GeoTransform( + origin_x=10.0, origin_y=50.0, + pixel_width=0.001, pixel_height=-0.001, + ), + crs_epsg=4326, + compression='none', + tiled=False, + raster_type=RASTER_PIXEL_IS_POINT, + ) + return path + + +# ----------------------------------------------------------------------- +# Bug 1: Windowed read + PixelIsPoint +# ----------------------------------------------------------------------- + +class TestWindowedReadPixelIsPoint: + + def test_full_read_pixel_is_point_no_offset(self, tmp_path): + """Full read of PixelIsPoint file should NOT add half-pixel offset.""" + path = _make_pixel_is_point_tiff(tmp_path) + da = open_geotiff(path) + # For PixelIsPoint, coordinates should be exactly at the tiepoint + # origin (10.0) without any 0.5*pixel_width offset. + assert da.attrs.get('raster_type') == 'point' + assert float(da.coords['x'].values[0]) == pytest.approx(10.0) + assert float(da.coords['y'].values[0]) == pytest.approx(50.0) + + def test_windowed_read_pixel_is_point_no_offset(self, tmp_path): + """Windowed read of PixelIsPoint file should match full-read coords.""" + path = _make_pixel_is_point_tiff(tmp_path) + da_full = open_geotiff(path) + da_win = open_geotiff(path, window=(2, 2, 6, 6)) + + # The windowed-read x/y should match the corresponding slice + # of the full-read coordinates. + np.testing.assert_allclose( + da_win.coords['x'].values, + da_full.coords['x'].values[2:6], + ) + np.testing.assert_allclose( + da_win.coords['y'].values, + da_full.coords['y'].values[2:6], + ) + + def test_windowed_read_pixel_is_area_has_offset(self, tmp_path): + """Windowed read of PixelIsArea should still apply half-pixel offset.""" + from xrspatial.geotiff._geotags import GeoTransform + + arr = np.ones((8, 8), dtype=np.float32) + path = str(tmp_path / 'area_1081.tif') + write( + arr, path, + geo_transform=GeoTransform( + origin_x=10.0, origin_y=50.0, + pixel_width=0.001, pixel_height=-0.001, + ), + crs_epsg=4326, + compression='none', + tiled=False, + ) + da_full = open_geotiff(path) + da_win = open_geotiff(path, window=(2, 2, 6, 6)) + + np.testing.assert_allclose( + da_win.coords['x'].values, + da_full.coords['x'].values[2:6], + ) + np.testing.assert_allclose( + da_win.coords['y'].values, + da_full.coords['y'].values[2:6], + ) + + +# ----------------------------------------------------------------------- +# Bug 2: CRS WKT loss on write +# ----------------------------------------------------------------------- + +# A custom WKT that has no EPSG code -- represents a local engineering grid +_CUSTOM_WKT = ( + 'LOCAL_CS["Local Grid",' + 'LOCAL_DATUM["Local",10000],' + 'UNIT["metre",1],' + 'AXIS["Easting",EAST],' + 'AXIS["Northing",NORTH]]' +) + + +class TestCrsWktRoundTrip: + + def test_wkt_survives_round_trip(self, tmp_path): + """Custom WKT CRS should be preserved in GeoAsciiParamsTag.""" + arr = np.ones((4, 4), dtype=np.float32) + da = xr.DataArray( + arr, + dims=['y', 'x'], + coords={ + 'y': np.arange(4, dtype=np.float64), + 'x': np.arange(4, dtype=np.float64), + }, + attrs={'crs_wkt': _CUSTOM_WKT}, + ) + path = str(tmp_path / 'wkt_1081.tif') + to_geotiff(da, path) + + # Read back the raw tags and check GeoAsciiParamsTag + import mmap + with open(path, 'rb') as f: + data = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + try: + header = parse_header(data) + ifds = parse_all_ifds(data, header) + geo_info = extract_geo_info(ifds[0], data, header.byte_order) + finally: + data.close() + + # The GeoKey directory should have a user-defined CRS (32767) + assert geo_info.crs_epsg is None or geo_info.crs_epsg == 32767 + + def test_wkt_crs_param_survives(self, tmp_path): + """crs= param with WKT string should be written when no EPSG.""" + arr = np.ones((4, 4), dtype=np.float32) + da = xr.DataArray( + arr, + dims=['y', 'x'], + coords={ + 'y': np.arange(4, dtype=np.float64), + 'x': np.arange(4, dtype=np.float64), + }, + ) + path = str(tmp_path / 'wkt_param_1081.tif') + to_geotiff(da, path, crs=_CUSTOM_WKT) + + # Verify the GeoAsciiParams tag was written + import mmap + with open(path, 'rb') as f: + data = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + try: + header = parse_header(data) + ifds = parse_all_ifds(data, header) + ifd = ifds[0] + # Check for TAG_GEO_ASCII_PARAMS (34737) in IFD entries + has_ascii_params = TAG_GEO_ASCII_PARAMS in ifd.entries + finally: + data.close() + + assert has_ascii_params, "GeoAsciiParamsTag should contain WKT" + + def test_epsg_crs_still_works(self, tmp_path): + """EPSG CRS should still work as before (no WKT fallback).""" + arr = np.ones((4, 4), dtype=np.float32) + da = xr.DataArray( + arr, + dims=['y', 'x'], + coords={ + 'y': np.arange(4, dtype=np.float64), + 'x': np.arange(4, dtype=np.float64), + }, + ) + path = str(tmp_path / 'epsg_1081.tif') + to_geotiff(da, path, crs=4326) + + da_back = open_geotiff(path) + assert da_back.attrs.get('crs') == 4326 + + +# ----------------------------------------------------------------------- +# Bug 3: NaN not restored to nodata sentinel on write +# ----------------------------------------------------------------------- + +class TestNodataRestore: + + def test_nan_restored_to_sentinel_float(self, tmp_path): + """NaN pixels should be written as the nodata sentinel, not NaN.""" + arr = np.array([[1.0, 2.0], [np.nan, 4.0]], dtype=np.float32) + da = xr.DataArray( + arr, + dims=['y', 'x'], + coords={ + 'y': np.arange(2, dtype=np.float64), + 'x': np.arange(2, dtype=np.float64), + }, + attrs={'nodata': -9999.0}, + ) + path = str(tmp_path / 'nodata_restore_1081.tif') + to_geotiff(da, path) + + # Read raw pixel data (before nodata masking) to verify sentinel + raw_arr, geo_info = read_to_array(path) + # The pixel that was NaN should now be -9999.0 + assert raw_arr[1, 0] == pytest.approx(-9999.0) + assert not np.isnan(raw_arr[1, 0]) + + def test_nan_nodata_sentinel_is_nan(self, tmp_path): + """When nodata is NaN, pixels should stay as NaN (no conversion).""" + arr = np.array([[1.0, np.nan], [3.0, 4.0]], dtype=np.float32) + da = xr.DataArray( + arr, + dims=['y', 'x'], + coords={ + 'y': np.arange(2, dtype=np.float64), + 'x': np.arange(2, dtype=np.float64), + }, + attrs={'nodata': float('nan')}, + ) + path = str(tmp_path / 'nan_nodata_1081.tif') + to_geotiff(da, path) + + raw_arr, _ = read_to_array(path) + assert np.isnan(raw_arr[0, 1]) + + def test_full_round_trip_preserves_nodata(self, tmp_path): + """open_geotiff -> to_geotiff round-trip should preserve nodata.""" + from xrspatial.geotiff._geotags import GeoTransform + + # Write a file with integer nodata sentinel + arr = np.array([[1, 2], [0, 4]], dtype=np.int16) + path1 = str(tmp_path / 'src_1081.tif') + write( + arr, path1, + geo_transform=GeoTransform(0.0, 0.0, 1.0, -1.0), + crs_epsg=4326, + nodata=0, + compression='none', + tiled=False, + ) + + # Read it (nodata=0 -> NaN) + da = open_geotiff(path1) + assert np.isnan(da.values[1, 0]) + assert da.attrs['nodata'] == 0 + + # Write it back + path2 = str(tmp_path / 'dst_1081.tif') + to_geotiff(da, path2) + + # Read raw data and check sentinel is restored + # Note: the array was promoted to float64, so nodata=0 becomes 0.0 + raw, geo = read_to_array(path2) + assert raw[1, 0] == pytest.approx(0.0) + assert not np.isnan(raw[1, 0]) + + def test_no_nodata_attr_no_conversion(self, tmp_path): + """Arrays without nodata attr should not have NaN converted.""" + arr = np.array([[1.0, np.nan], [3.0, 4.0]], dtype=np.float32) + da = xr.DataArray( + arr, + dims=['y', 'x'], + coords={ + 'y': np.arange(2, dtype=np.float64), + 'x': np.arange(2, dtype=np.float64), + }, + ) + path = str(tmp_path / 'no_nodata_1081.tif') + to_geotiff(da, path) + + raw_arr, _ = read_to_array(path) + assert np.isnan(raw_arr[0, 1])