diff --git a/.gitignore b/.gitignore index f42eb1e..10f4f8a 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,3 @@ test_stats_csv.csv test_stats_csv.dat tmp_tolerance_None.csv tolerance.csv -util/differences.csv diff --git a/tests/util/test_fof_utils.py b/tests/util/test_fof_utils.py index 9564311..7794869 100644 --- a/tests/util/test_fof_utils.py +++ b/tests/util/test_fof_utils.py @@ -11,9 +11,9 @@ clean_value, compare_arrays, compare_var_and_attr_ds, - fill_nans_for_float32, get_observation_variables, get_report_variables, + replace_nan_with_sentinel_float64, split_feedback_dataset, ) from util.log_handler import initialize_detailed_logger @@ -194,10 +194,10 @@ def fixture_arr(): def test_fill_nans_for_float32_nan(arr_nan): """ Test that if an array containing nan is given, these values are replaced - by -9.99999e05. + by -999999.0. """ - array = fill_nans_for_float32(arr_nan) - expected = np.array([1.0, -9.99999e05, 3.0, 4.0, -9.99999e05], dtype=np.float32) + array = replace_nan_with_sentinel_float64(arr_nan) + expected = np.array([1.0, -999999.0, 3.0, 4.0, -999999.0], dtype=np.float64) assert np.array_equal(array, expected) @@ -206,7 +206,7 @@ def test_fill_nans_for_float32(arr1): Test that if an array without nan is given, the output of the function is the same as the input. """ - array = fill_nans_for_float32(arr1) + array = replace_nan_with_sentinel_float64(arr1) assert np.array_equal(array, arr1) diff --git a/util/fof_utils.py b/util/fof_utils.py index 687d23d..de7eae9 100644 --- a/util/fof_utils.py +++ b/util/fof_utils.py @@ -104,13 +104,18 @@ def compare_arrays(arr1, arr2, var_name): return total, equal, diff -def fill_nans_for_float32(arr): +def replace_nan_with_sentinel_float64(arr): """ - To make sure nan values are recognised. + If the input array has a floating dtype, it is cast to float64 + and all NaN values are replaced with the sentinel value -999999.0 + If the array does not have a floating dtype, it is returned unchanged. """ - if arr.dtype == np.float32 and np.isnan(arr).any(): - return np.where(np.isnan(arr), -999999, arr) - return arr + if not np.issubdtype(arr.dtype, np.floating): + return arr + + arr = arr.astype(np.float64, copy=False) + + return np.where(np.isnan(arr), -999999.0, arr) def clean_value(x): @@ -203,8 +208,8 @@ def process_var(ds1, ds2, var, detailed_logger): number of matching elements. """ - arr1 = fill_nans_for_float32(ds1[var].values) - arr2 = fill_nans_for_float32(ds2[var].values) + arr1 = replace_nan_with_sentinel_float64(ds1[var].values) + arr2 = replace_nan_with_sentinel_float64(ds2[var].values) if arr1.size == arr2.size: t, e, diff = compare_arrays(arr1, arr2, var) if diff.size != 0: