File tree Expand file tree Collapse file tree 2 files changed +20
-7
lines changed
Expand file tree Collapse file tree 2 files changed +20
-7
lines changed Original file line number Diff line number Diff line change 99from pytensor .tensor .variable import _tensor_py_operators
1010
1111
12+ def __getattr__ (name ):
13+ if name == "ScalarSharedVariable" :
14+ warnings .warn (
15+ "The class `ScalarSharedVariable` has been deprecated. "
16+ "Use `TensorSharedVariable` instead and check for `ndim==0`." ,
17+ FutureWarning ,
18+ )
19+ return TensorSharedVariable
20+
21+ raise AttributeError (f"module { __name__ !r} has no attribute { name !r} " )
22+
23+
1224def load_shared_variable (val ):
1325 """
1426 This function is only here to keep some pickles loading
@@ -94,10 +106,6 @@ def tensor_constructor(
94106 )
95107
96108
97- class ScalarSharedVariable (TensorSharedVariable ):
98- pass
99-
100-
101109@shared_constructor .register (np .number )
102110@shared_constructor .register (float )
103111@shared_constructor .register (int )
@@ -132,7 +140,7 @@ def scalar_constructor(
132140
133141 # Do not pass the dtype to asarray because we want this to fail if
134142 # strict is True and the types do not match.
135- rval = ScalarSharedVariable (
143+ rval = TensorSharedVariable (
136144 type = tensor_type ,
137145 value = np .array (value , copy = True ),
138146 name = name ,
Original file line number Diff line number Diff line change 1010from pytensor .tensor import get_vector_length
1111from pytensor .tensor .basic import MakeVector
1212from pytensor .tensor .shape import Shape_i , specify_shape
13- from pytensor .tensor .sharedvar import ScalarSharedVariable , TensorSharedVariable
13+ from pytensor .tensor .sharedvar import TensorSharedVariable
1414from tests import unittest_tools as utt
1515
1616
@@ -679,12 +679,17 @@ def test_tensor_shared_zero():
679679
680680def test_scalar_shared_options ():
681681 res = pytensor .shared (value = np .float32 (0.0 ), name = "lk" , borrow = True )
682- assert isinstance (res , ScalarSharedVariable )
682+ assert isinstance (res , TensorSharedVariable ) and res . type . ndim == 0
683683 assert res .type .dtype == "float32"
684684 assert res .name == "lk"
685685 assert res .type .shape == ()
686686
687687
688+ def test_scalar_shared_deprecated ():
689+ with pytest .warns (FutureWarning , match = ".*deprecated.*" ):
690+ pytensor .tensor .sharedvar .ScalarSharedVariable
691+
692+
688693def test_get_vector_length ():
689694 x = pytensor .shared (np .array ((2 , 3 , 4 , 5 )))
690695 assert get_vector_length (x ) == 4
You can’t perform that action at this time.
0 commit comments