Skip to content

Commit 58cfc3d

Browse files
Merge pull request #334 from RedisAI/dagro
Add DAGRUNRO command
2 parents 56e67de + f3c03ec commit 58cfc3d

File tree

2 files changed

+199
-4
lines changed

2 files changed

+199
-4
lines changed

src/redisai.c

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,115 @@ int RedisAI_DagRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
986986
return REDISMODULE_OK;
987987
}
988988

989+
/**
990+
* AI.DAGRUNRO [LOAD <nkeys> key1 key2... ] |> [COMMAND1] |> [COMMAND2] |> [COMMANDN]
991+
*
992+
* Read-only (no PERSIST) DAG execution.
993+
* The request is queued and evaded asynchronously from a separate thread. The
994+
* client blocks until the computation finishes.
995+
*/
996+
int RedisAI_DagRunRO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
997+
int argc) {
998+
if (argc < 4) return RedisModule_WrongArity(ctx);
999+
1000+
RedisAI_RunInfo *rinfo = NULL;
1001+
if (RAI_InitRunInfo(&rinfo) == REDISMODULE_ERR) {
1002+
return RedisModule_ReplyWithError(ctx, "ERR Unable to allocate the memory and initialise the RedisAI_RunInfo structure");
1003+
}
1004+
rinfo->use_local_context = 1;
1005+
RAI_DagOp* currentDagOp = NULL;
1006+
RAI_InitDagOp(&currentDagOp);
1007+
array_append(rinfo->dagOps,currentDagOp);
1008+
1009+
int loadFlag=0;
1010+
int chainingOpCount=0;
1011+
const char* deviceStr = NULL;
1012+
1013+
for (size_t argpos = 1; argpos <= argc - 1; argpos++) {
1014+
const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
1015+
if (!strcasecmp(arg_string, "LOAD")) {
1016+
loadFlag=1;
1017+
const int parse_result = RAI_parseDAGLoadArgs(
1018+
ctx, &argv[argpos], argc - argpos,&(rinfo->dagTensorsLoadedContext), &(rinfo->dagTensorsContext), "|>");
1019+
if (parse_result > 0) {
1020+
argpos += parse_result - 1;
1021+
} else {
1022+
RAI_FreeRunInfo(ctx,rinfo);
1023+
return REDISMODULE_ERR;
1024+
}
1025+
} else if (!strcasecmp(arg_string, "PERSIST")) {
1026+
return RedisModule_ReplyWithError(ctx,"ERR PERSIST cannot be specified in a read-only DAG");;
1027+
} else if (!strcasecmp(arg_string, "|>")) {
1028+
// on the first pipe operator, if LOAD or PERSIST were used, we've already
1029+
// allocated memory
1030+
if (!(loadFlag == 1 && chainingOpCount == 0)) {
1031+
rinfo->dagNumberCommands++;
1032+
RAI_DagOp *currentDagOp = NULL;
1033+
RAI_InitDagOp(&currentDagOp);
1034+
array_append(rinfo->dagOps, currentDagOp);
1035+
}
1036+
chainingOpCount++;
1037+
} else {
1038+
if (!strcasecmp(arg_string, "AI.TENSORGET")) {
1039+
rinfo->dagOps[rinfo->dagNumberCommands]->commandType = REDISAI_DAG_CMD_TENSORGET;
1040+
}
1041+
if (!strcasecmp(arg_string, "AI.TENSORSET")) {
1042+
rinfo->dagOps[rinfo->dagNumberCommands]->commandType = REDISAI_DAG_CMD_TENSORSET;
1043+
}
1044+
if (!strcasecmp(arg_string, "AI.MODELRUN")) {
1045+
if (argc - 2 < argpos) {
1046+
return RedisModule_WrongArity(ctx);
1047+
}
1048+
rinfo->dagOps[rinfo->dagNumberCommands]->commandType = REDISAI_DAG_CMD_MODELRUN;
1049+
RAI_Model *mto;
1050+
RedisModuleKey *modelKey;
1051+
const int status = RAI_GetModelFromKeyspace(ctx, argv[argpos+1], &modelKey,
1052+
&mto, REDISMODULE_READ);
1053+
if (status == REDISMODULE_ERR) {
1054+
RAI_FreeRunInfo(ctx,rinfo);
1055+
return REDISMODULE_ERR;
1056+
}
1057+
if (deviceStr==NULL){
1058+
deviceStr=mto->devicestr;
1059+
}else{
1060+
// If the device strings are not equivalent, reply with error ( for now )
1061+
if(strcasecmp(mto->devicestr, deviceStr)!=0){
1062+
RAI_FreeRunInfo(ctx,rinfo);
1063+
return RedisModule_ReplyWithError(ctx,"ERR multi-device DAGs not supported yet");;
1064+
}
1065+
}
1066+
rinfo->dagOps[rinfo->dagNumberCommands]->runkey = argv[argpos];
1067+
rinfo->dagOps[rinfo->dagNumberCommands]->mctx =
1068+
RAI_ModelRunCtxCreate(mto);
1069+
}
1070+
RedisModule_RetainString(NULL, argv[argpos]);
1071+
array_append(rinfo->dagOps[rinfo->dagNumberCommands]->argv, argv[argpos]);
1072+
rinfo->dagOps[rinfo->dagNumberCommands]->argc++;
1073+
}
1074+
}
1075+
1076+
RunQueueInfo *run_queue_info = NULL;
1077+
// If there was no MODELRUN on the DAG, we default all ops to CPU
1078+
if(deviceStr==NULL){
1079+
deviceStr="CPU";
1080+
}
1081+
// If the queue does not exist, initialize it
1082+
if (ensureRunQueue(deviceStr,&run_queue_info) == REDISMODULE_ERR) {
1083+
RAI_FreeRunInfo(ctx,rinfo);
1084+
return RedisModule_ReplyWithError(
1085+
ctx, "ERR Queue not initialized for device");
1086+
}
1087+
1088+
rinfo->client = RedisModule_BlockClient(ctx, RedisAI_DagRun_Reply, NULL,
1089+
NULL, 0);
1090+
1091+
pthread_mutex_lock(&run_queue_info->run_queue_mutex);
1092+
queuePush(run_queue_info->run_queue, rinfo);
1093+
pthread_cond_signal(&run_queue_info->queue_condition_var);
1094+
pthread_mutex_unlock(&run_queue_info->run_queue_mutex);
1095+
1096+
return REDISMODULE_OK;
1097+
}
9891098
#define EXECUTION_PLAN_FREE_MSG 100
9901099

