Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/gfloat/round.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def round_float(
should_round_away = delta + 0.5 >= 1.0
case RoundMode.TiesToEven:
should_round_away = delta > 0.5 or (delta == 0.5 and code_is_odd)
case RoundMode.ToOdd:
should_round_away = delta > 0 and not code_is_odd

case RoundMode.StochasticFastest:
assert srbits is not None
Expand Down
3 changes: 3 additions & 0 deletions src/gfloat/round_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def round_ndarray(
case RoundMode.TiesToEven:
should_round_away = (delta > 0.5) | ((delta == 0.5) & code_is_odd)

case RoundMode.ToOdd:
should_round_away = (delta > 0) & ~code_is_odd

case RoundMode.StochasticFastest:
assert srbits is not None
exp2r = 2**srnumbits
Expand Down
21 changes: 11 additions & 10 deletions src/gfloat/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

from dataclasses import dataclass
from enum import Enum
from enum import Enum, auto
import math


Expand All @@ -25,15 +25,16 @@ class RoundMode(Enum):

"""

TowardZero = 1 #: Return the largest :math:`r` such that :math:`|r| \le |v|`
TowardNegative = 2 #: Return the largest :math:`r` such that :math:`r \le v`
TowardPositive = 3 #: Return the smallest :math:`r` such that :math:`r \ge v`
TiesToEven = 4 #: Round to nearest, ties to even
TiesToAway = 5 #: Round to nearest, ties away from zero
Stochastic = 6 #: Stochastic rounding, RTNE before comparison
StochasticOdd = 7 #: Stochastic rounding, RTNO before comparison
StochasticFast = 8 #: Stochastic rounding - faster, but biased
StochasticFastest = 9 #: Stochastic rounding - even faster, but more biased
TowardZero = auto() #: Return the largest :math:`r` such that :math:`|r| \le |v|`
TowardNegative = auto() #: Return the largest :math:`r` such that :math:`r \le v`
TowardPositive = auto() #: Return the smallest :math:`r` such that :math:`r \ge v`
TiesToEven = auto() #: Round to nearest, ties to even
TiesToAway = auto() #: Round to nearest, ties away from zero
ToOdd = auto() #: Round to odd
Stochastic = auto() #: Stochastic rounding, RTNE before comparison
StochasticOdd = auto() #: Stochastic rounding, RTNO before comparison
StochasticFast = auto() #: Stochastic rounding - faster, but biased
StochasticFastest = auto() #: Stochastic rounding - even faster, but more biased


class Domain(Enum):
Expand Down
89 changes: 89 additions & 0 deletions test/test_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def rnd_array(
@pytest.mark.parametrize("round_float", (rnd_scalar, rnd_array))
def test_round_p3109(round_float: Callable) -> None:
fi = format_info_p3109(8, 4)

assert fi.max == 224.0
assert round_float(fi, 0.0068359375) == 0.0068359375
assert round_float(fi, 0.0029296875) == 0.0029296875
assert round_float(fi, 0.0078125) == 0.0078125
Expand All @@ -73,6 +75,47 @@ def test_round_p3109(round_float: Callable) -> None:

assert round_float(fi, 232.1) == np.inf

def _isodd(v: int) -> bool:
return v & 0x1 == 1

assert round_float(fi, 224.1, RoundMode.ToOdd) == np.inf

fi_binary4p2se = format_info_p3109(4, 2, Signedness.Signed, Domain.Extended)
# Top three values are 1.0, 1.5, 2.0; even, odd, even
assert fi_binary4p2se.max == 2.0
assert not _isodd(fi_binary4p2se.code_of_max)

assert round_float(fi_binary4p2se, 1.0, RoundMode.ToOdd) == 1.0
assert round_float(fi_binary4p2se, 1.01, RoundMode.ToOdd) == 1.5
assert round_float(fi_binary4p2se, 1.5, RoundMode.ToOdd) == 1.5
assert round_float(fi_binary4p2se, 1.99, RoundMode.ToOdd) == 1.5
assert round_float(fi_binary4p2se, 2.0, RoundMode.ToOdd) == 2.0 # max
assert round_float(fi_binary4p2se, 2.01, RoundMode.ToOdd) == np.inf

fi_binary4p2ue = format_info_p3109(4, 2, Signedness.Unsigned, Domain.Extended)
# Top two values are 4.0, 6.0, then would be 8.0; even, odd, even
assert fi_binary4p2ue.max == 6.0
assert _isodd(fi_binary4p2ue.code_of_max)

assert round_float(fi_binary4p2ue, 4.00, RoundMode.ToOdd) == 4.0
assert round_float(fi_binary4p2ue, 4.01, RoundMode.ToOdd) == 6.0
assert round_float(fi_binary4p2ue, 6.0, RoundMode.ToOdd) == 6.0 # max
assert round_float(fi_binary4p2ue, 6.01, RoundMode.ToOdd) == 6.0
assert round_float(fi_binary4p2ue, 7.99, RoundMode.ToOdd) == 6.0
assert round_float(fi_binary4p2ue, 8.0, RoundMode.ToOdd) == np.inf # max + 1 ulp

fi_binary4p2uf = format_info_p3109(4, 2, Signedness.Unsigned, Domain.Finite)
# top 3 are 4.0, 6.0, 8.0; even, odd, even; sat must be true
assert fi_binary4p2uf.max == 8.0
assert not _isodd(fi_binary4p2uf.code_of_max)

assert round_float(fi_binary4p2uf, 3.99, RoundMode.ToOdd, True) == 3.0
assert round_float(fi_binary4p2uf, 4.01, RoundMode.ToOdd, True) == 6.0
assert round_float(fi_binary4p2uf, 6.01, RoundMode.ToOdd, True) == 6.0
assert round_float(fi_binary4p2uf, 7.99, RoundMode.ToOdd, True) == 6.0
assert round_float(fi_binary4p2uf, 8.0, RoundMode.ToOdd, True) == 8.0
assert round_float(fi_binary4p2uf, 8.01, RoundMode.ToOdd, True) == 8.0


p4min = 2**-10 # smallest subnormal in 8p4

Expand Down Expand Up @@ -171,6 +214,26 @@ def test_round_p3109(round_float: Callable) -> None:
(-58.0, -60.0),
),
),
(
RoundMode.ToOdd,
(
(p4min, p4min),
(p4min / 4, p4min),
(p4min / 2, p4min),
(-p4min, -p4min),
(-p4min / 4, -p4min),
(-p4min / 2, -p4min),
(64.0, 64.0),
(63.0, 60.0),
(62.0, 60.0),
(61.0, 60.0),
(-64.0, -64.0),
(-63.0, -60.0),
(-62.0, -60.0),
(-61.0, -60.0),
(-58.0, -60.0),
),
),
),
)
@pytest.mark.parametrize("round_float", (rnd_scalar, rnd_array))
Expand Down Expand Up @@ -320,6 +383,32 @@ def test_round_p3109b(round_float: Callable, mode: RoundMode, vals: list) -> Non
(-np.inf, -np.inf),
),
),
(
(RoundMode.ToOdd, True),
(
(p4max, p4max),
(p4maxhalfup, p4max),
(p4maxup, p4max),
(np.inf, p4max),
(-p4max, -p4max),
(-p4maxhalfup, -p4max),
(-p4maxup, -p4max),
(-np.inf, -p4max),
),
),
(
(RoundMode.ToOdd, False),
(
(p4max, p4max),
(p4maxhalfup, np.inf),
(p4maxup, np.inf),
(np.inf, np.inf),
(-p4max, -p4max),
(-p4maxhalfup, -np.inf),
(-p4maxup, -np.inf),
(-np.inf, -np.inf),
),
),
),
ids=lambda x: f"{str(x[0])}-{'Sat' if x[1] else 'Inf'}" if len(x) == 2 else None,
)
Expand Down