@@ -41,3 +41,61 @@ def test_pytorch_CumOp(axis, dtype):
4141 out = pt .cumprod (a , axis = axis )
4242 fgraph = FunctionGraph ([a ], [out ])
4343 compare_pytorch_and_py (fgraph , [test_value ])
44+
45+
46+ @pytest .mark .parametrize (
47+ "axis, repeats" ,
48+ [
49+ (0 , (1 , 2 , 3 )),
50+ (1 , (3 , 3 )),
51+ pytest .param (
52+ None ,
53+ 3 ,
54+ marks = pytest .mark .xfail (reason = "Reshape not implemented" ),
55+ ),
56+ ],
57+ )
58+ def test_pytorch_Repeat (axis , repeats ):
59+ a = pt .matrix ("a" , dtype = "float64" )
60+
61+ test_value = np .arange (6 , dtype = "float64" ).reshape ((3 , 2 ))
62+
63+ out = pt .repeat (a , repeats , axis = axis )
64+ fgraph = FunctionGraph ([a ], [out ])
65+ compare_pytorch_and_py (fgraph , [test_value ])
66+
67+
68+ @pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
69+ def test_pytorch_Unique_axis (axis ):
70+ a = pt .matrix ("a" , dtype = "float64" )
71+
72+ test_value = np .array (
73+ [[1.0 , 1.0 , 2.0 ], [1.0 , 1.0 , 2.0 ], [3.0 , 3.0 , 0.0 ]], dtype = "float64"
74+ )
75+
76+ out = pt .unique (a , axis = axis )
77+ fgraph = FunctionGraph ([a ], [out ])
78+ compare_pytorch_and_py (fgraph , [test_value ])
79+
80+
81+ @pytest .mark .parametrize ("return_inverse" , [False , True ])
82+ @pytest .mark .parametrize ("return_counts" , [False , True ])
83+ @pytest .mark .parametrize (
84+ "return_index" ,
85+ (False , pytest .param (True , marks = pytest .mark .xfail (raises = NotImplementedError ))),
86+ )
87+ def test_pytorch_Unique_params (return_index , return_inverse , return_counts ):
88+ a = pt .matrix ("a" , dtype = "float64" )
89+ test_value = np .array (
90+ [[1.0 , 1.0 , 2.0 ], [1.0 , 1.0 , 2.0 ], [3.0 , 3.0 , 0.0 ]], dtype = "float64"
91+ )
92+
93+ out = pt .unique (
94+ a ,
95+ return_index = return_index ,
96+ return_inverse = return_inverse ,
97+ return_counts = return_counts ,
98+ axis = 0 ,
99+ )
100+ fgraph = FunctionGraph ([a ], [out [0 ] if isinstance (out , list ) else out ])
101+ compare_pytorch_and_py (fgraph , [test_value ])
0 commit comments