Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pcxarray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .query import pc_query, get_pc_collections
from .processing import prepare_timeseries, prepare_data, query_and_prepare
from .processing import prepare_timeseries, prepare_data, query_and_prepare, lazy_merge_arrays

try:
from importlib.metadata import version
Expand Down
2 changes: 2 additions & 0 deletions src/pcxarray/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
from typing import Union
import xarray as xr
from pandas import notna



Expand Down Expand Up @@ -168,6 +169,7 @@ def read_single_item(
col for col in item_gs.index \
if col.startswith('assets.') \
and col.endswith('.href') \
and notna(item_gs[col.replace('.href', '.type')]) \
and 'image/tiff; application=geotiff; profile=cloud-optimized' in item_gs[col.replace('.href', '.type')] \
and col not in ignored_assets
]
Expand Down
149 changes: 82 additions & 67 deletions src/pcxarray/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from warnings import warn
import geopandas as gpd
import pandas as pd
from pyproj import Transformer, CRS, transform
from pyproj import CRS
import xarray as xr
from rioxarray.merge import merge_arrays
from rasterio.enums import Resampling
from tqdm import tqdm
from shapely.ops import transform, unary_union
from shapely.geometry import box
from shapely.geometry.base import BaseGeometry
import numpy as np
from odc.geo.geobox import GeoBox
Expand All @@ -24,11 +25,11 @@
def lazy_merge_arrays(
arrays: List[xr.DataArray],
method: Literal['last', 'first', 'min', 'max', 'mean', 'sum', 'median'] = 'last',
geom: Optional[BaseGeometry] = None,
crs: Optional[Union[CRS, str]] = None,
geometry: Optional[BaseGeometry] = None,
crs: Optional[Union[CRS, str, int]] = None,
resolution: Optional[Union[float, int]] = None,
resampling_method: Union[Resampling, str] = 'nearest',
nodata: Optional[float] = None
nodata: Optional[float] = None,
) -> xr.DataArray:
"""
Merge multiple xarray DataArrays lazily.
Expand All @@ -45,18 +46,15 @@ def lazy_merge_arrays(
List of georeferenced DataArrays to merge from rioxarray.
method : {'last', 'first', 'min', 'max', 'mean', 'sum', 'median'}, default='last'
Method for merging overlapping pixels.
geom : shapely.geometry.base.BaseGeometry, optional
Target geometry for the merged array. If None, computed as union of all
input array bounds. Must be provided together with crs and resolution, or
all three must be None.
crs : pyproj.CRS or str, optional
Target coordinate reference system. If None, uses CRS from first input array.
Must be provided together with geom and resolution, or all three must be
None.
geometry : shapely.geometry.base.BaseGeometry, optional
Target geometry for the merged array. If None, computed as the unary union
of all input array bounds.
crs : pyproj.CRS, str, or int, optional
Target coordinate reference system. If None, uses the CRS from the first
input array.
resolution : float or int, optional
Target pixel resolution in CRS units. If None, uses minimum resolution from
input arrays. Must be provided together with geom and crs, or all three
must be None.
Target pixel resolution in CRS units. If None, uses the minimum resolution
from all input arrays.
resampling_method : rasterio.enums.Resampling or str, default 'nearest'
Resampling method for reprojection. Can be Resampling enum or string name
(e.g., 'nearest', 'bilinear', 'cubic', etc.).
Expand All @@ -73,37 +71,33 @@ def lazy_merge_arrays(
Raises
------
ValueError
If only some of geom, crs, or resolution are provided (must be all or none),
or if an unknown merge method is specified.
If an unknown merge method is specified or if geometry and CRS do not produce
a valid geobox.
UserWarning
If multiple CRS are found in input arrays (uses first one found).
If reprojection fails for any input array (the array is skipped and
processing continues).

Notes
-----
- TODO: At some point, consider documenting edge cases around nodata handling.
This seems to be the most fragile part of this package, though it's difficult
to really nail down all the edge cases.
"""

