File tree Expand file tree Collapse file tree 2 files changed +8
-17
lines changed
Expand file tree Collapse file tree 2 files changed +8
-17
lines changed Original file line number Diff line number Diff line change @@ -3081,6 +3081,10 @@ def flatten(x, ndim=1):
30813081 else :
30823082 dims = (- 1 ,)
30833083
3084+ if len (dims ) == _x .ndim :
3085+ # Nothing to ravel
3086+ return _x
3087+
30843088 x_reshaped = _x .reshape (dims )
30853089 shape_kept_dims = _x .type .shape [: ndim - 1 ]
30863090 bcast_new_dim = builtins .all (s == 1 for s in _x .type .shape [ndim - 1 :])
Original file line number Diff line number Diff line change @@ -3867,35 +3867,22 @@ class TestInferShape(utt.InferShapeTester):
38673867 def test_Flatten (self ):
38683868 atens3 = tensor3 ()
38693869 atens3_val = random (4 , 5 , 3 )
3870- for ndim in (3 , 2 , 1 ):
3870+ for ndim in (2 , 1 ):
38713871 self ._compile_and_check (
38723872 [atens3 ],
38733873 [flatten (atens3 , ndim )],
38743874 [atens3_val ],
38753875 Reshape ,
3876- excluding = ["local_useless_reshape" ],
38773876 )
38783877
38793878 amat = matrix ()
38803879 amat_val = random (4 , 5 )
3881- for ndim in (2 , 1 ):
3882- self ._compile_and_check (
3883- [amat ],
3884- [flatten (amat , ndim )],
3885- [amat_val ],
3886- Reshape ,
3887- excluding = ["local_useless_reshape" ],
3888- )
3889-
3890- avec = vector ()
3891- avec_val = random (4 )
38923880 ndim = 1
38933881 self ._compile_and_check (
3894- [avec ],
3895- [flatten (avec , ndim )],
3896- [avec_val ],
3882+ [amat ],
3883+ [flatten (amat , ndim )],
3884+ [amat_val ],
38973885 Reshape ,
3898- excluding = ["local_useless_reshape" ],
38993886 )
39003887
39013888 def test_Eye (self ):
You can’t perform that action at this time.
0 commit comments