9911100
#define REGISTER_API(name, ctx) \
@@ -1135,6 +1244,10 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
11351244
== REDISMODULE_ERR)
11361245
return REDISMODULE_ERR;
11371246

1247+
if (RedisModule_CreateCommand(ctx, "ai.dagrunro", RedisAI_DagRunRO_RedisCommand, "readonly", 3, 3, 1)
1248+
== REDISMODULE_ERR)
1249+
return REDISMODULE_ERR;
1250+
11381251
// Default configs
11391252
RAI_BackendsPath = NULL;
11401253
perqueueThreadPoolSize = REDISAI_DEFAULT_THREADS_PER_QUEUE;

test/tests_dag.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,42 @@ def test_dag_common_errors(env):
9898
env.assertEqual("invalid or negative value found in number of keys to LOAD",exception.__str__())
9999

100100

101+
def test_dagro_common_errors(env):
102+
con = env.getConnection()
103+
104+
# ERR unsupported command within DAG
105+
try:
106+
command = "AI.DAGRUNRO |> "\
107+
"AI.DONTEXIST tensor1 FLOAT 1 2 VALUES 5 10"
108+
109+
ret = con.execute_command(command)
110+
except Exception as e:
111+
exception = e
112+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
113+
env.assertEqual("ERR unsupported command within DAG",exception.__str__())
114+
115+
# ERR wrong number of arguments for 'AI.DAGRUN' command
116+
try:
117+
command = "AI.DAGRUNRO "
118+
119+
ret = con.execute_command(command)
120+
except Exception as e:
121+
exception = e
122+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
123+
env.assertEqual("wrong number of arguments for 'AI.DAGRUNRO' command",exception.__str__())
124+
125+
# ERR invalid or negative value found in number of keys to LOAD
126+
try:
127+
command = "AI.DAGRUNRO LOAD notnumber |> "\
128+
"AI.TENSORSET tensor1 FLOAT 1 2 VALUES 5 10"
129+
130+
ret = con.execute_command(command)
131+
except Exception as e:
132+
exception = e
133+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
134+
env.assertEqual("invalid or negative value found in number of keys to LOAD",exception.__str__())
135+
136+
101137
def test_dag_modelrun_financialNet_errors(env):
102138
con = env.getConnection()
103139

@@ -112,7 +148,6 @@ def test_dag_modelrun_financialNet_errors(env):
112148
'FLOAT', 1, 256,
113149
'BLOB', creditcard_referencedata[0].tobytes())
114150
env.assertEqual(ret, b'OK')
115-
116151

117152
# ERR wrong number of inputs
118153
try:
@@ -131,7 +166,6 @@ def test_dag_modelrun_financialNet_errors(env):
131166
env.assertEqual(type(exception), redis.exceptions.ResponseError)
132167
env.assertEqual("ERR unsupported command within DAG",exception.__str__())
133168

134-
135169

136170
def test_dag_local_tensorset(env):
137171
con = env.getConnection()
@@ -147,6 +181,22 @@ def test_dag_local_tensorset(env):
147181
ret = con.execute_command("EXISTS volatile_tensor")
148182
env.assertEqual(ret, 0 )
149183

