Skip to content

Commit 9ac6579

Browse files
authored
Added PyTorch clip dispatch (#1797)
1 parent 6319fac commit 9ac6579

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
66
from pytensor.scalar.basic import (
77
Cast,
8+
Clip,
89
Invert,
910
ScalarOp,
1011
)
@@ -71,6 +72,14 @@ def pytorch_funcify_Softplus(op, node, **kwargs):
7172
return torch.nn.Softplus()
7273

7374

75+
@pytorch_funcify.register(Clip)
76+
def pytorch_funcify_Clip(op, node, **kwargs):
77+
def clip(x, min_val, max_val):
78+
return torch.where(x < min_val, min_val, torch.where(x > max_val, max_val, x))
79+
80+
return clip
81+
82+
7483
@pytorch_funcify.register(ScalarLoop)
7584
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
7685
update = pytorch_funcify(op.fgraph, **kwargs)

tests/link/pytorch/test_elemwise.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,19 @@ def test_cast():
151151
assert res.dtype == np.int32
152152

153153

154+
@pytest.mark.parametrize(
155+
"x_val, min_val, max_val",
156+
[
157+
(np.array([5.0], dtype=config.floatX), 0.0, 10.0),
158+
(np.array([-5.0], dtype=config.floatX), 0.0, 10.0),
159+
],
160+
)
161+
def test_clip(x_val, min_val, max_val):
162+
x = pt.tensor("x", shape=x_val.shape, dtype=config.floatX)
163+
out = pt.clip(x, min_val, max_val)
164+
compare_pytorch_and_py([x], [out], [x_val])
165+
166+
154167
def test_vmap_elemwise():
155168
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
156169

0 commit comments

Comments
 (0)