Skip to content

Commit 59b3651

Browse files
brianwa84tensorflower-gardener
authored andcommitted
Update tensor_shape.py for JAX/NumPy backend.
PiperOrigin-RevId: 681829342
1 parent 4b8182b commit 59b3651

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,11 +1435,12 @@ def is_compatible_with(self, other):
14351435
14361436
"""
14371437
other = as_shape(other)
1438-
if self.dims is not None and other.dims is not None:
1438+
if self._dims is not None and other._dims is not None: # pylint: disable=protected-access
14391439
if self.rank != other.rank:
14401440
return False
1441-
for x_dim, y_dim in zip(self.dims, other.dims):
1442-
if not x_dim.is_compatible_with(y_dim):
1441+
for x_dim, y_dim in zip(self._dims, other._dims): # pylint: disable=protected-access
1442+
# Inline TensorShape.dims logic for performance in tight loops.
1443+
if x_dim is not None and y_dim is not None and x_dim != y_dim:
14431444
return False
14441445
return True
14451446

0 commit comments

Comments
 (0)