@@ -98,6 +98,42 @@ def test_dag_common_errors(env):
9898 env .assertEqual ("invalid or negative value found in number of keys to LOAD" ,exception .__str__ ())
9999
100100
101+ def test_dagro_common_errors (env ):
102+ con = env .getConnection ()
103+
104+ # ERR unsupported command within DAG
105+ try :
106+ command = "AI.DAGRUNRO |> " \
107+ "AI.DONTEXIST tensor1 FLOAT 1 2 VALUES 5 10"
108+
109+ ret = con .execute_command (command )
110+ except Exception as e :
111+ exception = e
112+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
113+ env .assertEqual ("ERR unsupported command within DAG" ,exception .__str__ ())
114+
115+ # ERR wrong number of arguments for 'AI.DAGRUN' command
116+ try :
117+ command = "AI.DAGRUNRO "
118+
119+ ret = con .execute_command (command )
120+ except Exception as e :
121+ exception = e
122+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
123+ env .assertEqual ("wrong number of arguments for 'AI.DAGRUNRO' command" ,exception .__str__ ())
124+
125+ # ERR invalid or negative value found in number of keys to LOAD
126+ try :
127+ command = "AI.DAGRUNRO LOAD notnumber |> " \
128+ "AI.TENSORSET tensor1 FLOAT 1 2 VALUES 5 10"
129+
130+ ret = con .execute_command (command )
131+ except Exception as e :
132+ exception = e
133+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
134+ env .assertEqual ("invalid or negative value found in number of keys to LOAD" ,exception .__str__ ())
135+
136+
101137def test_dag_modelrun_financialNet_errors (env ):
102138 con = env .getConnection ()
103139
@@ -112,7 +148,6 @@ def test_dag_modelrun_financialNet_errors(env):
112148 'FLOAT' , 1 , 256 ,
113149 'BLOB' , creditcard_referencedata [0 ].tobytes ())
114150 env .assertEqual (ret , b'OK' )
115-
116151
117152 # ERR wrong number of inputs
118153 try :
@@ -131,7 +166,6 @@ def test_dag_modelrun_financialNet_errors(env):
131166 env .assertEqual (type (exception ), redis .exceptions .ResponseError )
132167 env .assertEqual ("ERR unsupported command within DAG" ,exception .__str__ ())
133168
134-
135169
136170def test_dag_local_tensorset (env ):
137171 con = env .getConnection ()
@@ -147,6 +181,22 @@ def test_dag_local_tensorset(env):
147181 ret = con .execute_command ("EXISTS volatile_tensor" )
148182 env .assertEqual (ret , 0 )
149183
184+
185+ def test_dagro_local_tensorset (env ):
186+ con = env .getConnection ()
187+
188+ command = "AI.DAGRUNRO " \
189+ "AI.TENSORSET volatile_tensor1 FLOAT 1 2 VALUES 5 10 |> " \
190+ "AI.TENSORSET volatile_tensor2 FLOAT 1 2 VALUES 5 10 "
191+
192+ ret = con .execute_command (command )
193+ env .assertEqual (ret , [b'OK' ,b'OK' ])
194+
195+ # assert that transaction tensor does not exist
196+ ret = con .execute_command ("EXISTS volatile_tensor" )
197+ env .assertEqual (ret , 0 )
198+
199+
150200def test_dag_local_tensorset_persist (env ):
151201 con = env .getConnection ()
152202
@@ -165,6 +215,21 @@ def test_dag_local_tensorset_persist(env):
165215 env .assertEqual (ret , [b'dtype' , b'FLOAT' , b'shape' , [1 , 2 ], b'values' , [b'5' , b'10' ]])
166216
167217
218+ def test_dagro_local_tensorset_persist (env ):
219+ con = env .getConnection ()
220+
221+ command = "AI.DAGRUNRO " \
222+ "PERSIST 1 tensor1 |> " \
223+ "AI.TENSORSET tensor1 FLOAT 1 2 VALUES 5 10"
224+
225+ try :
226+ con .execute_command (command )
227+ except Exception as e :
228+ exception = e
229+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
230+ env .assertEqual ("PERSIST cannot be specified in a read-only DAG" , exception .__str__ ())
231+
232+
168233def test_dag_multilocal_tensorset_persist (env ):
169234 con = env .getConnection ()
170235
@@ -197,6 +262,7 @@ def test_dag_multilocal_tensorset_persist(env):
197262 ret = con .execute_command ("AI.TENSORGET tensor3 META VALUES" )
198263 env .assertEqual (ret , [b'dtype' , b'FLOAT' , b'shape' , [1 , 2 ], b'values' , [b'5' , b'10' ]])
199264
265+
200266def test_dag_local_tensorset_tensorget_persist (env ):
201267 con = env .getConnection ()
202268
@@ -282,6 +348,20 @@ def test_dag_keyspace_tensorget(env):
282348 env .assertEqual (ret , [[b'5' , b'10' ]])
283349
284350
351+ def test_dagro_keyspace_tensorget (env ):
352+ con = env .getConnection ()
353+
354+ ret = con .execute_command (
355+ "AI.TENSORSET persisted_tensor FLOAT 1 2 VALUES 5 10" )
356+ env .assertEqual (ret , b'OK' )
357+
358+ command = "AI.DAGRUNRO LOAD 1 persisted_tensor |> " \
359+ "AI.TENSORGET persisted_tensor VALUES"
360+
361+ ret = con .execute_command (command )
362+ env .assertEqual (ret , [[b'5' , b'10' ]])
363+
364+
285365def test_dag_keyspace_and_localcontext_tensorget (env ):
286366 con = env .getConnection ()
287367
@@ -337,6 +417,7 @@ def test_dag_modelrun_financialNet_separate_tensorget(env):
337417 env .assertEqual (ret , 0 )
338418 tensor_number = tensor_number + 1
339419
420+
340421def test_dag_modelrun_financialNet (env ):
341422 con = env .getConnection ()
342423
@@ -373,6 +454,7 @@ def test_dag_modelrun_financialNet(env):
373454 env .assertEqual (ret , 0 )
374455 tensor_number = tensor_number + 1
375456
457+
376458def test_dag_modelrun_financialNet_no_writes (env ):
377459 con = env .getConnection ()
378460
@@ -422,7 +504,7 @@ def test_dag_modelrun_financialNet_no_writes(env):
422504 tensor_number = tensor_number + 1
423505
424506
425- def test_dag_modelrun_financialNet_no_writes_multiple_modelruns (env ):
507+ def test_dagro_modelrun_financialNet_no_writes_multiple_modelruns (env ):
426508 con = env .getConnection ()
427509
428510 model_pb , creditcard_transactions , creditcard_referencedata = load_creditcardfraud_data (
@@ -442,7 +524,7 @@ def test_dag_modelrun_financialNet_no_writes_multiple_modelruns(env):
442524 tensor_number = 1
443525 for transaction_tensor in creditcard_transactions :
444526 ret = con .execute_command (
445- 'AI.DAGRUN ' , 'LOAD' , '1' , 'referenceTensor:{}' .format (tensor_number ), '|>' ,
527+ 'AI.DAGRUNRO ' , 'LOAD' , '1' , 'referenceTensor:{}' .format (tensor_number ), '|>' ,
446528 'AI.TENSORSET' , 'transactionTensor:{}' .format (tensor_number ), 'FLOAT' , 1 , 30 ,'BLOB' , transaction_tensor .tobytes (), '|>' ,
447529 'AI.MODELRUN' , 'financialNet' ,
448530 'INPUTS' , 'transactionTensor:{}' .format (tensor_number ), 'referenceTensor:{}' .format (tensor_number ),
0 commit comments