# determine the common geobox for reprojection if args not provided
if geom is None or crs is None or resolution is None:
if sum([geom is None, crs is None, resolution is None]) > 1:
raise ValueError("If one of geom, crs, or resolution is None, all must be provided.")

geoms = [da.rio.transform_bounds() for da in arrays]
geom = unary_union(geoms)

crs_list = [da.rio.crs for da in arrays]
if len(set(crs_list)) > 1:
warn(f"Multiple CRSs found in input arrays: {set(crs_list)}. Using the first raster's CRS ({crs_list[0]}).")
crs = crs_list[0]

resolution = min([min(abs(da.rio.resolution()[0]), abs(da.rio.resolution()[1])) for da in arrays])

geobox = GeoBox.from_geopolygon(
Geometry(geom, crs=crs),
resolution=resolution
input_dtypes = [da.dtype for da in arrays]
if len(set(input_dtypes)) > 1:
raise ValueError(
"All input arrays must have the same dtype for merging. "
f"Found differing dtypes: {set(input_dtypes)}"
)
input_dtype = input_dtypes[0]

else:
geobox = GeoBox.from_geopolygon(
Geometry(geom, crs=crs),
resolution=resolution
)
if geometry is None:
geometry = unary_union([box(*da.rio.bounds()) for da in arrays])
if crs is None:
crs = arrays[0].rio.crs
if resolution is None:
resolution = min([min(abs(r) for r in da.rio.resolution()) for da in arrays])

if isinstance(resampling_method, Resampling):
resampling_method = resampling_method.name.lower()
Expand All @@ -113,6 +107,11 @@ def lazy_merge_arrays(
if nodata is not None and not np.isnan(nodata):
arrays = [da.where(da != nodata) for da in arrays] # mask nodata values

geobox = GeoBox.from_geopolygon(
Geometry(geometry, crs=crs),
resolution=resolution,
)

reprojected_arrays = []
for da in arrays:
try:
Expand All @@ -123,7 +122,7 @@ def lazy_merge_arrays(
except Exception as e:
warn(f"Reprojection failed for {da.name} with error: {e}. Skipping this array.")
continue

arrays = xr.align(*reprojected_arrays, join='exact')
stacked = xr.concat(arrays, dim='merge_dim')
# if nodata is not None and not np.isnan(nodata):
Expand All @@ -149,14 +148,21 @@ def lazy_merge_arrays(
raise ValueError(f"Unknown merge method: {method}")

result = result.rio.write_nodata(nodata)

if nodata is not None and not np.isnan(nodata):
result = result.fillna(nodata)

if result.dtype != input_dtype:
result = result.astype(input_dtype)

return result



def prepare_data(
items_gdf: gpd.GeoDataFrame,
geometry: BaseGeometry,
crs: Union[CRS, str, int] = 4326,
geometry: Optional[BaseGeometry] = None,
crs: Optional[Union[CRS, str, int]] = None,
bands: Optional[List[Union[str, int]]] = None,
target_resolution: Optional[Union[float, int]] = None,
all_touched: bool = False,
Expand All @@ -183,10 +189,11 @@ def prepare_data(
----------
items_gdf : geopandas.GeoDataFrame
GeoDataFrame of STAC items to process.
geometry : shapely.geometry.base.BaseGeometry
Area of interest geometry in the target CRS.
crs : pyproj.CRS, str or int, default=4326
Coordinate reference system for the output.
geometry : shapely.geometry.base.BaseGeometry, optional
Area of interest geometry in the target CRS. If None, uses the union of
all geometries in items_gdf.
crs : pyproj.CRS, str or int, optional
Coordinate reference system for the output. If None, uses the CRS from items_gdf.
bands : list of str or int, optional
List of band names or indices to select; if None, all valid bands are loaded.
target_resolution : float or int, optional
Expand Down Expand Up @@ -230,17 +237,20 @@ def prepare_data(
if isinstance(resampling_method, Resampling):
resampling_method = resampling_method.name.lower()

transformer = Transformer.from_crs(
crs,
CRS.from_epsg(4326),
always_xy=True,
)
geom_84 = transform(
transformer.transform,
geometry
)
if crs is not None and items_gdf.crs != crs:
items_gdf = items_gdf.to_crs(crs)
elif crs is None and items_gdf.crs is not None:
crs = items_gdf.crs
elif crs is None:
raise ValueError("CRS must be provided if items_gdf has no CRS.")

if not isinstance(crs, CRS):
crs = CRS.from_user_input(crs)

if geometry is None:
geometry = items_gdf.union_all()

items_gdf['percent_overlap'] = items_gdf.geometry.apply(lambda x: x.intersection(geom_84).area / geom_84.area)
items_gdf['percent_overlap'] = items_gdf.geometry.apply(lambda x: x.intersection(geometry).area / geometry.area)
items_full_overlap = items_gdf[items_gdf['percent_overlap'] == 1.0]

selected_items = []
Expand All @@ -250,7 +260,7 @@ def prepare_data(
try:
da = read_single_item(
item_gs=item,
geometry=geometry,
geometry=geometry.buffer(1) if crs.is_projected else geometry, # buffer slightly to avoid edge cases
bands=bands,
chunks=chunks,
all_touched=True,
Expand All @@ -269,15 +279,15 @@ def prepare_data(
da = da.odc.reproject(
how=GeoBox.from_geopolygon(
Geometry(geometry, crs=crs),
resolution=target_resolution
resolution=target_resolution if target_resolution is not None else abs(min(da.rio.resolution())),
),
resampling=resampling_method,
)

else: # multiple items, need to merge and reproject.
items_gdf = items_gdf.sort_values(by='percent_overlap', ascending=False)

remaining_geom = geom_84
remaining_geom = geometry
remaining_area = 1.0
selected_items = []
while remaining_area > 0:
Expand All @@ -286,7 +296,7 @@ def prepare_data(

intersection = item_series.geometry.intersection(remaining_geom)
remaining_geom = remaining_geom.difference(intersection)
remaining_area = remaining_geom.area / geom_84.area
remaining_area = remaining_geom.area / geometry.area
if remaining_area == 0:
break

Expand All @@ -305,7 +315,7 @@ def safe_read_item(item_series):
try:
return read_single_item(
item_gs=item_series,
geometry=geometry,
geometry=geometry.buffer(1) if crs.is_projected else geometry,
bands=bands,
chunks=chunks,
all_touched=True,
Expand Down Expand Up @@ -338,13 +348,18 @@ def safe_read_item(item_series):
da = lazy_merge_arrays(
das,
method=merge_method,
geom=geometry,
crs=crs if crs is not None else das[0].rio.crs,
resolution=target_resolution if target_resolution is not None else das[0].rio.resolution()[0], # assuming square pixels
geometry=geometry,
crs=crs,
resolution=target_resolution,
resampling_method=resampling_method,
)

da = da.odc.crop(Geometry(geometry, crs=crs), apply_mask=True, all_touched=all_touched)
da = da.odc.crop(
Geometry(geometry, crs=crs),
apply_mask=True,
all_touched=all_touched,
)

# return da
if chunks is not None and da.chunks != chunks:
da = da.chunk(chunks)
Expand Down
2 changes: 1 addition & 1 deletion src/pcxarray/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def pc_query(
if retries >= max_retries:
raise RuntimeError(f"STAC search failed after {max_retries} retries: {type(e).__str__}: {e}") from e

warn(f"STAC search failed: {type(e).__str__}: {e}. Retrying ({retries + 1}/{max_retries})...")
warn(f"STAC search failed: {type(e).__str__()}: {e}. Retrying ({retries + 1}/{max_retries})...")
sleep(2 ** (retries + 1)) # Exponential backoff
retries += 1

Expand Down