3030def test_passing_no_broadcast (
3131 spec : list [tuple [int | str , ...]], actual : list [tuple [int | str , ...]]
3232):
33- assert Desc . check_shapes (
34- * [( s , Desc ( dtype = float , shape = a )) for s , a in zip (spec , actual )]
35- )
33+ spec = { var : shape for var , shape in zip ( "abcdefg" , spec )}
34+ actual = { var : shape for var , shape in zip ("abcdefg" , actual )}
35+ Desc . validate_shapes ( spec , actual )
3636
3737
3838@pytest .mark .parametrize (
@@ -55,9 +55,10 @@ def test_passing_no_broadcast(
5555def test_failing_no_broadcast (
5656 spec : list [tuple [int | str , ...]], actual : list [tuple [int | str , ...]]
5757):
58- assert not Desc .check_shapes (
59- * [(s , Desc (dtype = float , shape = a )) for s , a in zip (spec , actual )]
60- )
58+ spec = {var : shape for var , shape in zip ("abcdefg" , spec )}
59+ actual = {var : shape for var , shape in zip ("abcdefg" , actual )}
60+ with pytest .raises (ValueError ):
61+ Desc .validate_shapes (spec , actual )
6162
6263
6364@pytest .mark .parametrize (
@@ -90,9 +91,9 @@ def test_failing_no_broadcast(
9091def test_passing_broadcast (
9192 spec : list [tuple [int | str , ...]], actual : list [tuple [int | str , ...]]
9293):
93- assert Desc . check_shapes (
94- * [( s , Desc ( dtype = float , shape = a )) for s , a in zip (spec , actual )], broadcast = True
95- )
94+ spec = { var : shape for var , shape in zip ( "abcdefg" , spec )}
95+ actual = { var : shape for var , shape in zip ("abcdefg" , actual )}
96+ Desc . validate_shapes ( spec , actual , broadcast = True )
9697
9798
9899@pytest .mark .parametrize (
@@ -113,6 +114,20 @@ def test_passing_broadcast(
113114def test_failing_broadcast (
114115 spec : list [tuple [int | str , ...]], actual : list [tuple [int | str , ...]]
115116):
116- assert not Desc .check_shapes (
117- * [(s , Desc (dtype = float , shape = a )) for s , a in zip (spec , actual )], broadcast = True
118- )
117+ spec = {var : shape for var , shape in zip ("abcdefg" , spec )}
118+ actual = {var : shape for var , shape in zip ("abcdefg" , actual )}
119+ with pytest .raises (ValueError ):
120+ Desc .validate_shapes (spec , actual , broadcast = True )
121+
122+
123+ def test_desc_object ():
124+ spec = {"a" : Desc (("N" ,), float ), "b" : Desc (("N+1" ,), float )}
125+ actual = {"a" : Desc ((3 ,), float ), "b" : Desc ((4 ,), float )}
126+ Desc .validate_shapes (spec , actual )
127+
128+
129+ def test_missing_key ():
130+ spec = {"a" : Desc (("N" ,), float ), "b" : Desc (("N+1" ,), float )}
131+ actual = {"a" : Desc ((3 ,), float )}
132+ with pytest .raises (KeyError ):
133+ Desc .validate_shapes (spec , actual )
0 commit comments