184+
185+
def test_dagro_local_tensorset(env):
186+
con = env.getConnection()
187+
188+
command = "AI.DAGRUNRO "\
189+
"AI.TENSORSET volatile_tensor1 FLOAT 1 2 VALUES 5 10 |> "\
190+
"AI.TENSORSET volatile_tensor2 FLOAT 1 2 VALUES 5 10 "
191+
192+
ret = con.execute_command(command)
193+
env.assertEqual(ret, [b'OK',b'OK'])
194+
195+
# assert that transaction tensor does not exist
196+
ret = con.execute_command("EXISTS volatile_tensor")
197+
env.assertEqual(ret, 0 )
198+
199+
150200
def test_dag_local_tensorset_persist(env):
151201
con = env.getConnection()
152202

@@ -165,6 +215,21 @@ def test_dag_local_tensorset_persist(env):
165215
env.assertEqual(ret, [b'dtype', b'FLOAT', b'shape', [1, 2], b'values', [b'5', b'10']])
166216

167217

218+
def test_dagro_local_tensorset_persist(env):
219+
con = env.getConnection()
220+
221+
command = "AI.DAGRUNRO "\
222+
"PERSIST 1 tensor1 |> "\
223+
"AI.TENSORSET tensor1 FLOAT 1 2 VALUES 5 10"
224+
225+
try:
226+
con.execute_command(command)
227+
except Exception as e:
228+
exception = e
229+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
230+
env.assertEqual("PERSIST cannot be specified in a read-only DAG", exception.__str__())
231+
232+
168233
def test_dag_multilocal_tensorset_persist(env):
169234
con = env.getConnection()
170235

@@ -197,6 +262,7 @@ def test_dag_multilocal_tensorset_persist(env):
197262
ret = con.execute_command("AI.TENSORGET tensor3 META VALUES")
198263
env.assertEqual(ret, [b'dtype', b'FLOAT', b'shape', [1, 2], b'values', [b'5', b'10']])
199264

265+
200266
def test_dag_local_tensorset_tensorget_persist(env):
201267
con = env.getConnection()
202268

@@ -282,6 +348,20 @@ def test_dag_keyspace_tensorget(env):
282348
env.assertEqual(ret, [[b'5', b'10']])
283349

284350

351+
def test_dagro_keyspace_tensorget(env):
352+
con = env.getConnection()
353+
354+
ret = con.execute_command(
355+
"AI.TENSORSET persisted_tensor FLOAT 1 2 VALUES 5 10")
356+
env.assertEqual(ret, b'OK')
357+
358+
command = "AI.DAGRUNRO LOAD 1 persisted_tensor |> "\
359+
"AI.TENSORGET persisted_tensor VALUES"
360+
361+
ret = con.execute_command(command)
362+
env.assertEqual(ret, [[b'5', b'10']])
363+
364+
285365
def test_dag_keyspace_and_localcontext_tensorget(env):
286366
con = env.getConnection()
287367

@@ -337,6 +417,7 @@ def test_dag_modelrun_financialNet_separate_tensorget(env):
337417
env.assertEqual(ret, 0 )
338418
tensor_number = tensor_number + 1
339419

420+
340421
def test_dag_modelrun_financialNet(env):
341422
con = env.getConnection()
342423

@@ -373,6 +454,7 @@ def test_dag_modelrun_financialNet(env):
373454
env.assertEqual(ret, 0 )
374455
tensor_number = tensor_number + 1
375456

457+
376458
def test_dag_modelrun_financialNet_no_writes(env):
377459
con = env.getConnection()
378460

@@ -422,7 +504,7 @@ def test_dag_modelrun_financialNet_no_writes(env):
422504
tensor_number = tensor_number + 1
423505

424506

425-
def test_dag_modelrun_financialNet_no_writes_multiple_modelruns(env):
507+
def test_dagro_modelrun_financialNet_no_writes_multiple_modelruns(env):
426508
con = env.getConnection()
427509

428510
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
@@ -442,7 +524,7 @@ def test_dag_modelrun_financialNet_no_writes_multiple_modelruns(env):
442524
tensor_number = 1
443525
for transaction_tensor in creditcard_transactions:
444526
ret = con.execute_command(
445-
'AI.DAGRUN', 'LOAD', '1', 'referenceTensor:{}'.format(tensor_number), '|>',
527+
'AI.DAGRUNRO', 'LOAD', '1', 'referenceTensor:{}'.format(tensor_number), '|>',
446528
'AI.TENSORSET', 'transactionTensor:{}'.format(tensor_number), 'FLOAT', 1, 30,'BLOB', transaction_tensor.tobytes(), '|>',
447529
'AI.MODELRUN', 'financialNet',
448530
'INPUTS', 'transactionTensor:{}'.format(tensor_number), 'referenceTensor:{}'.format(tensor_number),

0 commit comments

Comments
 (0)