diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index efd03089bbf8..4c30630090c3 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -65,6 +65,7 @@ ) if TYPE_CHECKING: + import numpy as np import pyarrow as pa import pandas as pd @@ -1661,6 +1662,35 @@ def convert_legacy( ) return converter(ser) + @staticmethod + def _ndarray_to_list(v: "np.ndarray") -> list: + """Recursively convert numpy ndarrays to Python lists.""" + import numpy as np + + return [ + ArrowArrayToPandasConversion._ndarray_to_list(x) if isinstance(x, np.ndarray) else x + for x in v + ] + + @staticmethod + def _contains_conversion_type(data_type: DataType) -> bool: + """ + Check if data type tree contains types that require post-processing conversion. + + Returns True if the type contains UserDefinedType, VariantType, GeographyType, + GeometryType, MapType, or StructType at any nesting level. + MapType and StructType require conversion because PyArrow's to_pandas() produces + maps as lists of tuples (not dicts) and structs as dicts (not Rows). + """ + if isinstance( + data_type, + (UserDefinedType, VariantType, GeographyType, GeometryType, MapType, StructType), + ): + return True + elif isinstance(data_type, ArrayType): + return ArrowArrayToPandasConversion._contains_conversion_type(data_type.elementType) + return False + @classmethod def _prefer_convert_numpy( cls, @@ -1688,8 +1718,14 @@ def _prefer_convert_numpy( ) if df_for_struct and isinstance(spark_type, StructType): return all(isinstance(f.dataType, supported_types) for f in spark_type.fields) + elif isinstance(spark_type, supported_types): + return True + elif isinstance(spark_type, ArrayType): + return not cls._contains_conversion_type(spark_type) + # elif isinstance(spark_type, (MapType, StructType)): + # TODO: Support MapType, StructType else: - return isinstance(spark_type, supported_types) + return False @classmethod def convert_numpy( @@ -1808,15 +1844,14 @@ def convert_numpy( series = series.map( lambda v: Geometry.fromWKB(v["wkb"], v["srid"]) if v is not None else None ) - # elif isinstance( - # spark_type, - # ( - # ArrayType, - # MapType, - # StructType, - # ), - # ): - # TODO(SPARK-55324): Support complex types + elif isinstance(spark_type, ArrayType): + if ndarray_as_list: + series = arr.to_pandas(integer_object_nulls=True) + series = series.map(lambda x: cls._ndarray_to_list(x) if x is not None else None) + else: + series = arr.to_pandas() + # elif isinstance(spark_type, (MapType, StructType)): + # TODO: Support MapType, StructType else: # pragma: no cover assert False, f"Need converter for {spark_type} but failed to find one." diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py index 9ac6bcbd0537..38145a69d5dd 100644 --- a/python/pyspark/sql/tests/test_conversion.py +++ b/python/pyspark/sql/tests/test_conversion.py @@ -44,6 +44,7 @@ StringType, StructField, StructType, + TimestampNTZType, TimestampType, UserDefinedType, VariantType, @@ -738,6 +739,86 @@ def test_geometry_convert_numpy(self): ) self.assertEqual(len(result), 0) + def test_array_convert_numpy(self): + import pyarrow as pa + import numpy as np + + arr = pa.array([[1, 2, 3], [4, 5]], type=pa.list_(pa.int64())) + result = ArrowArrayToPandasConversion.convert_numpy(arr, ArrayType(IntegerType())) + self.assertIsInstance(result.iloc[0], np.ndarray) + self.assertEqual(list(result.iloc[0]), [1, 2, 3]) + self.assertEqual(list(result.iloc[1]), [4, 5]) + + # empty inner arrays + arr = pa.array([[], [1, 2], []], type=pa.list_(pa.int64())) + result = ArrowArrayToPandasConversion.convert_numpy(arr, ArrayType(IntegerType())) + self.assertEqual(len(result.iloc[0]), 0) + self.assertEqual(list(result.iloc[1]), [1, 2]) + + # nulls: inner nulls become NaN (float64) to preserve numeric ndarray dtype + arr = pa.array([[1, None, 3], None, [4, 5]], type=pa.list_(pa.int64())) + result = ArrowArrayToPandasConversion.convert_numpy(arr, ArrayType(IntegerType())) + self.assertTrue(np.isnan(result.iloc[0][1])) + self.assertIsNone(result.iloc[1]) + + # nested arrays + arr = pa.array([[[1, 2], [3]], [[4, 5]]], type=pa.list_(pa.list_(pa.int64()))) + result = ArrowArrayToPandasConversion.convert_numpy( + arr, ArrayType(ArrayType(IntegerType())) + ) + self.assertIsInstance(result.iloc[0], np.ndarray) + self.assertEqual(list(result.iloc[0][0]), [1, 2]) + self.assertEqual(list(result.iloc[0][1]), [3]) + + def test_array_with_timestamps(self): + import pyarrow as pa + import numpy as np + + # tz-aware timestamps: preprocess_time strips tz and coerces to ns + ts1 = datetime.datetime(2024, 1, 1, 12, 0, tzinfo=ZoneInfo("UTC")) + ts2 = datetime.datetime(2024, 6, 15, 8, 30, tzinfo=ZoneInfo("UTC")) + arr = pa.array([[ts1, ts2]], type=pa.list_(pa.timestamp("us", tz="UTC"))) + result = ArrowArrayToPandasConversion.convert_numpy(arr, ArrayType(TimestampType())) + self.assertIsInstance(result.iloc[0], np.ndarray) + self.assertEqual(result.iloc[0][0], np.datetime64("2024-01-01T12:00:00", "ns")) + self.assertEqual(result.iloc[0][1], np.datetime64("2024-06-15T08:30:00", "ns")) + + # tz-naive timestamps + arr = pa.array( + [[datetime.datetime(2024, 1, 1), datetime.datetime(2024, 6, 15)]], + type=pa.list_(pa.timestamp("us")), + ) + result = ArrowArrayToPandasConversion.convert_numpy(arr, ArrayType(TimestampNTZType())) + self.assertEqual(result.iloc[0][0], np.datetime64("2024-01-01T00:00:00", "ns")) + + def test_array_ndarray_as_list(self): + import pyarrow as pa + + arr = pa.array([[1, 2, 3], [4, 5]], type=pa.list_(pa.int64())) + result = ArrowArrayToPandasConversion.convert_numpy( + arr, ArrayType(IntegerType()), ndarray_as_list=True + ) + self.assertIsInstance(result.iloc[0], list) + self.assertEqual(result.iloc[0], [1, 2, 3]) + + # nulls preserved as None (not NaN) + arr = pa.array([[1, None, 3], None], type=pa.list_(pa.int64())) + result = ArrowArrayToPandasConversion.convert_numpy( + arr, ArrayType(IntegerType()), ndarray_as_list=True + ) + self.assertIsInstance(result.iloc[0], list) + self.assertIsNone(result.iloc[0][1]) + self.assertIsNone(result.iloc[1]) + + # nested arrays recursively converted to lists + arr = pa.array([[[1, 2], [3]]], type=pa.list_(pa.list_(pa.int64()))) + result = ArrowArrayToPandasConversion.convert_numpy( + arr, ArrayType(ArrayType(IntegerType())), ndarray_as_list=True + ) + self.assertIsInstance(result.iloc[0], list) + self.assertIsInstance(result.iloc[0][0], list) + self.assertEqual(result.iloc[0][0], [1, 2]) + if __name__ == "__main__": from pyspark.testing import main