diff --git a/src/gfloat/round.py b/src/gfloat/round.py index c833edb..a48fdc2 100644 --- a/src/gfloat/round.py +++ b/src/gfloat/round.py @@ -128,6 +128,8 @@ def round_float( should_round_away = delta + 0.5 >= 1.0 case RoundMode.TiesToEven: should_round_away = delta > 0.5 or (delta == 0.5 and code_is_odd) + case RoundMode.ToOdd: + should_round_away = delta > 0 and not code_is_odd case RoundMode.StochasticFastest: assert srbits is not None diff --git a/src/gfloat/round_ndarray.py b/src/gfloat/round_ndarray.py index 026c077..957fb77 100644 --- a/src/gfloat/round_ndarray.py +++ b/src/gfloat/round_ndarray.py @@ -161,6 +161,9 @@ def round_ndarray( case RoundMode.TiesToEven: should_round_away = (delta > 0.5) | ((delta == 0.5) & code_is_odd) + case RoundMode.ToOdd: + should_round_away = (delta > 0) & ~code_is_odd + case RoundMode.StochasticFastest: assert srbits is not None exp2r = 2**srnumbits diff --git a/src/gfloat/types.py b/src/gfloat/types.py index 316cb2a..e465d92 100644 --- a/src/gfloat/types.py +++ b/src/gfloat/types.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Graphcore Ltd. All rights reserved. from dataclasses import dataclass -from enum import Enum +from enum import Enum, auto import math @@ -25,15 +25,16 @@ class RoundMode(Enum): """ - TowardZero = 1 #: Return the largest :math:`r` such that :math:`|r| \le |v|` - TowardNegative = 2 #: Return the largest :math:`r` such that :math:`r \le v` - TowardPositive = 3 #: Return the smallest :math:`r` such that :math:`r \ge v` - TiesToEven = 4 #: Round to nearest, ties to even - TiesToAway = 5 #: Round to nearest, ties away from zero - Stochastic = 6 #: Stochastic rounding, RTNE before comparison - StochasticOdd = 7 #: Stochastic rounding, RTNO before comparison - StochasticFast = 8 #: Stochastic rounding - faster, but biased - StochasticFastest = 9 #: Stochastic rounding - even faster, but more biased + TowardZero = auto() #: Return the largest :math:`r` such that :math:`|r| \le |v|` + TowardNegative = auto() #: Return the largest :math:`r` such that :math:`r \le v` + TowardPositive = auto() #: Return the smallest :math:`r` such that :math:`r \ge v` + TiesToEven = auto() #: Round to nearest, ties to even + TiesToAway = auto() #: Round to nearest, ties away from zero + ToOdd = auto() #: Round to odd + Stochastic = auto() #: Stochastic rounding, RTNE before comparison + StochasticOdd = auto() #: Stochastic rounding, RTNO before comparison + StochasticFast = auto() #: Stochastic rounding - faster, but biased + StochasticFastest = auto() #: Stochastic rounding - even faster, but more biased class Domain(Enum): diff --git a/test/test_round.py b/test/test_round.py index 27c798e..b6e9ea1 100644 --- a/test/test_round.py +++ b/test/test_round.py @@ -50,6 +50,8 @@ def rnd_array( @pytest.mark.parametrize("round_float", (rnd_scalar, rnd_array)) def test_round_p3109(round_float: Callable) -> None: fi = format_info_p3109(8, 4) + + assert fi.max == 224.0 assert round_float(fi, 0.0068359375) == 0.0068359375 assert round_float(fi, 0.0029296875) == 0.0029296875 assert round_float(fi, 0.0078125) == 0.0078125 @@ -73,6 +75,47 @@ def test_round_p3109(round_float: Callable) -> None: assert round_float(fi, 232.1) == np.inf + def _isodd(v: int) -> bool: + return v & 0x1 == 1 + + assert round_float(fi, 224.1, RoundMode.ToOdd) == np.inf + + fi_binary4p2se = format_info_p3109(4, 2, Signedness.Signed, Domain.Extended) + # Top three values are 1.0, 1.5, 2.0; even, odd, even + assert fi_binary4p2se.max == 2.0 + assert not _isodd(fi_binary4p2se.code_of_max) + + assert round_float(fi_binary4p2se, 1.0, RoundMode.ToOdd) == 1.0 + assert round_float(fi_binary4p2se, 1.01, RoundMode.ToOdd) == 1.5 + assert round_float(fi_binary4p2se, 1.5, RoundMode.ToOdd) == 1.5 + assert round_float(fi_binary4p2se, 1.99, RoundMode.ToOdd) == 1.5 + assert round_float(fi_binary4p2se, 2.0, RoundMode.ToOdd) == 2.0 # max + assert round_float(fi_binary4p2se, 2.01, RoundMode.ToOdd) == np.inf + + fi_binary4p2ue = format_info_p3109(4, 2, Signedness.Unsigned, Domain.Extended) + # Top two values are 4.0, 6.0, then would be 8.0; even, odd, even + assert fi_binary4p2ue.max == 6.0 + assert _isodd(fi_binary4p2ue.code_of_max) + + assert round_float(fi_binary4p2ue, 4.00, RoundMode.ToOdd) == 4.0 + assert round_float(fi_binary4p2ue, 4.01, RoundMode.ToOdd) == 6.0 + assert round_float(fi_binary4p2ue, 6.0, RoundMode.ToOdd) == 6.0 # max + assert round_float(fi_binary4p2ue, 6.01, RoundMode.ToOdd) == 6.0 + assert round_float(fi_binary4p2ue, 7.99, RoundMode.ToOdd) == 6.0 + assert round_float(fi_binary4p2ue, 8.0, RoundMode.ToOdd) == np.inf # max + 1 ulp + + fi_binary4p2uf = format_info_p3109(4, 2, Signedness.Unsigned, Domain.Finite) + # top 3 are 4.0, 6.0, 8.0; even, odd, even; sat must be true + assert fi_binary4p2uf.max == 8.0 + assert not _isodd(fi_binary4p2uf.code_of_max) + + assert round_float(fi_binary4p2uf, 3.99, RoundMode.ToOdd, True) == 3.0 + assert round_float(fi_binary4p2uf, 4.01, RoundMode.ToOdd, True) == 6.0 + assert round_float(fi_binary4p2uf, 6.01, RoundMode.ToOdd, True) == 6.0 + assert round_float(fi_binary4p2uf, 7.99, RoundMode.ToOdd, True) == 6.0 + assert round_float(fi_binary4p2uf, 8.0, RoundMode.ToOdd, True) == 8.0 + assert round_float(fi_binary4p2uf, 8.01, RoundMode.ToOdd, True) == 8.0 + p4min = 2**-10 # smallest subnormal in 8p4 @@ -171,6 +214,26 @@ def test_round_p3109(round_float: Callable) -> None: (-58.0, -60.0), ), ), + ( + RoundMode.ToOdd, + ( + (p4min, p4min), + (p4min / 4, p4min), + (p4min / 2, p4min), + (-p4min, -p4min), + (-p4min / 4, -p4min), + (-p4min / 2, -p4min), + (64.0, 64.0), + (63.0, 60.0), + (62.0, 60.0), + (61.0, 60.0), + (-64.0, -64.0), + (-63.0, -60.0), + (-62.0, -60.0), + (-61.0, -60.0), + (-58.0, -60.0), + ), + ), ), ) @pytest.mark.parametrize("round_float", (rnd_scalar, rnd_array)) @@ -320,6 +383,32 @@ def test_round_p3109b(round_float: Callable, mode: RoundMode, vals: list) -> Non (-np.inf, -np.inf), ), ), + ( + (RoundMode.ToOdd, True), + ( + (p4max, p4max), + (p4maxhalfup, p4max), + (p4maxup, p4max), + (np.inf, p4max), + (-p4max, -p4max), + (-p4maxhalfup, -p4max), + (-p4maxup, -p4max), + (-np.inf, -p4max), + ), + ), + ( + (RoundMode.ToOdd, False), + ( + (p4max, p4max), + (p4maxhalfup, np.inf), + (p4maxup, np.inf), + (np.inf, np.inf), + (-p4max, -p4max), + (-p4maxhalfup, -np.inf), + (-p4maxup, -np.inf), + (-np.inf, -np.inf), + ), + ), ), ids=lambda x: f"{str(x[0])}-{'Sat' if x[1] else 'Inf'}" if len(x) == 2 else None, )