@@ -44,25 +44,24 @@ def test_debugprint_sitsot():
4444 │ │ │ │ │ │ └─ 1.0 [id O]
4545 │ │ │ │ │ └─ 0 [id P]
4646 │ │ │ │ └─ Subtensor{i} [id Q]
47- │ │ │ │ ├─ Shape [id R]
48- │ │ │ │ │ └─ Unbroadcast{0} [id J]
49- │ │ │ │ │ └─ ···
50- │ │ │ │ └─ 1 [id S]
47+ │ │ │ │ ├─ Shape [id I]
48+ │ │ │ │ │ └─ ···
49+ │ │ │ │ └─ 1 [id R]
5150 │ │ │ ├─ Unbroadcast{0} [id J]
5251 │ │ │ │ └─ ···
53- │ │ │ └─ ScalarFromTensor [id T ]
52+ │ │ │ └─ ScalarFromTensor [id S ]
5453 │ │ │ └─ Subtensor{i} [id H]
5554 │ │ │ └─ ···
5655 │ │ └─ A [id M] (outer_in_non_seqs-0)
57- │ └─ 1 [id U ]
58- └─ -1 [id V ]
56+ │ └─ 1 [id T ]
57+ └─ -1 [id U ]
5958
6059 Inner graphs:
6160
6261 Scan{scan_fn, while_loop=False, inplace=none} [id C]
63- ← Mul [id W ] (inner_out_sit_sot-0)
64- ├─ *0-<Vector(float64, shape=(?,))> [id X ] -> [id E] (inner_in_sit_sot-0)
65- └─ *1-<Vector(float64, shape=(?,))> [id Y ] -> [id M] (inner_in_non_seqs-0)
62+ ← Mul [id V ] (inner_out_sit_sot-0)
63+ ├─ *0-<Vector(float64, shape=(?,))> [id W ] -> [id E] (inner_in_sit_sot-0)
64+ └─ *1-<Vector(float64, shape=(?,))> [id X ] -> [id M] (inner_in_non_seqs-0)
6665 """
6766
6867 for truth , out in zip (expected_output .split ("\n " ), lines , strict = True ):
@@ -103,25 +102,24 @@ def test_debugprint_sitsot_no_extra_info():
103102 │ │ │ │ │ │ └─ 1.0 [id O]
104103 │ │ │ │ │ └─ 0 [id P]
105104 │ │ │ │ └─ Subtensor{i} [id Q]
106- │ │ │ │ ├─ Shape [id R]
107- │ │ │ │ │ └─ Unbroadcast{0} [id J]
108- │ │ │ │ │ └─ ···
109- │ │ │ │ └─ 1 [id S]
105+ │ │ │ │ ├─ Shape [id I]
106+ │ │ │ │ │ └─ ···
107+ │ │ │ │ └─ 1 [id R]
110108 │ │ │ ├─ Unbroadcast{0} [id J]
111109 │ │ │ │ └─ ···
112- │ │ │ └─ ScalarFromTensor [id T ]
110+ │ │ │ └─ ScalarFromTensor [id S ]
113111 │ │ │ └─ Subtensor{i} [id H]
114112 │ │ │ └─ ···
115113 │ │ └─ A [id M]
116- │ └─ 1 [id U ]
117- └─ -1 [id V ]
114+ │ └─ 1 [id T ]
115+ └─ -1 [id U ]
118116
119117 Inner graphs:
120118
121119 Scan{scan_fn, while_loop=False, inplace=none} [id C]
122- ← Mul [id W ]
123- ├─ *0-<Vector(float64, shape=(?,))> [id X ] -> [id E]
124- └─ *1-<Vector(float64, shape=(?,))> [id Y ] -> [id M]
120+ ← Mul [id V ]
121+ ├─ *0-<Vector(float64, shape=(?,))> [id W ] -> [id E]
122+ └─ *1-<Vector(float64, shape=(?,))> [id X ] -> [id M]
125123 """
126124
127125 for truth , out in zip (expected_output .split ("\n " ), lines , strict = True ):
@@ -288,25 +286,24 @@ def compute_A_k(A, k):
288286 │ │ │ │ │ │ │ └─ 1.0 [id BQ]
289287 │ │ │ │ │ │ └─ 0 [id BR]
290288 │ │ │ │ │ └─ Subtensor{i} [id BS]
291- │ │ │ │ │ ├─ Shape [id BT]
292- │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
293- │ │ │ │ │ │ └─ ···
294- │ │ │ │ │ └─ 1 [id BU]
289+ │ │ │ │ │ ├─ Shape [id BK]
290+ │ │ │ │ │ │ └─ ···
291+ │ │ │ │ │ └─ 1 [id BT]
295292 │ │ │ │ ├─ Unbroadcast{0} [id BL]
296293 │ │ │ │ │ └─ ···
297- │ │ │ │ └─ ScalarFromTensor [id BV ]
294+ │ │ │ │ └─ ScalarFromTensor [id BU ]
298295 │ │ │ │ └─ Subtensor{i} [id BJ]
299296 │ │ │ │ └─ ···
300297 │ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
301- │ │ └─ 1 [id BW ]
302- │ └─ -1 [id BX ]
303- └─ ExpandDims{axis=0} [id BY ]
304- └─ *1-<Scalar(int64, shape=())> [id BZ ] -> [id U] (inner_in_seqs-1)
298+ │ │ └─ 1 [id BV ]
299+ │ └─ -1 [id BW ]
300+ └─ ExpandDims{axis=0} [id BX ]
301+ └─ *1-<Scalar(int64, shape=())> [id BY ] -> [id U] (inner_in_seqs-1)
305302
306303 Scan{scan_fn, while_loop=False, inplace=none} [id BE]
307- ← Mul [id CA ] (inner_out_sit_sot-0)
308- ├─ *0-<Vector(float64, shape=(?,))> [id CB ] -> [id BG] (inner_in_sit_sot-0)
309- └─ *1-<Vector(float64, shape=(?,))> [id CC ] -> [id BO] (inner_in_non_seqs-0)
304+ ← Mul [id BZ ] (inner_out_sit_sot-0)
305+ ├─ *0-<Vector(float64, shape=(?,))> [id CA ] -> [id BG] (inner_in_sit_sot-0)
306+ └─ *1-<Vector(float64, shape=(?,))> [id CB ] -> [id BO] (inner_in_non_seqs-0)
310307 """
311308
312309 for truth , out in zip (expected_output .split ("\n " ), lines , strict = True ):
@@ -386,27 +383,26 @@ def compute_A_k(A, k):
386383 │ │ │ │ │ │ │ └─ 1.0 [id BR]
387384 │ │ │ │ │ │ └─ 0 [id BS]
388385 │ │ │ │ │ └─ Subtensor{i} [id BT]
389- │ │ │ │ │ ├─ Shape [id BU]
390- │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
391- │ │ │ │ │ │ └─ ···
392- │ │ │ │ │ └─ 1 [id BV]
386+ │ │ │ │ │ ├─ Shape [id BM]
387+ │ │ │ │ │ │ └─ ···
388+ │ │ │ │ │ └─ 1 [id BU]
393389 │ │ │ │ ├─ Unbroadcast{0} [id BN]
394390 │ │ │ │ │ └─ ···
395- │ │ │ │ └─ ScalarFromTensor [id BW ]
391+ │ │ │ │ └─ ScalarFromTensor [id BV ]
396392 │ │ │ │ └─ Subtensor{i} [id BL]
397393 │ │ │ │ └─ ···
398394 │ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
399- │ │ └─ 1 [id BX ]
400- │ └─ -1 [id BY ]
401- └─ ExpandDims{axis=0} [id BZ ]
395+ │ │ └─ 1 [id BW ]
396+ │ └─ -1 [id BX ]
397+ └─ ExpandDims{axis=0} [id BY ]
402398 └─ *1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
403399
404400 Scan{scan_fn, while_loop=False, inplace=none} [id BH]
405- → *0-<Vector(float64, shape=(?,))> [id CA ] -> [id BI] (inner_in_sit_sot-0)
406- → *1-<Vector(float64, shape=(?,))> [id CB ] -> [id BA] (inner_in_non_seqs-0)
407- ← Mul [id CC ] (inner_out_sit_sot-0)
408- ├─ *0-<Vector(float64, shape=(?,))> [id CA ] (inner_in_sit_sot-0)
409- └─ *1-<Vector(float64, shape=(?,))> [id CB ] (inner_in_non_seqs-0)
401+ → *0-<Vector(float64, shape=(?,))> [id BZ ] -> [id BI] (inner_in_sit_sot-0)
402+ → *1-<Vector(float64, shape=(?,))> [id CA ] -> [id BA] (inner_in_non_seqs-0)
403+ ← Mul [id CB ] (inner_out_sit_sot-0)
404+ ├─ *0-<Vector(float64, shape=(?,))> [id BZ ] (inner_in_sit_sot-0)
405+ └─ *1-<Vector(float64, shape=(?,))> [id CA ] (inner_in_non_seqs-0)
410406 """
411407
412408 for truth , out in zip (expected_output .split ("\n " ), lines , strict = True ):
@@ -528,98 +524,97 @@ def test_debugprint_mitmot():
528524 │ │ │ │ │ │ │ │ └─ 1.0 [id R]
529525 │ │ │ │ │ │ │ └─ 0 [id S]
530526 │ │ │ │ │ │ └─ Subtensor{i} [id T]
531- │ │ │ │ │ │ ├─ Shape [id U]
532- │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
533- │ │ │ │ │ │ │ └─ ···
534- │ │ │ │ │ │ └─ 1 [id V]
527+ │ │ │ │ │ │ ├─ Shape [id L]
528+ │ │ │ │ │ │ │ └─ ···
529+ │ │ │ │ │ │ └─ 1 [id U]
535530 │ │ │ │ │ ├─ Unbroadcast{0} [id M]
536531 │ │ │ │ │ │ └─ ···
537- │ │ │ │ │ └─ ScalarFromTensor [id W ]
532+ │ │ │ │ │ └─ ScalarFromTensor [id V ]
538533 │ │ │ │ │ └─ Subtensor{i} [id K]
539534 │ │ │ │ │ └─ ···
540535 │ │ │ │ └─ A [id P] (outer_in_non_seqs-0)
541- │ │ │ └─ 0 [id X ]
542- │ │ └─ 1 [id Y ]
543- │ ├─ Subtensor{:stop} [id Z ] (outer_in_seqs-0)
544- │ │ ├─ Subtensor{::step} [id BA ]
545- │ │ │ ├─ Subtensor{:stop} [id BB ]
536+ │ │ │ └─ 0 [id W ]
537+ │ │ └─ 1 [id X ]
538+ │ ├─ Subtensor{:stop} [id Y ] (outer_in_seqs-0)
539+ │ │ ├─ Subtensor{::step} [id Z ]
540+ │ │ │ ├─ Subtensor{:stop} [id BA ]
546541 │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
547542 │ │ │ │ │ └─ ···
548- │ │ │ │ └─ -1 [id BC ]
549- │ │ │ └─ -1 [id BD ]
550- │ │ └─ ScalarFromTensor [id BE ]
543+ │ │ │ │ └─ -1 [id BB ]
544+ │ │ │ └─ -1 [id BC ]
545+ │ │ └─ ScalarFromTensor [id BD ]
551546 │ │ └─ Sub [id C]
552547 │ │ └─ ···
553- │ ├─ Subtensor{:stop} [id BF ] (outer_in_seqs-1)
554- │ │ ├─ Subtensor{:stop} [id BG ]
555- │ │ │ ├─ Subtensor{::step} [id BH ]
548+ │ ├─ Subtensor{:stop} [id BE ] (outer_in_seqs-1)
549+ │ │ ├─ Subtensor{:stop} [id BF ]
550+ │ │ │ ├─ Subtensor{::step} [id BG ]
556551 │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
557552 │ │ │ │ │ └─ ···
558- │ │ │ │ └─ -1 [id BI ]
559- │ │ │ └─ -1 [id BJ ]
560- │ │ └─ ScalarFromTensor [id BK ]
553+ │ │ │ │ └─ -1 [id BH ]
554+ │ │ │ └─ -1 [id BI ]
555+ │ │ └─ ScalarFromTensor [id BJ ]
561556 │ │ └─ Sub [id C]
562557 │ │ └─ ···
563- │ ├─ Subtensor{::step} [id BL ] (outer_in_mit_mot-0)
564- │ │ ├─ IncSubtensor{start:} [id BM ]
565- │ │ │ ├─ Second [id BN ]
558+ │ ├─ Subtensor{::step} [id BK ] (outer_in_mit_mot-0)
559+ │ │ ├─ IncSubtensor{start:} [id BL ]
560+ │ │ │ ├─ Second [id BM ]
566561 │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
567562 │ │ │ │ │ └─ ···
568- │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO ]
569- │ │ │ │ └─ 0.0 [id BP ]
570- │ │ │ ├─ IncSubtensor{i} [id BQ ]
571- │ │ │ │ ├─ Second [id BR ]
572- │ │ │ │ │ ├─ Subtensor{start:} [id BS ]
563+ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BN ]
564+ │ │ │ │ └─ 0.0 [id BO ]
565+ │ │ │ ├─ IncSubtensor{i} [id BP ]
566+ │ │ │ │ ├─ Second [id BQ ]
567+ │ │ │ │ │ ├─ Subtensor{start:} [id BR ]
573568 │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
574569 │ │ │ │ │ │ │ └─ ···
575- │ │ │ │ │ │ └─ 1 [id BT ]
576- │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU ]
577- │ │ │ │ │ └─ 0.0 [id BV ]
578- │ │ │ │ ├─ Second [id BW ]
579- │ │ │ │ │ ├─ Subtensor{i} [id BX ]
580- │ │ │ │ │ │ ├─ Subtensor{start:} [id BS ]
570+ │ │ │ │ │ │ └─ 1 [id BS ]
571+ │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BT ]
572+ │ │ │ │ │ └─ 0.0 [id BU ]
573+ │ │ │ │ ├─ Second [id BV ]
574+ │ │ │ │ │ ├─ Subtensor{i} [id BW ]
575+ │ │ │ │ │ │ ├─ Subtensor{start:} [id BR ]
581576 │ │ │ │ │ │ │ └─ ···
582- │ │ │ │ │ │ └─ -1 [id BY ]
583- │ │ │ │ │ └─ ExpandDims{axis=0} [id BZ ]
584- │ │ │ │ │ └─ Second [id CA ]
585- │ │ │ │ │ ├─ Sum{axes=None} [id CB ]
586- │ │ │ │ │ │ └─ Subtensor{i} [id BX ]
577+ │ │ │ │ │ │ └─ -1 [id BX ]
578+ │ │ │ │ │ └─ ExpandDims{axis=0} [id BY ]
579+ │ │ │ │ │ └─ Second [id BZ ]
580+ │ │ │ │ │ ├─ Sum{axes=None} [id CA ]
581+ │ │ │ │ │ │ └─ Subtensor{i} [id BW ]
587582 │ │ │ │ │ │ └─ ···
588- │ │ │ │ │ └─ 1.0 [id CC ]
589- │ │ │ │ └─ -1 [id BY ]
590- │ │ │ └─ 1 [id BT ]
591- │ │ └─ -1 [id CD ]
592- │ ├─ Alloc [id CE ] (outer_in_sit_sot-0)
593- │ │ ├─ 0.0 [id CF ]
594- │ │ ├─ Add [id CG ]
583+ │ │ │ │ │ └─ 1.0 [id CB ]
584+ │ │ │ │ └─ -1 [id BX ]
585+ │ │ │ └─ 1 [id BS ]
586+ │ │ └─ -1 [id CC ]
587+ │ ├─ Alloc [id CD ] (outer_in_sit_sot-0)
588+ │ │ ├─ 0.0 [id CE ]
589+ │ │ ├─ Add [id CF ]
595590 │ │ │ ├─ Sub [id C]
596591 │ │ │ │ └─ ···
597- │ │ │ └─ 1 [id CH ]
598- │ │ └─ Subtensor{i} [id CI ]
599- │ │ ├─ Shape [id CJ ]
592+ │ │ │ └─ 1 [id CG ]
593+ │ │ └─ Subtensor{i} [id CH ]
594+ │ │ ├─ Shape [id CI ]
600595 │ │ │ └─ A [id P]
601- │ │ └─ 0 [id CK ]
596+ │ │ └─ 0 [id CJ ]
602597 │ └─ A [id P] (outer_in_non_seqs-0)
603- └─ -1 [id CL ]
598+ └─ -1 [id CK ]
604599
605600 Inner graphs:
606601
607602 Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
608- ← Add [id CM ] (inner_out_mit_mot-0-0)
609- ├─ Mul [id CN ]
610- │ ├─ *2-<Vector(float64, shape=(?,))> [id CO ] -> [id BL ] (inner_in_mit_mot-0-0)
611- │ └─ *5-<Vector(float64, shape=(?,))> [id CP ] -> [id P] (inner_in_non_seqs-0)
612- └─ *3-<Vector(float64, shape=(?,))> [id CQ ] -> [id BL ] (inner_in_mit_mot-0-1)
613- ← Add [id CR ] (inner_out_sit_sot-0)
614- ├─ Mul [id CS ]
615- │ ├─ *2-<Vector(float64, shape=(?,))> [id CO ] -> [id BL ] (inner_in_mit_mot-0-0)
616- │ └─ *0-<Vector(float64, shape=(?,))> [id CT ] -> [id Z ] (inner_in_seqs-0)
617- └─ *4-<Vector(float64, shape=(?,))> [id CU ] -> [id CE ] (inner_in_sit_sot-0)
603+ ← Add [id CL ] (inner_out_mit_mot-0-0)
604+ ├─ Mul [id CM ]
605+ │ ├─ *2-<Vector(float64, shape=(?,))> [id CN ] -> [id BK ] (inner_in_mit_mot-0-0)
606+ │ └─ *5-<Vector(float64, shape=(?,))> [id CO ] -> [id P] (inner_in_non_seqs-0)
607+ └─ *3-<Vector(float64, shape=(?,))> [id CP ] -> [id BK ] (inner_in_mit_mot-0-1)
608+ ← Add [id CQ ] (inner_out_sit_sot-0)
609+ ├─ Mul [id CR ]
610+ │ ├─ *2-<Vector(float64, shape=(?,))> [id CN ] -> [id BK ] (inner_in_mit_mot-0-0)
611+ │ └─ *0-<Vector(float64, shape=(?,))> [id CS ] -> [id Y ] (inner_in_seqs-0)
612+ └─ *4-<Vector(float64, shape=(?,))> [id CT ] -> [id CD ] (inner_in_sit_sot-0)
618613
619614 Scan{scan_fn, while_loop=False, inplace=none} [id F]
620- ← Mul [id CV ] (inner_out_sit_sot-0)
621- ├─ *0-<Vector(float64, shape=(?,))> [id CT ] -> [id H] (inner_in_sit_sot-0)
622- └─ *1-<Vector(float64, shape=(?,))> [id CW ] -> [id P] (inner_in_non_seqs-0)
615+ ← Mul [id CU ] (inner_out_sit_sot-0)
616+ ├─ *0-<Vector(float64, shape=(?,))> [id CS ] -> [id H] (inner_in_sit_sot-0)
617+ └─ *1-<Vector(float64, shape=(?,))> [id CV ] -> [id P] (inner_in_non_seqs-0)
623618 """
624619
625620 for truth , out in zip (expected_output .split ("\n " ), lines , strict = True ):
0 commit comments