Skip to content

Commit 12ed36c

Browse files
AI.DAGRUN and AI.DAGRUN_RO run stats (#336)
* [fix] dagrun now adding statistics to runstats * [add] added ai.info testing for dagrun * [add] running dagrun tests on gpu * [add] ensuring that dag tests with TF respect the WITH_TF rules. excluding util and rmutil from coverage report * [fix] ensuring sync prior to save on tests
1 parent 6626021 commit 12ed36c

File tree

12 files changed

+129
-54
lines changed

12 files changed

+129
-54
lines changed

opt/Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ SRCDIR=..
8585
BINDIR=$(BINROOT)/src
8686
DEPS_DIR=$(ROOT)/deps/$(OS)-$(ARCH)-$(DEVICE)
8787
INSTALL_DIR=$(BINROOT)/install-$(DEVICE)
88+
COV_EXCLUDE=\
89+
'./rmutil/*'\
90+
'./util/*'
8891

8992
TARGET=$(BINDIR)/redisai.so
9093
INSTALLED_TARGET=$(INSTALL_DIR)/redisai.so

src/dag.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,15 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv,
144144

145145
case REDISAI_DAG_CMD_MODELRUN: {
146146
rinfo->dagReplyLength++;
147+
struct RedisAI_RunStats *rstats = NULL;
148+
const char *runkey =
149+
RedisModule_StringPtrLen(currentOp->runkey, NULL);
150+
RAI_GetRunStats(runkey,&rstats);
147151
if (currentOp->result == REDISMODULE_ERR) {
152+
RAI_SafeAddDataPoint(rstats,0,1,1,0);
148153
RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline);
149154
} else {
155+
RAI_SafeAddDataPoint(rstats,currentOp->duration_us,1,0,0);
150156
RedisModule_ReplyWithSimpleString(ctx, "OK");
151157
}
152158
break;

src/model.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,6 @@ int RAI_ModelSerialize(RAI_Model *model, char **buffer, size_t *len, RAI_Error *
494494

495495
int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx,
496496
RedisModuleString **argv, int argc,
497-
// RedisAI_RunInfo **rinfo,
498497
RAI_ModelRunCtx **mctx,
499498
RedisModuleString ***outkeys,
500499
RAI_Model **mto, int useLocalContext,

src/model_script_run_session.c

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,12 @@ int RAI_ModelRunScriptRunReply(RedisModuleCtx *ctx, RedisModuleString **argv,
8686
struct RedisAI_RunInfo *rinfo = RedisModule_GetBlockedClientPrivateData(ctx);
8787

8888
const char *runkey = RedisModule_StringPtrLen(rinfo->runkey, NULL);
89-
AI_dictEntry *stats_entry = AI_dictFind(run_stats, runkey);
90-
9189
struct RedisAI_RunStats *rstats = NULL;
92-
if (stats_entry) {
93-
rstats = AI_dictGetVal(stats_entry);
94-
}
90+
RAI_GetRunStats(runkey, &rstats);
9591

9692
if (rinfo->result == REDISMODULE_ERR) {
9793
RedisModule_Log(ctx, "warning", "ERR %s", rinfo->err->detail);
98-
if (rstats) {
99-
rstats->calls += 1;
100-
rstats->nerrors += 1;
101-
}
94+
RAI_SafeAddDataPoint(rstats,0,1,1,0);
10295
int ret = RedisModule_ReplyWithError(ctx, rinfo->err->detail_oneline);
10396
RAI_FreeRunInfo(ctx, rinfo);
10497
return ret;
@@ -119,10 +112,7 @@ int RAI_ModelRunScriptRunReply(RedisModuleCtx *ctx, RedisModuleString **argv,
119112
REDISMODULE_READ | REDISMODULE_WRITE);
120113
if (status == REDISMODULE_ERR) {
121114
RAI_FreeRunInfo(ctx, rinfo);
122-
if (rstats) {
123-
rstats->calls += 1;
124-
rstats->nerrors += 1;
125-
}
115+
RAI_SafeAddDataPoint(rstats,0,1,1,0);
126116
return REDISMODULE_ERR;
127117
}
128118
RAI_Tensor *t = NULL;
@@ -144,16 +134,7 @@ int RAI_ModelRunScriptRunReply(RedisModuleCtx *ctx, RedisModuleString **argv,
144134
RedisAI_ReplicateTensorSet(ctx, rinfo->outkeys[i], t);
145135
}
146136
}
147-
148-
if (rstats) {
149-
rstats->duration_us += rinfo->duration_us;
150-
rstats->calls += 1;
151-
152-
if (rinfo->mctx) {
153-
rstats->samples += batch_size;
154-
}
155-
}
156-
137+
RAI_SafeAddDataPoint(rstats,rinfo->duration_us,1,0,batch_size);
157138
RAI_FreeRunInfo(ctx, rinfo);
158139
return RedisModule_ReplyWithSimpleString(ctx, "OK");
159140
}

src/redisai.c

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -772,37 +772,21 @@ int RedisAI_ScriptScan_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg
772772
return REDISMODULE_OK;
773773
}
774774

775-
/**
775+
/**
776776
* AI.INFO <model_or_script_key> [RESETSTAT]
777777
*/
778778
int RedisAI_Info_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
779-
RedisModule_AutoMemory(ctx);
780-
781779
if (argc != 2 && argc != 3) return RedisModule_WrongArity(ctx);
782-
783-
ArgsCursor ac;
784-
ArgsCursor_InitRString(&ac, argv+1, argc-1);
785-
786-
const char* runkey;
787-
AC_GetString(&ac, &runkey, NULL, 0);
788-
789-
AI_dictEntry *stats_entry = AI_dictFind(run_stats, runkey);
790-
791-
if (!stats_entry) {
780+
const char *runkey = RedisModule_StringPtrLen(argv[1], NULL);
781+
struct RedisAI_RunStats *rstats = NULL;
782+
if (RAI_GetRunStats(runkey, &rstats) == REDISMODULE_ERR) {
792783
return RedisModule_ReplyWithError(ctx, "ERR cannot find run info for key");
793784
}
794785

795-
struct RedisAI_RunStats *rstats = AI_dictGetVal(stats_entry);
796-
797-
if (!AC_IsAtEnd(&ac)) {
798-
const char* opt;
799-
AC_GetString(&ac, &opt, NULL, 0);
800-
801-
if (strcasecmp(opt, "RESETSTAT") == 0) {
802-
rstats->duration_us = 0;
803-
rstats->samples = 0;
804-
rstats->calls = 0;
805-
rstats->nerrors = 0;
786+
if(argc==3){
787+
const char *subcommand = RedisModule_StringPtrLen(argv[2], NULL);
788+
if (!strcasecmp(subcommand, "RESETSTAT")) {
789+
RAI_ResetRunStats(rstats);
806790
RedisModule_ReplyWithSimpleString(ctx, "OK");
807791
return REDISMODULE_OK;
808792
}
@@ -953,7 +937,7 @@ int RedisAI_DagRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
953937
return RedisModule_ReplyWithError(ctx,"ERR multi-device DAGs not supported yet");;
954938
}
955939
}
956-
rinfo->dagOps[rinfo->dagNumberCommands]->runkey = argv[argpos];
940+
rinfo->dagOps[rinfo->dagNumberCommands]->runkey = argv[argpos+1];
957941
rinfo->dagOps[rinfo->dagNumberCommands]->mctx =
958942
RAI_ModelRunCtxCreate(mto);
959943
}
@@ -1063,7 +1047,7 @@ int RedisAI_DagRunRO_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
10631047
return RedisModule_ReplyWithError(ctx,"ERR multi-device DAGs not supported yet");;
10641048
}
10651049
}
1066-
rinfo->dagOps[rinfo->dagNumberCommands]->runkey = argv[argpos];
1050+
rinfo->dagOps[rinfo->dagNumberCommands]->runkey = argv[argpos+1];
10671051
rinfo->dagOps[rinfo->dagNumberCommands]->mctx =
10681052
RAI_ModelRunCtxCreate(mto);
10691053
}

src/stats.c

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,47 @@ void RAI_RemoveStatsEntry(void* infokey) {
7373
}
7474
}
7575

76+
int RAI_ResetRunStats(struct RedisAI_RunStats* rstats) {
77+
rstats->duration_us = 0;
78+
rstats->samples = 0;
79+
rstats->calls = 0;
80+
rstats->nerrors = 0;
81+
return 0;
82+
}
83+
84+
int RAI_SafeAddDataPoint(struct RedisAI_RunStats* rstats, long long duration,
85+
long long calls, long long errors, long long samples) {
86+
int result = 1;
87+
if (rstats == NULL) {
88+
return result;
89+
} else {
90+
rstats->duration_us += duration;
91+
rstats->calls += calls;
92+
rstats->nerrors += errors;
93+
rstats->samples += samples;
94+
result = 0;
95+
}
96+
return result;
97+
}
98+
7699
void RAI_FreeRunStats(struct RedisAI_RunStats* rstats) {
77100
RedisModule_Free(rstats->devicestr);
78101
RedisModule_Free(rstats->tag);
79102
}
80103

104+
int RAI_GetRunStats(const char* runkey, struct RedisAI_RunStats** rstats) {
105+
int result = 1;
106+
if (run_stats == NULL) {
107+
return result;
108+
}
109+
AI_dictEntry* entry = AI_dictFind(run_stats, runkey);
110+
if (entry) {
111+
*rstats = AI_dictGetVal(entry);
112+
result = 0;
113+
}
114+
return result;
115+
}
116+
81117
void RedisAI_FreeRunStats(RedisModuleCtx* ctx,
82118
struct RedisAI_RunStats* rstats) {
83119
RedisModule_FreeString(ctx, rstats->key);

src/stats.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,35 @@ void RAI_RemoveStatsEntry(void* infokey);
3333
void RAI_ListStatsEntries(RAI_RunType type, long long* nkeys,
3434
RedisModuleString*** keys, const char*** tags);
3535

36+
/**
37+
*
38+
* @param rstats
39+
* @return 0 on success, or 1 if the reset failed
40+
*/
41+
int RAI_ResetRunStats(struct RedisAI_RunStats *rstats);
42+
43+
/**
44+
* Safely add datapoint to the run stats. Protected against null pointer runstats
45+
* @param rstats
46+
* @param duration
47+
* @param calls
48+
* @param errors
49+
* @param samples
50+
* @return 0 on success, or 1 if the addition failed
51+
*/
52+
int RAI_SafeAddDataPoint(struct RedisAI_RunStats* rstats, long long duration, long long calls, long long errors, long long samples );
53+
3654
void RAI_FreeRunStats(struct RedisAI_RunStats* rstats);
3755

56+
57+
/**
58+
*
59+
* @param runkey
60+
* @param rstats
61+
* @return 0 on success, or 1 if the the run stats with runkey does not exist
62+
*/
63+
int RAI_GetRunStats(const char *runkey,struct RedisAI_RunStats **rstats);
64+
3865
void RedisAI_FreeRunStats(RedisModuleCtx* ctx, struct RedisAI_RunStats* rstats);
3966

4067
#endif /* SRC_SATTS_H_ */

test/tests_dag.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import redis
2+
from functools import wraps
3+
import multiprocessing as mp
24

35
from includes import *
46

@@ -135,6 +137,8 @@ def test_dagro_common_errors(env):
135137

136138

137139
def test_dag_modelrun_financialNet_errors(env):
140+
if not TEST_TF:
141+
return
138142
con = env.getConnection()
139143

140144
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
@@ -379,6 +383,8 @@ def test_dag_keyspace_and_localcontext_tensorget(env):
379383

380384

381385
def test_dag_modelrun_financialNet_separate_tensorget(env):
386+
if not TEST_TF:
387+
return
382388
con = env.getConnection()
383389

384390
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
@@ -419,6 +425,8 @@ def test_dag_modelrun_financialNet_separate_tensorget(env):
419425

420426

421427
def test_dag_modelrun_financialNet(env):
428+
if not TEST_TF:
429+
return
422430
con = env.getConnection()
423431

424432
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
@@ -456,6 +464,8 @@ def test_dag_modelrun_financialNet(env):
456464

457465

458466
def test_dag_modelrun_financialNet_no_writes(env):
467+
if not TEST_TF:
468+
return
459469
con = env.getConnection()
460470

461471
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
@@ -505,11 +515,13 @@ def test_dag_modelrun_financialNet_no_writes(env):
505515

506516

507517
def test_dagro_modelrun_financialNet_no_writes_multiple_modelruns(env):
518+
if not TEST_TF:
519+
return
508520
con = env.getConnection()
509521

510522
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(
511523
env)
512-
ret = con.execute_command('AI.MODELSET', 'financialNet', 'TF', "CPU",
524+
ret = con.execute_command('AI.MODELSET', 'financialNet', 'TF', DEVICE,
513525
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', model_pb)
514526
env.assertEqual(ret, b'OK')
515527

@@ -555,3 +567,28 @@ def test_dagro_modelrun_financialNet_no_writes_multiple_modelruns(env):
555567
tensor_number))
556568
env.assertEqual(ret, 0)
557569
tensor_number = tensor_number + 1
570+
571+
info = con.execute_command('AI.INFO', 'financialNet')
572+
financialNetRunInfo = info_to_dict(info)
573+
574+
env.assertEqual('financialNet', financialNetRunInfo['key'])
575+
env.assertEqual('MODEL', financialNetRunInfo['type'])
576+
env.assertEqual('TF', financialNetRunInfo['backend'])
577+
env.assertEqual(DEVICE, financialNetRunInfo['device'])
578+
env.assertTrue(financialNetRunInfo['duration'] > 0)
579+
env.assertEqual(0, financialNetRunInfo['samples'])
580+
env.assertEqual(2*len(creditcard_transactions), financialNetRunInfo['calls'])
581+
env.assertEqual(0, financialNetRunInfo['errors'])
582+
583+
con.execute_command('AI.INFO', 'financialNet', 'RESETSTAT')
584+
info = con.execute_command('AI.INFO', 'financialNet')
585+
financialNetRunInfo = info_to_dict(info)
586+
587+
env.assertEqual('financialNet', financialNetRunInfo['key'])
588+
env.assertEqual('MODEL', financialNetRunInfo['type'])
589+
env.assertEqual('TF', financialNetRunInfo['backend'])
590+
env.assertEqual(DEVICE, financialNetRunInfo['device'])
591+
env.assertEqual(0, financialNetRunInfo['duration'])
592+
env.assertEqual(0, financialNetRunInfo['samples'])
593+
env.assertEqual(0, financialNetRunInfo['calls'])
594+
env.assertEqual(0, financialNetRunInfo['errors'])

test/tests_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_onnx_modelrun_mnist(env):
147147

148148

149149
def test_onnx_modelrun_mnist_autobatch(env):
150-
if not TEST_PT:
150+
if not TEST_ONNX:
151151
return
152152

153153
con = env.getConnection()

test/tests_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ def test_pytorch_model_rdb_save_load(env):
658658
con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c')
659659
_, dtype_memory, _, shape_memory, _, data_memory = con.execute_command('AI.TENSORGET', 'c', 'META', 'VALUES')
660660

661+
ensureSlaveSynced(con, env)
661662
ret = con.execute_command('SAVE')
662663
env.assertEqual(ret, True)
663664

0 commit comments

Comments
 (0)